Files
pyTrainSeg/test_dask_framework-Copy1.ipynb
T

849 KiB

In [53]:
import dask
from dask.distributed import Client, LocalCluster
import xarray as xr
import numpy as np
import  matplotlib.pyplot as plt
import os
from ipywidgets import Image
from ipywidgets import ColorPicker, IntSlider, link, AppLayout, HBox
from ipycanvas import  hold_canvas,  MultiCanvas #RoughCanvas,Canvas,
import imageio
In [2]:
# local cluster on current machine
# cluster = LocalCluster() 
# client = Client(cluster)
# print('Dashboard at '+cluster.dashboard_link)
In [3]:
# home-made cluster
scheduler_address = "129.129.188.248:8786"
client = Client(scheduler_address)
print('Dashboard at '+client.dashboard_link)
Dashboard at http://129.129.188.248:8787/status
In [4]:
tempfolder = '/mnt/SSD_2TB_nvme0n1/Robert'
In [5]:
# dask.config.config['temporary-directory'] = tempfolder
In [6]:
client
Out[6]:

Client

Client-4c39d915-293a-11ed-98ad-901b0e50e2fd

Connection method: Direct
Dashboard: http://129.129.188.248:8787/status

Scheduler Info

Scheduler

Scheduler-6918c53a-afa4-4f8b-aa73-f19087617933

Comm: tcp://129.129.188.248:8786 Workers: 2
Dashboard: http://129.129.188.248:8787/status Total threads: 148
Started: 4 hours ago Total memory: 1.48 TiB

Workers

Worker: tcp://129.129.188.222:38131

Comm: tcp://129.129.188.222:38131 Total threads: 128
Dashboard: http://129.129.188.222:35403/status Memory: 0.98 TiB
Nanny: tcp://129.129.188.222:36643
Local directory: /tmp/dask-worker-space/worker-a5933tbq
GPU: NVIDIA RTX A4000 GPU memory: 15.99 GiB
Tasks executing: 0 Tasks in memory: 0
Tasks ready: 0 Tasks in flight: 0
CPU usage: 2.0% Last seen: Just now
Memory usage: 9.54 GiB Spilled bytes: 0 B
Read bytes: 1.01 kiB Write bytes: 1.24 kiB

Worker: tcp://129.129.188.248:35404

Comm: tcp://129.129.188.248:35404 Total threads: 20
Dashboard: http://129.129.188.248:40628/status Memory: 503.62 GiB
Nanny: tcp://129.129.188.248:37847
Local directory: /tmp/dask-worker-space/worker-8fzliihb
Tasks executing: 0 Tasks in memory: 0
Tasks ready: 0 Tasks in flight: 0
CPU usage: 4.0% Last seen: Just now
Memory usage: 2.06 GiB Spilled bytes: 0 B
Read bytes: 26.14 kiB Write bytes: 35.92 kiB
In [7]:
from filter_functions import image_filter
In [8]:
import training_functions as tfs
from training_functions import train_segmentation
In [9]:
def on_mouse_down(x, y):
    global drawing
    global position
    global shape
    drawing = True
    position = (x, y)
    shape = [position]

def on_mouse_move(x, y):
    global drawing
    global position
    global shape
    if not drawing:
        return
    with hold_canvas():
        canvas.stroke_line(position[0], position[1], x, y)
        position = (x, y)
    shape.append(position)

def on_mouse_up(x, y):
    global drawing
    global positiondu
    global shape
    drawing = False
    with hold_canvas():
        canvas.stroke_line(position[0], position[1], x, y)
        canvas.fill_polygon(shape)
    shape = []
In [10]:
path = '/home/fische_r/NAS/testing/Jeremy_tomo/tomodata.nc'
featpath = '/home/fische_r/NAS/testing/Jeremy_tomo/featdata.nc'
# path = r"C:\Zwischenlager\tomodata.nc"
# path = '/mpc/homes/fische_r/wood3/wood_tomo.nc'
# featpath = '/mpc/homes/fische_r/wood3/featdata.nc'

path = '/home/fische_r/NAS/testing/TIM_tomo/tomodata.nc'
featpath = '/home/fische_r/NAS/testing/TIM_tomo/featdata.nc'
In [11]:
IF = image_filter(data_path=path, outpath = featpath)
In [12]:
IF.open_raw_data()
In [11]:
# IF.data = IF.data[30:-20,15:-50,:100,:50] #cropping for wood
In [13]:
IF.data
Out[13]:
Array Chunk
Bytes 328.56 GiB 15.81 MiB
Shape (1781, 140, 956, 185) (40, 35, 40, 37)
Count 21600 Tasks 21600 Chunks
Type float64 numpy.ndarray
1781 1 185 956 140
In [14]:
IF.prepare()
In [15]:
IF.stack_features()
In [16]:
IF.feature_stack
Out[16]:
Array Chunk
Bytes 23.42 TiB 15.81 MiB
Shape (1781, 140, 956, 185, 73) (40, 35, 40, 37, 1)
Count 23945408 Tasks 1576800 Chunks
Type float64 numpy.ndarray
140 1781 73 185 956
In [ ]:
IF.compute() #not sure what is more efficient, but I would compute the features and even store them on disk
# had the impression that otherwise many redundant operations happen
In [22]:
IF.make_xarray_nc()
In [18]:
# IF.make_xarray_nc(store=True)
In [23]:
# training_path = r"C:\Zwischenlager\Jeremy_tomo"
training_path = '/home/fische_r/NAS/testing/Jeremy_tomo'
training_path = '/mpc/homes/fische_r/wood3/'
training_path = '/home/fische_r/NAS/testing/TIM_tomo'
if not os.path.exists(training_path):
    os.mkdir(training_path)
In [24]:
TS = train_segmentation(training_path=training_path)
In [26]:
TS.import_lazy_feature_data(IF.result)
In [21]:
# TS.import_feature_data(IF.result)

iterative training, if you have a training set, skip here

In [27]:
TS.suggest_training_set()
You could try  y = 84  and  z = 131
However, please sort it like the original xyztime
In [28]:
c1 = 'y'
p1 = 84
c2 = 'time'
p2 = 48
In [29]:
TS.load_training_set(c1, p1, c2, p2)
In [30]:
1
Out[30]:
1
In [680]:
# TS.current_im8 = TS.current_im8.compute()
In [31]:
alpha = 0.15
im8 = TS.current_im8
resultim = TS.current_result
width = im8.shape[1]
height = im8.shape[0]
Mcanvas = MultiCanvas(4, width=width, height=height)
background = Mcanvas[0]
resultdisplay = Mcanvas[2]
truthdisplay = Mcanvas[1]
canvas = Mcanvas[3]
canvas.sync_image_data = True
drawing = False
position = None
shape = []
image_data = np.stack((im8, im8, im8), axis=2)
background.put_image_data(image_data, 0, 0)
resultdisplay.global_alpha = alpha
if np.any(resultim>0):
    result_data = np.stack((255*(resultim==0), 255*(resultim==1), 255*(resultim==2)), axis=2)
else:
    result_data = np.stack((0*resultim, 0*resultim, 0*resultim), axis=2)
resultdisplay.put_image_data(result_data, 0, 0)
canvas.on_mouse_down(on_mouse_down)
canvas.on_mouse_move(on_mouse_move)
canvas.on_mouse_up(on_mouse_up)
picker = ColorPicker(description="Color:", value="#ff0000")
slidealpha = IntSlider(description="Result overlay", value=0.15)
link((picker, "value"), (canvas, "stroke_style"))
link((picker, "value"), (canvas, "fill_style"))
HBox((Mcanvas, picker, slidealpha))
HBox(children=(MultiCanvas(height=1781, width=956), ColorPicker(value='#ff0000', description='Color:'), IntSli…
In [32]:
# tfs.plot_im_histogram(TS.current_im8)
# TS.current_im8 = tfs.adjust_image_contrast(TS.current_im8, 50,255)
In [33]:
fig, axes = plt.subplots(1,5, figsize=(20,10))
axes[0].imshow(TS.current_result)
axes[1].imshow(TS.current_im8, 'gray')
axes[2].imshow(TS.current_diff_im)
# axes[3].imshow(m8old, 'gray')
axes[3].imshow(TS.current_first_im, 'gray')
axes[4].imshow(TS.current_truth)

for ax in axes:
    ax.set_xticks([])
    ax.set_yticks([])
No description has been provided for this image
In [34]:
label_set = canvas.get_image_data()

TS.current_truth[label_set[:,:,0]>0] = 1
TS.current_truth[label_set[:,:,1]>0] = 2
TS.current_truth[label_set[:,:,2]>0] = 4

imageio.imsave(TS.current_truthpath, TS.current_truth)
In [35]:
TS.train_slice()
now actually calculating the features
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
File ~/miniconda3/lib/python3.9/_collections_abc.py:769, in Mapping.__contains__(self, key)
    768 try:
--> 769     self[key]
    770 except KeyError:

File ~/miniconda3/lib/python3.9/site-packages/dask/blockwise.py:547, in Blockwise.__getitem__(self, key)
    546 def __getitem__(self, key):
--> 547     return self._dict[key]

KeyError: ('concatenate-8008a02e16ee48e431854aa0d25f2eaa', 39, 1, 7, 0)

During handling of the above exception, another exception occurred:

KeyboardInterrupt                         Traceback (most recent call last)
Input In [35], in <cell line: 1>()
----> 1 TS.train_slice()

File /mnt/nas_Uwrite/fische_r/lib/pytrainseg/training_functions.py:360, in train_segmentation.train_slice(self)
    358     print('now actually calculating the features')
    359     # self.current_feat_stack.rechunk('auto') #why rechunk 'auto' ?! if anything should be something small fot massive parallel
--> 360     feat_stack = feat_stack.compute() 
    361     self.current_computed = True
    362 if type(feat_stack) is not np.ndarray:

File ~/miniconda3/lib/python3.9/site-packages/dask/base.py:315, in DaskMethodsMixin.compute(self, **kwargs)
    291 def compute(self, **kwargs):
    292     """Compute this dask collection
    293 
    294     This turns a lazy Dask collection into its in-memory equivalent.
   (...)
    313     dask.base.compute
    314     """
--> 315     (result,) = compute(self, traverse=False, **kwargs)
    316     return result

File ~/miniconda3/lib/python3.9/site-packages/dask/base.py:592, in compute(traverse, optimize_graph, scheduler, get, *args, **kwargs)
    584     return args
    586 schedule = get_scheduler(
    587     scheduler=scheduler,
    588     collections=collections,
    589     get=get,
    590 )
--> 592 dsk = collections_to_dsk(collections, optimize_graph, **kwargs)
    593 keys, postcomputes = [], []
    594 for x in collections:

File ~/miniconda3/lib/python3.9/site-packages/dask/base.py:367, in collections_to_dsk(collections, optimize_graph, optimizations, **kwargs)
    365 for opt, val in groups.items():
    366     dsk, keys = _extract_graph_and_keys(val)
--> 367     dsk = opt(dsk, keys, **kwargs)
    369     for opt_inner in optimizations:
    370         dsk = opt_inner(dsk, keys, **kwargs)

File ~/miniconda3/lib/python3.9/site-packages/dask/array/optimization.py:57, in optimize(dsk, keys, fuse_keys, fast_functions, inline_functions_fast_functions, rename_fused_keys, **kwargs)
     54 if config.get("optimization.fuse.active") is False:
     55     return dsk
---> 57 dependencies = dsk.get_all_dependencies()
     58 dsk = ensure_dict(dsk)
     60 # Low level task optimizations

File ~/miniconda3/lib/python3.9/site-packages/dask/highlevelgraph.py:813, in HighLevelGraph.get_all_dependencies(self)
    811 if missing_keys:
    812     for layer in self.layers.values():
--> 813         for k in missing_keys & layer.keys():
    814             self.key_dependencies[k] = layer.get_dependencies(k, all_keys)
    815 return self.key_dependencies

File ~/miniconda3/lib/python3.9/_collections_abc.py:577, in Set.__and__(self, other)
    575 if not isinstance(other, Iterable):
    576     return NotImplemented
--> 577 return self._from_iterable(value for value in other if value in self)

File ~/miniconda3/lib/python3.9/_collections_abc.py:820, in KeysView._from_iterable(cls, it)
    818 @classmethod
    819 def _from_iterable(cls, it):
--> 820     return set(it)

File ~/miniconda3/lib/python3.9/_collections_abc.py:577, in <genexpr>(.0)
    575 if not isinstance(other, Iterable):
    576     return NotImplemented
--> 577 return self._from_iterable(value for value in other if value in self)

File ~/miniconda3/lib/python3.9/_collections_abc.py:823, in KeysView.__contains__(self, key)
    822 def __contains__(self, key):
--> 823     return key in self._mapping

File ~/miniconda3/lib/python3.9/_collections_abc.py:769, in Mapping.__contains__(self, key)
    767 def __contains__(self, key):
    768     try:
--> 769         self[key]
    770     except KeyError:
    771         return False

KeyboardInterrupt: 
In [704]:
# TS.current_im8 = tfs.adjust_image_contrast(TS.current_im8, 50,200)

when done, maybe save the classifier

In [706]:
TS.pickle_classifier()

use an existing trainingset to train classifier (adhere to label iamge naming convention)

In [24]:
# TS.feat_data = TS.feat_data.compute() #better option for retraining, but creates a numpy array, maybe you can avoid
In [22]:
# provide new feature data if necessary and say if it is a lazy dask array or not
# TS.feat_data = 
# TS.lazy = 

TS.train()
In [26]:
# TS.pickle_classifier()
In [27]:
from segmentation import segmentation
# import pickle
In [24]:
SM = segmentation(training_path = training_path, classifier_path=os.path.join(training_path, 'classifier.p'))
In [28]:
# SM.import_classifier(TS.clf)
# SM.clf = pickle.load(open(os.path.join(training_path, 'classifier.p'), 'rb'))
In [29]:
SM.import_feature_data(IF.result)
In [39]:
# SM.lazy = False
part2 = SM.feat_data.feature_stack[:,:,:,25:,:]
In [40]:
num_feat = part2.shape[-1]
clf = SM.clf
seg2 = clf.predict(part2.data.reshape(-1,num_feat))
seg2 = seg2.reshape(part2[...,0].shape).astype(np.uint8)
In [38]:
# seg1 = seg1.reshape(part1[...,0].shape).astype(np.uint8)
In [ ]:
# SM.classify_all()
classifying ...
In [707]:
# SM.store_segmented_data()
In [49]:
seg_data.size/1024**3
Out[49]:
0.2153683453798294
In [47]:
path = os.path.join(SM.training_path, 'segmented.nc')

#TODO: propagate labels from raw data
#TODO: if self.segmented_data is a dask array, rechunk for saving
shp = seg_data.shape
data = xr.Dataset({'segmented': (['x','y','z','time'], seg_data)},
                               coords = {'x': np.arange(shp[0]),
                               'y': np.arange(shp[1]),
                               'z': np.arange(shp[2]),
                               'time': np.arange(shp[3]),
                               'feature': SM.feature_names}
                     )
# data.to
In [50]:
data.to_netcdf(path)
In [54]:
test = xr.load_dataset(path)
In [62]:
test.segmented.sel(z=10, time=49).plot()
Out[62]:
<matplotlib.collections.QuadMesh at 0x7f4f34f447f0>
No description has been provided for this image
2022-08-24 15:53:00,051 - distributed.nanny - ERROR - Worker process died unexpectedly
2022-08-24 15:53:00,057 - distributed.nanny - ERROR - Worker process died unexpectedly
2022-08-24 15:53:00,066 - distributed.nanny - ERROR - Worker process died unexpectedly
2022-08-24 15:53:00,070 - distributed.nanny - ERROR - Worker process died unexpectedly
2022-08-24 15:53:00,074 - distributed.nanny - ERROR - Worker process died unexpectedly
2022-08-24 15:53:00,076 - distributed.nanny - ERROR - Worker process died unexpectedly
2022-08-24 15:53:00,080 - distributed.nanny - ERROR - Worker process died unexpectedly
2022-08-24 15:53:00,084 - distributed.nanny - ERROR - Worker process died unexpectedly
2022-08-24 15:53:00,084 - distributed.nanny - ERROR - Worker process died unexpectedly
2022-08-24 15:53:00,084 - distributed.nanny - ERROR - Worker process died unexpectedly
2022-08-24 15:53:00,093 - distributed.nanny - ERROR - Worker process died unexpectedly
2022-08-24 15:53:00,101 - distributed.nanny - ERROR - Worker process died unexpectedly
In [632]:
plt.figure(figsize=(16,9))
plt.stem(TS.feature_names, TS.clf.feature_importances_,'x')
plt.xticks(rotation=90)
plt.ylabel('importance') 
/tmp/ipykernel_835997/3278944968.py:2: MatplotlibDeprecationWarning: Passing the linefmt parameter positionally is deprecated since Matplotlib 3.5; the parameter will become keyword-only two minor releases later.
  plt.stem(TS.feature_names, TS.clf.feature_importances_,'x')
Out[632]:
Text(0, 0.5, 'importance')
No description has been provided for this image
In [52]:
plt.figure(figsize=(16,9))
plt.stem(SM.feature_names, clf.feature_importances_,'x')
plt.xticks(rotation=90)
plt.ylabel('importance') 
/tmp/ipykernel_870706/1913450262.py:2: MatplotlibDeprecationWarning: Passing the linefmt parameter positionally is deprecated since Matplotlib 3.5; the parameter will become keyword-only two minor releases later.
  plt.stem(SM.feature_names, clf.feature_importances_,'x')
Out[52]:
Text(0, 0.5, 'importance')
No description has been provided for this image
In [49]:
plt.imshow(SM.segmented_data[:,10,:,-1])
Out[49]:
<matplotlib.image.AxesImage at 0x7fb97046a700>
No description has been provided for this image
In [ ]: