mirror of
https://github.com/paulscherrerinstitute/sf_daq_buffer.git
synced 2026-05-01 17:32:21 +02:00
Add postprocess_raw.py
This commit is contained in:
@@ -0,0 +1,140 @@
|
||||
import os
|
||||
import struct
|
||||
|
||||
import bitshuffle
|
||||
import h5py
|
||||
import numpy as np
|
||||
from bitshuffle.h5 import H5_COMPRESS_LZ4, H5FILTER # pylint: disable=no-name-in-module
|
||||
|
||||
# bitshuffle hdf5 filter params
|
||||
BLOCK_SIZE = 2048
|
||||
compargs = {"compression": H5FILTER, "compression_opts": (BLOCK_SIZE, H5_COMPRESS_LZ4)}
|
||||
# limit bitshuffle omp to a single thread
|
||||
# a better fix would be to use bitshuffle compiled without omp support
|
||||
os.environ["OMP_NUM_THREADS"] = "1"
|
||||
|
||||
DTYPE = np.dtype(np.uint16)
|
||||
DTYPE_SIZE = DTYPE.itemsize
|
||||
|
||||
MODULE_SIZE_X = 1024
|
||||
MODULE_SIZE_Y = 512
|
||||
|
||||
|
||||
def postprocess_raw(
|
||||
source, dest, disabled_modules=(), index=None, compression=False, batch_size=100
|
||||
):
|
||||
# a function for 'visititems' should have the args (name, object)
|
||||
def _visititems(name, obj):
|
||||
if isinstance(obj, h5py.Group):
|
||||
h5_dest.create_group(name)
|
||||
|
||||
elif isinstance(obj, h5py.Dataset):
|
||||
dset_source = h5_source[name]
|
||||
|
||||
# process all but the raw data
|
||||
if name != data_dset:
|
||||
if name.startswith("data"):
|
||||
# datasets with data per image, so indexing should be applied
|
||||
if index is None:
|
||||
data = dset_source[:]
|
||||
else:
|
||||
data = dset_source[index, :]
|
||||
|
||||
args = {"shape": data.shape}
|
||||
h5_dest.create_dataset_like(name, dset_source, data=data, **args)
|
||||
else:
|
||||
h5_dest.create_dataset_like(name, dset_source, data=dset_source)
|
||||
|
||||
else:
|
||||
raise TypeError(f"Unknown h5py object type {obj}")
|
||||
|
||||
# copy group/dataset attributes if it's not a dataset with the actual data
|
||||
if name != data_dset:
|
||||
for key, value in h5_source[name].attrs.items():
|
||||
h5_dest[name].attrs[key] = value
|
||||
|
||||
with h5py.File(source, "r") as h5_source, h5py.File(dest, "w") as h5_dest:
|
||||
detector_name = h5_source["general/detector_name"][()].decode()
|
||||
data_dset = f"data/{detector_name}/data"
|
||||
|
||||
# traverse the source file and copy/index all datasets, except the raw data
|
||||
h5_source.visititems(_visititems)
|
||||
|
||||
# now process the raw data
|
||||
dset = h5_source[data_dset]
|
||||
|
||||
args = dict()
|
||||
if index is None:
|
||||
n_images = dset.shape[0]
|
||||
else:
|
||||
index = np.array(index)
|
||||
n_images = len(index)
|
||||
|
||||
n_modules = dset.shape[1] // MODULE_SIZE_Y
|
||||
out_shape = (MODULE_SIZE_Y * (n_modules - len(disabled_modules)), MODULE_SIZE_X)
|
||||
|
||||
args["shape"] = (n_images, *out_shape)
|
||||
args["maxshape"] = (n_images, *out_shape)
|
||||
args["chunks"] = (1, *out_shape)
|
||||
|
||||
if compression:
|
||||
args.update(compargs)
|
||||
|
||||
h5_dest.create_dataset_like(data_dset, dset, **args)
|
||||
|
||||
# calculate and save module_map
|
||||
module_map = []
|
||||
tmp = 0
|
||||
for ind in range(n_modules):
|
||||
if ind in disabled_modules:
|
||||
module_map.append(-1)
|
||||
else:
|
||||
module_map.append(tmp)
|
||||
tmp += 1
|
||||
|
||||
h5_dest[f"data/{detector_name}/module_map"] = np.tile(module_map, (n_images, 1))
|
||||
|
||||
# prepare buffers to be reused for every batch
|
||||
read_buffer = np.empty((batch_size, *dset.shape[1:]), dtype=DTYPE)
|
||||
out_buffer = np.zeros((batch_size, *out_shape), dtype=DTYPE)
|
||||
|
||||
# process and write data in batches
|
||||
for batch_start_ind in range(0, n_images, batch_size):
|
||||
batch_range = range(batch_start_ind, min(batch_start_ind + batch_size, n_images))
|
||||
|
||||
if index is None:
|
||||
batch_ind = np.array(batch_range)
|
||||
else:
|
||||
batch_ind = index[batch_range]
|
||||
|
||||
# TODO: avoid unnecessary buffers
|
||||
read_buffer_view = read_buffer[: len(batch_ind)]
|
||||
out_buffer_view = out_buffer[: len(batch_ind)]
|
||||
|
||||
# Avoid a stride-bottleneck, see https://github.com/h5py/h5py/issues/977
|
||||
if np.sum(np.diff(batch_ind)) == len(batch_ind) - 1:
|
||||
# consecutive index values
|
||||
dset.read_direct(read_buffer_view, source_sel=np.s_[batch_ind])
|
||||
else:
|
||||
for i, j in enumerate(batch_ind):
|
||||
dset.read_direct(read_buffer_view, source_sel=np.s_[j], dest_sel=np.s_[i])
|
||||
|
||||
for i, m in enumerate(module_map):
|
||||
if m == -1:
|
||||
continue
|
||||
|
||||
read_slice = read_buffer_view[:, i * MODULE_SIZE_Y : (i + 1) * MODULE_SIZE_Y, :]
|
||||
out_slice = out_buffer_view[:, m * MODULE_SIZE_Y : (m + 1) * MODULE_SIZE_Y, :]
|
||||
out_slice[:] = read_slice
|
||||
|
||||
bytes_num_elem = struct.pack(">q", out_shape[0] * out_shape[1] * DTYPE_SIZE)
|
||||
bytes_block_size = struct.pack(">i", BLOCK_SIZE * DTYPE_SIZE)
|
||||
header = bytes_num_elem + bytes_block_size
|
||||
|
||||
for pos, im in zip(batch_range, out_buffer_view):
|
||||
if compression:
|
||||
byte_array = header + bitshuffle.compress_lz4(im, BLOCK_SIZE).tobytes()
|
||||
else:
|
||||
byte_array = im.tobytes()
|
||||
|
||||
h5_dest[data_dset].id.write_direct_chunk((pos, 0, 0), byte_array)
|
||||
Reference in New Issue
Block a user