mirror of
https://github.com/bec-project/ophyd_devices.git
synced 2025-06-17 16:57:12 +02:00
fix: Fixed SimWaveform, works as async device and device_monitor_1d simultaneously
This commit is contained in:
@ -351,7 +351,11 @@ class SimulatedDataMonitor(SimulatedDataBase):
|
||||
|
||||
def _init_default(self) -> None:
|
||||
"""Initialize the default parameters for the simulated data."""
|
||||
models = self.get_all_sim_models()
|
||||
if "ConstantModel" in models:
|
||||
self.select_model("ConstantModel")
|
||||
else:
|
||||
self.select_model(models[0])
|
||||
|
||||
def get_model_cls(self, model: str) -> any:
|
||||
"""Get the class for the active simulation model."""
|
||||
@ -512,6 +516,8 @@ class SimulatedDataWaveform(SimulatedDataMonitor):
|
||||
size = size[0] if isinstance(size, tuple) else size
|
||||
method = self._model
|
||||
value = method.eval(params=self._model_params, x=np.array(range(size)))
|
||||
# Upscale the normalised gaussian if possible
|
||||
if "amplitude" in method.param_names:
|
||||
value *= self.params["amplitude"] / np.max(value)
|
||||
return self._add_noise(value, self.params["noise"], self.params["noise_multiplier"])
|
||||
|
||||
|
@ -2,15 +2,19 @@
|
||||
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
|
||||
import numpy as np
|
||||
from bec_lib import messages
|
||||
from bec_lib.endpoints import MessageEndpoints
|
||||
from bec_lib.logger import bec_logger
|
||||
from ophyd import Component as Cpt
|
||||
from ophyd import Device, DeviceStatus, Kind
|
||||
from ophyd import Device, DeviceStatus, Kind, Staged
|
||||
|
||||
from ophyd_devices.sim.sim_data import SimulatedDataWaveform
|
||||
from ophyd_devices.sim.sim_signals import ReadOnlySignal, SetableSignal
|
||||
from ophyd_devices.utils import bec_utils
|
||||
from ophyd_devices.utils.bec_scaninfo_mixin import BecScaninfoMixin
|
||||
from ophyd_devices.utils.errors import DeviceStopError
|
||||
|
||||
@ -42,7 +46,7 @@ class SimWaveform(Device):
|
||||
SHAPE = (1000,)
|
||||
BIT_DEPTH = np.uint16
|
||||
|
||||
SUB_MONITOR = "monitor"
|
||||
SUB_MONITOR = "device_monitor_1d"
|
||||
_default_sub = SUB_MONITOR
|
||||
|
||||
exp_time = Cpt(SetableSignal, name="exp_time", value=1, kind=Kind.config)
|
||||
@ -57,20 +61,28 @@ class SimWaveform(Device):
|
||||
name="waveform",
|
||||
value=np.empty(SHAPE, dtype=BIT_DEPTH),
|
||||
compute_readback=True,
|
||||
kind=Kind.omitted,
|
||||
kind=Kind.normal,
|
||||
)
|
||||
# Can be extend or append
|
||||
async_update = Cpt(SetableSignal, value="append", kind=Kind.config)
|
||||
|
||||
def __init__(
|
||||
self, name, *, kind=None, parent=None, sim_init: dict = None, device_manager=None, **kwargs
|
||||
):
|
||||
self.device_manager = device_manager
|
||||
self.sim_init = sim_init
|
||||
self._registered_proxies = {}
|
||||
self.sim = self.sim_cls(parent=self, **kwargs)
|
||||
|
||||
super().__init__(name=name, parent=parent, kind=kind, **kwargs)
|
||||
if device_manager:
|
||||
self.device_manager = device_manager
|
||||
else:
|
||||
self.device_manager = bec_utils.DMMock()
|
||||
|
||||
self.connector = self.device_manager.connector
|
||||
self._stream_ttl = 1800 # 30 min max
|
||||
self.stopped = False
|
||||
self._staged = False
|
||||
self._staged = Staged.no
|
||||
self.scaninfo = None
|
||||
self._trigger_thread = None
|
||||
self._update_scaninfo()
|
||||
@ -96,6 +108,7 @@ class SimWaveform(Device):
|
||||
try:
|
||||
for _ in range(self.burst.get()):
|
||||
self._run_subs(sub_type=self.SUB_MONITOR, value=self.waveform.get())
|
||||
self._send_async_update()
|
||||
if self.stopped:
|
||||
raise DeviceStopError(f"{self.name} was stopped")
|
||||
status.set_finished()
|
||||
@ -109,6 +122,25 @@ class SimWaveform(Device):
|
||||
self._trigger_thread.start()
|
||||
return status
|
||||
|
||||
def _send_async_update(self):
|
||||
"""Send the async update to BEC."""
|
||||
metadata = self.scaninfo.scan_msg.metadata
|
||||
async_update_type = self.async_update.get()
|
||||
if async_update_type not in ["extend", "append"]:
|
||||
raise ValueError(f"Invalid async_update type: {async_update_type}")
|
||||
metadata.update({"async_update": async_update_type})
|
||||
|
||||
msg = messages.DeviceMessage(
|
||||
signals={self.waveform.name: {"value": self.waveform.get(), "timestamp": time.time()}},
|
||||
metadata=metadata,
|
||||
)
|
||||
# logger.warning(f"Adding async update to {self.name} and {self.scaninfo.scan_id}")
|
||||
self.connector.xadd(
|
||||
MessageEndpoints.device_async_readback(scan_id=self.scaninfo.scan_id, device=self.name),
|
||||
{"data": msg},
|
||||
expire=self._stream_ttl,
|
||||
)
|
||||
|
||||
def _update_scaninfo(self) -> None:
|
||||
"""Update scaninfo from BecScaninfoMixing
|
||||
This depends on device manager and operation/sim_mode
|
||||
@ -125,7 +157,8 @@ class SimWaveform(Device):
|
||||
FYI: No data is written to disk in the simulation, but upon each trigger it
|
||||
is published to the device_monitor endpoint in REDIS.
|
||||
"""
|
||||
if self._staged:
|
||||
if self._staged is Staged.yes:
|
||||
|
||||
return super().stage()
|
||||
self.scaninfo.load_scan_metadata()
|
||||
self.file_path.set(
|
||||
@ -137,6 +170,7 @@ class SimWaveform(Device):
|
||||
self.exp_time.set(self.scaninfo.exp_time)
|
||||
self.burst.set(self.scaninfo.frames_per_trigger)
|
||||
self.stopped = False
|
||||
logger.warning(f"Staged {self.name}, scan_id : {self.scaninfo.scan_id}")
|
||||
return super().stage()
|
||||
|
||||
def unstage(self) -> list[object]:
|
||||
@ -144,9 +178,9 @@ class SimWaveform(Device):
|
||||
|
||||
Send reads from all config signals to redis
|
||||
"""
|
||||
logger.warning(f"Unstaging {self.name}, {self._staged}")
|
||||
if self.stopped is True or not self._staged:
|
||||
return super().unstage()
|
||||
|
||||
return super().unstage()
|
||||
|
||||
def stop(self, *, success=False):
|
||||
@ -156,3 +190,8 @@ class SimWaveform(Device):
|
||||
self._trigger_thread.join()
|
||||
self._trigger_thread = None
|
||||
super().stop(success=success)
|
||||
|
||||
|
||||
if __name__ == "__main__": # pragma: no cover
|
||||
waveform = SimWaveform(name="waveform")
|
||||
waveform.sim.sim_select_model("GaussianModel")
|
||||
|
@ -29,9 +29,18 @@ from ophyd_devices.sim.sim_monitor import SimMonitor, SimMonitorAsync
|
||||
from ophyd_devices.sim.sim_positioner import SimLinearTrajectoryPositioner, SimPositioner
|
||||
from ophyd_devices.sim.sim_signals import ReadOnlySignal
|
||||
from ophyd_devices.sim.sim_utils import H5Writer, LinearTrajectory
|
||||
from ophyd_devices.sim.sim_waveform import SimWaveform
|
||||
from ophyd_devices.utils.bec_device_base import BECDevice, BECDeviceBase
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def waveform(name="waveform"):
|
||||
"""Fixture for SimWaveform."""
|
||||
dm = DMMock()
|
||||
wave = SimWaveform(name=name, device_manager=dm)
|
||||
yield wave
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def signal(name="signal"):
|
||||
"""Fixture for Signal."""
|
||||
@ -595,3 +604,33 @@ def test_positioner_updated_timestamp(positioner):
|
||||
readback = positioner.read()[positioner.name]
|
||||
assert readback["value"] == 5
|
||||
assert readback["timestamp"] > timestamp
|
||||
|
||||
|
||||
def test_waveform(waveform):
|
||||
"""Test the SimWaveform class"""
|
||||
waveform.sim.sim_select_model("GaussianModel")
|
||||
waveform.sim.params = {"amplitude": 500, "center": 500, "sigma": 10}
|
||||
data = waveform.waveform.get()
|
||||
assert isinstance(data, np.ndarray)
|
||||
assert data.shape == waveform.SHAPE
|
||||
assert np.isclose(np.argmax(data), 500, atol=5)
|
||||
waveform.waveform_shape.put(50)
|
||||
data = waveform.waveform.get()
|
||||
for model in waveform.sim.get_all_sim_models():
|
||||
waveform.sim.sim_select_model(model)
|
||||
waveform.waveform.get()
|
||||
# Now also test the async readback
|
||||
mock_connector = waveform.connector = mock.MagicMock()
|
||||
mock_run_subs = waveform._run_subs = mock.MagicMock()
|
||||
waveform.scaninfo.scan_msg = SimpleNamespace(metadata={})
|
||||
waveform.scaninfo.scan_id = "test"
|
||||
status = waveform.trigger()
|
||||
timer = 0
|
||||
while not status.done:
|
||||
time.sleep(0.1)
|
||||
timer += 0.1
|
||||
if timer > 5:
|
||||
raise TimeoutError("Trigger did not complete")
|
||||
assert status.done is True
|
||||
assert mock_connector.xadd.call_count == 1
|
||||
assert mock_run_subs.call_count == 1
|
||||
|
Reference in New Issue
Block a user