feat(ScanHistoryMessage): add data shape and dtype to ScanHistoryMessage

This commit is contained in:
2025-07-08 17:45:08 +02:00
committed by Christian Appel
parent 5ba57d50ee
commit 824315b7ef
5 changed files with 49 additions and 2 deletions
+14
View File
@@ -661,6 +661,7 @@ class ScanHistoryMessage(BECMessage):
scan_name (str): Name of the scan.
num_points (int): Number of points in the scan.
request_inputs (dict, optional): Inputs for the scan request, if available.
stored_data_info (dict[str, dict[str, _StoredDataInfo]], optional): Information about the stored data for each device in the scan.
metadata (dict, optional): Additional metadata.
"""
@@ -676,6 +677,19 @@ class ScanHistoryMessage(BECMessage):
scan_name: str
num_points: int
request_inputs: dict | None = None
stored_data_info: dict[str, dict[str, _StoredDataInfo]] | None = None
class _StoredDataInfo(BaseModel):
"""Internal class to store data info for each device in the scan history message
Args:
shape (tuple): Shape of the data for the device.
dtype (str, optional): Data type of the data for the device. Defaults to None.
"""
shape: tuple[int, ...] = Field(default_factory=tuple)
dtype: str | None = None
class ScanBaselineMessage(BECMessage):
@@ -77,7 +77,6 @@ class AsyncWriter(threading.Thread):
self.append_shapes = {}
self.written_devices = set()
self.file_handle = None
self.cursor = defaultdict(dict)
def initialize_stream_keys(self):
@@ -136,7 +135,6 @@ class AsyncWriter(threading.Thread):
self.poll_and_write_data()
# run one last time to get any remaining data
self.poll_and_write_data(final=True)
# self.send_file_message(done=True, successful=True)
logger.info(f"Finished writing async data file {self.file_path}")
# pylint: disable=broad-except
except Exception:
@@ -5,6 +5,7 @@ import json
import os
import traceback
import typing
from collections import defaultdict
import h5py
@@ -212,6 +213,7 @@ class HDF5FileWriter:
def __init__(self, file_writer_manager):
self.file_writer_manager = file_writer_manager
self.stored_data_info = defaultdict(dict)
@staticmethod
def _create_device_data_storage(data):
@@ -300,9 +302,36 @@ class HDF5FileWriter:
file_handle = file_handle or h5py.File(file_path, mode=mode)
try:
HDF5StorageWriter.write(writer_storage, file_handle)
self.update_data_info(file_handle)
finally:
file_handle.close()
def update_data_info(self, file_handle: h5py.File):
"""
Update the stored data information in the file handle.
Args:
file_handle (h5py.File): The HDF5 file handle to update.
"""
device_group = file_handle.get("/entry/collection/devices")
for device_name, device_group in device_group.items():
if not isinstance(device_group, h5py.Group):
continue
for signal_name, signal_group in device_group.items():
if not isinstance(signal_group, h5py.Group):
continue
if "value" in signal_group:
value_dset = signal_group["value"]
if not isinstance(value_dset, h5py.Dataset):
continue
value_dset_shape = value_dset.shape
if value_dset_shape == ():
value_dset_shape = (1,)
self.stored_data_info[device_name][signal_name] = {
"shape": value_dset_shape,
"dtype": value_dset.dtype.name,
}
def dict_to_storage(storage, data):
for key, val in data.items():
@@ -367,6 +367,7 @@ class FileWriterManager(BECService):
MessageEndpoints.public_file(scan_id, "master"),
messages.FileMessage(file_path=file_path, done=True, successful=successful),
)
history_msg = messages.ScanHistoryMessage(
scan_id=scan_id,
scan_number=storage.scan_number,
@@ -378,6 +379,7 @@ class FileWriterManager(BECService):
num_points=storage.num_points,
scan_name=storage.metadata.get("scan_name"),
request_inputs=storage.metadata.get("request_inputs", {}),
stored_data_info=self.file_writer.stored_data_info or {},
)
self.connector.xadd(
topic=MessageEndpoints.scan_history(), msg_dict={"data": history_msg}, max_size=10000
@@ -188,6 +188,10 @@ def test_write_data_storage(segments, baseline, metadata, hdf5_file_writer):
file_writer.write("./test.h5", storage, configuration_data={})
data_info = file_writer.stored_data_info.get("samx")
assert data_info.get("samx").get("shape") == (2,)
assert data_info.get("samx_setpoint").get("shape") == (2,)
assert data_info.get("samx").get("dtype") == "float64"
# open file and check that time stamps are correct
with h5py.File("./test.h5", "r") as test_file:
assert (