Files
pyTrainSeg/training_functions.py
T
2022-10-13 15:13:18 +02:00

345 lines
12 KiB
Python

# -*- coding: utf-8 -*-
"""
Created on Wed Aug 17 16:04:34 2022
to be loaded in Jupyter
TODO: store git commit sha
@author: fische_r
"""
import xarray as xr
import os
from skimage import io, exposure
import matplotlib.pyplot as plt
import numpy as np
import dask
import pickle
#the classifier
from sklearn.ensemble import RandomForestClassifier
#stuff for painting on the image
# from ipywidgets import Image
# from ipywidgets import ColorPicker, IntSlider, link, AppLayout, HBox
# from ipycanvas import hold_canvas, MultiCanvas #RoughCanvas,Canvas,
default_classifier = RandomForestClassifier(n_estimators = 300, n_jobs=-1, random_state = 42, max_features=None)
def extract_training_data(truth, feat_stack):
#pixelwise training data
phase1 = truth==1
phase2 = truth==2
phase3 = truth==4
X1 = feat_stack[phase1]
y1 = np.zeros(X1.shape[0])
X2 = feat_stack[phase2]
y2 = np.ones(X2.shape[0])
X3 = feat_stack[phase3]
y3 = 2*np.ones(X3.shape[0])
y = np.concatenate([y1,y2,y3])
X = np.concatenate([X1,X2,X3])
return X,y
def classify(X,y,im, feat_stack, clf):
# TODO: allow choice and manipulation of ML method
clf.fit(X, y)
num_feat = feat_stack.shape[-1]
ypred = clf.predict(feat_stack.reshape(-1,num_feat))
result = ypred.reshape(im.shape).astype(np.uint8)
return result, clf
def training_function(im, truth, feat_stack, training_dict, slice_name, clf):
flag = False
slices = list(training_dict.keys())
if slice_name in slices:
slices.remove(slice_name)
if len(slices)>0:
flag = True
Xall = training_dict[slices[0]][0]
yall = training_dict[slices[0]][1]
for i in range(1,len(slices)): #why was there 1, in range ? because first initiates the Xall, np.stack could be an alternative way
Xall = np.concatenate([Xall, training_dict[slices[i]][0]])
yall = np.concatenate([yall, training_dict[slices[i]][1]])
X,y = extract_training_data(truth, feat_stack)
print('training and classifying')
if flag:
Xt = np.concatenate([Xall,X])
yt = np.concatenate([yall,y])
Xall = None
yall = None
else:
Xt = X
yt = y
result, clf = classify(Xt, yt, im, feat_stack, clf)
# store training data of current slice in dict
training_dict[slice_name] = (X,y)
return result, clf, training_dict
def adjust_image_contrast(im, low, high):
# careful, rescales in every case to 255
im = exposure.rescale_intensity(im, (low,high))*255
return im
def plot_im_histogram(im):
hist = np.histogram(im, bins=100)
plt.plot(hist[1][1:],hist[0])
def extract_coords(labelname):
parts = labelname.split('_')
c1 = parts[2]
p1 = int(parts[3])
c2 = parts[4]
p2 = int(parts[5])
return c1, p1, c2, p2
def training_set_per_image(label_name, trainingpath, feat_data, lazy = False):
c1, p1, c2, p2 = extract_coords(label_name)
# print(label_name)
# print(c1, p1, c2, p2)
truth = io.imread(os.path.join(trainingpath, label_name))
# temporary workaround, make general
if c1 == 'x' and c2 == 'time':
feat_stack = feat_data['feature_stack'].sel(x = p1, time = p2).data
elif c1 == 'x' and c2 == 'y':
feat_stack = feat_data['feature_stack'].sel(x = p1, y = p2).data
elif c1 == 'x' and c2 == 'z':
feat_stack = feat_data['feature_stack'].sel(x = p1, z = p2).data
elif c1 == 'y' and c2 == 'z':
feat_stack = feat_data['feature_stack'].sel(y = p1, z = p2).data
elif c1 == 'y' and c2 == 'time':
feat_stack = feat_data['feature_stack'].sel(y = p1, time = p2).data
elif c1 == 'z' and c2 == 'time':
feat_stack = feat_data['feature_stack'].sel(z = p1, time = p2).data
else:
print('coordinates not found')
# if lazy:
# print('Need to actually calculate the features for each slice, seems inefficient')
# # not sure how efficient this is
# # multiple training slices might be faster with the chunks
# # probably getting the feature stack at least as persist is better
# feat_stack = feat_stack.compute()
# else:
if type(feat_stack) is not np.ndarray:
feat_stack = feat_stack.compute()
X, y = extract_training_data(truth, feat_stack)
return X,y
class train_segmentation:
def __init__(self,
feature_path = None,
training_path = None,
clf_method = default_classifier
):
self.feature_path = feature_path
self.training_path = training_path
self.label_path = os.path.join(training_path, 'label_images')
self.training_dict = {}
self.clf_method = clf_method
if not os.path.exists(self.label_path):
os.mkdir(self.label_path)
self.lazy = False #maybe this can be more elegant without flag
def open_feature_data(self):
self.feat_data = xr.open_dataset(self.feature_path)
self.feature_names = self.feat_data['feature'].data
def import_feature_data(self, data):
self.feat_data = data
self.feature_names = self.feat_data['feature'].data
self.lazy = False
def import_lazy_feature_data(self, data, rawdata, lazy = True):
self.raw_data = rawdata
self.feat_data = data
self.feature_names = self.feat_data['feature'].data
self.lazy = lazy
def suggest_training_set(self):
dimensions = list(self.feat_data.coords.keys())[:-1]
test_dims = np.random.choice(dimensions, 2, replace=False)
p1 = np.random.choice(range(len(self.feat_data[test_dims[0]])))
p2 = np.random.choice(range(len(self.feat_data[test_dims[1]])))
print('You could try ',test_dims[0],'=',str(p1),' and ',test_dims[1],'=',str(p2))
print('However, please sort it like the original '+''.join(dimensions))
def load_training_set(self, c1, p1, c2, p2):
data = self.feat_data['feature_stack']
rawdata = self.raw_data['tomo']
# this has to be possible in a more elegant way!
if c1 == 'x':
stage1 = rawdata.sel(x=p1)
stage1feat = data.sel(x=p1)
elif c1 == 'y':
stage1 = rawdata.sel(y=p1)
stage1feat = data.sel(y=p1)
elif c1 == 'z':
stage1 = rawdata.sel(z=p1)
stage1feat = data.sel(z=p1)
elif c1 == 'time':
print('time cannot be first coordinate')
if not c1=='time':
if c2 == 'x':
im = stage1.sel( x = p2).data #feature = 'original',
feat_stack = stage1feat.sel(x = p2).data
imfirst = None
elif c2 == 'y':
im = stage1.sel( y = p2).data
feat_stack = stage1feat.sel(y = p2).data
imfirst = None
elif c2 == 'z':
im = stage1.sel( z = p2).data
feat_stack = stage1feat.sel(z = p2).data
imfirst = None
elif c2 == 'time':
im = stage1.sel(time = p2).data
feat_stack = stage1feat.sel(time = p2).data
imfirst = stage1.sel(time = 0).data
if self.lazy:
# get the reference images directly as numpy array
# im = im.compute()
# if imfirst is not None:
# imfirst = imfirst.compute()
#already start calculating the feature stack
feat_stack.persist()
self.current_computed = False
else:
self.current_computed = True
if type(im) is not np.ndarray:
im = im.compute()
if imfirst is not None and type(imfirst) is not np.ndarray:
imfirst = imfirst.compute()
im8 = im-im.min()
im8 = im8/im8.max()*255
if imfirst is not None:
diff = im-imfirst
# diff = diff/diff.max()*255
self.current_diff_im = diff
else:
self.current_diff_im = None
slice_name = ''.join([c1,'_',str(p1),'_',c2,'_',str(p2),'_'])
truthpath = os.path.join(self.label_path, ''.join(['label_image_',slice_name,'.tif']))
resultim = np.zeros(im.shape, dtype=np.uint8)
if os.path.exists(truthpath):
truth = io.imread(truthpath)
print('existing label set loaded')
else:
truth = resultim.copy()
self.current_coordinates = [c1,p1,c2,p2]
self.current_im = im
self.current_im8 = im8
self.current_feat_stack = feat_stack
self.current_first_im = imfirst
self.current_truth = truth
self.current_result = resultim
self.current_truthpath = truthpath
self.current_slice_name = slice_name
#TODO: maybe keep lazy computed feature stacks of older slices somewhere and purge if using up too much RAM
def train_slice(self):
#fetch variables
im = self.current_im
truth = self.current_truth
training_dict = self.training_dict
slice_name = self.current_slice_name
feat_stack = self.current_feat_stack
#re-consider these lines
if self.lazy and not self.current_computed and type(feat_stack) is not np.ndarray:
print('now actually calculating the features')
# self.current_feat_stack.rechunk('auto') #why rechunk 'auto' ?! if anything should be something small fot massive parallel
feat_stack = feat_stack.compute()
self.current_computed = True
if type(feat_stack) is not np.ndarray:
print('feat_stack is not a numpy array!')
feat_stack = feat_stack.compute()
self.current_feat_stack = feat_stack
#train
# print('training ...')
resultim, clf, training_dict = training_function(im, truth, feat_stack, training_dict, slice_name, self.clf_method)
# update variables
self.current_result = resultim
self.training_dict = training_dict #this necessary ? yes!
self.clf = clf
def plot_importance(self, figsize=(16,9)):
plt.figure(figsize=figsize)
plt.stem(self.feature_names, self.clf.feature_importances_,'x')
plt.xticks(rotation=90)
plt.ylabel('importance')
def pickle_training_dict(self):
pickle.dump(self.training_dict, open(os.path.join(self.training_path, 'training_dict.p'),'wb'))
def pickle_classifier(self):
pickle.dump(self.clf, open(os.path.join(self.training_path, 'classifier.p'),'wb'))
def train(self):
path = self.label_path
feat_data = self.feat_data #probably requires computed feature data, added the flag below
training_dict = {}
labelnames = os.listdir(path)
flag = True
for label_name in labelnames:
print(label_name)
X, y = training_set_per_image(label_name, path, feat_data, self.lazy)
training_dict[label_name] = X,y
if flag:
Xall = X
yall = y
flag = False
else:
Xall = np.concatenate([Xall,X])
yall = np.concatenate([yall,y])
clf = self.clf_method
clf.fit(Xall, yall)
self.clf = clf
self.training_dict = training_dict
def train_parallel(self):
#come up with a way to train() in parallel
# maybe with dask.delayed
path = self.label_path
feat_data = self.feat_data
training_dict = {}
labelnames = os.listdir(path)
XX = []
yy = []
for label_name in labelnames:
X, y = dask.delayed(training_set_per_image)(label_name, path, feat_data.persist(), self.lazy)
training_dict[label_name] = X,y
XX.append(X)
yy.append(y)
Xall = np.concatenate(XX)
yall = np.concatenate(yy)
clf = self.clf_method
clf.fit(Xall, yall)
self.clf = clf
self.training_dict = training_dict