mirror of
https://github.com/bec-project/ophyd_devices.git
synced 2025-06-17 16:57:12 +02:00
fix: Fixed SimWaveform, works as async device and device_monitor_1d simultaneously
This commit is contained in:
@ -351,7 +351,11 @@ class SimulatedDataMonitor(SimulatedDataBase):
|
|||||||
|
|
||||||
def _init_default(self) -> None:
|
def _init_default(self) -> None:
|
||||||
"""Initialize the default parameters for the simulated data."""
|
"""Initialize the default parameters for the simulated data."""
|
||||||
self.select_model("ConstantModel")
|
models = self.get_all_sim_models()
|
||||||
|
if "ConstantModel" in models:
|
||||||
|
self.select_model("ConstantModel")
|
||||||
|
else:
|
||||||
|
self.select_model(models[0])
|
||||||
|
|
||||||
def get_model_cls(self, model: str) -> any:
|
def get_model_cls(self, model: str) -> any:
|
||||||
"""Get the class for the active simulation model."""
|
"""Get the class for the active simulation model."""
|
||||||
@ -512,7 +516,9 @@ class SimulatedDataWaveform(SimulatedDataMonitor):
|
|||||||
size = size[0] if isinstance(size, tuple) else size
|
size = size[0] if isinstance(size, tuple) else size
|
||||||
method = self._model
|
method = self._model
|
||||||
value = method.eval(params=self._model_params, x=np.array(range(size)))
|
value = method.eval(params=self._model_params, x=np.array(range(size)))
|
||||||
value *= self.params["amplitude"] / np.max(value)
|
# Upscale the normalised gaussian if possible
|
||||||
|
if "amplitude" in method.param_names:
|
||||||
|
value *= self.params["amplitude"] / np.max(value)
|
||||||
return self._add_noise(value, self.params["noise"], self.params["noise_multiplier"])
|
return self._add_noise(value, self.params["noise"], self.params["noise_multiplier"])
|
||||||
|
|
||||||
def _add_noise(self, v: np.ndarray, noise: NoiseType, noise_multiplier: float) -> np.ndarray:
|
def _add_noise(self, v: np.ndarray, noise: NoiseType, noise_multiplier: float) -> np.ndarray:
|
||||||
|
@ -2,15 +2,19 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from bec_lib import messages
|
||||||
|
from bec_lib.endpoints import MessageEndpoints
|
||||||
from bec_lib.logger import bec_logger
|
from bec_lib.logger import bec_logger
|
||||||
from ophyd import Component as Cpt
|
from ophyd import Component as Cpt
|
||||||
from ophyd import Device, DeviceStatus, Kind
|
from ophyd import Device, DeviceStatus, Kind, Staged
|
||||||
|
|
||||||
from ophyd_devices.sim.sim_data import SimulatedDataWaveform
|
from ophyd_devices.sim.sim_data import SimulatedDataWaveform
|
||||||
from ophyd_devices.sim.sim_signals import ReadOnlySignal, SetableSignal
|
from ophyd_devices.sim.sim_signals import ReadOnlySignal, SetableSignal
|
||||||
|
from ophyd_devices.utils import bec_utils
|
||||||
from ophyd_devices.utils.bec_scaninfo_mixin import BecScaninfoMixin
|
from ophyd_devices.utils.bec_scaninfo_mixin import BecScaninfoMixin
|
||||||
from ophyd_devices.utils.errors import DeviceStopError
|
from ophyd_devices.utils.errors import DeviceStopError
|
||||||
|
|
||||||
@ -42,7 +46,7 @@ class SimWaveform(Device):
|
|||||||
SHAPE = (1000,)
|
SHAPE = (1000,)
|
||||||
BIT_DEPTH = np.uint16
|
BIT_DEPTH = np.uint16
|
||||||
|
|
||||||
SUB_MONITOR = "monitor"
|
SUB_MONITOR = "device_monitor_1d"
|
||||||
_default_sub = SUB_MONITOR
|
_default_sub = SUB_MONITOR
|
||||||
|
|
||||||
exp_time = Cpt(SetableSignal, name="exp_time", value=1, kind=Kind.config)
|
exp_time = Cpt(SetableSignal, name="exp_time", value=1, kind=Kind.config)
|
||||||
@ -57,20 +61,28 @@ class SimWaveform(Device):
|
|||||||
name="waveform",
|
name="waveform",
|
||||||
value=np.empty(SHAPE, dtype=BIT_DEPTH),
|
value=np.empty(SHAPE, dtype=BIT_DEPTH),
|
||||||
compute_readback=True,
|
compute_readback=True,
|
||||||
kind=Kind.omitted,
|
kind=Kind.normal,
|
||||||
)
|
)
|
||||||
|
# Can be extend or append
|
||||||
|
async_update = Cpt(SetableSignal, value="append", kind=Kind.config)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, name, *, kind=None, parent=None, sim_init: dict = None, device_manager=None, **kwargs
|
self, name, *, kind=None, parent=None, sim_init: dict = None, device_manager=None, **kwargs
|
||||||
):
|
):
|
||||||
self.device_manager = device_manager
|
|
||||||
self.sim_init = sim_init
|
self.sim_init = sim_init
|
||||||
self._registered_proxies = {}
|
self._registered_proxies = {}
|
||||||
self.sim = self.sim_cls(parent=self, **kwargs)
|
self.sim = self.sim_cls(parent=self, **kwargs)
|
||||||
|
|
||||||
super().__init__(name=name, parent=parent, kind=kind, **kwargs)
|
super().__init__(name=name, parent=parent, kind=kind, **kwargs)
|
||||||
|
if device_manager:
|
||||||
|
self.device_manager = device_manager
|
||||||
|
else:
|
||||||
|
self.device_manager = bec_utils.DMMock()
|
||||||
|
|
||||||
|
self.connector = self.device_manager.connector
|
||||||
|
self._stream_ttl = 1800 # 30 min max
|
||||||
self.stopped = False
|
self.stopped = False
|
||||||
self._staged = False
|
self._staged = Staged.no
|
||||||
self.scaninfo = None
|
self.scaninfo = None
|
||||||
self._trigger_thread = None
|
self._trigger_thread = None
|
||||||
self._update_scaninfo()
|
self._update_scaninfo()
|
||||||
@ -96,6 +108,7 @@ class SimWaveform(Device):
|
|||||||
try:
|
try:
|
||||||
for _ in range(self.burst.get()):
|
for _ in range(self.burst.get()):
|
||||||
self._run_subs(sub_type=self.SUB_MONITOR, value=self.waveform.get())
|
self._run_subs(sub_type=self.SUB_MONITOR, value=self.waveform.get())
|
||||||
|
self._send_async_update()
|
||||||
if self.stopped:
|
if self.stopped:
|
||||||
raise DeviceStopError(f"{self.name} was stopped")
|
raise DeviceStopError(f"{self.name} was stopped")
|
||||||
status.set_finished()
|
status.set_finished()
|
||||||
@ -109,6 +122,25 @@ class SimWaveform(Device):
|
|||||||
self._trigger_thread.start()
|
self._trigger_thread.start()
|
||||||
return status
|
return status
|
||||||
|
|
||||||
|
def _send_async_update(self):
|
||||||
|
"""Send the async update to BEC."""
|
||||||
|
metadata = self.scaninfo.scan_msg.metadata
|
||||||
|
async_update_type = self.async_update.get()
|
||||||
|
if async_update_type not in ["extend", "append"]:
|
||||||
|
raise ValueError(f"Invalid async_update type: {async_update_type}")
|
||||||
|
metadata.update({"async_update": async_update_type})
|
||||||
|
|
||||||
|
msg = messages.DeviceMessage(
|
||||||
|
signals={self.waveform.name: {"value": self.waveform.get(), "timestamp": time.time()}},
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
# logger.warning(f"Adding async update to {self.name} and {self.scaninfo.scan_id}")
|
||||||
|
self.connector.xadd(
|
||||||
|
MessageEndpoints.device_async_readback(scan_id=self.scaninfo.scan_id, device=self.name),
|
||||||
|
{"data": msg},
|
||||||
|
expire=self._stream_ttl,
|
||||||
|
)
|
||||||
|
|
||||||
def _update_scaninfo(self) -> None:
|
def _update_scaninfo(self) -> None:
|
||||||
"""Update scaninfo from BecScaninfoMixing
|
"""Update scaninfo from BecScaninfoMixing
|
||||||
This depends on device manager and operation/sim_mode
|
This depends on device manager and operation/sim_mode
|
||||||
@ -125,7 +157,8 @@ class SimWaveform(Device):
|
|||||||
FYI: No data is written to disk in the simulation, but upon each trigger it
|
FYI: No data is written to disk in the simulation, but upon each trigger it
|
||||||
is published to the device_monitor endpoint in REDIS.
|
is published to the device_monitor endpoint in REDIS.
|
||||||
"""
|
"""
|
||||||
if self._staged:
|
if self._staged is Staged.yes:
|
||||||
|
|
||||||
return super().stage()
|
return super().stage()
|
||||||
self.scaninfo.load_scan_metadata()
|
self.scaninfo.load_scan_metadata()
|
||||||
self.file_path.set(
|
self.file_path.set(
|
||||||
@ -137,6 +170,7 @@ class SimWaveform(Device):
|
|||||||
self.exp_time.set(self.scaninfo.exp_time)
|
self.exp_time.set(self.scaninfo.exp_time)
|
||||||
self.burst.set(self.scaninfo.frames_per_trigger)
|
self.burst.set(self.scaninfo.frames_per_trigger)
|
||||||
self.stopped = False
|
self.stopped = False
|
||||||
|
logger.warning(f"Staged {self.name}, scan_id : {self.scaninfo.scan_id}")
|
||||||
return super().stage()
|
return super().stage()
|
||||||
|
|
||||||
def unstage(self) -> list[object]:
|
def unstage(self) -> list[object]:
|
||||||
@ -144,9 +178,9 @@ class SimWaveform(Device):
|
|||||||
|
|
||||||
Send reads from all config signals to redis
|
Send reads from all config signals to redis
|
||||||
"""
|
"""
|
||||||
|
logger.warning(f"Unstaging {self.name}, {self._staged}")
|
||||||
if self.stopped is True or not self._staged:
|
if self.stopped is True or not self._staged:
|
||||||
return super().unstage()
|
return super().unstage()
|
||||||
|
|
||||||
return super().unstage()
|
return super().unstage()
|
||||||
|
|
||||||
def stop(self, *, success=False):
|
def stop(self, *, success=False):
|
||||||
@ -156,3 +190,8 @@ class SimWaveform(Device):
|
|||||||
self._trigger_thread.join()
|
self._trigger_thread.join()
|
||||||
self._trigger_thread = None
|
self._trigger_thread = None
|
||||||
super().stop(success=success)
|
super().stop(success=success)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__": # pragma: no cover
|
||||||
|
waveform = SimWaveform(name="waveform")
|
||||||
|
waveform.sim.sim_select_model("GaussianModel")
|
||||||
|
@ -29,9 +29,18 @@ from ophyd_devices.sim.sim_monitor import SimMonitor, SimMonitorAsync
|
|||||||
from ophyd_devices.sim.sim_positioner import SimLinearTrajectoryPositioner, SimPositioner
|
from ophyd_devices.sim.sim_positioner import SimLinearTrajectoryPositioner, SimPositioner
|
||||||
from ophyd_devices.sim.sim_signals import ReadOnlySignal
|
from ophyd_devices.sim.sim_signals import ReadOnlySignal
|
||||||
from ophyd_devices.sim.sim_utils import H5Writer, LinearTrajectory
|
from ophyd_devices.sim.sim_utils import H5Writer, LinearTrajectory
|
||||||
|
from ophyd_devices.sim.sim_waveform import SimWaveform
|
||||||
from ophyd_devices.utils.bec_device_base import BECDevice, BECDeviceBase
|
from ophyd_devices.utils.bec_device_base import BECDevice, BECDeviceBase
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def waveform(name="waveform"):
|
||||||
|
"""Fixture for SimWaveform."""
|
||||||
|
dm = DMMock()
|
||||||
|
wave = SimWaveform(name=name, device_manager=dm)
|
||||||
|
yield wave
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
def signal(name="signal"):
|
def signal(name="signal"):
|
||||||
"""Fixture for Signal."""
|
"""Fixture for Signal."""
|
||||||
@ -595,3 +604,33 @@ def test_positioner_updated_timestamp(positioner):
|
|||||||
readback = positioner.read()[positioner.name]
|
readback = positioner.read()[positioner.name]
|
||||||
assert readback["value"] == 5
|
assert readback["value"] == 5
|
||||||
assert readback["timestamp"] > timestamp
|
assert readback["timestamp"] > timestamp
|
||||||
|
|
||||||
|
|
||||||
|
def test_waveform(waveform):
|
||||||
|
"""Test the SimWaveform class"""
|
||||||
|
waveform.sim.sim_select_model("GaussianModel")
|
||||||
|
waveform.sim.params = {"amplitude": 500, "center": 500, "sigma": 10}
|
||||||
|
data = waveform.waveform.get()
|
||||||
|
assert isinstance(data, np.ndarray)
|
||||||
|
assert data.shape == waveform.SHAPE
|
||||||
|
assert np.isclose(np.argmax(data), 500, atol=5)
|
||||||
|
waveform.waveform_shape.put(50)
|
||||||
|
data = waveform.waveform.get()
|
||||||
|
for model in waveform.sim.get_all_sim_models():
|
||||||
|
waveform.sim.sim_select_model(model)
|
||||||
|
waveform.waveform.get()
|
||||||
|
# Now also test the async readback
|
||||||
|
mock_connector = waveform.connector = mock.MagicMock()
|
||||||
|
mock_run_subs = waveform._run_subs = mock.MagicMock()
|
||||||
|
waveform.scaninfo.scan_msg = SimpleNamespace(metadata={})
|
||||||
|
waveform.scaninfo.scan_id = "test"
|
||||||
|
status = waveform.trigger()
|
||||||
|
timer = 0
|
||||||
|
while not status.done:
|
||||||
|
time.sleep(0.1)
|
||||||
|
timer += 0.1
|
||||||
|
if timer > 5:
|
||||||
|
raise TimeoutError("Trigger did not complete")
|
||||||
|
assert status.done is True
|
||||||
|
assert mock_connector.xadd.call_count == 1
|
||||||
|
assert mock_run_subs.call_count == 1
|
||||||
|
Reference in New Issue
Block a user