From 17485fbb6b25686554ad12be9e56b48b031db5c8 Mon Sep 17 00:00:00 2001 From: Ivan Usov Date: Wed, 24 Feb 2021 10:31:19 +0100 Subject: [PATCH] Add postprocess_raw.py --- scripts/postprocess_raw.py | 140 +++++++++++++++++++++++++++++++++++++ 1 file changed, 140 insertions(+) create mode 100644 scripts/postprocess_raw.py diff --git a/scripts/postprocess_raw.py b/scripts/postprocess_raw.py new file mode 100644 index 0000000..3d3a285 --- /dev/null +++ b/scripts/postprocess_raw.py @@ -0,0 +1,140 @@ +import os +import struct + +import bitshuffle +import h5py +import numpy as np +from bitshuffle.h5 import H5_COMPRESS_LZ4, H5FILTER # pylint: disable=no-name-in-module + +# bitshuffle hdf5 filter params +BLOCK_SIZE = 2048 +compargs = {"compression": H5FILTER, "compression_opts": (BLOCK_SIZE, H5_COMPRESS_LZ4)} +# limit bitshuffle omp to a single thread +# a better fix would be to use bitshuffle compiled without omp support +os.environ["OMP_NUM_THREADS"] = "1" + +DTYPE = np.dtype(np.uint16) +DTYPE_SIZE = DTYPE.itemsize + +MODULE_SIZE_X = 1024 +MODULE_SIZE_Y = 512 + + +def postprocess_raw( + source, dest, disabled_modules=(), index=None, compression=False, batch_size=100 +): + # a function for 'visititems' should have the args (name, object) + def _visititems(name, obj): + if isinstance(obj, h5py.Group): + h5_dest.create_group(name) + + elif isinstance(obj, h5py.Dataset): + dset_source = h5_source[name] + + # process all but the raw data + if name != data_dset: + if name.startswith("data"): + # datasets with data per image, so indexing should be applied + if index is None: + data = dset_source[:] + else: + data = dset_source[index, :] + + args = {"shape": data.shape} + h5_dest.create_dataset_like(name, dset_source, data=data, **args) + else: + h5_dest.create_dataset_like(name, dset_source, data=dset_source) + + else: + raise TypeError(f"Unknown h5py object type {obj}") + + # copy group/dataset attributes if it's not a dataset with the actual data + if name != data_dset: + for key, value in h5_source[name].attrs.items(): + h5_dest[name].attrs[key] = value + + with h5py.File(source, "r") as h5_source, h5py.File(dest, "w") as h5_dest: + detector_name = h5_source["general/detector_name"][()].decode() + data_dset = f"data/{detector_name}/data" + + # traverse the source file and copy/index all datasets, except the raw data + h5_source.visititems(_visititems) + + # now process the raw data + dset = h5_source[data_dset] + + args = dict() + if index is None: + n_images = dset.shape[0] + else: + index = np.array(index) + n_images = len(index) + + n_modules = dset.shape[1] // MODULE_SIZE_Y + out_shape = (MODULE_SIZE_Y * (n_modules - len(disabled_modules)), MODULE_SIZE_X) + + args["shape"] = (n_images, *out_shape) + args["maxshape"] = (n_images, *out_shape) + args["chunks"] = (1, *out_shape) + + if compression: + args.update(compargs) + + h5_dest.create_dataset_like(data_dset, dset, **args) + + # calculate and save module_map + module_map = [] + tmp = 0 + for ind in range(n_modules): + if ind in disabled_modules: + module_map.append(-1) + else: + module_map.append(tmp) + tmp += 1 + + h5_dest[f"data/{detector_name}/module_map"] = np.tile(module_map, (n_images, 1)) + + # prepare buffers to be reused for every batch + read_buffer = np.empty((batch_size, *dset.shape[1:]), dtype=DTYPE) + out_buffer = np.zeros((batch_size, *out_shape), dtype=DTYPE) + + # process and write data in batches + for batch_start_ind in range(0, n_images, batch_size): + batch_range = range(batch_start_ind, min(batch_start_ind + batch_size, n_images)) + + if index is None: + batch_ind = np.array(batch_range) + else: + batch_ind = index[batch_range] + + # TODO: avoid unnecessary buffers + read_buffer_view = read_buffer[: len(batch_ind)] + out_buffer_view = out_buffer[: len(batch_ind)] + + # Avoid a stride-bottleneck, see https://github.com/h5py/h5py/issues/977 + if np.sum(np.diff(batch_ind)) == len(batch_ind) - 1: + # consecutive index values + dset.read_direct(read_buffer_view, source_sel=np.s_[batch_ind]) + else: + for i, j in enumerate(batch_ind): + dset.read_direct(read_buffer_view, source_sel=np.s_[j], dest_sel=np.s_[i]) + + for i, m in enumerate(module_map): + if m == -1: + continue + + read_slice = read_buffer_view[:, i * MODULE_SIZE_Y : (i + 1) * MODULE_SIZE_Y, :] + out_slice = out_buffer_view[:, m * MODULE_SIZE_Y : (m + 1) * MODULE_SIZE_Y, :] + out_slice[:] = read_slice + + bytes_num_elem = struct.pack(">q", out_shape[0] * out_shape[1] * DTYPE_SIZE) + bytes_block_size = struct.pack(">i", BLOCK_SIZE * DTYPE_SIZE) + header = bytes_num_elem + bytes_block_size + + for pos, im in zip(batch_range, out_buffer_view): + if compression: + byte_array = header + bitshuffle.compress_lz4(im, BLOCK_SIZE).tobytes() + else: + byte_array = im.tobytes() + + h5_dest[data_dset].id.write_direct_chunk((pos, 0, 0), byte_array)