Files
2021-02-24 10:31:19 +01:00

141 lines
5.1 KiB
Python

import os
import struct
import bitshuffle
import h5py
import numpy as np
from bitshuffle.h5 import H5_COMPRESS_LZ4, H5FILTER # pylint: disable=no-name-in-module
# bitshuffle hdf5 filter params
BLOCK_SIZE = 2048
compargs = {"compression": H5FILTER, "compression_opts": (BLOCK_SIZE, H5_COMPRESS_LZ4)}
# limit bitshuffle omp to a single thread
# a better fix would be to use bitshuffle compiled without omp support
os.environ["OMP_NUM_THREADS"] = "1"
DTYPE = np.dtype(np.uint16)
DTYPE_SIZE = DTYPE.itemsize
MODULE_SIZE_X = 1024
MODULE_SIZE_Y = 512
def postprocess_raw(
source, dest, disabled_modules=(), index=None, compression=False, batch_size=100
):
# a function for 'visititems' should have the args (name, object)
def _visititems(name, obj):
if isinstance(obj, h5py.Group):
h5_dest.create_group(name)
elif isinstance(obj, h5py.Dataset):
dset_source = h5_source[name]
# process all but the raw data
if name != data_dset:
if name.startswith("data"):
# datasets with data per image, so indexing should be applied
if index is None:
data = dset_source[:]
else:
data = dset_source[index, :]
args = {"shape": data.shape}
h5_dest.create_dataset_like(name, dset_source, data=data, **args)
else:
h5_dest.create_dataset_like(name, dset_source, data=dset_source)
else:
raise TypeError(f"Unknown h5py object type {obj}")
# copy group/dataset attributes if it's not a dataset with the actual data
if name != data_dset:
for key, value in h5_source[name].attrs.items():
h5_dest[name].attrs[key] = value
with h5py.File(source, "r") as h5_source, h5py.File(dest, "w") as h5_dest:
detector_name = h5_source["general/detector_name"][()].decode()
data_dset = f"data/{detector_name}/data"
# traverse the source file and copy/index all datasets, except the raw data
h5_source.visititems(_visititems)
# now process the raw data
dset = h5_source[data_dset]
args = dict()
if index is None:
n_images = dset.shape[0]
else:
index = np.array(index)
n_images = len(index)
n_modules = dset.shape[1] // MODULE_SIZE_Y
out_shape = (MODULE_SIZE_Y * (n_modules - len(disabled_modules)), MODULE_SIZE_X)
args["shape"] = (n_images, *out_shape)
args["maxshape"] = (n_images, *out_shape)
args["chunks"] = (1, *out_shape)
if compression:
args.update(compargs)
h5_dest.create_dataset_like(data_dset, dset, **args)
# calculate and save module_map
module_map = []
tmp = 0
for ind in range(n_modules):
if ind in disabled_modules:
module_map.append(-1)
else:
module_map.append(tmp)
tmp += 1
h5_dest[f"data/{detector_name}/module_map"] = np.tile(module_map, (n_images, 1))
# prepare buffers to be reused for every batch
read_buffer = np.empty((batch_size, *dset.shape[1:]), dtype=DTYPE)
out_buffer = np.zeros((batch_size, *out_shape), dtype=DTYPE)
# process and write data in batches
for batch_start_ind in range(0, n_images, batch_size):
batch_range = range(batch_start_ind, min(batch_start_ind + batch_size, n_images))
if index is None:
batch_ind = np.array(batch_range)
else:
batch_ind = index[batch_range]
# TODO: avoid unnecessary buffers
read_buffer_view = read_buffer[: len(batch_ind)]
out_buffer_view = out_buffer[: len(batch_ind)]
# Avoid a stride-bottleneck, see https://github.com/h5py/h5py/issues/977
if np.sum(np.diff(batch_ind)) == len(batch_ind) - 1:
# consecutive index values
dset.read_direct(read_buffer_view, source_sel=np.s_[batch_ind])
else:
for i, j in enumerate(batch_ind):
dset.read_direct(read_buffer_view, source_sel=np.s_[j], dest_sel=np.s_[i])
for i, m in enumerate(module_map):
if m == -1:
continue
read_slice = read_buffer_view[:, i * MODULE_SIZE_Y : (i + 1) * MODULE_SIZE_Y, :]
out_slice = out_buffer_view[:, m * MODULE_SIZE_Y : (m + 1) * MODULE_SIZE_Y, :]
out_slice[:] = read_slice
bytes_num_elem = struct.pack(">q", out_shape[0] * out_shape[1] * DTYPE_SIZE)
bytes_block_size = struct.pack(">i", BLOCK_SIZE * DTYPE_SIZE)
header = bytes_num_elem + bytes_block_size
for pos, im in zip(batch_range, out_buffer_view):
if compression:
byte_array = header + bitshuffle.compress_lz4(im, BLOCK_SIZE).tobytes()
else:
byte_array = im.tobytes()
h5_dest[data_dset].id.write_direct_chunk((pos, 0, 0), byte_array)