mirror of
https://github.com/bec-project/bec.git
synced 2026-06-01 15:58:31 +02:00
feat(ScanHistoryMessage): add data shape and dtype to ScanHistoryMessage
This commit is contained in:
@@ -661,6 +661,7 @@ class ScanHistoryMessage(BECMessage):
|
||||
scan_name (str): Name of the scan.
|
||||
num_points (int): Number of points in the scan.
|
||||
request_inputs (dict, optional): Inputs for the scan request, if available.
|
||||
stored_data_info (dict[str, dict[str, _StoredDataInfo]], optional): Information about the stored data for each device in the scan.
|
||||
metadata (dict, optional): Additional metadata.
|
||||
|
||||
"""
|
||||
@@ -676,6 +677,19 @@ class ScanHistoryMessage(BECMessage):
|
||||
scan_name: str
|
||||
num_points: int
|
||||
request_inputs: dict | None = None
|
||||
stored_data_info: dict[str, dict[str, _StoredDataInfo]] | None = None
|
||||
|
||||
|
||||
class _StoredDataInfo(BaseModel):
|
||||
"""Internal class to store data info for each device in the scan history message
|
||||
|
||||
Args:
|
||||
shape (tuple): Shape of the data for the device.
|
||||
dtype (str, optional): Data type of the data for the device. Defaults to None.
|
||||
"""
|
||||
|
||||
shape: tuple[int, ...] = Field(default_factory=tuple)
|
||||
dtype: str | None = None
|
||||
|
||||
|
||||
class ScanBaselineMessage(BECMessage):
|
||||
|
||||
@@ -77,7 +77,6 @@ class AsyncWriter(threading.Thread):
|
||||
self.append_shapes = {}
|
||||
self.written_devices = set()
|
||||
self.file_handle = None
|
||||
|
||||
self.cursor = defaultdict(dict)
|
||||
|
||||
def initialize_stream_keys(self):
|
||||
@@ -136,7 +135,6 @@ class AsyncWriter(threading.Thread):
|
||||
self.poll_and_write_data()
|
||||
# run one last time to get any remaining data
|
||||
self.poll_and_write_data(final=True)
|
||||
# self.send_file_message(done=True, successful=True)
|
||||
logger.info(f"Finished writing async data file {self.file_path}")
|
||||
# pylint: disable=broad-except
|
||||
except Exception:
|
||||
|
||||
@@ -5,6 +5,7 @@ import json
|
||||
import os
|
||||
import traceback
|
||||
import typing
|
||||
from collections import defaultdict
|
||||
|
||||
import h5py
|
||||
|
||||
@@ -212,6 +213,7 @@ class HDF5FileWriter:
|
||||
|
||||
def __init__(self, file_writer_manager):
|
||||
self.file_writer_manager = file_writer_manager
|
||||
self.stored_data_info = defaultdict(dict)
|
||||
|
||||
@staticmethod
|
||||
def _create_device_data_storage(data):
|
||||
@@ -300,9 +302,36 @@ class HDF5FileWriter:
|
||||
file_handle = file_handle or h5py.File(file_path, mode=mode)
|
||||
try:
|
||||
HDF5StorageWriter.write(writer_storage, file_handle)
|
||||
self.update_data_info(file_handle)
|
||||
finally:
|
||||
file_handle.close()
|
||||
|
||||
def update_data_info(self, file_handle: h5py.File):
|
||||
"""
|
||||
Update the stored data information in the file handle.
|
||||
|
||||
Args:
|
||||
file_handle (h5py.File): The HDF5 file handle to update.
|
||||
"""
|
||||
device_group = file_handle.get("/entry/collection/devices")
|
||||
for device_name, device_group in device_group.items():
|
||||
if not isinstance(device_group, h5py.Group):
|
||||
continue
|
||||
for signal_name, signal_group in device_group.items():
|
||||
if not isinstance(signal_group, h5py.Group):
|
||||
continue
|
||||
if "value" in signal_group:
|
||||
value_dset = signal_group["value"]
|
||||
if not isinstance(value_dset, h5py.Dataset):
|
||||
continue
|
||||
value_dset_shape = value_dset.shape
|
||||
if value_dset_shape == ():
|
||||
value_dset_shape = (1,)
|
||||
self.stored_data_info[device_name][signal_name] = {
|
||||
"shape": value_dset_shape,
|
||||
"dtype": value_dset.dtype.name,
|
||||
}
|
||||
|
||||
|
||||
def dict_to_storage(storage, data):
|
||||
for key, val in data.items():
|
||||
|
||||
@@ -367,6 +367,7 @@ class FileWriterManager(BECService):
|
||||
MessageEndpoints.public_file(scan_id, "master"),
|
||||
messages.FileMessage(file_path=file_path, done=True, successful=successful),
|
||||
)
|
||||
|
||||
history_msg = messages.ScanHistoryMessage(
|
||||
scan_id=scan_id,
|
||||
scan_number=storage.scan_number,
|
||||
@@ -378,6 +379,7 @@ class FileWriterManager(BECService):
|
||||
num_points=storage.num_points,
|
||||
scan_name=storage.metadata.get("scan_name"),
|
||||
request_inputs=storage.metadata.get("request_inputs", {}),
|
||||
stored_data_info=self.file_writer.stored_data_info or {},
|
||||
)
|
||||
self.connector.xadd(
|
||||
topic=MessageEndpoints.scan_history(), msg_dict={"data": history_msg}, max_size=10000
|
||||
|
||||
@@ -188,6 +188,10 @@ def test_write_data_storage(segments, baseline, metadata, hdf5_file_writer):
|
||||
|
||||
file_writer.write("./test.h5", storage, configuration_data={})
|
||||
|
||||
data_info = file_writer.stored_data_info.get("samx")
|
||||
assert data_info.get("samx").get("shape") == (2,)
|
||||
assert data_info.get("samx_setpoint").get("shape") == (2,)
|
||||
assert data_info.get("samx").get("dtype") == "float64"
|
||||
# open file and check that time stamps are correct
|
||||
with h5py.File("./test.h5", "r") as test_file:
|
||||
assert (
|
||||
|
||||
Reference in New Issue
Block a user