345 lines
12 KiB
Python
345 lines
12 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
Created on Wed Aug 17 16:04:34 2022
|
|
|
|
to be loaded in Jupyter
|
|
|
|
TODO: store git commit sha
|
|
@author: fische_r
|
|
"""
|
|
import xarray as xr
|
|
import os
|
|
from skimage import io, exposure
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import dask
|
|
import pickle
|
|
|
|
#the classifier
|
|
from sklearn.ensemble import RandomForestClassifier
|
|
|
|
#stuff for painting on the image
|
|
# from ipywidgets import Image
|
|
# from ipywidgets import ColorPicker, IntSlider, link, AppLayout, HBox
|
|
# from ipycanvas import hold_canvas, MultiCanvas #RoughCanvas,Canvas,
|
|
|
|
default_classifier = RandomForestClassifier(n_estimators = 300, n_jobs=-1, random_state = 42, max_features=None)
|
|
|
|
def extract_training_data(truth, feat_stack):
|
|
#pixelwise training data
|
|
phase1 = truth==1
|
|
phase2 = truth==2
|
|
phase3 = truth==4
|
|
X1 = feat_stack[phase1]
|
|
y1 = np.zeros(X1.shape[0])
|
|
X2 = feat_stack[phase2]
|
|
y2 = np.ones(X2.shape[0])
|
|
X3 = feat_stack[phase3]
|
|
y3 = 2*np.ones(X3.shape[0])
|
|
|
|
y = np.concatenate([y1,y2,y3])
|
|
X = np.concatenate([X1,X2,X3])
|
|
|
|
return X,y
|
|
|
|
def classify(X,y,im, feat_stack, clf):
|
|
# TODO: allow choice and manipulation of ML method
|
|
clf.fit(X, y)
|
|
num_feat = feat_stack.shape[-1]
|
|
ypred = clf.predict(feat_stack.reshape(-1,num_feat))
|
|
result = ypred.reshape(im.shape).astype(np.uint8)
|
|
return result, clf
|
|
|
|
def training_function(im, truth, feat_stack, training_dict, slice_name, clf):
|
|
flag = False
|
|
slices = list(training_dict.keys())
|
|
if slice_name in slices:
|
|
slices.remove(slice_name)
|
|
if len(slices)>0:
|
|
flag = True
|
|
Xall = training_dict[slices[0]][0]
|
|
yall = training_dict[slices[0]][1]
|
|
for i in range(1,len(slices)): #why was there 1, in range ? because first initiates the Xall, np.stack could be an alternative way
|
|
Xall = np.concatenate([Xall, training_dict[slices[i]][0]])
|
|
yall = np.concatenate([yall, training_dict[slices[i]][1]])
|
|
|
|
X,y = extract_training_data(truth, feat_stack)
|
|
|
|
print('training and classifying')
|
|
|
|
if flag:
|
|
Xt = np.concatenate([Xall,X])
|
|
yt = np.concatenate([yall,y])
|
|
Xall = None
|
|
yall = None
|
|
else:
|
|
Xt = X
|
|
yt = y
|
|
result, clf = classify(Xt, yt, im, feat_stack, clf)
|
|
|
|
# store training data of current slice in dict
|
|
training_dict[slice_name] = (X,y)
|
|
return result, clf, training_dict
|
|
|
|
def adjust_image_contrast(im, low, high):
|
|
# careful, rescales in every case to 255
|
|
im = exposure.rescale_intensity(im, (low,high))*255
|
|
return im
|
|
|
|
def plot_im_histogram(im):
|
|
hist = np.histogram(im, bins=100)
|
|
plt.plot(hist[1][1:],hist[0])
|
|
|
|
def extract_coords(labelname):
|
|
parts = labelname.split('_')
|
|
c1 = parts[2]
|
|
p1 = int(parts[3])
|
|
c2 = parts[4]
|
|
p2 = int(parts[5])
|
|
return c1, p1, c2, p2
|
|
|
|
def training_set_per_image(label_name, trainingpath, feat_data, lazy = False):
|
|
c1, p1, c2, p2 = extract_coords(label_name)
|
|
# print(label_name)
|
|
# print(c1, p1, c2, p2)
|
|
truth = io.imread(os.path.join(trainingpath, label_name))
|
|
|
|
# temporary workaround, make general
|
|
if c1 == 'x' and c2 == 'time':
|
|
feat_stack = feat_data['feature_stack'].sel(x = p1, time = p2).data
|
|
elif c1 == 'x' and c2 == 'y':
|
|
feat_stack = feat_data['feature_stack'].sel(x = p1, y = p2).data
|
|
elif c1 == 'x' and c2 == 'z':
|
|
feat_stack = feat_data['feature_stack'].sel(x = p1, z = p2).data
|
|
elif c1 == 'y' and c2 == 'z':
|
|
feat_stack = feat_data['feature_stack'].sel(y = p1, z = p2).data
|
|
elif c1 == 'y' and c2 == 'time':
|
|
feat_stack = feat_data['feature_stack'].sel(y = p1, time = p2).data
|
|
elif c1 == 'z' and c2 == 'time':
|
|
feat_stack = feat_data['feature_stack'].sel(z = p1, time = p2).data
|
|
else:
|
|
print('coordinates not found')
|
|
# if lazy:
|
|
# print('Need to actually calculate the features for each slice, seems inefficient')
|
|
# # not sure how efficient this is
|
|
# # multiple training slices might be faster with the chunks
|
|
# # probably getting the feature stack at least as persist is better
|
|
# feat_stack = feat_stack.compute()
|
|
# else:
|
|
if type(feat_stack) is not np.ndarray:
|
|
feat_stack = feat_stack.compute()
|
|
|
|
X, y = extract_training_data(truth, feat_stack)
|
|
return X,y
|
|
|
|
class train_segmentation:
|
|
def __init__(self,
|
|
feature_path = None,
|
|
training_path = None,
|
|
clf_method = default_classifier
|
|
):
|
|
self.feature_path = feature_path
|
|
self.training_path = training_path
|
|
self.label_path = os.path.join(training_path, 'label_images')
|
|
self.training_dict = {}
|
|
self.clf_method = clf_method
|
|
|
|
if not os.path.exists(self.label_path):
|
|
os.mkdir(self.label_path)
|
|
|
|
self.lazy = False #maybe this can be more elegant without flag
|
|
|
|
def open_feature_data(self):
|
|
self.feat_data = xr.open_dataset(self.feature_path)
|
|
self.feature_names = self.feat_data['feature'].data
|
|
|
|
def import_feature_data(self, data):
|
|
self.feat_data = data
|
|
self.feature_names = self.feat_data['feature'].data
|
|
self.lazy = False
|
|
|
|
def import_lazy_feature_data(self, data, rawdata, lazy = True):
|
|
self.raw_data = rawdata
|
|
self.feat_data = data
|
|
self.feature_names = self.feat_data['feature'].data
|
|
self.lazy = lazy
|
|
|
|
def suggest_training_set(self):
|
|
dimensions = list(self.feat_data.coords.keys())[:-1]
|
|
|
|
test_dims = np.random.choice(dimensions, 2, replace=False)
|
|
p1 = np.random.choice(range(len(self.feat_data[test_dims[0]])))
|
|
p2 = np.random.choice(range(len(self.feat_data[test_dims[1]])))
|
|
|
|
print('You could try ',test_dims[0],'=',str(p1),' and ',test_dims[1],'=',str(p2))
|
|
print('However, please sort it like the original '+''.join(dimensions))
|
|
|
|
def load_training_set(self, c1, p1, c2, p2):
|
|
|
|
data = self.feat_data['feature_stack']
|
|
rawdata = self.raw_data['tomo']
|
|
|
|
|
|
# this has to be possible in a more elegant way!
|
|
if c1 == 'x':
|
|
stage1 = rawdata.sel(x=p1)
|
|
stage1feat = data.sel(x=p1)
|
|
elif c1 == 'y':
|
|
stage1 = rawdata.sel(y=p1)
|
|
stage1feat = data.sel(y=p1)
|
|
elif c1 == 'z':
|
|
stage1 = rawdata.sel(z=p1)
|
|
stage1feat = data.sel(z=p1)
|
|
elif c1 == 'time':
|
|
print('time cannot be first coordinate')
|
|
|
|
if not c1=='time':
|
|
if c2 == 'x':
|
|
im = stage1.sel( x = p2).data #feature = 'original',
|
|
feat_stack = stage1feat.sel(x = p2).data
|
|
imfirst = None
|
|
elif c2 == 'y':
|
|
im = stage1.sel( y = p2).data
|
|
feat_stack = stage1feat.sel(y = p2).data
|
|
imfirst = None
|
|
elif c2 == 'z':
|
|
im = stage1.sel( z = p2).data
|
|
feat_stack = stage1feat.sel(z = p2).data
|
|
imfirst = None
|
|
elif c2 == 'time':
|
|
im = stage1.sel(time = p2).data
|
|
feat_stack = stage1feat.sel(time = p2).data
|
|
imfirst = stage1.sel(time = 0).data
|
|
|
|
if self.lazy:
|
|
# get the reference images directly as numpy array
|
|
# im = im.compute()
|
|
# if imfirst is not None:
|
|
# imfirst = imfirst.compute()
|
|
#already start calculating the feature stack
|
|
feat_stack.persist()
|
|
self.current_computed = False
|
|
else:
|
|
self.current_computed = True
|
|
|
|
if type(im) is not np.ndarray:
|
|
im = im.compute()
|
|
if imfirst is not None and type(imfirst) is not np.ndarray:
|
|
imfirst = imfirst.compute()
|
|
|
|
im8 = im-im.min()
|
|
im8 = im8/im8.max()*255
|
|
|
|
if imfirst is not None:
|
|
diff = im-imfirst
|
|
# diff = diff/diff.max()*255
|
|
self.current_diff_im = diff
|
|
else:
|
|
self.current_diff_im = None
|
|
|
|
slice_name = ''.join([c1,'_',str(p1),'_',c2,'_',str(p2),'_'])
|
|
truthpath = os.path.join(self.label_path, ''.join(['label_image_',slice_name,'.tif']))
|
|
|
|
resultim = np.zeros(im.shape, dtype=np.uint8)
|
|
if os.path.exists(truthpath):
|
|
truth = io.imread(truthpath)
|
|
print('existing label set loaded')
|
|
else:
|
|
truth = resultim.copy()
|
|
|
|
self.current_coordinates = [c1,p1,c2,p2]
|
|
self.current_im = im
|
|
self.current_im8 = im8
|
|
self.current_feat_stack = feat_stack
|
|
self.current_first_im = imfirst
|
|
self.current_truth = truth
|
|
self.current_result = resultim
|
|
self.current_truthpath = truthpath
|
|
self.current_slice_name = slice_name
|
|
#TODO: maybe keep lazy computed feature stacks of older slices somewhere and purge if using up too much RAM
|
|
|
|
def train_slice(self):
|
|
#fetch variables
|
|
im = self.current_im
|
|
truth = self.current_truth
|
|
training_dict = self.training_dict
|
|
slice_name = self.current_slice_name
|
|
feat_stack = self.current_feat_stack
|
|
|
|
#re-consider these lines
|
|
if self.lazy and not self.current_computed and type(feat_stack) is not np.ndarray:
|
|
print('now actually calculating the features')
|
|
# self.current_feat_stack.rechunk('auto') #why rechunk 'auto' ?! if anything should be something small fot massive parallel
|
|
feat_stack = feat_stack.compute()
|
|
self.current_computed = True
|
|
if type(feat_stack) is not np.ndarray:
|
|
print('feat_stack is not a numpy array!')
|
|
feat_stack = feat_stack.compute()
|
|
|
|
self.current_feat_stack = feat_stack
|
|
#train
|
|
# print('training ...')
|
|
resultim, clf, training_dict = training_function(im, truth, feat_stack, training_dict, slice_name, self.clf_method)
|
|
|
|
# update variables
|
|
|
|
self.current_result = resultim
|
|
self.training_dict = training_dict #this necessary ? yes!
|
|
self.clf = clf
|
|
|
|
def plot_importance(self, figsize=(16,9)):
|
|
plt.figure(figsize=figsize)
|
|
plt.stem(self.feature_names, self.clf.feature_importances_,'x')
|
|
plt.xticks(rotation=90)
|
|
plt.ylabel('importance')
|
|
|
|
def pickle_training_dict(self):
|
|
pickle.dump(self.training_dict, open(os.path.join(self.training_path, 'training_dict.p'),'wb'))
|
|
|
|
def pickle_classifier(self):
|
|
pickle.dump(self.clf, open(os.path.join(self.training_path, 'classifier.p'),'wb'))
|
|
|
|
def train(self):
|
|
path = self.label_path
|
|
feat_data = self.feat_data #probably requires computed feature data, added the flag below
|
|
training_dict = {}
|
|
labelnames = os.listdir(path)
|
|
flag = True
|
|
for label_name in labelnames:
|
|
print(label_name)
|
|
X, y = training_set_per_image(label_name, path, feat_data, self.lazy)
|
|
training_dict[label_name] = X,y
|
|
if flag:
|
|
Xall = X
|
|
yall = y
|
|
flag = False
|
|
else:
|
|
Xall = np.concatenate([Xall,X])
|
|
yall = np.concatenate([yall,y])
|
|
|
|
clf = self.clf_method
|
|
clf.fit(Xall, yall)
|
|
self.clf = clf
|
|
self.training_dict = training_dict
|
|
|
|
|
|
def train_parallel(self):
|
|
#come up with a way to train() in parallel
|
|
# maybe with dask.delayed
|
|
path = self.label_path
|
|
feat_data = self.feat_data
|
|
training_dict = {}
|
|
labelnames = os.listdir(path)
|
|
XX = []
|
|
yy = []
|
|
for label_name in labelnames:
|
|
X, y = dask.delayed(training_set_per_image)(label_name, path, feat_data.persist(), self.lazy)
|
|
training_dict[label_name] = X,y
|
|
XX.append(X)
|
|
yy.append(y)
|
|
Xall = np.concatenate(XX)
|
|
yall = np.concatenate(yy)
|
|
clf = self.clf_method
|
|
clf.fit(Xall, yall)
|
|
self.clf = clf
|
|
self.training_dict = training_dict |