84 lines
3.2 KiB
Python
84 lines
3.2 KiB
Python
"""
|
|
TODO: stor git_sha
|
|
"""
|
|
|
|
|
|
import xarray as xr
|
|
import pickle
|
|
import os
|
|
import numpy as np
|
|
import dask
|
|
|
|
|
|
class segmentation:
|
|
def __init__(self,
|
|
feature_path = None,
|
|
classifier_path = None,
|
|
training_path = None
|
|
):
|
|
# TODO: get these paths from training class
|
|
self.feature_path = feature_path
|
|
self.clf_path = classifier_path
|
|
self.training_path = training_path
|
|
|
|
def import_classifier(self, clf):
|
|
self.clf = clf
|
|
|
|
def load_classifier(self):
|
|
self.clf = pickle.load(open(self.clf_path, 'rb'))
|
|
|
|
def open_feature_data(self):
|
|
self.feat_data = xr.open_dataset(self.feature_path)
|
|
self.feature_names = self.feat_data['feature'].data
|
|
self.lazy = False
|
|
|
|
def import_feature_data(self, data):
|
|
self.feat_data = data
|
|
self.feature_names = self.feat_data['feature'].data
|
|
self.lazy = False
|
|
|
|
def import_lazy_feature_data(self, data):
|
|
self.feat_data = data
|
|
self.feature_names = self.feat_data['feature'].data
|
|
self.feature_names_t_idp = self.feat_data['feature_time_independent'].data
|
|
self.combined_names = self.feature_names + self.feature_names_t_idp
|
|
self.lazy = True
|
|
|
|
# def full_5D_feature_stack(self):
|
|
# feat_stack = self.feat_data['feature_stack']
|
|
# feat_stack_t_idp = self.feat_data['feature_stack_time_independent']
|
|
#
|
|
# TODO: this cannot work ?! first t_idp needs to be expanded in time
|
|
# feat_stack = dask.array.concatenate([feat_stack, feat_stack_t_idp], axis=4)
|
|
# self.feat_stack = feat_stack
|
|
|
|
# def classify_all(self):
|
|
# # TODO: streamline classifier and feature calculation. maybe integrate both within dask
|
|
# # especially if original and segmented dataset don't fit in RAM
|
|
# feat_stack = self.feat_data['feature_stack']
|
|
# num_feat = feat_stack.shape[-1]
|
|
# clf = self.clf
|
|
# if not self.lazy:
|
|
# print('classifying ...')
|
|
# # result = clf.predict(feat_stack.reshape(-1,num_feat))
|
|
# result = clf.predict(feat_stack.data.reshape(-1,num_feat))
|
|
# else:
|
|
# print('calculate feature stack and then classify. might take a while ... ')
|
|
# result = clf.predict(feat_stack.data.reshape(-1,num_feat))
|
|
# result = result.reshape(feat_stack[...,0].shape).astype(np.uint8)
|
|
# self.segmented_data = result
|
|
|
|
def store_segmented_data(self):
|
|
path = os.path.join(self.training_path, 'segmented.nc')
|
|
|
|
#TODO: propagate labels from raw data
|
|
#TODO: if self.segmented_data is a dask array, rechunk for saving
|
|
shp = self.segmented_data.shape
|
|
data = xr.Dataset({'segmented': (['x','y','z','time'], self.segmented_data)},
|
|
coords = {'x': np.arange(shp[0]),
|
|
'y': np.arange(shp[1]),
|
|
'z': np.arange(shp[2]),
|
|
'time': np.arange(shp[3]),
|
|
'feature': self.feature_names}
|
|
)
|
|
data.to_netcdf(path) |