diff --git a/ophyd_devices/sim/sim_data.py b/ophyd_devices/sim/sim_data.py index 1f3cd91..39478e2 100644 --- a/ophyd_devices/sim/sim_data.py +++ b/ophyd_devices/sim/sim_data.py @@ -351,7 +351,11 @@ class SimulatedDataMonitor(SimulatedDataBase): def _init_default(self) -> None: """Initialize the default parameters for the simulated data.""" - self.select_model("ConstantModel") + models = self.get_all_sim_models() + if "ConstantModel" in models: + self.select_model("ConstantModel") + else: + self.select_model(models[0]) def get_model_cls(self, model: str) -> any: """Get the class for the active simulation model.""" @@ -512,7 +516,9 @@ class SimulatedDataWaveform(SimulatedDataMonitor): size = size[0] if isinstance(size, tuple) else size method = self._model value = method.eval(params=self._model_params, x=np.array(range(size))) - value *= self.params["amplitude"] / np.max(value) + # Upscale the normalised gaussian if possible + if "amplitude" in method.param_names: + value *= self.params["amplitude"] / np.max(value) return self._add_noise(value, self.params["noise"], self.params["noise_multiplier"]) def _add_noise(self, v: np.ndarray, noise: NoiseType, noise_multiplier: float) -> np.ndarray: diff --git a/ophyd_devices/sim/sim_waveform.py b/ophyd_devices/sim/sim_waveform.py index 2b06eb4..5e70e62 100644 --- a/ophyd_devices/sim/sim_waveform.py +++ b/ophyd_devices/sim/sim_waveform.py @@ -2,15 +2,19 @@ import os import threading +import time import traceback import numpy as np +from bec_lib import messages +from bec_lib.endpoints import MessageEndpoints from bec_lib.logger import bec_logger from ophyd import Component as Cpt -from ophyd import Device, DeviceStatus, Kind +from ophyd import Device, DeviceStatus, Kind, Staged from ophyd_devices.sim.sim_data import SimulatedDataWaveform from ophyd_devices.sim.sim_signals import ReadOnlySignal, SetableSignal +from ophyd_devices.utils import bec_utils from ophyd_devices.utils.bec_scaninfo_mixin import BecScaninfoMixin from ophyd_devices.utils.errors import DeviceStopError @@ -42,7 +46,7 @@ class SimWaveform(Device): SHAPE = (1000,) BIT_DEPTH = np.uint16 - SUB_MONITOR = "monitor" + SUB_MONITOR = "device_monitor_1d" _default_sub = SUB_MONITOR exp_time = Cpt(SetableSignal, name="exp_time", value=1, kind=Kind.config) @@ -57,20 +61,28 @@ class SimWaveform(Device): name="waveform", value=np.empty(SHAPE, dtype=BIT_DEPTH), compute_readback=True, - kind=Kind.omitted, + kind=Kind.normal, ) + # Can be extend or append + async_update = Cpt(SetableSignal, value="append", kind=Kind.config) def __init__( self, name, *, kind=None, parent=None, sim_init: dict = None, device_manager=None, **kwargs ): - self.device_manager = device_manager self.sim_init = sim_init self._registered_proxies = {} self.sim = self.sim_cls(parent=self, **kwargs) super().__init__(name=name, parent=parent, kind=kind, **kwargs) + if device_manager: + self.device_manager = device_manager + else: + self.device_manager = bec_utils.DMMock() + + self.connector = self.device_manager.connector + self._stream_ttl = 1800 # 30 min max self.stopped = False - self._staged = False + self._staged = Staged.no self.scaninfo = None self._trigger_thread = None self._update_scaninfo() @@ -96,6 +108,7 @@ class SimWaveform(Device): try: for _ in range(self.burst.get()): self._run_subs(sub_type=self.SUB_MONITOR, value=self.waveform.get()) + self._send_async_update() if self.stopped: raise DeviceStopError(f"{self.name} was stopped") status.set_finished() @@ -109,6 +122,25 @@ class SimWaveform(Device): self._trigger_thread.start() return status + def _send_async_update(self): + """Send the async update to BEC.""" + metadata = self.scaninfo.scan_msg.metadata + async_update_type = self.async_update.get() + if async_update_type not in ["extend", "append"]: + raise ValueError(f"Invalid async_update type: {async_update_type}") + metadata.update({"async_update": async_update_type}) + + msg = messages.DeviceMessage( + signals={self.waveform.name: {"value": self.waveform.get(), "timestamp": time.time()}}, + metadata=metadata, + ) + # logger.warning(f"Adding async update to {self.name} and {self.scaninfo.scan_id}") + self.connector.xadd( + MessageEndpoints.device_async_readback(scan_id=self.scaninfo.scan_id, device=self.name), + {"data": msg}, + expire=self._stream_ttl, + ) + def _update_scaninfo(self) -> None: """Update scaninfo from BecScaninfoMixing This depends on device manager and operation/sim_mode @@ -125,7 +157,8 @@ class SimWaveform(Device): FYI: No data is written to disk in the simulation, but upon each trigger it is published to the device_monitor endpoint in REDIS. """ - if self._staged: + if self._staged is Staged.yes: + return super().stage() self.scaninfo.load_scan_metadata() self.file_path.set( @@ -137,6 +170,7 @@ class SimWaveform(Device): self.exp_time.set(self.scaninfo.exp_time) self.burst.set(self.scaninfo.frames_per_trigger) self.stopped = False + logger.warning(f"Staged {self.name}, scan_id : {self.scaninfo.scan_id}") return super().stage() def unstage(self) -> list[object]: @@ -144,9 +178,9 @@ class SimWaveform(Device): Send reads from all config signals to redis """ + logger.warning(f"Unstaging {self.name}, {self._staged}") if self.stopped is True or not self._staged: return super().unstage() - return super().unstage() def stop(self, *, success=False): @@ -156,3 +190,8 @@ class SimWaveform(Device): self._trigger_thread.join() self._trigger_thread = None super().stop(success=success) + + +if __name__ == "__main__": # pragma: no cover + waveform = SimWaveform(name="waveform") + waveform.sim.sim_select_model("GaussianModel") diff --git a/tests/test_simulation.py b/tests/test_simulation.py index 343ef15..cb45975 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -29,9 +29,18 @@ from ophyd_devices.sim.sim_monitor import SimMonitor, SimMonitorAsync from ophyd_devices.sim.sim_positioner import SimLinearTrajectoryPositioner, SimPositioner from ophyd_devices.sim.sim_signals import ReadOnlySignal from ophyd_devices.sim.sim_utils import H5Writer, LinearTrajectory +from ophyd_devices.sim.sim_waveform import SimWaveform from ophyd_devices.utils.bec_device_base import BECDevice, BECDeviceBase +@pytest.fixture(scope="function") +def waveform(name="waveform"): + """Fixture for SimWaveform.""" + dm = DMMock() + wave = SimWaveform(name=name, device_manager=dm) + yield wave + + @pytest.fixture(scope="function") def signal(name="signal"): """Fixture for Signal.""" @@ -595,3 +604,33 @@ def test_positioner_updated_timestamp(positioner): readback = positioner.read()[positioner.name] assert readback["value"] == 5 assert readback["timestamp"] > timestamp + + +def test_waveform(waveform): + """Test the SimWaveform class""" + waveform.sim.sim_select_model("GaussianModel") + waveform.sim.params = {"amplitude": 500, "center": 500, "sigma": 10} + data = waveform.waveform.get() + assert isinstance(data, np.ndarray) + assert data.shape == waveform.SHAPE + assert np.isclose(np.argmax(data), 500, atol=5) + waveform.waveform_shape.put(50) + data = waveform.waveform.get() + for model in waveform.sim.get_all_sim_models(): + waveform.sim.sim_select_model(model) + waveform.waveform.get() + # Now also test the async readback + mock_connector = waveform.connector = mock.MagicMock() + mock_run_subs = waveform._run_subs = mock.MagicMock() + waveform.scaninfo.scan_msg = SimpleNamespace(metadata={}) + waveform.scaninfo.scan_id = "test" + status = waveform.trigger() + timer = 0 + while not status.done: + time.sleep(0.1) + timer += 0.1 + if timer > 5: + raise TimeoutError("Trigger did not complete") + assert status.done is True + assert mock_connector.xadd.call_count == 1 + assert mock_run_subs.call_count == 1