Files
pyTrainSeg/V2_example_notebook.ipynb
T
2025-04-10 09:50:45 +02:00

564 KiB

In [1]:
# modules
import os
import xarray as xr
import matplotlib.pyplot as plt
import numpy as np
import dask
import dask.array
from scipy import ndimage
from skimage import filters, feature, io
from skimage.morphology import disk,ball
import sys
from itertools import combinations_with_replacement
import pickle
import imageio
import json
from dask.distributed import Client, LocalCluster
import socket
import subprocess
import gc
import h5py
import logging
import warnings
warnings.filterwarnings('ignore')


from dask import config as cfg
# cfg.set({'distributed.scheduler.worker-ttl': None, # Workaround so that dask does not kill workers while they are busy fetching data: https://dask.discourse.group/t/dask-workers-killed-because-of-heartbeat-fail/856, maybe this helps: https://www.youtube.com/watch?v=vF2VItVU5zg?
#         'distributed.scheduler.transition-log-length': 100, #potential workaround for ballooning scheduler memory https://baumgartner.io/posts/how-to-reduce-memory-usage-of-dask-scheduler/
#          'distributed.scheduler.events-log-length': 100
#         }) seems to be outdate

cfg.set({'distributed.scheduler.worker-ttl': None, # Workaround so that dask does not kill workers while they are busy fetching data: https://dask.discourse.group/t/dask-workers-killed-because-of-heartbeat-fail/856, maybe this helps: https://www.youtube.com/watch?v=vF2VItVU5zg?
        'distributed.admin.low-level-log-length': 100 #potential workaround for ballooning scheduler memory https://baumgartner.io/posts/how-to-reduce-memory-usage-of-dask-scheduler/
        }) # still relevant ?

#paths
host = socket.gethostname()
if host == 'mpc2959.psi.ch':
    temppath = '/mnt/SSD/fische_r/tmp'
    training_path =  '/mnt/SSD/fische_r/Tomcat_2/'
    pytrainpath = '/mpc/homes/fische_r/lib/pytrainseg'
    # memlim = '840GB'
    memlim = '440GB'
    # memlim = '920GB'
elif host[:3] == 'ra-':
    temppath = '/das/home/fische_r/interlaces/Tomcat_2/tmp'
    training_path = '/das/home/fische_r/interlaces/Tomcat_2'
    pytrainpath = '/das/home/fische_r/lib/pytrainseg'
    memlim = '220GB'
else:
    print('host '+host+' currently not supported')
    
# get the ML functions, TODO: make a library once it works/is in a stable state

cwd = os.getcwd()
os.chdir(pytrainpath)
from V2_feature_stack import image_filter
import V2_training as tfs
from V2_training import training

pytrain_git_sha = subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD']).decode().strip()
os.chdir(cwd)

functionalities for interactive training

In [2]:
from ipywidgets import Image
from ipywidgets import ColorPicker, IntSlider, link, AppLayout, HBox
from ipycanvas import  hold_canvas,  MultiCanvas #RoughCanvas,Canvas,

def on_mouse_down(x, y):
    global drawing
    global position
    global shape
    drawing = True
    position = (x, y)
    shape = [position]

def on_mouse_move(x, y):
    global drawing
    global position
    global shape
    if not drawing:
        return
    with hold_canvas():
        canvas.stroke_line(position[0], position[1], x, y)
        position = (x, y)
    shape.append(position)

def on_mouse_up(x, y, fill=False):
    global drawing
    global positiondu
    global shape
    drawing = False
    with hold_canvas():
        canvas.stroke_line(position[0], position[1], x, y)
        if fill:
            canvas.fill_polygon(shape)
    shape = []
    
def display_feature(i, TS, feat_stack):
    # print('selected '+TS.feature_names[i])
    im = feat_stack[:,:,i]
    im8 = im-im.min()
    im8 = im8/im8.max()*255
    return im8

fire up dask, distributed Client currently not usable. No idea how not setting up dask affects the computation

In [3]:
dask.config.config['temporary-directory'] = temppath
def boot_client(dashboard_address=':35000', memory_limit = memlim, n_workers=2): # 2 workers appears to be the optimum, will still distribute over the full machine
    tempfolder = temppath  #a big SSD is a major adavantage to allow spill to disk and still be efficient. large dataset might crash with too small SSD or be slow with normal HDD
# tempfolder = temppath_2
# dask.config.config['distributed']['worker']['memory']['recent-to-old-time'] = '200000s'

# here you have the option to use a virtual cluster or even slurm on ra (not attempted yet)
    cluster = LocalCluster(dashboard_address=dashboard_address, memory_limit = memory_limit, n_workers=n_workers, silence_logs=logging.ERROR) #settings optimised for mpc2959, play around if needed, if you know nothing else is using RAM then you can almost go to the limit
# # maybe less workers with more threads makes better use of shared memory 

# # scheduler_port = 'tcp://129.129.188.222:8786' #<-- if scheduler on mpc2959; scheduler on mpc2053 -> 'tcp://129.129.188.248:8786'
# # cluster = scheduler_port

    client = Client(cluster) #don't show warnings, too many seem to block execution
# # client.amm.start()
    print('Dashboard at '+client.dashboard_link)
    return client, cluster

def reboot_client(client, dashboard_address=':35000', memory_limit = memlim, n_workers=2):
    client.shutdown()
    cluster = LocalCluster(dashboard_address=dashboard_address, memory_limit = memory_limit, n_workers=n_workers, silence_logs=logging.ERROR)
    client = Client(cluster)
    return client
In [4]:
client, cluster = boot_client()
Dashboard at http://127.0.0.1:35000/status
2025-04-10 09:12:53,158 - distributed.scheduler - ERROR - Removing worker 'tcp://127.0.0.1:33721' caused the cluster to lose scattered data, which can't be recovered: {'DataArray-0ecdf6d6829f9a004fefff3779e158c0'} (stimulus_id='handle-worker-cleanup-1744269173.1579642')
2025-04-10 09:19:14,330 - distributed.core - ERROR - Exception while handling op scatter
Traceback (most recent call last):
  File "/das/home/fische_r/miniconda3/lib/python3.12/site-packages/distributed/core.py", line 834, in _handle_comm
    result = await result
             ^^^^^^^^^^^^
  File "/das/home/fische_r/miniconda3/lib/python3.12/site-packages/distributed/scheduler.py", line 6381, in scatter
    raise TimeoutError("No valid workers found")
TimeoutError: No valid workers found

Data preparation

let dask load the data

store data in a hdf5 (eg. xarray to .nc) as an entry 'image_data' containing a 4D array. There is potential in structuring the data on the disk (SSD recommended for fast data streaming)

In [5]:
sample = 'R_m7_33_200_1_II'
imagepath = os.path.join(training_path, '01_'+sample+'_cropped.nc')
In [6]:
file = h5py.File(imagepath)
In [7]:
chunk_space = 36 # potential for optmisation by matching chunksize with planned image filter kernels and file structure on disk for fast data streaming
chunks = (chunk_space,chunk_space,chunk_space,len(file['time']))
da = dask.array.from_array(file['image_data'], chunks= chunks)

get data into image filter class

In [8]:
# TODO: include this routine into pytrainseg

IF = image_filter(sigmas = [0,1,3,6]) 
IF.data = da
shp = da.shape

prepare features

creates a dask graph with dependent calculations to allow on-demand and larger-than-memory calculation

In [9]:
# load a feature list if it exists
prefix = 'd001fb3'
feature_names_to_use = ['Gaussian_4D_Blur_0.0',
 'Gaussian_4D_Blur_1.0',
 'Gaussian_4D_Blur_6.0',
 'diff_of_gauss_4D_6.0_0.0',
 'diff_of_gauss_4D_6.0_1.0',
 'Gradient_sigma_1.0_0',
 'Gradient_sigma_1.0_1',
 'Gradient_sigma_1.0_3',
 'hessian_sigma_1.0_00',
 'hessian_sigma_1.0_01',
 'hessian_sigma_1.0_11',
 'Gradient_sigma_3.0_3',
 'hessian_sigma_3.0_00',
 'hessian_sigma_3.0_01',
 'hessian_sigma_3.0_02',
 'hessian_sigma_3.0_03',
 'hessian_sigma_3.0_11',
 'hessian_sigma_3.0_33',
 'Gradient_sigma_6.0_0',
 'Gradient_sigma_6.0_1',
 'Gradient_sigma_6.0_2',
 'Gradient_sigma_6.0_3',
 'hessian_sigma_6.0_01',
 'hessian_sigma_6.0_03',
 'hessian_sigma_6.0_11',
 'hessian_sigma_6.0_12',
 'hessian_sigma_6.0_13',
 'hessian_sigma_6.0_22',
 'hessian_sigma_6.0_23',
 'hessian_sigma_6.0_33',
 'Gradient_sigma_2.0_0',
 'Gradient_sigma_2.0_3',
 'hessian_sigma_2.0_00',
 'hessian_sigma_2.0_01',
 'hessian_sigma_2.0_02',
 'hessian_sigma_2.0_11',
 'hessian_sigma_2.0_13',
 'Gaussian_time_0.0',
 'Gaussian_time_1.0',
 'Gaussian_time_6.0',
 'Gaussian_time_2.0',
 'Gaussian_space_0.0',
 'Gaussian_space_6.0',
 'diff_of_gauss_space_3.0_0.0',
 'diff_of_gauss_space_6.0_0.0',
 'diff_of_gauss_space_2.0_0.0',
 'diff_of_gauss_space_3.0_1.0',
 'diff_of_gauss_space_6.0_1.0',
 'diff_of_gauss_space_2.0_1.0',
 'diff_of_gauss_space_2.0_3.0',
 'diff_temp_min_Gauss_2.0',
 'diff_to_first_',
 'full_temp_min_Gauss_2.0',
 'first_',
 'last_']
In [10]:
IF.prepare()
In [11]:
IF.stack_features()
In [12]:
IF.feature_stack
Out[12]:
Array Chunk
Bytes 35.58 TiB 35.60 MiB
Shape (700, 330, 2016, 100, 105) (36, 36, 36, 100, 1)
Dask graph 1481760 chunks in 440 graph layers
Data type float64 numpy.ndarray
330 700 105 100 2016
In [13]:
# uncomment if you want to reduce the feature stack
IF.reduce_feature_stack(feature_names_to_use)
In [14]:
IF.make_xarray()
using reduced feature stack

Training

set up some objects

In [15]:
training_path_sample = os.path.join(training_path, sample)
if not os.path.exists(training_path_sample):
    os.mkdir(training_path_sample)
In [16]:
TS = training(training_path=training_path_sample)
TS.client = client
IF.client = client
TS.cluster = cluster
IF.cluster = cluster
TS.memlim = memlim
TS.n_workers = 2
There are existing training sets, run .train() if you want to use them:
label_image_x_500_time_81_.tif
label_image_x_556_time_79_.tif
label_image_x_590_time_99_.tif
label_image_y_300_time_36_.tif
label_image_y_80_time_67_.tif
label_image_z_100_time_88_.tif
label_image_z_1200_time_82_.tif
label_image_z_1304_time_51_.tif
label_image_z_1900_time_55_.tif
label_image_z_200_time_80_.tif
label_image_z_42_time_30_.tif
label_image_z_900_time_71_.tif

give the feature stack to the training class

In [17]:
TS.feat_data = IF.feature_xarray
In [18]:
IF.combined_feature_names = list(IF.feature_names) + list(IF.feature_names_time_independent)
TS.combined_feature_names = IF.combined_feature_names
In [19]:
for i in range(len(TS.combined_feature_names)):
    print(i, TS.combined_feature_names[i])
0 Gaussian_4D_Blur_0.0
1 Gaussian_4D_Blur_1.0
2 Gaussian_4D_Blur_3.0
3 Gaussian_4D_Blur_6.0
4 Gaussian_4D_Blur_2.0
5 diff_of_gauss_4D_1.0_0.0
6 diff_of_gauss_4D_3.0_0.0
7 diff_of_gauss_4D_6.0_0.0
8 diff_of_gauss_4D_2.0_0.0
9 diff_of_gauss_4D_3.0_1.0
10 diff_of_gauss_4D_6.0_1.0
11 diff_of_gauss_4D_2.0_1.0
12 diff_of_gauss_4D_6.0_3.0
13 diff_of_gauss_4D_2.0_3.0
14 diff_of_gauss_4D_2.0_6.0
15 Gradient_sigma_1.0_0
16 Gradient_sigma_1.0_1
17 Gradient_sigma_1.0_2
18 Gradient_sigma_1.0_3
19 hessian_sigma_1.0_00
20 hessian_sigma_1.0_01
21 hessian_sigma_1.0_02
22 hessian_sigma_1.0_03
23 hessian_sigma_1.0_11
24 hessian_sigma_1.0_12
25 hessian_sigma_1.0_13
26 hessian_sigma_1.0_22
27 hessian_sigma_1.0_23
28 hessian_sigma_1.0_33
29 Gradient_sigma_3.0_0
30 Gradient_sigma_3.0_1
31 Gradient_sigma_3.0_2
32 Gradient_sigma_3.0_3
33 hessian_sigma_3.0_00
34 hessian_sigma_3.0_01
35 hessian_sigma_3.0_02
36 hessian_sigma_3.0_03
37 hessian_sigma_3.0_11
38 hessian_sigma_3.0_12
39 hessian_sigma_3.0_13
40 hessian_sigma_3.0_22
41 hessian_sigma_3.0_23
42 hessian_sigma_3.0_33
43 Gradient_sigma_6.0_0
44 Gradient_sigma_6.0_1
45 Gradient_sigma_6.0_2
46 Gradient_sigma_6.0_3
47 hessian_sigma_6.0_00
48 hessian_sigma_6.0_01
49 hessian_sigma_6.0_02
50 hessian_sigma_6.0_03
51 hessian_sigma_6.0_11
52 hessian_sigma_6.0_12
53 hessian_sigma_6.0_13
54 hessian_sigma_6.0_22
55 hessian_sigma_6.0_23
56 hessian_sigma_6.0_33
57 Gradient_sigma_2.0_0
58 Gradient_sigma_2.0_1
59 Gradient_sigma_2.0_2
60 Gradient_sigma_2.0_3
61 hessian_sigma_2.0_00
62 hessian_sigma_2.0_01
63 hessian_sigma_2.0_02
64 hessian_sigma_2.0_03
65 hessian_sigma_2.0_11
66 hessian_sigma_2.0_12
67 hessian_sigma_2.0_13
68 hessian_sigma_2.0_22
69 hessian_sigma_2.0_23
70 hessian_sigma_2.0_33
71 Gaussian_time_0.0
72 Gaussian_time_1.0
73 Gaussian_time_3.0
74 Gaussian_time_6.0
75 Gaussian_time_2.0
76 diff_of_gauss_time_1.0_0.0
77 diff_of_gauss_time_3.0_0.0
78 diff_of_gauss_time_6.0_0.0
79 diff_of_gauss_time_2.0_0.0
80 diff_of_gauss_time_3.0_1.0
81 diff_of_gauss_time_6.0_1.0
82 diff_of_gauss_time_2.0_1.0
83 diff_of_gauss_time_6.0_3.0
84 diff_of_gauss_time_2.0_3.0
85 diff_of_gauss_time_2.0_6.0
86 Gaussian_space_0.0
87 Gaussian_space_1.0
88 Gaussian_space_3.0
89 Gaussian_space_6.0
90 Gaussian_space_2.0
91 diff_of_gauss_space_1.0_0.0
92 diff_of_gauss_space_3.0_0.0
93 diff_of_gauss_space_6.0_0.0
94 diff_of_gauss_space_2.0_0.0
95 diff_of_gauss_space_3.0_1.0
96 diff_of_gauss_space_6.0_1.0
97 diff_of_gauss_space_2.0_1.0
98 diff_of_gauss_space_6.0_3.0
99 diff_of_gauss_space_2.0_3.0
100 diff_of_gauss_space_2.0_6.0
101 diff_to_min_
102 diff_temp_min_Gauss_2.0
103 diff_to_first_
104 diff_to_last_
105 full_temp_mean_
106 full_temp_min_
107 full_temp_min_Gauss_2.0
108 first_
109 last_

interactive training

load training dict if you have one and want to use it

In [23]:
training_prefix = '093c73c'
TS.previous_training_dict = pickle.load(open(os.path.join(training_path, training_prefix+'_training_dict.p'), 'rb'))
TS.previous_feature_names = pickle.load(open(os.path.join(training_path, training_prefix+'_feature_names.p'), 'rb'))

reduce the previous training dict

In [25]:
num_feat = len(IF.combined_feature_names)
num_feat_to_use = len(feature_names_to_use)
feature_ids = np.zeros(num_feat, dtype=bool)
for i in range(num_feat):
    if IF.combined_feature_names[i] in feature_names_to_use:
        feature_ids[i] = True

TS.training_dict = {}
for training_set in TS.previous_training_dict.keys():
    X,y = TS.previous_training_dict[training_set]
    X = X[:,feature_ids]                          # since there is no copy, probably previous_training_dict and training_dict will become the same, reload the previous dict if necessary
    TS.training_dict[training_set] = X,y

print('Reduced the previous training dict to ',str(np.count_nonzero(feature_ids)),'/',str(num_feat),'features considering the ',str(num_feat_to_use),' desired features.')
print(str(num_feat_to_use-np.count_nonzero(feature_ids)), ' desired features were not found in the training dict.')
Reduced the previous training dict to  55 / 110 features considering the  55  desired features.
0  desired feature were not found in the training dict.

re-train with existing label sets. clear the training dictionary if necessary (training_dict)

don't use this if you already have a pickled training_dict
TODO: reduce training dict with feature_ids instead of retraining with the label images

In [ ]:
 
In [28]:
# TS.training_dict = {}
In [20]:
# TS.train()
training with existing label images
label_image_z_42_time_30_.tif
label_image_z_900_time_71_.tif
2025-04-10 09:12:53,131 - distributed.worker - ERROR - Failed to communicate with scheduler during heartbeat.
Traceback (most recent call last):
  File "/das/home/fische_r/miniconda3/lib/python3.12/site-packages/distributed/comm/tcp.py", line 226, in read
    frames_nosplit_nbytes_bin = await stream.read_bytes(fmt_size)
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
tornado.iostream.StreamClosedError: Stream is closed

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/das/home/fische_r/miniconda3/lib/python3.12/site-packages/distributed/worker.py", line 1269, in heartbeat
    response = await retry_operation(
               ^^^^^^^^^^^^^^^^^^^^^^
  File "/das/home/fische_r/miniconda3/lib/python3.12/site-packages/distributed/utils_comm.py", line 416, in retry_operation
    return await retry(
           ^^^^^^^^^^^^
  File "/das/home/fische_r/miniconda3/lib/python3.12/site-packages/distributed/utils_comm.py", line 395, in retry
    return await coro()
           ^^^^^^^^^^^^
  File "/das/home/fische_r/miniconda3/lib/python3.12/site-packages/distributed/core.py", line 1259, in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/das/home/fische_r/miniconda3/lib/python3.12/site-packages/distributed/core.py", line 1018, in send_recv
    response = await comm.read(deserializers=deserializers)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/das/home/fische_r/miniconda3/lib/python3.12/site-packages/distributed/comm/tcp.py", line 237, in read
    convert_stream_closed_error(self, e)
  File "/das/home/fische_r/miniconda3/lib/python3.12/site-packages/distributed/comm/tcp.py", line 137, in convert_stream_closed_error
    raise CommClosedError(f"in {obj}: {exc}") from exc
distributed.comm.core.CommClosedError: in <TCP (closed) ConnectionPool.heartbeat_worker local=tcp://127.0.0.1:44726 remote=tcp://127.0.0.1:45529>: Stream is closed
2025-04-10 09:12:53,143 - distributed.worker - ERROR - Failed to communicate with scheduler during heartbeat.
Traceback (most recent call last):
  File "/das/home/fische_r/miniconda3/lib/python3.12/site-packages/distributed/comm/tcp.py", line 226, in read
    frames_nosplit_nbytes_bin = await stream.read_bytes(fmt_size)
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
tornado.iostream.StreamClosedError: Stream is closed

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/das/home/fische_r/miniconda3/lib/python3.12/site-packages/distributed/worker.py", line 1269, in heartbeat
    response = await retry_operation(
               ^^^^^^^^^^^^^^^^^^^^^^
  File "/das/home/fische_r/miniconda3/lib/python3.12/site-packages/distributed/utils_comm.py", line 416, in retry_operation
    return await retry(
           ^^^^^^^^^^^^
  File "/das/home/fische_r/miniconda3/lib/python3.12/site-packages/distributed/utils_comm.py", line 395, in retry
    return await coro()
           ^^^^^^^^^^^^
  File "/das/home/fische_r/miniconda3/lib/python3.12/site-packages/distributed/core.py", line 1259, in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/das/home/fische_r/miniconda3/lib/python3.12/site-packages/distributed/core.py", line 1018, in send_recv
    response = await comm.read(deserializers=deserializers)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/das/home/fische_r/miniconda3/lib/python3.12/site-packages/distributed/comm/tcp.py", line 237, in read
    convert_stream_closed_error(self, e)
  File "/das/home/fische_r/miniconda3/lib/python3.12/site-packages/distributed/comm/tcp.py", line 137, in convert_stream_closed_error
    raise CommClosedError(f"in {obj}: {exc}") from exc
distributed.comm.core.CommClosedError: in <TCP (closed) ConnectionPool.heartbeat_worker local=tcp://127.0.0.1:44714 remote=tcp://127.0.0.1:45529>: Stream is closed
Process Dask Worker process (from Nanny):
2025-04-10 09:12:55,141 - distributed.nanny - ERROR - Worker process died unexpectedly
Process Dask Worker process (from Nanny):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/das/home/fische_r/miniconda3/lib/python3.12/asyncio/runners.py", line 118, in run
    return self._loop.run_until_complete(task)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/das/home/fische_r/miniconda3/lib/python3.12/asyncio/base_events.py", line 691, in run_until_complete
    return future.result()
           ^^^^^^^^^^^^^^^
  File "/das/home/fische_r/miniconda3/lib/python3.12/site-packages/distributed/nanny.py", line 985, in run
    await worker.finished()
  File "/das/home/fische_r/miniconda3/lib/python3.12/site-packages/distributed/core.py", line 494, in finished
    await self._event_finished.wait()
  File "/das/home/fische_r/miniconda3/lib/python3.12/asyncio/locks.py", line 212, in wait
    await fut
asyncio.exceptions.CancelledError

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/das/home/fische_r/miniconda3/lib/python3.12/asyncio/runners.py", line 118, in run
    return self._loop.run_until_complete(task)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/das/home/fische_r/miniconda3/lib/python3.12/asyncio/base_events.py", line 691, in run_until_complete
    return future.result()
           ^^^^^^^^^^^^^^^
  File "/das/home/fische_r/miniconda3/lib/python3.12/site-packages/distributed/nanny.py", line 985, in run
    await worker.finished()
  File "/das/home/fische_r/miniconda3/lib/python3.12/site-packages/distributed/core.py", line 494, in finished
    await self._event_finished.wait()
  File "/das/home/fische_r/miniconda3/lib/python3.12/asyncio/locks.py", line 212, in wait
    await fut
asyncio.exceptions.CancelledError

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/das/home/fische_r/miniconda3/lib/python3.12/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/das/home/fische_r/miniconda3/lib/python3.12/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/das/home/fische_r/miniconda3/lib/python3.12/site-packages/distributed/process.py", line 202, in _run
    target(*args, **kwargs)
  File "/das/home/fische_r/miniconda3/lib/python3.12/site-packages/distributed/nanny.py", line 1023, in _run
    asyncio_run(run(), loop_factory=get_loop_factory())
  File "/das/home/fische_r/miniconda3/lib/python3.12/asyncio/runners.py", line 195, in run
    return runner.run(main)
           ^^^^^^^^^^^^^^^^
  File "/das/home/fische_r/miniconda3/lib/python3.12/asyncio/runners.py", line 123, in run
    raise KeyboardInterrupt()
KeyboardInterrupt
  File "/das/home/fische_r/miniconda3/lib/python3.12/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/das/home/fische_r/miniconda3/lib/python3.12/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/das/home/fische_r/miniconda3/lib/python3.12/site-packages/distributed/process.py", line 202, in _run
    target(*args, **kwargs)
  File "/das/home/fische_r/miniconda3/lib/python3.12/site-packages/distributed/nanny.py", line 1023, in _run
    asyncio_run(run(), loop_factory=get_loop_factory())
  File "/das/home/fische_r/miniconda3/lib/python3.12/asyncio/runners.py", line 195, in run
    return runner.run(main)
           ^^^^^^^^^^^^^^^^
  File "/das/home/fische_r/miniconda3/lib/python3.12/asyncio/runners.py", line 123, in run
    raise KeyboardInterrupt()
KeyboardInterrupt
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[20], line 1
----> 1 TS.train()

File ~/lib/pytrainseg/V2_training.py:253, in training.train(self, clear_dict, redo, first_set)
    251     continue
    252 print(label_name)
--> 253 X, y = self.training_set_per_image(label_name, path, feat_data)
    254 self.training_dict[label_name] = X,y
    255 if flag:

File ~/lib/pytrainseg/V2_training.py:210, in training.training_set_per_image(self, label_name, trainingpath, feat_data, lazy)
    208 if type(feat_stack) is not np.ndarray:
    209         fut = self.client.scatter(feat_stack)
--> 210         fut = fut.result()
    211         fut = fut.compute()
    212         feat_stack = fut.data

File ~/miniconda3/lib/python3.12/site-packages/distributed/client.py:401, in Future.result(self, timeout)
    399 self._verify_initialized()
    400 with shorten_traceback():
--> 401     return self.client.sync(self._result, callback_timeout=timeout)

File ~/miniconda3/lib/python3.12/threading.py:655, in Event.wait(self, timeout)
    653 signaled = self._flag
    654 if not signaled:
--> 655     signaled = self._cond.wait(timeout)
    656 return signaled

File ~/miniconda3/lib/python3.12/threading.py:359, in Condition.wait(self, timeout)
    357 else:
    358     if timeout > 0:
--> 359         gotit = waiter.acquire(True, timeout)
    360     else:
    361         gotit = waiter.acquire(False)

KeyboardInterrupt: 

import training dict of other samples

(replace sample name and repeat for multiple samples), if necessary check features for overlap

In [22]:
# TODO: make better naming convention below when pickling the training_dict

# oldsample = '4'
# oldgitsha = 'e42ad75' #'109a7ce3' #retrain at one point
# # if oldsample == '4':
# #     training_dict_old = pickle.load(open(os.path.join(toppathSSD, '05_water_GDL_ML', '4', 'ec4415d_training_dict_without_loc_feat.p'), 'rb'))
# # else:
# training_dict_old = pickle.load(open(os.path.join(training_path, oldsample,  oldgitsha+'_training_dict.p'),'rb'))
# oldfeatures = pickle.load(open(os.path.join(training_path, oldsample,  oldgitsha+'_feature_names.p'),'rb'))
    
#     # pickle.dump(TS.training_dict, open(os.path.join(TS.training_path, pytrain_git_sha+'_training_dict.p'),'wb'))
# # pickle.dump(TS.feature_names, open(os.path.join(TS.training_path, pytrain_git_sha+'_feature_names.p'),'wb'))

# for key in training_dict_old.keys():
#     TS.training_dict[oldsample+key] = training_dict_old[key]

suggest a new training coordinate

currently retraining with new feature stack not properly implemented. Workaround: choose from the exiting training sets and train with them (additional labeling optional)

In [26]:
TS.suggest_training_set()
You could try  y = 67  at time step  56
In [27]:
c1 = 'z'
p1 = 1304
c2 = 'time'  # c2 has always to be time currently . Removed option to chose two spatial coordinates because was not useful, but left syntax to keep the potential to add it again in the future
p2 = 51

activate the training set and load label images if existent

In [30]:
TS.load_training_set(c1, p1, c2, p2)
existing label set loaded
In [31]:
# reboot the dask client if it lost a worker
if not len(client.cluster.workers)>1:   
    client = reboot_client(client)
    TS.client = client
    IF.client = client
In [32]:
#  TODO: move  the routine into training class
# TODO: add if clause to not do anything if the coordinates did not change and the stack has already been calcualted

feat_data = TS.feat_data
[c1,p1,c2,p2] = TS.current_coordinates
newslice = True

if c1 == 'x' and c2 == 'time':
    feat_stack = feat_data['feature_stack'].sel(x = p1, time = p2)
    feat_stack_t_idp = feat_data['feature_stack_time_independent'].sel(x = p1, time_0 = 0)
# elif c1 == 'x' and c2 == 'y':
#     feat_stack = feat_data['feature_stack'].sel(x = p1, y = p2)#.data
#     feat_stack_t_idp = feat_data['feature_stack_time_independent'].sel(x = p1, y = p2)
# elif c1 == 'x' and c2 == 'z':
#     feat_stack = feat_data['feature_stack'].sel(x = p1, z = p2)#.data
#     feat_stack_t_idp = feat_data['feature_stack_time_independent'].sel(x = p1, z = p2)
# elif c1 == 'y' and c2 == 'z':
#     feat_stack = feat_data['feature_stack'].sel(y = p1, z = p2)#.data
#     feat_stack_t_idp = feat_data['feature_stack_time_independent'].sel(y = p1, z = p2)
elif c1 == 'y' and c2 == 'time':
    feat_stack = feat_data['feature_stack'].sel(y = p1, time = p2)#.data
    feat_stack_t_idp = feat_data['feature_stack_time_independent'].sel(y = p1, time_0 = 0)
elif c1 == 'z' and c2 == 'time':
    feat_stack = feat_data['feature_stack'].sel(z = p1, time = p2)#.data
    feat_stack_t_idp = feat_data['feature_stack_time_independent'].sel(z = p1, time_0 = 0)

calculate the feature stack for the selected slice

time dependent features

In [33]:
#  TODO: move into training class and keep up to date with dask development for the best way to do this
# watch the dashboard for some colorful process tracing, having a eye on the "workers" can help to deal with memory issues
if type(feat_stack) is not np.ndarray:
        fut = client.scatter(feat_stack)
        fut = fut.result()
        fut = fut.compute()
        feat_stack = fut
        try:
            # restart dask client to wipe leaked memory
            client.restart()
        except:
            # do a full reboot if this fails
            client = reboot_client(client)
            TS.client = client
            IF.client = client   

check if the cluster survived the calculation

In [34]:
# needs to stay to be interactive
client.cluster.workers
Out[34]:
{0: <Nanny: tcp://127.0.0.1:36697, threads: 28>,
 1: <Nanny: tcp://127.0.0.1:42459, threads: 28>}
In [35]:
# # reboot the dask client if it lost a worker
if not len(client.cluster.workers)>1:   
    client = reboot_client(client)
    TS.client = client
    IF.client = client

time independent features

In [36]:
# move into training class at one pointl
if type(feat_stack_t_idp) is not np.ndarray:
        fut = client.scatter(feat_stack_t_idp)
        fut = fut.result()
        fut = fut.compute()
        feat_stack_t_idp = fut
        try:
            client.restart()
        except:
            client = reboot_client(client)
            TS.client = client
            IF.client = client   

check if the cluster survived the calculation

In [37]:
# needs to stay to be interactive
client.cluster.workers
Out[37]:
{0: <Nanny: tcp://127.0.0.1:40593, threads: 28>,
 1: <Nanny: tcp://127.0.0.1:42459, threads: 28>}
In [38]:
print('I am back from calculating and have still '+str(len(client.cluster.workers))+' workers')
I am back from calculating and have still 2 workers

merge the two feature stacks

In [39]:
# needs to stay to be interactive
# feat_stack_full = np.concatenate([feat_stack, feat_stack_t_idp], axis = 2)
feat_stack = np.concatenate([feat_stack, feat_stack_t_idp], axis = 2) #this line to save a bit RAM
In [40]:
feat_stack.shape
Out[40]:
(700, 330, 55)
In [41]:
# this necessary ??
# TS.current_feat_stack_full = feat_stack_full
TS.current_feat_stack_full = feat_stack
if type(TS.current_feat_stack_full) is not np.ndarray:
    TS.current_computed = False
else:
    TS.current_computed = True
In [42]:
TS.current_feat_stack_full.shape
Out[42]:
(700, 330, 55)

canvas for labeling and training

give index of feature in case you want to use it for training

can be very useful to for example label static components

In [50]:
# TODO: consider reduced feature number
i  = 10
print(i, IF.combined_feature_names[i])
10 diff_of_gauss_4D_6.0_1.0
In [51]:
im8 = TS.current_im8 # execute this line to get the orignal display

set up the canvas

In [58]:
# move to training class at low prio
# needs interactive buttons if the code is hidden in class
# button for color and alpha
# for now, leave exposed

alpha = 0.05 #transparance

## in case you want to zoom, but it works better if you pan (not zoom) in the browser. panning can be done with the trackpad of a laptop, but there is also some key combinations TODO lock up
# zoom1 = (-500,-1)
# zoom2 = (600,1400)
# zoom1 = (0, -1)
# zoom2 = (0, -1)
#trick: use gaussian_time_4_0 to label static phases ()
# im8 = display_feature(103, TS, feat_stack)

# print(IF.combined_feature_names[-20])
print('original shape: ',im8.shape)
im8_display = im8.copy() #[zoom1[0]:zoom1[1], zoom2[0]:zoom2[1]]
# print('diyplay shape : ',im8_display.shape,' at: ', (zoom1[0], zoom2[0]))

resultim = TS.current_result.copy()
resultim_display = resultim #[zoom1[0]:zoom1[1], zoom2[0]:zoom2[1]]


width = im8_display.shape[1]
height = im8_display.shape[0]
Mcanvas = MultiCanvas(4, width=width, height=height)
background = Mcanvas[0]
resultdisplay = Mcanvas[2]
truthdisplay = Mcanvas[1]
canvas = Mcanvas[3]
canvas.sync_image_data = True
drawing = False
position = None
shape = []
image_data = np.stack((im8_display, im8_display, im8_display), axis=2)
background.put_image_data(image_data, 0, 0)
slidealpha = IntSlider(description="Result overlay", value=0.15)
resultdisplay.global_alpha = alpha #slidealpha.value
if np.any(resultim>0):
    result_data = np.stack(((resultim_display==0), (resultim_display==1),(resultim_display==2)), axis=2)*255
    mask3 = resultim_display==3
    result_data[mask3,0] = 255
    result_data[mask3,1] = 255
else:
    result_data = np.stack((0*resultim, 0*resultim, 0*resultim), axis=2)
resultdisplay.put_image_data(result_data, 0, 0)
canvas.on_mouse_down(on_mouse_down)
canvas.on_mouse_move(on_mouse_move)
canvas.on_mouse_up(on_mouse_up)
picker = ColorPicker(description="Color:", value="#ff0000") #red
# picker = ColorPicker(description="Color:", value="#0000ff") #blue
# picker = ColorPicker(description="Color:", value="#00ff00") #green
# picker = ColorPicker(description="Color:", value="#ffff00") #yellow
### all currently supported color options. -> gives the possibility to label 4 different phases

link((picker, "value"), (canvas, "stroke_style"))
link((picker, "value"), (canvas, "fill_style"))
link((slidealpha, "value"), (resultdisplay, "global_alpha"))

HBox((Mcanvas,picker))
# HBox((Mcanvas,)) #picker 
original shape:  (700, 330)
Out[58]:
HBox(children=(MultiCanvas(height=700, width=330), ColorPicker(value='#ff0000', description='Color:')))

adjust grayscale range of the display of the image by playing with the lines below

has no effect on the data. If you mess up, there are some lines at beggining of the canvas cell or below to get the previous display back

In [53]:
tfs.display.plot_im_histogram(im8)
# im8 = TS.current_im8 # uncomment this line to get the original back display (raw at original grayscale range)
# im8 = tfs.display.adjust_image_contrast(im8,0,100)
No description has been provided for this image

update training set if labels are ok or clear the current canvas by re-running the cell above if not

automatically updates the stored label image on the disk

In [54]:
# same as above

label_set = canvas.get_image_data()

test = TS.current_truth.copy()

test[np.bitwise_and(label_set[:,:,0]>0,np.bitwise_xor(label_set[:,:,0]>0,label_set[:,:,1]>0))] = 1
test[label_set[:,:,1]>0] = 2
test[label_set[:,:,2]>0] = 4 #order of 4&3 flipped for legacy reasons (existing training labels)
test[np.bitwise_and(label_set[:,:,0]>0,label_set[:,:,1]>0)] = 3

TS.current_truth = test.copy()
imageio.imsave(TS.current_truthpath, TS.current_truth)

inspect labels and training progress

can be sometimes useful

In [55]:
# same as above

fig, axes = plt.subplots(1,4, figsize=(20,10))
axes[0].imshow(TS.current_result, 'gray')
axes[0].set_title('current result')
axes[1].imshow(TS.current_im8, 'gray')
axes[1].set_title('original grayscale')

# TS.current_diff_im = TS.current_im-TS.current_first_im
# TS.current_diff_im = TS.current_diff_im/TS.current_diff_im.max()*255
# axes[2].imshow(-TS.current_diff_im)#,vmin=6e4)
# axes[3].imshow(im8old, 'gray')
# axes[3].imshow(TS.current_first_im, 'gray')
axes[2].imshow(TS.current_truth)
axes[2].set_title('label image')
if TS.current_computed:
    axes[3].imshow(TS.current_feat_stack_full[:,:,i], 'gray')
    axes[3].set_title(str(i)+': '+IF.combined_feature_names[i])
else:
    axes[3].imshow(TS.current_result, 'gray')
    axes[3].set_title('current result')
# for ax in axes:
    # ax.set_xticks([])
    # ax.set_yticks([])
No description has been provided for this image

train!

In [56]:
TS.train_slice()
training and classifying

revise feature importance to decide on omiting a few to make calculation more efficient

TODO implementation to be done

In [60]:
# TODO consider reduced feature stack
plt.figure(figsize=(16,9))
plt.stem(np.array(IF.combined_feature_names)[feature_ids], TS.clf.feature_importances_,'x')
plt.xticks(rotation=90)
plt.ylabel('importance') 
Out[60]:
Text(0, 0.5, 'importance')
No description has been provided for this image
In [54]:
pickle.dump(TS.training_dict, open(os.path.join(training_path, pytrain_git_sha+'_training_dict.p'),'wb'))
pickle.dump(TS.combined_feature_names, open(os.path.join(training_path, pytrain_git_sha+'_feature_names.p'),'wb'))
In [55]:
training_path
Out[55]:
'/das/home/fische_r/interlaces/Tomcat_2'

Select features to keep to reduce feature stack

throw away 50% of the features by using the median of the feature importance
the yarn sample showed no susceptible redcution of segmentation quality

In [56]:
importance_median = np.median(TS.clf.feature_importances_)
feature_ids = TS.clf.feature_importances_>importance_median
In [57]:
features_to_use = []
for i in range(len(feature_ids)):
    if feature_ids[i]:
        features_to_use.append(IF.combined_feature_names[i])
    
In [58]:
# TODO write txt instead of pickle dump
093c73c
In [59]:
features_to_use
Out[59]:
['Gaussian_4D_Blur_0.0',
 'Gaussian_4D_Blur_1.0',
 'Gaussian_4D_Blur_6.0',
 'diff_of_gauss_4D_6.0_0.0',
 'diff_of_gauss_4D_6.0_1.0',
 'Gradient_sigma_1.0_0',
 'Gradient_sigma_1.0_1',
 'Gradient_sigma_1.0_3',
 'hessian_sigma_1.0_00',
 'hessian_sigma_1.0_01',
 'hessian_sigma_1.0_11',
 'Gradient_sigma_3.0_3',
 'hessian_sigma_3.0_00',
 'hessian_sigma_3.0_01',
 'hessian_sigma_3.0_02',
 'hessian_sigma_3.0_03',
 'hessian_sigma_3.0_11',
 'hessian_sigma_3.0_33',
 'Gradient_sigma_6.0_0',
 'Gradient_sigma_6.0_1',
 'Gradient_sigma_6.0_2',
 'Gradient_sigma_6.0_3',
 'hessian_sigma_6.0_01',
 'hessian_sigma_6.0_03',
 'hessian_sigma_6.0_11',
 'hessian_sigma_6.0_12',
 'hessian_sigma_6.0_13',
 'hessian_sigma_6.0_22',
 'hessian_sigma_6.0_23',
 'hessian_sigma_6.0_33',
 'Gradient_sigma_2.0_0',
 'Gradient_sigma_2.0_3',
 'hessian_sigma_2.0_00',
 'hessian_sigma_2.0_01',
 'hessian_sigma_2.0_02',
 'hessian_sigma_2.0_11',
 'hessian_sigma_2.0_13',
 'Gaussian_time_0.0',
 'Gaussian_time_1.0',
 'Gaussian_time_6.0',
 'Gaussian_time_2.0',
 'Gaussian_space_0.0',
 'Gaussian_space_6.0',
 'diff_of_gauss_space_3.0_0.0',
 'diff_of_gauss_space_6.0_0.0',
 'diff_of_gauss_space_2.0_0.0',
 'diff_of_gauss_space_3.0_1.0',
 'diff_of_gauss_space_6.0_1.0',
 'diff_of_gauss_space_2.0_1.0',
 'diff_of_gauss_space_2.0_3.0',
 'diff_temp_min_Gauss_2.0',
 'diff_to_first_',
 'full_temp_min_Gauss_2.0',
 'first_',
 'last_']

Segmentation

In [ ]: