wait for persist to finish
This commit is contained in:
@@ -16,6 +16,7 @@ import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import dask
|
||||
import pickle
|
||||
from dask.distributed import wait
|
||||
|
||||
#the classifier
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
@@ -326,10 +327,11 @@ class train_segmentation:
|
||||
if self.lazy and not self.current_computed and type(feat_stack) is not np.ndarray:
|
||||
print('now actually calculating the features')
|
||||
feat_stack = feat_stack.persist() #compute() persist may prevent an memory blow up https://stackoverflow.com/questions/73770527/dask-compute-uses-twice-the-expected-memory
|
||||
wait(feat_stack) #if you use persist(), you have to wait for the calculation to finish before passing the feat stack to sklearn
|
||||
self.current_computed = True
|
||||
if type(feat_stack) is not np.ndarray:
|
||||
print('feat_stack is not a numpy array! check why')
|
||||
feat_stack = feat_stack.persist() # compute()
|
||||
feat_stack = feat_stack.compute()
|
||||
|
||||
self.current_feat_stack = feat_stack
|
||||
#train
|
||||
|
||||
Reference in New Issue
Block a user