Files
pyTrainSeg/V2_example_notebook.ipynb
2025-04-11 15:20:45 +02:00

602 KiB

In [1]:
# modules
import os
import xarray as xr
import matplotlib.pyplot as plt
import numpy as np
import dask
import dask.array
from scipy import ndimage
from skimage import filters, feature, io
from skimage.morphology import disk,ball
import sys
from itertools import combinations_with_replacement
import pickle
import imageio
import json
from dask.distributed import Client, LocalCluster
import socket
import subprocess
import gc
import h5py
import logging
import warnings
warnings.filterwarnings('ignore')


from dask import config as cfg
# cfg.set({'distributed.scheduler.worker-ttl': None, # Workaround so that dask does not kill workers while they are busy fetching data: https://dask.discourse.group/t/dask-workers-killed-because-of-heartbeat-fail/856, maybe this helps: https://www.youtube.com/watch?v=vF2VItVU5zg?
#         'distributed.scheduler.transition-log-length': 100, #potential workaround for ballooning scheduler memory https://baumgartner.io/posts/how-to-reduce-memory-usage-of-dask-scheduler/
#          'distributed.scheduler.events-log-length': 100
#         }) seems to be outdate

cfg.set({'distributed.scheduler.worker-ttl': None, # Workaround so that dask does not kill workers while they are busy fetching data: https://dask.discourse.group/t/dask-workers-killed-because-of-heartbeat-fail/856, maybe this helps: https://www.youtube.com/watch?v=vF2VItVU5zg?
        'distributed.admin.low-level-log-length': 100 #potential workaround for ballooning scheduler memory https://baumgartner.io/posts/how-to-reduce-memory-usage-of-dask-scheduler/
        }) # still relevant ?

#paths
host = socket.gethostname()
if host == 'mpc2959.psi.ch':
    temppath = '/mnt/SSD/fische_r/tmp'
    training_path =  '/mnt/SSD/fische_r/Tomcat_2/'
    pytrainpath = '/mpc/homes/fische_r/lib/pytrainseg'
    # memlim = '840GB'
    memlim = '440GB'
    # memlim = '920GB'
elif host[:3] == 'ra-':
    temppath = '/das/home/fische_r/interlaces/Tomcat_2/tmp'
    training_path = '/das/home/fische_r/interlaces/Tomcat_2'
    pytrainpath = '/das/home/fische_r/lib/pytrainseg'
    memlim = '160GB'  # also fine on the small nodes, you can differentiate more if you want
else:
    print('host '+host+' currently not supported')
    
# get the ML functions, TODO: make a library once it works/is in a stable state

cwd = os.getcwd()
os.chdir(pytrainpath)
from V2_feature_stack import image_filter
import V2_training as tfs
from V2_training import training

pytrain_git_sha = subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD']).decode().strip()
os.chdir(cwd)

functionalities for interactive training

In [2]:
from ipywidgets import Image
from ipywidgets import ColorPicker, IntSlider, link, AppLayout, HBox
from ipycanvas import  hold_canvas,  MultiCanvas #RoughCanvas,Canvas,

def on_mouse_down(x, y):
    global drawing
    global position
    global shape
    drawing = True
    position = (x, y)
    shape = [position]

def on_mouse_move(x, y):
    global drawing
    global position
    global shape
    if not drawing:
        return
    with hold_canvas():
        canvas.stroke_line(position[0], position[1], x, y)
        position = (x, y)
    shape.append(position)

def on_mouse_up(x, y, fill=False):
    global drawing
    global positiondu
    global shape
    drawing = False
    with hold_canvas():
        canvas.stroke_line(position[0], position[1], x, y)
        if fill:
            canvas.fill_polygon(shape)
    shape = []
    
def display_feature(i, TS, feat_stack):
    # print('selected '+TS.feature_names[i])
    im = feat_stack[:,:,i]
    im8 = im-im.min()
    im8 = im8/im8.max()*255
    return im8

fire up dask, distributed Client currently not usable. No idea how not setting up dask affects the computation

In [3]:
dask.config.config['temporary-directory'] = temppath
def boot_client(dashboard_address=':35000', memory_limit = memlim, n_workers=2): # 2 workers appears to be the optimum, will still distribute over the full machine
    tempfolder = temppath  #a big SSD is a major adavantage to allow spill to disk and still be efficient. large dataset might crash with too small SSD or be slow with normal HDD
# tempfolder = temppath_2
# dask.config.config['distributed']['worker']['memory']['recent-to-old-time'] = '200000s'

# here you have the option to use a virtual cluster or even slurm on ra (not attempted yet)
    cluster = LocalCluster(dashboard_address=dashboard_address, memory_limit = memory_limit, n_workers=n_workers, silence_logs=logging.ERROR) #settings optimised for mpc2959, play around if needed, if you know nothing else is using RAM then you can almost go to the limit
# # maybe less workers with more threads makes better use of shared memory 

# # scheduler_port = 'tcp://129.129.188.222:8786' #<-- if scheduler on mpc2959; scheduler on mpc2053 -> 'tcp://129.129.188.248:8786'
# # cluster = scheduler_port

    client = Client(cluster) #don't show warnings, too many seem to block execution
# # client.amm.start()
    print('Dashboard at '+client.dashboard_link)
    return client, cluster

def reboot_client(client, dashboard_address=':35000', memory_limit = memlim, n_workers=2):
    client.shutdown()
    cluster = LocalCluster(dashboard_address=dashboard_address, memory_limit = memory_limit, n_workers=n_workers, silence_logs=logging.ERROR)
    client = Client(cluster)
    return client
In [4]:
client, cluster = boot_client()
Dashboard at http://127.0.0.1:35000/status

Data preparation

let dask load the data

store data in a hdf5 (eg. xarray to .nc) as an entry 'image_data' containing a 4D array. There is potential in structuring the data on the disk (SSD recommended for fast data streaming)

In [5]:
sample = 'R_m7_33_200_1_II'
imagepath = os.path.join(training_path, '01_'+sample+'_cropped.nc')
In [6]:
file = h5py.File(imagepath)
In [7]:
chunk_space = 36 # potential for optmisation by matching chunksize with planned image filter kernels and file structure on disk for fast data streaming
chunks = (chunk_space,chunk_space,chunk_space,len(file['time']))
da = dask.array.from_array(file['image_data'], chunks= chunks)

get data into image filter class

In [8]:
# TODO: include this routine into pytrainseg

IF = image_filter(sigmas = [0,1,3,6]) 
IF.data = da
shp = da.shape
shp_raw = shp

prepare features

creates a dask graph with dependent calculations to allow on-demand and larger-than-memory calculation

In [9]:
# load a feature list if it exists
prefix = 'd001fb3'
feature_names_to_use = ['Gaussian_4D_Blur_0.0',
 'Gaussian_4D_Blur_1.0',
 'Gaussian_4D_Blur_6.0',
 'diff_of_gauss_4D_6.0_0.0',
 'diff_of_gauss_4D_6.0_1.0',
 'Gradient_sigma_1.0_0',
 'Gradient_sigma_1.0_1',
 'Gradient_sigma_1.0_3',
 'hessian_sigma_1.0_00',
 'hessian_sigma_1.0_01',
 'hessian_sigma_1.0_11',
 'Gradient_sigma_3.0_3',
 'hessian_sigma_3.0_00',
 'hessian_sigma_3.0_01',
 'hessian_sigma_3.0_02',
 'hessian_sigma_3.0_03',
 'hessian_sigma_3.0_11',
 'hessian_sigma_3.0_33',
 'Gradient_sigma_6.0_0',
 'Gradient_sigma_6.0_1',
 'Gradient_sigma_6.0_2',
 'Gradient_sigma_6.0_3',
 'hessian_sigma_6.0_01',
 'hessian_sigma_6.0_03',
 'hessian_sigma_6.0_11',
 'hessian_sigma_6.0_12',
 'hessian_sigma_6.0_13',
 'hessian_sigma_6.0_22',
 'hessian_sigma_6.0_23',
 'hessian_sigma_6.0_33',
 'Gradient_sigma_2.0_0',
 'Gradient_sigma_2.0_3',
 'hessian_sigma_2.0_00',
 'hessian_sigma_2.0_01',
 'hessian_sigma_2.0_02',
 'hessian_sigma_2.0_11',
 'hessian_sigma_2.0_13',
 'Gaussian_time_0.0',
 'Gaussian_time_1.0',
 'Gaussian_time_6.0',
 'Gaussian_time_2.0',
 'Gaussian_space_0.0',
 'Gaussian_space_6.0',
 'diff_of_gauss_space_3.0_0.0',
 'diff_of_gauss_space_6.0_0.0',
 'diff_of_gauss_space_2.0_0.0',
 'diff_of_gauss_space_3.0_1.0',
 'diff_of_gauss_space_6.0_1.0',
 'diff_of_gauss_space_2.0_1.0',
 'diff_of_gauss_space_2.0_3.0',
 'diff_temp_min_Gauss_2.0',
 'diff_to_first_',
 'full_temp_min_Gauss_2.0',
 'first_',
 'last_']
In [10]:
IF.prepare()
In [ ]:
IF.stack_features()
In [ ]:
IF.feature_stack
In [ ]:
# uncomment if you want to reduce the feature stack
IF.reduce_feature_stack(feature_names_to_use)
In [ ]:
IF.make_xarray()

Training

set up some objects

In [ ]:
training_path_sample = os.path.join(training_path, sample)
if not os.path.exists(training_path_sample):
    os.mkdir(training_path_sample)
In [ ]:
TS = training(training_path=training_path_sample)
TS.client = client
IF.client = client
TS.cluster = cluster
IF.cluster = cluster
TS.memlim = memlim
TS.n_workers = 2

give the feature stack to the training class

In [ ]:
TS.feat_data = IF.feature_xarray
In [ ]:
IF.combined_feature_names = list(IF.feature_names) + list(IF.feature_names_time_independent)
TS.combined_feature_names = IF.combined_feature_names
In [19]:
for i in range(len(TS.combined_feature_names)):
    print(i, TS.combined_feature_names[i])
0 Gaussian_4D_Blur_0.0
1 Gaussian_4D_Blur_1.0
2 Gaussian_4D_Blur_3.0
3 Gaussian_4D_Blur_6.0
4 Gaussian_4D_Blur_2.0
5 diff_of_gauss_4D_1.0_0.0
6 diff_of_gauss_4D_3.0_0.0
7 diff_of_gauss_4D_6.0_0.0
8 diff_of_gauss_4D_2.0_0.0
9 diff_of_gauss_4D_3.0_1.0
10 diff_of_gauss_4D_6.0_1.0
11 diff_of_gauss_4D_2.0_1.0
12 diff_of_gauss_4D_6.0_3.0
13 diff_of_gauss_4D_2.0_3.0
14 diff_of_gauss_4D_2.0_6.0
15 Gradient_sigma_1.0_0
16 Gradient_sigma_1.0_1
17 Gradient_sigma_1.0_2
18 Gradient_sigma_1.0_3
19 hessian_sigma_1.0_00
20 hessian_sigma_1.0_01
21 hessian_sigma_1.0_02
22 hessian_sigma_1.0_03
23 hessian_sigma_1.0_11
24 hessian_sigma_1.0_12
25 hessian_sigma_1.0_13
26 hessian_sigma_1.0_22
27 hessian_sigma_1.0_23
28 hessian_sigma_1.0_33
29 Gradient_sigma_3.0_0
30 Gradient_sigma_3.0_1
31 Gradient_sigma_3.0_2
32 Gradient_sigma_3.0_3
33 hessian_sigma_3.0_00
34 hessian_sigma_3.0_01
35 hessian_sigma_3.0_02
36 hessian_sigma_3.0_03
37 hessian_sigma_3.0_11
38 hessian_sigma_3.0_12
39 hessian_sigma_3.0_13
40 hessian_sigma_3.0_22
41 hessian_sigma_3.0_23
42 hessian_sigma_3.0_33
43 Gradient_sigma_6.0_0
44 Gradient_sigma_6.0_1
45 Gradient_sigma_6.0_2
46 Gradient_sigma_6.0_3
47 hessian_sigma_6.0_00
48 hessian_sigma_6.0_01
49 hessian_sigma_6.0_02
50 hessian_sigma_6.0_03
51 hessian_sigma_6.0_11
52 hessian_sigma_6.0_12
53 hessian_sigma_6.0_13
54 hessian_sigma_6.0_22
55 hessian_sigma_6.0_23
56 hessian_sigma_6.0_33
57 Gradient_sigma_2.0_0
58 Gradient_sigma_2.0_1
59 Gradient_sigma_2.0_2
60 Gradient_sigma_2.0_3
61 hessian_sigma_2.0_00
62 hessian_sigma_2.0_01
63 hessian_sigma_2.0_02
64 hessian_sigma_2.0_03
65 hessian_sigma_2.0_11
66 hessian_sigma_2.0_12
67 hessian_sigma_2.0_13
68 hessian_sigma_2.0_22
69 hessian_sigma_2.0_23
70 hessian_sigma_2.0_33
71 Gaussian_time_0.0
72 Gaussian_time_1.0
73 Gaussian_time_3.0
74 Gaussian_time_6.0
75 Gaussian_time_2.0
76 diff_of_gauss_time_1.0_0.0
77 diff_of_gauss_time_3.0_0.0
78 diff_of_gauss_time_6.0_0.0
79 diff_of_gauss_time_2.0_0.0
80 diff_of_gauss_time_3.0_1.0
81 diff_of_gauss_time_6.0_1.0
82 diff_of_gauss_time_2.0_1.0
83 diff_of_gauss_time_6.0_3.0
84 diff_of_gauss_time_2.0_3.0
85 diff_of_gauss_time_2.0_6.0
86 Gaussian_space_0.0
87 Gaussian_space_1.0
88 Gaussian_space_3.0
89 Gaussian_space_6.0
90 Gaussian_space_2.0
91 diff_of_gauss_space_1.0_0.0
92 diff_of_gauss_space_3.0_0.0
93 diff_of_gauss_space_6.0_0.0
94 diff_of_gauss_space_2.0_0.0
95 diff_of_gauss_space_3.0_1.0
96 diff_of_gauss_space_6.0_1.0
97 diff_of_gauss_space_2.0_1.0
98 diff_of_gauss_space_6.0_3.0
99 diff_of_gauss_space_2.0_3.0
100 diff_of_gauss_space_2.0_6.0
101 diff_to_min_
102 diff_temp_min_Gauss_2.0
103 diff_to_first_
104 diff_to_last_
105 full_temp_mean_
106 full_temp_min_
107 full_temp_min_Gauss_2.0
108 first_
109 last_

interactive training

load training dict if you have one and want to use it

In [20]:
training_prefix = '093c73c'
TS.previous_training_dict = pickle.load(open(os.path.join(training_path, training_prefix+'_training_dict.p'), 'rb'))
TS.previous_feature_names = pickle.load(open(os.path.join(training_path, training_prefix+'_feature_names.p'), 'rb'))

reduce the previous training dict

In [21]:
num_feat = len(IF.combined_feature_names)
num_feat_to_use = len(feature_names_to_use)
feature_ids = np.zeros(num_feat, dtype=bool)
for i in range(num_feat):
    if IF.combined_feature_names[i] in feature_names_to_use:
        feature_ids[i] = True

TS.training_dict = {}
for training_set in TS.previous_training_dict.keys():
    X,y = TS.previous_training_dict[training_set]
    X = X[:,feature_ids]                          # since there is no copy, probably previous_training_dict and training_dict will become the same, reload the previous dict if necessary
    TS.training_dict[training_set] = X,y

print('Reduced the previous training dict to ',str(np.count_nonzero(feature_ids)),'/',str(num_feat),'features considering the ',str(num_feat_to_use),' desired features.')
print(str(num_feat_to_use-np.count_nonzero(feature_ids)), ' desired features were not found in the training dict.')
Reduced the previous training dict to  55 / 110 features considering the  55  desired features.
0  desired features were not found in the training dict.

re-train with existing label sets. clear the training dictionary if necessary (training_dict)

don't use this if you already have a pickled training_dict
TODO: reduce training dict with feature_ids instead of retraining with the label images

In [ ]:
 
In [22]:
# TS.training_dict = {}
In [23]:
# TS.train()

import training dict of other samples

(replace sample name and repeat for multiple samples), if necessary check features for overlap

In [24]:
# TODO: make better naming convention below when pickling the training_dict

# oldsample = '4'
# oldgitsha = 'e42ad75' #'109a7ce3' #retrain at one point
# # if oldsample == '4':
# #     training_dict_old = pickle.load(open(os.path.join(toppathSSD, '05_water_GDL_ML', '4', 'ec4415d_training_dict_without_loc_feat.p'), 'rb'))
# # else:
# training_dict_old = pickle.load(open(os.path.join(training_path, oldsample,  oldgitsha+'_training_dict.p'),'rb'))
# oldfeatures = pickle.load(open(os.path.join(training_path, oldsample,  oldgitsha+'_feature_names.p'),'rb'))
    
#     # pickle.dump(TS.training_dict, open(os.path.join(TS.training_path, pytrain_git_sha+'_training_dict.p'),'wb'))
# # pickle.dump(TS.feature_names, open(os.path.join(TS.training_path, pytrain_git_sha+'_feature_names.p'),'wb'))

# for key in training_dict_old.keys():
#     TS.training_dict[oldsample+key] = training_dict_old[key]

suggest a new training coordinate

currently retraining with new feature stack not properly implemented. Workaround: choose from the exiting training sets and train with them (additional labeling optional)

In [25]:
TS.suggest_training_set()
You could try  x = 52  at time step  22
In [26]:
c1 = 'z'
p1 = 1304
c2 = 'time'  # c2 has always to be time currently . Removed option to chose two spatial coordinates because was not useful, but left syntax to keep the potential to add it again in the future
p2 = 51

activate the training set and load label images if existent

In [27]:
TS.load_training_set(c1, p1, c2, p2)
existing label set loaded
In [31]:
# reboot the dask client if it lost a worker
if not len(client.cluster.workers)>1:   
    client = reboot_client(client)
    TS.client = client
    IF.client = client
In [29]:
#  TODO: move  the routine into training class
# TODO: add if clause to not do anything if the coordinates did not change and the stack has already been calcualted

feat_data = TS.feat_data
[c1,p1,c2,p2] = TS.current_coordinates
newslice = True

if c1 == 'x' and c2 == 'time':
    feat_stack = feat_data['feature_stack'].sel(x = p1, time = p2)
    feat_stack_t_idp = feat_data['feature_stack_time_independent'].sel(x = p1, time_0 = 0)
# elif c1 == 'x' and c2 == 'y':
#     feat_stack = feat_data['feature_stack'].sel(x = p1, y = p2)#.data
#     feat_stack_t_idp = feat_data['feature_stack_time_independent'].sel(x = p1, y = p2)
# elif c1 == 'x' and c2 == 'z':
#     feat_stack = feat_data['feature_stack'].sel(x = p1, z = p2)#.data
#     feat_stack_t_idp = feat_data['feature_stack_time_independent'].sel(x = p1, z = p2)
# elif c1 == 'y' and c2 == 'z':
#     feat_stack = feat_data['feature_stack'].sel(y = p1, z = p2)#.data
#     feat_stack_t_idp = feat_data['feature_stack_time_independent'].sel(y = p1, z = p2)
elif c1 == 'y' and c2 == 'time':
    feat_stack = feat_data['feature_stack'].sel(y = p1, time = p2)#.data
    feat_stack_t_idp = feat_data['feature_stack_time_independent'].sel(y = p1, time_0 = 0)
elif c1 == 'z' and c2 == 'time':
    feat_stack = feat_data['feature_stack'].sel(z = p1, time = p2)#.data
    feat_stack_t_idp = feat_data['feature_stack_time_independent'].sel(z = p1, time_0 = 0)
In [ ]:
# reboot the dask client if it lost a worker
if not len(client.cluster.workers)>1:   
    client = reboot_client(client)
    TS.client = client
    IF.client = client

calculate the feature stack for the selected slice

time dependent features

In [32]:
#  TODO: move into training class and keep up to date with dask development for the best way to do this
# watch the dashboard for some colorful process tracing, having a eye on the "workers" can help to deal with memory issues
if type(feat_stack) is not np.ndarray:
        fut = client.scatter(feat_stack)
        fut = fut.result()
        fut = fut.compute()
        feat_stack = fut
        try:
            # restart dask client to wipe leaked memory
            client.restart()
        except:
            # do a full reboot if this fails
            client = reboot_client(client)
            TS.client = client
            IF.client = client   

check if the cluster survived the calculation

In [33]:
# needs to stay to be interactive
client.cluster.workers
Out[33]:
{0: <Nanny: tcp://127.0.0.1:32821, threads: 28>,
 1: <Nanny: tcp://127.0.0.1:38097, threads: 28>}
In [34]:
# # reboot the dask client if it lost a worker
if not len(client.cluster.workers)>1:   
    client = reboot_client(client)
    TS.client = client
    IF.client = client

time independent features

In [35]:
# move into training class at one pointl
if type(feat_stack_t_idp) is not np.ndarray:
        fut = client.scatter(feat_stack_t_idp)
        fut = fut.result()
        fut = fut.compute()
        feat_stack_t_idp = fut
        try:
            client.restart()
        except:
            client = reboot_client(client)
            TS.client = client
            IF.client = client   

check if the cluster survived the calculation

In [36]:
# needs to stay to be interactive
client.cluster.workers
Out[36]:
{0: <Nanny: tcp://127.0.0.1:42863, threads: 28>,
 1: <Nanny: tcp://127.0.0.1:44909, threads: 28>}
In [37]:
print('I am back from calculating and have still '+str(len(client.cluster.workers))+' workers')
I am back from calculating and have still 2 workers
In [38]:
# # reboot the dask client if it lost a worker
if not len(client.cluster.workers)>1:   
    client = reboot_client(client)
    TS.client = client
    IF.client = client

merge the two feature stacks

In [39]:
# needs to stay to be interactive
# feat_stack_full = np.concatenate([feat_stack, feat_stack_t_idp], axis = 2)
feat_stack = np.concatenate([feat_stack, feat_stack_t_idp], axis = 2) #this line to save a bit RAM
In [40]:
feat_stack.shape
Out[40]:
(700, 330, 55)
In [41]:
# this necessary ??
# TS.current_feat_stack_full = feat_stack_full
TS.current_feat_stack_full = feat_stack
if type(TS.current_feat_stack_full) is not np.ndarray:
    TS.current_computed = False
else:
    TS.current_computed = True
In [42]:
TS.current_feat_stack_full.shape
Out[42]:
(700, 330, 55)

canvas for labeling and training

give index of feature in case you want to use it for training

can be very useful to for example label static components

In [43]:
# TODO: consider reduced feature number
i  = 10
print(i, IF.combined_feature_names[i])
10 diff_of_gauss_4D_6.0_1.0
In [44]:
im8 = TS.current_im8 # execute this line to get the orignal display

set up the canvas

In [64]:
# move to training class at low prio
# needs interactive buttons if the code is hidden in class
# button for color and alpha
# for now, leave exposed

alpha = 0.35 #transparance

## in case you want to zoom, but it works better if you pan (not zoom) in the browser. panning can be done with the trackpad of a laptop, but there is also some key combinations TODO lock up
# zoom1 = (-500,-1)
# zoom2 = (600,1400)
# zoom1 = (0, -1)
# zoom2 = (0, -1)
#trick: use gaussian_time_4_0 to label static phases ()
# im8 = display_feature(103, TS, feat_stack)

# print(IF.combined_feature_names[-20])
print('original shape: ',im8.shape)
im8_display = im8.copy() #[zoom1[0]:zoom1[1], zoom2[0]:zoom2[1]]
# print('diyplay shape : ',im8_display.shape,' at: ', (zoom1[0], zoom2[0]))

resultim = TS.current_result.copy()
resultim_display = resultim #[zoom1[0]:zoom1[1], zoom2[0]:zoom2[1]]


width = im8_display.shape[1]
height = im8_display.shape[0]
Mcanvas = MultiCanvas(4, width=width, height=height)
background = Mcanvas[0]
resultdisplay = Mcanvas[2]
truthdisplay = Mcanvas[1]
canvas = Mcanvas[3]
canvas.sync_image_data = True
drawing = False
position = None
shape = []
image_data = np.stack((im8_display, im8_display, im8_display), axis=2)
background.put_image_data(image_data, 0, 0)
slidealpha = IntSlider(description="Result overlay", value=0.15)
resultdisplay.global_alpha = alpha #slidealpha.value
if np.any(resultim>0):
    result_data = np.stack(((resultim_display==0), (resultim_display==1),(resultim_display==2)), axis=2)*255
    mask3 = resultim_display==3
    result_data[mask3,0] = 255
    result_data[mask3,1] = 255
else:
    result_data = np.stack((0*resultim, 0*resultim, 0*resultim), axis=2)
resultdisplay.put_image_data(result_data, 0, 0)
canvas.on_mouse_down(on_mouse_down)
canvas.on_mouse_move(on_mouse_move)
canvas.on_mouse_up(on_mouse_up)
picker = ColorPicker(description="Color:", value="#ff0000") #red
# picker = ColorPicker(description="Color:", value="#0000ff") #blue
# picker = ColorPicker(description="Color:", value="#00ff00") #green
# picker = ColorPicker(description="Color:", value="#ffff00") #yellow
### all currently supported color options. -> gives the possibility to label 4 different phases

link((picker, "value"), (canvas, "stroke_style"))
link((picker, "value"), (canvas, "fill_style"))
link((slidealpha, "value"), (resultdisplay, "global_alpha"))

HBox((Mcanvas,picker))
# HBox((Mcanvas,)) #picker 
original shape:  (700, 330)
Out[64]:
HBox(children=(MultiCanvas(height=700, width=330), ColorPicker(value='#ff0000', description='Color:')))

adjust grayscale range of the display of the image by playing with the lines below

has no effect on the data. If you mess up, there are some lines at beggining of the canvas cell or below to get the previous display back

In [57]:
tfs.display.plot_im_histogram(im8)
# im8 = TS.current_im8 # uncomment this line to get the original back display (raw at original grayscale range)
# im8 = tfs.display.adjust_image_contrast(im8,0,100)
No description has been provided for this image

update training set if labels are ok or clear the current canvas by re-running the cell above if not

automatically updates the stored label image on the disk

In [58]:
# same as above

label_set = canvas.get_image_data()

test = TS.current_truth.copy()

test[np.bitwise_and(label_set[:,:,0]>0,np.bitwise_xor(label_set[:,:,0]>0,label_set[:,:,1]>0))] = 1
test[label_set[:,:,1]>0] = 2
test[label_set[:,:,2]>0] = 4 #order of 4&3 flipped for legacy reasons (existing training labels)
test[np.bitwise_and(label_set[:,:,0]>0,label_set[:,:,1]>0)] = 3

TS.current_truth = test.copy()
imageio.imsave(TS.current_truthpath, TS.current_truth)

inspect labels and training progress

can be sometimes useful

In [59]:
# same as above

fig, axes = plt.subplots(1,4, figsize=(20,10))
axes[0].imshow(TS.current_result, 'gray')
axes[0].set_title('current result')
axes[1].imshow(TS.current_im8, 'gray')
axes[1].set_title('original grayscale')

# TS.current_diff_im = TS.current_im-TS.current_first_im
# TS.current_diff_im = TS.current_diff_im/TS.current_diff_im.max()*255
# axes[2].imshow(-TS.current_diff_im)#,vmin=6e4)
# axes[3].imshow(im8old, 'gray')
# axes[3].imshow(TS.current_first_im, 'gray')
axes[2].imshow(TS.current_truth)
axes[2].set_title('label image')
if TS.current_computed:
    axes[3].imshow(TS.current_feat_stack_full[:,:,i], 'gray')
    axes[3].set_title(str(i)+': '+IF.combined_feature_names[i])
else:
    axes[3].imshow(TS.current_result, 'gray')
    axes[3].set_title('current result')
# for ax in axes:
    # ax.set_xticks([])
    # ax.set_yticks([])
No description has been provided for this image

train!

In [60]:
TS.train_slice()
training and classifying

revise feature importance to decide on omiting a few to make calculation more efficient

TODO implementation to be done

In [61]:
# TODO consider reduced feature stack
plt.figure(figsize=(16,9))
plt.stem(np.array(IF.combined_feature_names)[feature_ids], TS.clf.feature_importances_,'x')
plt.xticks(rotation=90)
plt.ylabel('importance') 
Out[61]:
Text(0, 0.5, 'importance')
No description has been provided for this image
In [62]:
pickle.dump(TS.training_dict, open(os.path.join(training_path, pytrain_git_sha+'_training_dict.p'),'wb'))
pickle.dump(TS.combined_feature_names, open(os.path.join(training_path, pytrain_git_sha+'_feature_names.p'),'wb'))
In [63]:
training_path
Out[63]:
'/das/home/fische_r/interlaces/Tomcat_2'

Select features to keep to reduce feature stack

throw away 50% of the features by using the median of the feature importance
the yarn sample showed no susceptible redcution of segmentation quality

In [56]:
importance_median = np.median(TS.clf.feature_importances_)
feature_ids = TS.clf.feature_importances_>importance_median
In [57]:
features_to_use = []
for i in range(len(feature_ids)):
    if feature_ids[i]:
        features_to_use.append(IF.combined_feature_names[i])
    
In [58]:
# TODO write txt instead of pickle dump
093c73c
In [59]:
features_to_use
Out[59]:
['Gaussian_4D_Blur_0.0',
 'Gaussian_4D_Blur_1.0',
 'Gaussian_4D_Blur_6.0',
 'diff_of_gauss_4D_6.0_0.0',
 'diff_of_gauss_4D_6.0_1.0',
 'Gradient_sigma_1.0_0',
 'Gradient_sigma_1.0_1',
 'Gradient_sigma_1.0_3',
 'hessian_sigma_1.0_00',
 'hessian_sigma_1.0_01',
 'hessian_sigma_1.0_11',
 'Gradient_sigma_3.0_3',
 'hessian_sigma_3.0_00',
 'hessian_sigma_3.0_01',
 'hessian_sigma_3.0_02',
 'hessian_sigma_3.0_03',
 'hessian_sigma_3.0_11',
 'hessian_sigma_3.0_33',
 'Gradient_sigma_6.0_0',
 'Gradient_sigma_6.0_1',
 'Gradient_sigma_6.0_2',
 'Gradient_sigma_6.0_3',
 'hessian_sigma_6.0_01',
 'hessian_sigma_6.0_03',
 'hessian_sigma_6.0_11',
 'hessian_sigma_6.0_12',
 'hessian_sigma_6.0_13',
 'hessian_sigma_6.0_22',
 'hessian_sigma_6.0_23',
 'hessian_sigma_6.0_33',
 'Gradient_sigma_2.0_0',
 'Gradient_sigma_2.0_3',
 'hessian_sigma_2.0_00',
 'hessian_sigma_2.0_01',
 'hessian_sigma_2.0_02',
 'hessian_sigma_2.0_11',
 'hessian_sigma_2.0_13',
 'Gaussian_time_0.0',
 'Gaussian_time_1.0',
 'Gaussian_time_6.0',
 'Gaussian_time_2.0',
 'Gaussian_space_0.0',
 'Gaussian_space_6.0',
 'diff_of_gauss_space_3.0_0.0',
 'diff_of_gauss_space_6.0_0.0',
 'diff_of_gauss_space_2.0_0.0',
 'diff_of_gauss_space_3.0_1.0',
 'diff_of_gauss_space_6.0_1.0',
 'diff_of_gauss_space_2.0_1.0',
 'diff_of_gauss_space_2.0_3.0',
 'diff_temp_min_Gauss_2.0',
 'diff_to_first_',
 'full_temp_min_Gauss_2.0',
 'first_',
 'last_']

Save classifier

In [66]:
clf = TS.clf
pickle.dump(clf, open(os.path.join(training_path, pytrain_git_sha+'_clf.p'),'wb'))

Segmentation

load classifier

In [22]:
prefix = 'e765007'
clf = pickle.load(open(os.path.join(training_path, prefix+'_clf.p'),'rb'))
# clf.n_jobs = 40 #threads to be used for classifier. more is faster but needs more RAM, default is all available, tune if you get memory issues
In [23]:
feat = TS.feat_data['feature_stack']
feat_idp = TS.feat_data['feature_stack_time_independent']

figure out good dimensions for peacewise segmentation

the chunks contain the entire time series --> makes no sense to split in time
take the two biggest dimensions an don't split along the smallest spatial dimension --> saves one loop if you get away with it
TODO: automate somehow

In [24]:
def round_up(val, dec=1):
    rval = np.round(val, dec)
    if rval < val:
        rval = rval+10**(-dec)
    rval = np.round(rval, dec) # to get rid of floating point uncertainties
    return rval

def calculate_part(part):
    if type(part) is not np.ndarray:
        fut = client.scatter(part)
        fut = fut.result()
        fut = fut.compute()
        part = fut
    return part

def get_the_client_back(client):
    try:
        client.restart()
    except:
        client = reboot_client(client)
    if not len(client.cluster.workers)>1:   
        client = reboot_client(client)
    return client
In [25]:
dim1 = 36#better use multiple of chunk size !?  <-- tune this parameter to minimize imax, jmax and the size of the result

# shp = feat.shape[:-1]
shp = shp_raw

# aspect ratio 
dimsize = np.sort(shp[:-1] )
aspect = round_up(dimsize[-1]/dimsize[-2])

# check length of loops to process entire dataset, estimate size of obtained sub-feature to avoid out-of-memory issues

dim2 = int(round_up(aspect*dim1, 0))
jmax = int(round_up(dimsize[-2]/dim1, 0))
imax = int(round_up(dimsize[-1]/dim2, 0))

i = imax -2
j = jmax -2
print(imax,jmax)
feat[i*dim1:(i+1)*dim1,:,j*dim2:(j+1)*dim2,:,:]
20 20
Out[25]:
<style>/* CSS stylesheet for displaying xarray objects in jupyterlab. * */ :root { --xr-font-color0: var(--jp-content-font-color0, rgba(0, 0, 0, 1)); --xr-font-color2: var(--jp-content-font-color2, rgba(0, 0, 0, 0.54)); --xr-font-color3: var(--jp-content-font-color3, rgba(0, 0, 0, 0.38)); --xr-border-color: var(--jp-border-color2, #e0e0e0); --xr-disabled-color: var(--jp-layout-color3, #bdbdbd); --xr-background-color: var(--jp-layout-color0, white); --xr-background-color-row-even: var(--jp-layout-color1, white); --xr-background-color-row-odd: var(--jp-layout-color2, #eeeeee); } html[theme="dark"], html[data-theme="dark"], body[data-theme="dark"], body.vscode-dark { --xr-font-color0: rgba(255, 255, 255, 1); --xr-font-color2: rgba(255, 255, 255, 0.54); --xr-font-color3: rgba(255, 255, 255, 0.38); --xr-border-color: #1f1f1f; --xr-disabled-color: #515151; --xr-background-color: #111111; --xr-background-color-row-even: #111111; --xr-background-color-row-odd: #313131; } .xr-wrap { display: block !important; min-width: 300px; max-width: 700px; } .xr-text-repr-fallback { /* fallback to plain text repr when CSS is not injected (untrusted notebook) */ display: none; } .xr-header { padding-top: 6px; padding-bottom: 6px; margin-bottom: 4px; border-bottom: solid 1px var(--xr-border-color); } .xr-header > div, .xr-header > ul { display: inline; margin-top: 0; margin-bottom: 0; } .xr-obj-type, .xr-array-name { margin-left: 2px; margin-right: 10px; } .xr-obj-type { color: var(--xr-font-color2); } .xr-sections { padding-left: 0 !important; display: grid; grid-template-columns: 150px auto auto 1fr 0 20px 0 20px; } .xr-section-item { display: contents; } .xr-section-item input { display: inline-block; opacity: 0; height: 0; } .xr-section-item input + label { color: var(--xr-disabled-color); } .xr-section-item input:enabled + label { cursor: pointer; color: var(--xr-font-color2); } .xr-section-item input:focus + label { border: 2px solid var(--xr-font-color0); } .xr-section-item input:enabled + label:hover { color: var(--xr-font-color0); } .xr-section-summary { grid-column: 1; color: var(--xr-font-color2); font-weight: 500; } .xr-section-summary > span { display: inline-block; padding-left: 0.5em; } .xr-section-summary-in:disabled + label { color: var(--xr-font-color2); } .xr-section-summary-in + label:before { display: inline-block; content: "►"; font-size: 11px; width: 15px; text-align: center; } .xr-section-summary-in:disabled + label:before { color: var(--xr-disabled-color); } .xr-section-summary-in:checked + label:before { content: "▼"; } .xr-section-summary-in:checked + label > span { display: none; } .xr-section-summary, .xr-section-inline-details { padding-top: 4px; padding-bottom: 4px; } .xr-section-inline-details { grid-column: 2 / -1; } .xr-section-details { display: none; grid-column: 1 / -1; margin-bottom: 5px; } .xr-section-summary-in:checked ~ .xr-section-details { display: contents; } .xr-array-wrap { grid-column: 1 / -1; display: grid; grid-template-columns: 20px auto; } .xr-array-wrap > label { grid-column: 1; vertical-align: top; } .xr-preview { color: var(--xr-font-color3); } .xr-array-preview, .xr-array-data { padding: 0 5px !important; grid-column: 2; } .xr-array-data, .xr-array-in:checked ~ .xr-array-preview { display: none; } .xr-array-in:checked ~ .xr-array-data, .xr-array-preview { display: inline-block; } .xr-dim-list { display: inline-block !important; list-style: none; padding: 0 !important; margin: 0; } .xr-dim-list li { display: inline-block; padding: 0; margin: 0; } .xr-dim-list:before { content: "("; } .xr-dim-list:after { content: ")"; } .xr-dim-list li:not(:last-child):after { content: ","; padding-right: 5px; } .xr-has-index { font-weight: bold; } .xr-var-list, .xr-var-item { display: contents; } .xr-var-item > div, .xr-var-item label, .xr-var-item > .xr-var-name span { background-color: var(--xr-background-color-row-even); margin-bottom: 0; } .xr-var-item > .xr-var-name:hover span { padding-right: 5px; } .xr-var-list > li:nth-child(odd) > div, .xr-var-list > li:nth-child(odd) > label, .xr-var-list > li:nth-child(odd) > .xr-var-name span { background-color: var(--xr-background-color-row-odd); } .xr-var-name { grid-column: 1; } .xr-var-dims { grid-column: 2; } .xr-var-dtype { grid-column: 3; text-align: right; color: var(--xr-font-color2); } .xr-var-preview { grid-column: 4; } .xr-index-preview { grid-column: 2 / 5; color: var(--xr-font-color2); } .xr-var-name, .xr-var-dims, .xr-var-dtype, .xr-preview, .xr-attrs dt { white-space: nowrap; overflow: hidden; text-overflow: ellipsis; padding-right: 10px; } .xr-var-name:hover, .xr-var-dims:hover, .xr-var-dtype:hover, .xr-attrs dt:hover { overflow: visible; width: auto; z-index: 1; } .xr-var-attrs, .xr-var-data, .xr-index-data { display: none; background-color: var(--xr-background-color) !important; padding-bottom: 5px !important; } .xr-var-attrs-in:checked ~ .xr-var-attrs, .xr-var-data-in:checked ~ .xr-var-data, .xr-index-data-in:checked ~ .xr-index-data { display: block; } .xr-var-data > table { float: right; } .xr-var-name span, .xr-var-data, .xr-index-name div, .xr-index-data, .xr-attrs { padding-left: 25px !important; } .xr-attrs, .xr-var-attrs, .xr-var-data, .xr-index-data { grid-column: 1 / -1; } dl.xr-attrs { padding: 0; margin: 0; display: grid; grid-template-columns: 125px auto; } .xr-attrs dt, .xr-attrs dd { padding: 0; margin: 0; float: left; padding-right: 10px; width: auto; } .xr-attrs dt { font-weight: normal; grid-column: 1; } .xr-attrs dt:hover span { display: inline-block; background: var(--xr-background-color); padding-right: 10px; } .xr-attrs dd { grid-column: 2; white-space: pre-wrap; word-break: break-all; } .xr-icon-database, .xr-icon-file-text2, .xr-no-icon { display: inline-block; vertical-align: middle; width: 1em; height: 1.5em !important; stroke-width: 0; stroke: currentColor; fill: currentColor; } </style>
<xarray.DataArray 'feature_stack' (x: 36, y: 330, z: 105, time: 100, feature: 52)> Size: 52GB
dask.array<getitem, shape=(36, 330, 105, 100, 52), dtype=float64, chunksize=(28, 36, 36, 100, 1), chunktype=numpy.ndarray>
Coordinates:
  * x        (x) int64 288B 648 649 650 651 652 653 ... 678 679 680 681 682 683
  * y        (y) int64 3kB 0 1 2 3 4 5 6 7 8 ... 322 323 324 325 326 327 328 329
  * z        (z) int64 840B 1890 1891 1892 1893 1894 ... 1991 1992 1993 1994
  * time     (time) int64 800B 0 1 2 3 4 5 6 7 8 ... 91 92 93 94 95 96 97 98 99
  * feature  (feature) <U27 6kB 'Gaussian_4D_Blur_0.0' ... 'diff_to_first_'
xarray.DataArray
'feature_stack'
  • x: 36
  • y: 330
  • z: 105
  • time: 100
  • feature: 52
  • dask.array<chunksize=(28, 36, 18, 100, 1), meta=np.ndarray>
    Array Chunk
    Bytes 48.33 GiB 27.69 MiB
    Shape (36, 330, 105, 100, 52) (28, 36, 36, 100, 1)
    Dask graph 4992 chunks in 442 graph layers
    Data type float64 numpy.ndarray
    330 36 52 100 105
    • x
      (x)
      int64
      648 649 650 651 ... 680 681 682 683
      array([648, 649, 650, 651, 652, 653, 654, 655, 656, 657, 658, 659, 660, 661,
             662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673, 674, 675,
             676, 677, 678, 679, 680, 681, 682, 683])
    • y
      (y)
      int64
      0 1 2 3 4 5 ... 325 326 327 328 329
      array([  0,   1,   2, ..., 327, 328, 329], shape=(330,))
    • z
      (z)
      int64
      1890 1891 1892 ... 1992 1993 1994
      array([1890, 1891, 1892, 1893, 1894, 1895, 1896, 1897, 1898, 1899, 1900, 1901,
             1902, 1903, 1904, 1905, 1906, 1907, 1908, 1909, 1910, 1911, 1912, 1913,
             1914, 1915, 1916, 1917, 1918, 1919, 1920, 1921, 1922, 1923, 1924, 1925,
             1926, 1927, 1928, 1929, 1930, 1931, 1932, 1933, 1934, 1935, 1936, 1937,
             1938, 1939, 1940, 1941, 1942, 1943, 1944, 1945, 1946, 1947, 1948, 1949,
             1950, 1951, 1952, 1953, 1954, 1955, 1956, 1957, 1958, 1959, 1960, 1961,
             1962, 1963, 1964, 1965, 1966, 1967, 1968, 1969, 1970, 1971, 1972, 1973,
             1974, 1975, 1976, 1977, 1978, 1979, 1980, 1981, 1982, 1983, 1984, 1985,
             1986, 1987, 1988, 1989, 1990, 1991, 1992, 1993, 1994])
    • time
      (time)
      int64
      0 1 2 3 4 5 6 ... 94 95 96 97 98 99
      array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
             18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
             36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
             54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,
             72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89,
             90, 91, 92, 93, 94, 95, 96, 97, 98, 99])
    • feature
      (feature)
      <U27
      'Gaussian_4D_Blur_0.0' ... 'diff...
      array(['Gaussian_4D_Blur_0.0', 'Gaussian_4D_Blur_1.0', 'Gaussian_4D_Blur_6.0',
             'diff_of_gauss_4D_6.0_0.0', 'diff_of_gauss_4D_6.0_1.0',
             'Gradient_sigma_1.0_0', 'Gradient_sigma_1.0_1', 'Gradient_sigma_1.0_3',
             'hessian_sigma_1.0_00', 'hessian_sigma_1.0_01', 'hessian_sigma_1.0_11',
             'Gradient_sigma_3.0_3', 'hessian_sigma_3.0_00', 'hessian_sigma_3.0_01',
             'hessian_sigma_3.0_02', 'hessian_sigma_3.0_03', 'hessian_sigma_3.0_11',
             'hessian_sigma_3.0_33', 'Gradient_sigma_6.0_0', 'Gradient_sigma_6.0_1',
             'Gradient_sigma_6.0_2', 'Gradient_sigma_6.0_3', 'hessian_sigma_6.0_01',
             'hessian_sigma_6.0_03', 'hessian_sigma_6.0_11', 'hessian_sigma_6.0_12',
             'hessian_sigma_6.0_13', 'hessian_sigma_6.0_22', 'hessian_sigma_6.0_23',
             'hessian_sigma_6.0_33', 'Gradient_sigma_2.0_0', 'Gradient_sigma_2.0_3',
             'hessian_sigma_2.0_00', 'hessian_sigma_2.0_01', 'hessian_sigma_2.0_02',
             'hessian_sigma_2.0_11', 'hessian_sigma_2.0_13', 'Gaussian_time_0.0',
             'Gaussian_time_1.0', 'Gaussian_time_6.0', 'Gaussian_time_2.0',
             'Gaussian_space_0.0', 'Gaussian_space_6.0',
             'diff_of_gauss_space_3.0_0.0', 'diff_of_gauss_space_6.0_0.0',
             'diff_of_gauss_space_2.0_0.0', 'diff_of_gauss_space_3.0_1.0',
             'diff_of_gauss_space_6.0_1.0', 'diff_of_gauss_space_2.0_1.0',
             'diff_of_gauss_space_2.0_3.0', 'diff_temp_min_Gauss_2.0',
             'diff_to_first_'], dtype='<U27')
    • x
      PandasIndex
      PandasIndex(Index([648, 649, 650, 651, 652, 653, 654, 655, 656, 657, 658, 659, 660, 661,
             662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673, 674, 675,
             676, 677, 678, 679, 680, 681, 682, 683],
            dtype='int64', name='x'))
    • y
      PandasIndex
      PandasIndex(Index([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,
             ...
             320, 321, 322, 323, 324, 325, 326, 327, 328, 329],
            dtype='int64', name='y', length=330))
    • z
      PandasIndex
      PandasIndex(Index([1890, 1891, 1892, 1893, 1894, 1895, 1896, 1897, 1898, 1899,
             ...
             1985, 1986, 1987, 1988, 1989, 1990, 1991, 1992, 1993, 1994],
            dtype='int64', name='z', length=105))
    • time
      PandasIndex
      PandasIndex(Index([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
             18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
             36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
             54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,
             72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89,
             90, 91, 92, 93, 94, 95, 96, 97, 98, 99],
            dtype='int64', name='time'))
    • feature
      PandasIndex
      PandasIndex(Index(['Gaussian_4D_Blur_0.0', 'Gaussian_4D_Blur_1.0', 'Gaussian_4D_Blur_6.0',
             'diff_of_gauss_4D_6.0_0.0', 'diff_of_gauss_4D_6.0_1.0',
             'Gradient_sigma_1.0_0', 'Gradient_sigma_1.0_1', 'Gradient_sigma_1.0_3',
             'hessian_sigma_1.0_00', 'hessian_sigma_1.0_01', 'hessian_sigma_1.0_11',
             'Gradient_sigma_3.0_3', 'hessian_sigma_3.0_00', 'hessian_sigma_3.0_01',
             'hessian_sigma_3.0_02', 'hessian_sigma_3.0_03', 'hessian_sigma_3.0_11',
             'hessian_sigma_3.0_33', 'Gradient_sigma_6.0_0', 'Gradient_sigma_6.0_1',
             'Gradient_sigma_6.0_2', 'Gradient_sigma_6.0_3', 'hessian_sigma_6.0_01',
             'hessian_sigma_6.0_03', 'hessian_sigma_6.0_11', 'hessian_sigma_6.0_12',
             'hessian_sigma_6.0_13', 'hessian_sigma_6.0_22', 'hessian_sigma_6.0_23',
             'hessian_sigma_6.0_33', 'Gradient_sigma_2.0_0', 'Gradient_sigma_2.0_3',
             'hessian_sigma_2.0_00', 'hessian_sigma_2.0_01', 'hessian_sigma_2.0_02',
             'hessian_sigma_2.0_11', 'hessian_sigma_2.0_13', 'Gaussian_time_0.0',
             'Gaussian_time_1.0', 'Gaussian_time_6.0', 'Gaussian_time_2.0',
             'Gaussian_space_0.0', 'Gaussian_space_6.0',
             'diff_of_gauss_space_3.0_0.0', 'diff_of_gauss_space_6.0_0.0',
             'diff_of_gauss_space_2.0_0.0', 'diff_of_gauss_space_3.0_1.0',
             'diff_of_gauss_space_6.0_1.0', 'diff_of_gauss_space_2.0_1.0',
             'diff_of_gauss_space_2.0_3.0', 'diff_temp_min_Gauss_2.0',
             'diff_to_first_'],
            dtype='object', name='feature'))
In [26]:
# segs = np.zeros(feat.shape[:4], dtype=np.uint8)
# segs = pickle.load(open(os.path.join(training_path, 'segs.p'),'rb')
In [27]:
piecepath = os.path.join(os.path.join(temppath, 'segmentation_pieces'))

restart_i = 0
restart_j = 0  #replace with the iterations you coudl reach before dask crashed
# get from the written files the restart coordinates
if os.path.exists(piecepath):
    ij_mat = np.zeros((imax,jmax), dtype=bool)
    for filename in os.listdir(piecepath):
        i = int(filename.split('_')[2])
        j = int(filename.split('_')[4])
        ij_mat[i,j] = True
        restart_i = int(np.min(np.argmin(ij_mat, axis=0)))
        restart_j = int(np.max(np.argmin(ij_mat, axis=1)))

print(restart_i, restart_j)
0 17
In [ ]:
# restart_i = 0
# restart_j = 14

# elapsed walltime: subprocess.check_output(['squeue','-u', 'fische_r']).decode().strip().split(' ')[-8]

# piecepath = os.path.join(os.path.join(temppath, 'segmentation_pieces'))
if not os.path.exists(piecepath):
    os.mkdir(piecepath)

for i in range(restart_i,imax):
    print(str(i+1)+'/'+str(imax))
    start_j = 0
    if i == restart_i:
        start_j = restart_j
    for j in range(start_j,jmax):
        print(j)
        part = feat[i*dim1:(i+1)*dim1,:,j*dim2:(j+1)*dim2,:,:] 
        part_idp = feat_idp[i*dim1:(i+1)*dim1,:,j*dim2:(j+1)*dim2,:]  
        if 0 in part.shape:
            print('hit the edge (one dimension 0), ignore')
            continue
        part = calculate_part(part)
        client = get_the_client_back(client)
        part_idp = calculate_part(part_idp)
        client = get_the_client_back(client)
        
        part_idp = np.stack([part_idp]*shp[-1], axis=-2)[:,:,:,0,:,:] #expand in time, a bit ugly, could maybe more elegant
        part = np.concatenate([part, part_idp], axis = -1)
        del part_idp # drop the time independent part, is this garbage collected?

        shp_orig = part.shape
        num_feat = part.shape[-1]  
        part = part.reshape(-1,num_feat)

        seg = clf.predict(part).astype(np.uint8)

        # put segs together when all calculated
        seg = seg.reshape(shp_orig[:4])

        pickle.dump(seg, open(os.path.join(piecepath, 'seg_i_'+str(i)+'_j_'+str(j)+'_.p'), 'wb'))
1/20
17
18
19
2/20
0
1
In [ ]:
print('You got until i',i,'& j',j)

piece segmentation back together

In [ ]:
segs = np.zeros(shp_raw, dtype=np.uint8)

for filename in os.listdir(piecepath):
    i = int(filename.split('_')[2])
    j = int(filename.split('_')[4])
    ### not sure if this switch cases are necessary
    seg = pickle.load(open(os.path.join(piecepath, filename),'rb'))
    if i < imax-1 and j < jmax-1:
        segs[i*dim1:(i+1)*dim1,:,j*dim2:(j+1)*dim2,:] = seg
    elif not i < imax-1 and j < jmax-1:
        segs[i*dim1:,:,j*dim2:(j+1)*dim2,:] =  seg
    elif not j < jmax-1 and i < imax-1:
        segs[i*dim1:(i+1)*dim1,:,j*dim2:,:] =  seg
    else:
        segs[i*dim1:,:,j*dim2:,:] = seg
    

save result to disk when full volume has been processed

In [ ]:
# TODO: include metadata in segmented nc

shp = segs.shape
segdata = xr.Dataset({'segmented': (['x','y','z','timestep'], segs),
                     't_utc': ('timestep', t_utc),
                     'time': ('timestep', time)},
                               coords = {'x': np.arange(shp[0]),
                               'y': np.arange(shp[1]),
                               'z': np.arange(shp[2]),
                               'timestep': np.arange(shp[3]),
                               'feature': TS.combined_feature_names}
                     )
segdata.attrs = data.attrs.copy()
segdata.attrs['pytrain_git'] = pytrain_git_sha
In [ ]:
segpath = os.path.join(training_path_sample, sample+'_segmentation.nc')
In [ ]:
segdata.to_netcdf(segpath)
In [ ]: