fix: Fixed SimWaveform, works as async device and device_monitor_1d simultaneously

This commit is contained in:
2024-10-01 09:13:53 +02:00
parent 0f06adcd7c
commit 7ff37c0dcd
3 changed files with 93 additions and 9 deletions

View File

@ -351,7 +351,11 @@ class SimulatedDataMonitor(SimulatedDataBase):
def _init_default(self) -> None: def _init_default(self) -> None:
"""Initialize the default parameters for the simulated data.""" """Initialize the default parameters for the simulated data."""
self.select_model("ConstantModel") models = self.get_all_sim_models()
if "ConstantModel" in models:
self.select_model("ConstantModel")
else:
self.select_model(models[0])
def get_model_cls(self, model: str) -> any: def get_model_cls(self, model: str) -> any:
"""Get the class for the active simulation model.""" """Get the class for the active simulation model."""
@ -512,7 +516,9 @@ class SimulatedDataWaveform(SimulatedDataMonitor):
size = size[0] if isinstance(size, tuple) else size size = size[0] if isinstance(size, tuple) else size
method = self._model method = self._model
value = method.eval(params=self._model_params, x=np.array(range(size))) value = method.eval(params=self._model_params, x=np.array(range(size)))
value *= self.params["amplitude"] / np.max(value) # Upscale the normalised gaussian if possible
if "amplitude" in method.param_names:
value *= self.params["amplitude"] / np.max(value)
return self._add_noise(value, self.params["noise"], self.params["noise_multiplier"]) return self._add_noise(value, self.params["noise"], self.params["noise_multiplier"])
def _add_noise(self, v: np.ndarray, noise: NoiseType, noise_multiplier: float) -> np.ndarray: def _add_noise(self, v: np.ndarray, noise: NoiseType, noise_multiplier: float) -> np.ndarray:

View File

@ -2,15 +2,19 @@
import os import os
import threading import threading
import time
import traceback import traceback
import numpy as np import numpy as np
from bec_lib import messages
from bec_lib.endpoints import MessageEndpoints
from bec_lib.logger import bec_logger from bec_lib.logger import bec_logger
from ophyd import Component as Cpt from ophyd import Component as Cpt
from ophyd import Device, DeviceStatus, Kind from ophyd import Device, DeviceStatus, Kind, Staged
from ophyd_devices.sim.sim_data import SimulatedDataWaveform from ophyd_devices.sim.sim_data import SimulatedDataWaveform
from ophyd_devices.sim.sim_signals import ReadOnlySignal, SetableSignal from ophyd_devices.sim.sim_signals import ReadOnlySignal, SetableSignal
from ophyd_devices.utils import bec_utils
from ophyd_devices.utils.bec_scaninfo_mixin import BecScaninfoMixin from ophyd_devices.utils.bec_scaninfo_mixin import BecScaninfoMixin
from ophyd_devices.utils.errors import DeviceStopError from ophyd_devices.utils.errors import DeviceStopError
@ -42,7 +46,7 @@ class SimWaveform(Device):
SHAPE = (1000,) SHAPE = (1000,)
BIT_DEPTH = np.uint16 BIT_DEPTH = np.uint16
SUB_MONITOR = "monitor" SUB_MONITOR = "device_monitor_1d"
_default_sub = SUB_MONITOR _default_sub = SUB_MONITOR
exp_time = Cpt(SetableSignal, name="exp_time", value=1, kind=Kind.config) exp_time = Cpt(SetableSignal, name="exp_time", value=1, kind=Kind.config)
@ -57,20 +61,28 @@ class SimWaveform(Device):
name="waveform", name="waveform",
value=np.empty(SHAPE, dtype=BIT_DEPTH), value=np.empty(SHAPE, dtype=BIT_DEPTH),
compute_readback=True, compute_readback=True,
kind=Kind.omitted, kind=Kind.normal,
) )
# Can be extend or append
async_update = Cpt(SetableSignal, value="append", kind=Kind.config)
def __init__( def __init__(
self, name, *, kind=None, parent=None, sim_init: dict = None, device_manager=None, **kwargs self, name, *, kind=None, parent=None, sim_init: dict = None, device_manager=None, **kwargs
): ):
self.device_manager = device_manager
self.sim_init = sim_init self.sim_init = sim_init
self._registered_proxies = {} self._registered_proxies = {}
self.sim = self.sim_cls(parent=self, **kwargs) self.sim = self.sim_cls(parent=self, **kwargs)
super().__init__(name=name, parent=parent, kind=kind, **kwargs) super().__init__(name=name, parent=parent, kind=kind, **kwargs)
if device_manager:
self.device_manager = device_manager
else:
self.device_manager = bec_utils.DMMock()
self.connector = self.device_manager.connector
self._stream_ttl = 1800 # 30 min max
self.stopped = False self.stopped = False
self._staged = False self._staged = Staged.no
self.scaninfo = None self.scaninfo = None
self._trigger_thread = None self._trigger_thread = None
self._update_scaninfo() self._update_scaninfo()
@ -96,6 +108,7 @@ class SimWaveform(Device):
try: try:
for _ in range(self.burst.get()): for _ in range(self.burst.get()):
self._run_subs(sub_type=self.SUB_MONITOR, value=self.waveform.get()) self._run_subs(sub_type=self.SUB_MONITOR, value=self.waveform.get())
self._send_async_update()
if self.stopped: if self.stopped:
raise DeviceStopError(f"{self.name} was stopped") raise DeviceStopError(f"{self.name} was stopped")
status.set_finished() status.set_finished()
@ -109,6 +122,25 @@ class SimWaveform(Device):
self._trigger_thread.start() self._trigger_thread.start()
return status return status
def _send_async_update(self):
"""Send the async update to BEC."""
metadata = self.scaninfo.scan_msg.metadata
async_update_type = self.async_update.get()
if async_update_type not in ["extend", "append"]:
raise ValueError(f"Invalid async_update type: {async_update_type}")
metadata.update({"async_update": async_update_type})
msg = messages.DeviceMessage(
signals={self.waveform.name: {"value": self.waveform.get(), "timestamp": time.time()}},
metadata=metadata,
)
# logger.warning(f"Adding async update to {self.name} and {self.scaninfo.scan_id}")
self.connector.xadd(
MessageEndpoints.device_async_readback(scan_id=self.scaninfo.scan_id, device=self.name),
{"data": msg},
expire=self._stream_ttl,
)
def _update_scaninfo(self) -> None: def _update_scaninfo(self) -> None:
"""Update scaninfo from BecScaninfoMixing """Update scaninfo from BecScaninfoMixing
This depends on device manager and operation/sim_mode This depends on device manager and operation/sim_mode
@ -125,7 +157,8 @@ class SimWaveform(Device):
FYI: No data is written to disk in the simulation, but upon each trigger it FYI: No data is written to disk in the simulation, but upon each trigger it
is published to the device_monitor endpoint in REDIS. is published to the device_monitor endpoint in REDIS.
""" """
if self._staged: if self._staged is Staged.yes:
return super().stage() return super().stage()
self.scaninfo.load_scan_metadata() self.scaninfo.load_scan_metadata()
self.file_path.set( self.file_path.set(
@ -137,6 +170,7 @@ class SimWaveform(Device):
self.exp_time.set(self.scaninfo.exp_time) self.exp_time.set(self.scaninfo.exp_time)
self.burst.set(self.scaninfo.frames_per_trigger) self.burst.set(self.scaninfo.frames_per_trigger)
self.stopped = False self.stopped = False
logger.warning(f"Staged {self.name}, scan_id : {self.scaninfo.scan_id}")
return super().stage() return super().stage()
def unstage(self) -> list[object]: def unstage(self) -> list[object]:
@ -144,9 +178,9 @@ class SimWaveform(Device):
Send reads from all config signals to redis Send reads from all config signals to redis
""" """
logger.warning(f"Unstaging {self.name}, {self._staged}")
if self.stopped is True or not self._staged: if self.stopped is True or not self._staged:
return super().unstage() return super().unstage()
return super().unstage() return super().unstage()
def stop(self, *, success=False): def stop(self, *, success=False):
@ -156,3 +190,8 @@ class SimWaveform(Device):
self._trigger_thread.join() self._trigger_thread.join()
self._trigger_thread = None self._trigger_thread = None
super().stop(success=success) super().stop(success=success)
if __name__ == "__main__": # pragma: no cover
waveform = SimWaveform(name="waveform")
waveform.sim.sim_select_model("GaussianModel")

View File

@ -29,9 +29,18 @@ from ophyd_devices.sim.sim_monitor import SimMonitor, SimMonitorAsync
from ophyd_devices.sim.sim_positioner import SimLinearTrajectoryPositioner, SimPositioner from ophyd_devices.sim.sim_positioner import SimLinearTrajectoryPositioner, SimPositioner
from ophyd_devices.sim.sim_signals import ReadOnlySignal from ophyd_devices.sim.sim_signals import ReadOnlySignal
from ophyd_devices.sim.sim_utils import H5Writer, LinearTrajectory from ophyd_devices.sim.sim_utils import H5Writer, LinearTrajectory
from ophyd_devices.sim.sim_waveform import SimWaveform
from ophyd_devices.utils.bec_device_base import BECDevice, BECDeviceBase from ophyd_devices.utils.bec_device_base import BECDevice, BECDeviceBase
@pytest.fixture(scope="function")
def waveform(name="waveform"):
"""Fixture for SimWaveform."""
dm = DMMock()
wave = SimWaveform(name=name, device_manager=dm)
yield wave
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def signal(name="signal"): def signal(name="signal"):
"""Fixture for Signal.""" """Fixture for Signal."""
@ -595,3 +604,33 @@ def test_positioner_updated_timestamp(positioner):
readback = positioner.read()[positioner.name] readback = positioner.read()[positioner.name]
assert readback["value"] == 5 assert readback["value"] == 5
assert readback["timestamp"] > timestamp assert readback["timestamp"] > timestamp
def test_waveform(waveform):
"""Test the SimWaveform class"""
waveform.sim.sim_select_model("GaussianModel")
waveform.sim.params = {"amplitude": 500, "center": 500, "sigma": 10}
data = waveform.waveform.get()
assert isinstance(data, np.ndarray)
assert data.shape == waveform.SHAPE
assert np.isclose(np.argmax(data), 500, atol=5)
waveform.waveform_shape.put(50)
data = waveform.waveform.get()
for model in waveform.sim.get_all_sim_models():
waveform.sim.sim_select_model(model)
waveform.waveform.get()
# Now also test the async readback
mock_connector = waveform.connector = mock.MagicMock()
mock_run_subs = waveform._run_subs = mock.MagicMock()
waveform.scaninfo.scan_msg = SimpleNamespace(metadata={})
waveform.scaninfo.scan_id = "test"
status = waveform.trigger()
timer = 0
while not status.done:
time.sleep(0.1)
timer += 0.1
if timer > 5:
raise TimeoutError("Trigger did not complete")
assert status.done is True
assert mock_connector.xadd.call_count == 1
assert mock_run_subs.call_count == 1