feat(sim_waveform): added option to emit data with add_slice

This commit is contained in:
wakonig_k 2025-03-06 18:15:17 +01:00 committed by appel_c
parent 7797e4003b
commit 21746e5445
3 changed files with 177 additions and 17 deletions

View File

@ -4,7 +4,7 @@ import time
import numpy as np
from bec_lib import bec_logger
from ophyd import Kind, Signal
from ophyd import DeviceStatus, Kind, Signal
from ophyd.utils import ReadOnlyError
from ophyd_devices.utils.bec_device_base import BECDeviceBase
@ -87,10 +87,18 @@ class SetableSignal(Signal):
Core function for signal.
"""
self.check_value(value)
self._update_sim_state(value)
self._value = value
self._run_subs(sub_type=self.SUB_VALUE, value=value)
def set(self, value):
"""Set method"""
self.put(value)
status = DeviceStatus(self)
status.set_finished()
return status
def describe(self):
"""Describe the readback signal.

View File

@ -4,6 +4,7 @@ import os
import threading
import time
import traceback
from typing import Any
import numpy as np
from bec_lib import messages
@ -11,6 +12,7 @@ from bec_lib.endpoints import MessageEndpoints
from bec_lib.logger import bec_logger
from ophyd import Component as Cpt
from ophyd import Device, DeviceStatus, Kind, Staged
from typeguard import typechecked
from ophyd_devices.sim.sim_data import SimulatedDataWaveform
from ophyd_devices.sim.sim_signals import ReadOnlySignal, SetableSignal
@ -20,6 +22,31 @@ from ophyd_devices.utils.errors import DeviceStopError
logger = bec_logger.logger
class AsyncUpdateSignal(SetableSignal):
"""Async updated signal, with check for async_update type."""
def check_value(self, value, **kwargs) -> None:
"""Check the value of the async_update signal."""
if value not in ["add_slice", "add"]:
raise ValueError(f"Invalid async_update type: {value} for signal {self.name}")
# FIXME: BEC issue #443 remove this method once tests in BEC are updated.
def put(self, value: Any) -> None:
"""Put the value of the async_update signal."""
if value in ["append", "extend"]:
if value == "append":
logger.warning(
f"Deprecated async_update of type {value} for signal {self.name}, falling back to 'add_slice'"
)
value = "add_slice"
elif value == "extend":
logger.warning(
f"Deprecated async_update of type {value} for signal {self.name}, falling back to 'add'"
)
value = "add"
super().put(value)
class SimWaveform(Device):
"""A simulated device mimic any 1D Waveform detector.
@ -39,7 +66,7 @@ class SimWaveform(Device):
"""
USER_ACCESS = ["sim", "registered_proxies"]
USER_ACCESS = ["sim", "registered_proxies", "delay_slice_update"]
sim_cls = SimulatedDataWaveform
SHAPE = (1000,)
@ -60,10 +87,11 @@ class SimWaveform(Device):
name="waveform",
value=np.empty(SHAPE, dtype=BIT_DEPTH),
compute_readback=True,
kind=Kind.normal,
kind=Kind.hinted,
)
# Can be extend or append
async_update = Cpt(SetableSignal, value="append", kind=Kind.config)
async_update = Cpt(AsyncUpdateSignal, value="add", kind=Kind.config)
slice_size = Cpt(SetableSignal, value=100, dtype=np.int32, kind=Kind.config)
def __init__(
self,
@ -92,8 +120,20 @@ class SimWaveform(Device):
self._staged = Staged.no
self._trigger_thread = None
self.scan_info = scan_info
self._delay_slice_update = False
if self.sim_init:
self.sim.set_init(self.sim_init)
self._slice_index = 0
@property
def delay_slice_update(self) -> bool:
"""Delay updates in-between slices specified by waveform_shape and slice_size."""
return self._delay_slice_update
@typechecked
@delay_slice_update.setter
def delay_slice_update(self, value: bool) -> None:
self._delay_slice_update = value
@property
def registered_proxies(self) -> None:
@ -113,8 +153,33 @@ class SimWaveform(Device):
def acquire(status: DeviceStatus):
try:
for _ in range(self.burst.get()):
self._run_subs(sub_type=self.SUB_MONITOR, value=self.waveform.get())
self._send_async_update()
# values of the Waveform
values = self.waveform.get()
# add_slice option
if self.async_update.get() == "add_slice":
size = self.slice_size.get()
mod = len(values) % size
num_slices = len(values) // size + int(mod > 0)
for i in range(num_slices):
value_slice = values[i * size : min((i + 1) * size, len(values))]
logger.info(
f"Sending slice {i} of {self._slice_index} with length {len(value_slice)}"
)
self._run_subs(sub_type=self.SUB_MONITOR, value=value_slice)
self._send_async_update(index=self._slice_index, value=value_slice)
if self.delay_slice_update is True:
time.sleep(0.025) # 25ms to be really fast
if self.stopped:
raise DeviceStopError(f"{self.name} was stopped")
self._slice_index += 1
# option add
elif self.async_update.get() == "add":
self._run_subs(sub_type=self.SUB_MONITOR, value=values)
self._send_async_update(value=values)
else:
# This should never happen, but just in case
# we raise an exception
raise ValueError(f"Invalid async_update type: {self.async_update.get()}")
if self.stopped:
raise DeviceStopError(f"{self.name} was stopped")
status.set_finished()
@ -128,23 +193,41 @@ class SimWaveform(Device):
self._trigger_thread.start()
return status
def _send_async_update(self):
"""Send the async update to BEC."""
async_update_type = self.async_update.get()
if async_update_type not in ["extend", "append"]:
raise ValueError(f"Invalid async_update type: {async_update_type}")
def _send_async_update(self, value: Any, index: int | None = None) -> None:
"""
Send the async update to BEC.
Args:
index (int | None): The index of the slice to be sent. If None, the entire waveform is sent.
value (Any): The value to be sent.
"""
async_update_type = self.async_update.get()
waveform_shape = self.waveform_shape.get()
if async_update_type == "append":
metadata = {"async_update": {"type": "add", "max_shape": [None, waveform_shape]}}
else:
if async_update_type == "add_slice":
if index is not None:
metadata = {
"async_update": {
"type": "add_slice",
"index": index,
"max_shape": [None, waveform_shape],
}
}
else:
metadata = {"async_update": {"type": "add", "max_shape": [None, waveform_shape]}}
elif async_update_type == "add":
metadata = {"async_update": {"type": "add", "max_shape": [None]}}
else:
# Again, this should never happen -> check_value,
# but just in case we raise an exception
raise ValueError(
f"Invalid async_update type: {async_update_type} for device {self.name}"
)
msg = messages.DeviceMessage(
signals={self.waveform.name: {"value": self.waveform.get(), "timestamp": time.time()}},
signals={self.waveform.name: {"value": value, "timestamp": time.time()}},
metadata=metadata,
)
# logger.warning(f"Adding async update to {self.name} and {self.scan_info.msg.scan_id}")
# Send the message to BEC
self.connector.xadd(
MessageEndpoints.device_async_readback(
scan_id=self.scan_info.msg.scan_id, device=self.name
@ -177,6 +260,7 @@ class SimWaveform(Device):
self.exp_time.set(self.scan_info.msg.scan_parameters["exp_time"])
self.burst.set(self.scan_info.msg.scan_parameters["frames_per_trigger"])
self.stopped = False
self._slice_index = 0
logger.warning(f"Staged {self.name}, scan_id : {self.scan_info.msg.scan_id}")
return super().stage()
@ -186,6 +270,7 @@ class SimWaveform(Device):
Send reads from all config signals to redis
"""
logger.warning(f"Unstaging {self.name}, {self._staged}")
self._slice_index = 0
if self.stopped is True or not self._staged:
return super().unstage()
return super().unstage()

View File

@ -1,4 +1,4 @@
""" This module contains tests for the simulation devices in ophyd_devices """
"""This module contains tests for the simulation devices in ophyd_devices"""
# pylint: disable: all
import os
@ -765,3 +765,70 @@ def test_waveform(waveform):
assert status.done is True
assert mock_connector.xadd.call_count == 1
assert mock_run_subs.call_count == 1
@pytest.mark.parametrize(
"mode, mock_data, expected_calls",
[
(
"add",
np.zeros(5),
[{"sub_type": "device_monitor_1d", "value": np.zeros(5)}, {"value": np.zeros(5)}],
)
],
)
def test_waveform_update_modes(waveform, mode, mock_data, expected_calls):
"""Test the add and add_slice update modes of the SimWaveform class"""
waveform.sim.select_model("GaussianModel")
waveform.sim.params = {"amplitude": 500, "center": 500, "sigma": 10}
with pytest.raises(ValueError):
waveform.async_update.put("invalid_mode")
# Use add mode
waveform.async_update.put(mode)
with (
mock.patch.object(waveform, "_run_subs") as mock_run_subs,
mock.patch.object(waveform, "_send_async_update") as mock_send_async_update,
mock.patch.object(waveform.waveform, "get", return_value=mock_data),
):
status = waveform.trigger()
status_wait(status, timeout=10) # Raise if times out
assert status.done is True
# Run subs
assert mock_run_subs.call_args[1]["sub_type"] == expected_calls[0]["sub_type"]
assert np.array_equal(mock_run_subs.call_args[1]["value"], expected_calls[0]["value"])
# Send async update
assert np.array_equal(
mock_send_async_update.call_args[1]["value"], expected_calls[1]["value"]
)
@pytest.mark.parametrize(
"mode, index, expected_md",
[
(
"add_slice",
0,
{"async_update": {"type": "add_slice", "index": 0, "max_shape": [None, 100]}},
),
("add_slice", None, {"async_update": {"type": "add", "max_shape": [None, 100]}}),
("add", 0, {"async_update": {"type": "add", "max_shape": [None]}}),
],
)
def test_waveform_send_async_update(waveform, mode, index, expected_md):
"""Test the send_async_update method of SimWaveform."""
max_shape = expected_md["async_update"]["max_shape"]
if len(max_shape) > 1:
wv_shape = max_shape[1]
else:
wv_shape = 100
waveform.waveform_shape.put(wv_shape)
waveform.async_update.put(mode)
waveform.scan_info = get_mock_scan_info(device=waveform)
value = 0
with mock.patch.object(waveform.connector, "xadd") as mock_xadd:
waveform._send_async_update(index=index, value=value)
# Check here that metadata is properly set
args, kwargs = mock_xadd.call_args
msg = args[1]["data"]
assert msg.metadata == expected_md