feat(sim_waveform): added option to emit data with add_slice

This commit is contained in:
wakonig_k 2025-03-06 18:15:17 +01:00 committed by appel_c
parent 7797e4003b
commit 21746e5445
3 changed files with 177 additions and 17 deletions

View File

@ -4,7 +4,7 @@ import time
import numpy as np import numpy as np
from bec_lib import bec_logger from bec_lib import bec_logger
from ophyd import Kind, Signal from ophyd import DeviceStatus, Kind, Signal
from ophyd.utils import ReadOnlyError from ophyd.utils import ReadOnlyError
from ophyd_devices.utils.bec_device_base import BECDeviceBase from ophyd_devices.utils.bec_device_base import BECDeviceBase
@ -87,10 +87,18 @@ class SetableSignal(Signal):
Core function for signal. Core function for signal.
""" """
self.check_value(value)
self._update_sim_state(value) self._update_sim_state(value)
self._value = value self._value = value
self._run_subs(sub_type=self.SUB_VALUE, value=value) self._run_subs(sub_type=self.SUB_VALUE, value=value)
def set(self, value):
"""Set method"""
self.put(value)
status = DeviceStatus(self)
status.set_finished()
return status
def describe(self): def describe(self):
"""Describe the readback signal. """Describe the readback signal.

View File

@ -4,6 +4,7 @@ import os
import threading import threading
import time import time
import traceback import traceback
from typing import Any
import numpy as np import numpy as np
from bec_lib import messages from bec_lib import messages
@ -11,6 +12,7 @@ from bec_lib.endpoints import MessageEndpoints
from bec_lib.logger import bec_logger from bec_lib.logger import bec_logger
from ophyd import Component as Cpt from ophyd import Component as Cpt
from ophyd import Device, DeviceStatus, Kind, Staged from ophyd import Device, DeviceStatus, Kind, Staged
from typeguard import typechecked
from ophyd_devices.sim.sim_data import SimulatedDataWaveform from ophyd_devices.sim.sim_data import SimulatedDataWaveform
from ophyd_devices.sim.sim_signals import ReadOnlySignal, SetableSignal from ophyd_devices.sim.sim_signals import ReadOnlySignal, SetableSignal
@ -20,6 +22,31 @@ from ophyd_devices.utils.errors import DeviceStopError
logger = bec_logger.logger logger = bec_logger.logger
class AsyncUpdateSignal(SetableSignal):
"""Async updated signal, with check for async_update type."""
def check_value(self, value, **kwargs) -> None:
"""Check the value of the async_update signal."""
if value not in ["add_slice", "add"]:
raise ValueError(f"Invalid async_update type: {value} for signal {self.name}")
# FIXME: BEC issue #443 remove this method once tests in BEC are updated.
def put(self, value: Any) -> None:
"""Put the value of the async_update signal."""
if value in ["append", "extend"]:
if value == "append":
logger.warning(
f"Deprecated async_update of type {value} for signal {self.name}, falling back to 'add_slice'"
)
value = "add_slice"
elif value == "extend":
logger.warning(
f"Deprecated async_update of type {value} for signal {self.name}, falling back to 'add'"
)
value = "add"
super().put(value)
class SimWaveform(Device): class SimWaveform(Device):
"""A simulated device mimic any 1D Waveform detector. """A simulated device mimic any 1D Waveform detector.
@ -39,7 +66,7 @@ class SimWaveform(Device):
""" """
USER_ACCESS = ["sim", "registered_proxies"] USER_ACCESS = ["sim", "registered_proxies", "delay_slice_update"]
sim_cls = SimulatedDataWaveform sim_cls = SimulatedDataWaveform
SHAPE = (1000,) SHAPE = (1000,)
@ -60,10 +87,11 @@ class SimWaveform(Device):
name="waveform", name="waveform",
value=np.empty(SHAPE, dtype=BIT_DEPTH), value=np.empty(SHAPE, dtype=BIT_DEPTH),
compute_readback=True, compute_readback=True,
kind=Kind.normal, kind=Kind.hinted,
) )
# Can be extend or append # Can be extend or append
async_update = Cpt(SetableSignal, value="append", kind=Kind.config) async_update = Cpt(AsyncUpdateSignal, value="add", kind=Kind.config)
slice_size = Cpt(SetableSignal, value=100, dtype=np.int32, kind=Kind.config)
def __init__( def __init__(
self, self,
@ -92,8 +120,20 @@ class SimWaveform(Device):
self._staged = Staged.no self._staged = Staged.no
self._trigger_thread = None self._trigger_thread = None
self.scan_info = scan_info self.scan_info = scan_info
self._delay_slice_update = False
if self.sim_init: if self.sim_init:
self.sim.set_init(self.sim_init) self.sim.set_init(self.sim_init)
self._slice_index = 0
@property
def delay_slice_update(self) -> bool:
"""Delay updates in-between slices specified by waveform_shape and slice_size."""
return self._delay_slice_update
@typechecked
@delay_slice_update.setter
def delay_slice_update(self, value: bool) -> None:
self._delay_slice_update = value
@property @property
def registered_proxies(self) -> None: def registered_proxies(self) -> None:
@ -113,8 +153,33 @@ class SimWaveform(Device):
def acquire(status: DeviceStatus): def acquire(status: DeviceStatus):
try: try:
for _ in range(self.burst.get()): for _ in range(self.burst.get()):
self._run_subs(sub_type=self.SUB_MONITOR, value=self.waveform.get()) # values of the Waveform
self._send_async_update() values = self.waveform.get()
# add_slice option
if self.async_update.get() == "add_slice":
size = self.slice_size.get()
mod = len(values) % size
num_slices = len(values) // size + int(mod > 0)
for i in range(num_slices):
value_slice = values[i * size : min((i + 1) * size, len(values))]
logger.info(
f"Sending slice {i} of {self._slice_index} with length {len(value_slice)}"
)
self._run_subs(sub_type=self.SUB_MONITOR, value=value_slice)
self._send_async_update(index=self._slice_index, value=value_slice)
if self.delay_slice_update is True:
time.sleep(0.025) # 25ms to be really fast
if self.stopped:
raise DeviceStopError(f"{self.name} was stopped")
self._slice_index += 1
# option add
elif self.async_update.get() == "add":
self._run_subs(sub_type=self.SUB_MONITOR, value=values)
self._send_async_update(value=values)
else:
# This should never happen, but just in case
# we raise an exception
raise ValueError(f"Invalid async_update type: {self.async_update.get()}")
if self.stopped: if self.stopped:
raise DeviceStopError(f"{self.name} was stopped") raise DeviceStopError(f"{self.name} was stopped")
status.set_finished() status.set_finished()
@ -128,23 +193,41 @@ class SimWaveform(Device):
self._trigger_thread.start() self._trigger_thread.start()
return status return status
def _send_async_update(self): def _send_async_update(self, value: Any, index: int | None = None) -> None:
"""Send the async update to BEC.""" """
async_update_type = self.async_update.get() Send the async update to BEC.
if async_update_type not in ["extend", "append"]:
raise ValueError(f"Invalid async_update type: {async_update_type}")
Args:
index (int | None): The index of the slice to be sent. If None, the entire waveform is sent.
value (Any): The value to be sent.
"""
async_update_type = self.async_update.get()
waveform_shape = self.waveform_shape.get() waveform_shape = self.waveform_shape.get()
if async_update_type == "append": if async_update_type == "add_slice":
metadata = {"async_update": {"type": "add", "max_shape": [None, waveform_shape]}} if index is not None:
else: metadata = {
"async_update": {
"type": "add_slice",
"index": index,
"max_shape": [None, waveform_shape],
}
}
else:
metadata = {"async_update": {"type": "add", "max_shape": [None, waveform_shape]}}
elif async_update_type == "add":
metadata = {"async_update": {"type": "add", "max_shape": [None]}} metadata = {"async_update": {"type": "add", "max_shape": [None]}}
else:
# Again, this should never happen -> check_value,
# but just in case we raise an exception
raise ValueError(
f"Invalid async_update type: {async_update_type} for device {self.name}"
)
msg = messages.DeviceMessage( msg = messages.DeviceMessage(
signals={self.waveform.name: {"value": self.waveform.get(), "timestamp": time.time()}}, signals={self.waveform.name: {"value": value, "timestamp": time.time()}},
metadata=metadata, metadata=metadata,
) )
# logger.warning(f"Adding async update to {self.name} and {self.scan_info.msg.scan_id}") # Send the message to BEC
self.connector.xadd( self.connector.xadd(
MessageEndpoints.device_async_readback( MessageEndpoints.device_async_readback(
scan_id=self.scan_info.msg.scan_id, device=self.name scan_id=self.scan_info.msg.scan_id, device=self.name
@ -177,6 +260,7 @@ class SimWaveform(Device):
self.exp_time.set(self.scan_info.msg.scan_parameters["exp_time"]) self.exp_time.set(self.scan_info.msg.scan_parameters["exp_time"])
self.burst.set(self.scan_info.msg.scan_parameters["frames_per_trigger"]) self.burst.set(self.scan_info.msg.scan_parameters["frames_per_trigger"])
self.stopped = False self.stopped = False
self._slice_index = 0
logger.warning(f"Staged {self.name}, scan_id : {self.scan_info.msg.scan_id}") logger.warning(f"Staged {self.name}, scan_id : {self.scan_info.msg.scan_id}")
return super().stage() return super().stage()
@ -186,6 +270,7 @@ class SimWaveform(Device):
Send reads from all config signals to redis Send reads from all config signals to redis
""" """
logger.warning(f"Unstaging {self.name}, {self._staged}") logger.warning(f"Unstaging {self.name}, {self._staged}")
self._slice_index = 0
if self.stopped is True or not self._staged: if self.stopped is True or not self._staged:
return super().unstage() return super().unstage()
return super().unstage() return super().unstage()

View File

@ -1,4 +1,4 @@
""" This module contains tests for the simulation devices in ophyd_devices """ """This module contains tests for the simulation devices in ophyd_devices"""
# pylint: disable: all # pylint: disable: all
import os import os
@ -765,3 +765,70 @@ def test_waveform(waveform):
assert status.done is True assert status.done is True
assert mock_connector.xadd.call_count == 1 assert mock_connector.xadd.call_count == 1
assert mock_run_subs.call_count == 1 assert mock_run_subs.call_count == 1
@pytest.mark.parametrize(
"mode, mock_data, expected_calls",
[
(
"add",
np.zeros(5),
[{"sub_type": "device_monitor_1d", "value": np.zeros(5)}, {"value": np.zeros(5)}],
)
],
)
def test_waveform_update_modes(waveform, mode, mock_data, expected_calls):
"""Test the add and add_slice update modes of the SimWaveform class"""
waveform.sim.select_model("GaussianModel")
waveform.sim.params = {"amplitude": 500, "center": 500, "sigma": 10}
with pytest.raises(ValueError):
waveform.async_update.put("invalid_mode")
# Use add mode
waveform.async_update.put(mode)
with (
mock.patch.object(waveform, "_run_subs") as mock_run_subs,
mock.patch.object(waveform, "_send_async_update") as mock_send_async_update,
mock.patch.object(waveform.waveform, "get", return_value=mock_data),
):
status = waveform.trigger()
status_wait(status, timeout=10) # Raise if times out
assert status.done is True
# Run subs
assert mock_run_subs.call_args[1]["sub_type"] == expected_calls[0]["sub_type"]
assert np.array_equal(mock_run_subs.call_args[1]["value"], expected_calls[0]["value"])
# Send async update
assert np.array_equal(
mock_send_async_update.call_args[1]["value"], expected_calls[1]["value"]
)
@pytest.mark.parametrize(
"mode, index, expected_md",
[
(
"add_slice",
0,
{"async_update": {"type": "add_slice", "index": 0, "max_shape": [None, 100]}},
),
("add_slice", None, {"async_update": {"type": "add", "max_shape": [None, 100]}}),
("add", 0, {"async_update": {"type": "add", "max_shape": [None]}}),
],
)
def test_waveform_send_async_update(waveform, mode, index, expected_md):
"""Test the send_async_update method of SimWaveform."""
max_shape = expected_md["async_update"]["max_shape"]
if len(max_shape) > 1:
wv_shape = max_shape[1]
else:
wv_shape = 100
waveform.waveform_shape.put(wv_shape)
waveform.async_update.put(mode)
waveform.scan_info = get_mock_scan_info(device=waveform)
value = 0
with mock.patch.object(waveform.connector, "xadd") as mock_xadd:
waveform._send_async_update(index=index, value=value)
# Check here that metadata is properly set
args, kwargs = mock_xadd.call_args
msg = args[1]["data"]
assert msg.metadata == expected_md