diff --git a/ophyd_devices/sim/sim_signals.py b/ophyd_devices/sim/sim_signals.py index fe9653d..9cc443d 100644 --- a/ophyd_devices/sim/sim_signals.py +++ b/ophyd_devices/sim/sim_signals.py @@ -4,7 +4,7 @@ import time import numpy as np from bec_lib import bec_logger -from ophyd import Kind, Signal +from ophyd import DeviceStatus, Kind, Signal from ophyd.utils import ReadOnlyError from ophyd_devices.utils.bec_device_base import BECDeviceBase @@ -87,10 +87,18 @@ class SetableSignal(Signal): Core function for signal. """ + self.check_value(value) self._update_sim_state(value) self._value = value self._run_subs(sub_type=self.SUB_VALUE, value=value) + def set(self, value): + """Set method""" + self.put(value) + status = DeviceStatus(self) + status.set_finished() + return status + def describe(self): """Describe the readback signal. diff --git a/ophyd_devices/sim/sim_waveform.py b/ophyd_devices/sim/sim_waveform.py index ea4af98..3bec4cf 100644 --- a/ophyd_devices/sim/sim_waveform.py +++ b/ophyd_devices/sim/sim_waveform.py @@ -4,6 +4,7 @@ import os import threading import time import traceback +from typing import Any import numpy as np from bec_lib import messages @@ -11,6 +12,7 @@ from bec_lib.endpoints import MessageEndpoints from bec_lib.logger import bec_logger from ophyd import Component as Cpt from ophyd import Device, DeviceStatus, Kind, Staged +from typeguard import typechecked from ophyd_devices.sim.sim_data import SimulatedDataWaveform from ophyd_devices.sim.sim_signals import ReadOnlySignal, SetableSignal @@ -20,6 +22,31 @@ from ophyd_devices.utils.errors import DeviceStopError logger = bec_logger.logger +class AsyncUpdateSignal(SetableSignal): + """Async updated signal, with check for async_update type.""" + + def check_value(self, value, **kwargs) -> None: + """Check the value of the async_update signal.""" + if value not in ["add_slice", "add"]: + raise ValueError(f"Invalid async_update type: {value} for signal {self.name}") + + # FIXME: BEC issue #443 remove this method once tests in BEC are updated. + def put(self, value: Any) -> None: + """Put the value of the async_update signal.""" + if value in ["append", "extend"]: + if value == "append": + logger.warning( + f"Deprecated async_update of type {value} for signal {self.name}, falling back to 'add_slice'" + ) + value = "add_slice" + elif value == "extend": + logger.warning( + f"Deprecated async_update of type {value} for signal {self.name}, falling back to 'add'" + ) + value = "add" + super().put(value) + + class SimWaveform(Device): """A simulated device mimic any 1D Waveform detector. @@ -39,7 +66,7 @@ class SimWaveform(Device): """ - USER_ACCESS = ["sim", "registered_proxies"] + USER_ACCESS = ["sim", "registered_proxies", "delay_slice_update"] sim_cls = SimulatedDataWaveform SHAPE = (1000,) @@ -60,10 +87,11 @@ class SimWaveform(Device): name="waveform", value=np.empty(SHAPE, dtype=BIT_DEPTH), compute_readback=True, - kind=Kind.normal, + kind=Kind.hinted, ) # Can be extend or append - async_update = Cpt(SetableSignal, value="append", kind=Kind.config) + async_update = Cpt(AsyncUpdateSignal, value="add", kind=Kind.config) + slice_size = Cpt(SetableSignal, value=100, dtype=np.int32, kind=Kind.config) def __init__( self, @@ -92,8 +120,20 @@ class SimWaveform(Device): self._staged = Staged.no self._trigger_thread = None self.scan_info = scan_info + self._delay_slice_update = False if self.sim_init: self.sim.set_init(self.sim_init) + self._slice_index = 0 + + @property + def delay_slice_update(self) -> bool: + """Delay updates in-between slices specified by waveform_shape and slice_size.""" + return self._delay_slice_update + + @typechecked + @delay_slice_update.setter + def delay_slice_update(self, value: bool) -> None: + self._delay_slice_update = value @property def registered_proxies(self) -> None: @@ -113,8 +153,33 @@ class SimWaveform(Device): def acquire(status: DeviceStatus): try: for _ in range(self.burst.get()): - self._run_subs(sub_type=self.SUB_MONITOR, value=self.waveform.get()) - self._send_async_update() + # values of the Waveform + values = self.waveform.get() + # add_slice option + if self.async_update.get() == "add_slice": + size = self.slice_size.get() + mod = len(values) % size + num_slices = len(values) // size + int(mod > 0) + for i in range(num_slices): + value_slice = values[i * size : min((i + 1) * size, len(values))] + logger.info( + f"Sending slice {i} of {self._slice_index} with length {len(value_slice)}" + ) + self._run_subs(sub_type=self.SUB_MONITOR, value=value_slice) + self._send_async_update(index=self._slice_index, value=value_slice) + if self.delay_slice_update is True: + time.sleep(0.025) # 25ms to be really fast + if self.stopped: + raise DeviceStopError(f"{self.name} was stopped") + self._slice_index += 1 + # option add + elif self.async_update.get() == "add": + self._run_subs(sub_type=self.SUB_MONITOR, value=values) + self._send_async_update(value=values) + else: + # This should never happen, but just in case + # we raise an exception + raise ValueError(f"Invalid async_update type: {self.async_update.get()}") if self.stopped: raise DeviceStopError(f"{self.name} was stopped") status.set_finished() @@ -128,23 +193,41 @@ class SimWaveform(Device): self._trigger_thread.start() return status - def _send_async_update(self): - """Send the async update to BEC.""" - async_update_type = self.async_update.get() - if async_update_type not in ["extend", "append"]: - raise ValueError(f"Invalid async_update type: {async_update_type}") + def _send_async_update(self, value: Any, index: int | None = None) -> None: + """ + Send the async update to BEC. + Args: + index (int | None): The index of the slice to be sent. If None, the entire waveform is sent. + value (Any): The value to be sent. + """ + async_update_type = self.async_update.get() waveform_shape = self.waveform_shape.get() - if async_update_type == "append": - metadata = {"async_update": {"type": "add", "max_shape": [None, waveform_shape]}} - else: + if async_update_type == "add_slice": + if index is not None: + metadata = { + "async_update": { + "type": "add_slice", + "index": index, + "max_shape": [None, waveform_shape], + } + } + else: + metadata = {"async_update": {"type": "add", "max_shape": [None, waveform_shape]}} + elif async_update_type == "add": metadata = {"async_update": {"type": "add", "max_shape": [None]}} + else: + # Again, this should never happen -> check_value, + # but just in case we raise an exception + raise ValueError( + f"Invalid async_update type: {async_update_type} for device {self.name}" + ) msg = messages.DeviceMessage( - signals={self.waveform.name: {"value": self.waveform.get(), "timestamp": time.time()}}, + signals={self.waveform.name: {"value": value, "timestamp": time.time()}}, metadata=metadata, ) - # logger.warning(f"Adding async update to {self.name} and {self.scan_info.msg.scan_id}") + # Send the message to BEC self.connector.xadd( MessageEndpoints.device_async_readback( scan_id=self.scan_info.msg.scan_id, device=self.name @@ -177,6 +260,7 @@ class SimWaveform(Device): self.exp_time.set(self.scan_info.msg.scan_parameters["exp_time"]) self.burst.set(self.scan_info.msg.scan_parameters["frames_per_trigger"]) self.stopped = False + self._slice_index = 0 logger.warning(f"Staged {self.name}, scan_id : {self.scan_info.msg.scan_id}") return super().stage() @@ -186,6 +270,7 @@ class SimWaveform(Device): Send reads from all config signals to redis """ logger.warning(f"Unstaging {self.name}, {self._staged}") + self._slice_index = 0 if self.stopped is True or not self._staged: return super().unstage() return super().unstage() diff --git a/tests/test_simulation.py b/tests/test_simulation.py index beb6284..7293f8b 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -1,4 +1,4 @@ -""" This module contains tests for the simulation devices in ophyd_devices """ +"""This module contains tests for the simulation devices in ophyd_devices""" # pylint: disable: all import os @@ -765,3 +765,70 @@ def test_waveform(waveform): assert status.done is True assert mock_connector.xadd.call_count == 1 assert mock_run_subs.call_count == 1 + + +@pytest.mark.parametrize( + "mode, mock_data, expected_calls", + [ + ( + "add", + np.zeros(5), + [{"sub_type": "device_monitor_1d", "value": np.zeros(5)}, {"value": np.zeros(5)}], + ) + ], +) +def test_waveform_update_modes(waveform, mode, mock_data, expected_calls): + """Test the add and add_slice update modes of the SimWaveform class""" + waveform.sim.select_model("GaussianModel") + waveform.sim.params = {"amplitude": 500, "center": 500, "sigma": 10} + with pytest.raises(ValueError): + waveform.async_update.put("invalid_mode") + # Use add mode + waveform.async_update.put(mode) + with ( + mock.patch.object(waveform, "_run_subs") as mock_run_subs, + mock.patch.object(waveform, "_send_async_update") as mock_send_async_update, + mock.patch.object(waveform.waveform, "get", return_value=mock_data), + ): + + status = waveform.trigger() + status_wait(status, timeout=10) # Raise if times out + assert status.done is True + # Run subs + assert mock_run_subs.call_args[1]["sub_type"] == expected_calls[0]["sub_type"] + assert np.array_equal(mock_run_subs.call_args[1]["value"], expected_calls[0]["value"]) + # Send async update + assert np.array_equal( + mock_send_async_update.call_args[1]["value"], expected_calls[1]["value"] + ) + + +@pytest.mark.parametrize( + "mode, index, expected_md", + [ + ( + "add_slice", + 0, + {"async_update": {"type": "add_slice", "index": 0, "max_shape": [None, 100]}}, + ), + ("add_slice", None, {"async_update": {"type": "add", "max_shape": [None, 100]}}), + ("add", 0, {"async_update": {"type": "add", "max_shape": [None]}}), + ], +) +def test_waveform_send_async_update(waveform, mode, index, expected_md): + """Test the send_async_update method of SimWaveform.""" + max_shape = expected_md["async_update"]["max_shape"] + if len(max_shape) > 1: + wv_shape = max_shape[1] + else: + wv_shape = 100 + waveform.waveform_shape.put(wv_shape) + waveform.async_update.put(mode) + waveform.scan_info = get_mock_scan_info(device=waveform) + value = 0 + with mock.patch.object(waveform.connector, "xadd") as mock_xadd: + waveform._send_async_update(index=index, value=value) + # Check here that metadata is properly set + args, kwargs = mock_xadd.call_args + msg = args[1]["data"] + assert msg.metadata == expected_md