mirror of
https://github.com/bec-project/ophyd_devices.git
synced 2025-06-06 20:00:41 +02:00
feat(sim_waveform): added option to emit data with add_slice
This commit is contained in:
parent
7797e4003b
commit
21746e5445
@ -4,7 +4,7 @@ import time
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from bec_lib import bec_logger
|
from bec_lib import bec_logger
|
||||||
from ophyd import Kind, Signal
|
from ophyd import DeviceStatus, Kind, Signal
|
||||||
from ophyd.utils import ReadOnlyError
|
from ophyd.utils import ReadOnlyError
|
||||||
|
|
||||||
from ophyd_devices.utils.bec_device_base import BECDeviceBase
|
from ophyd_devices.utils.bec_device_base import BECDeviceBase
|
||||||
@ -87,10 +87,18 @@ class SetableSignal(Signal):
|
|||||||
|
|
||||||
Core function for signal.
|
Core function for signal.
|
||||||
"""
|
"""
|
||||||
|
self.check_value(value)
|
||||||
self._update_sim_state(value)
|
self._update_sim_state(value)
|
||||||
self._value = value
|
self._value = value
|
||||||
self._run_subs(sub_type=self.SUB_VALUE, value=value)
|
self._run_subs(sub_type=self.SUB_VALUE, value=value)
|
||||||
|
|
||||||
|
def set(self, value):
|
||||||
|
"""Set method"""
|
||||||
|
self.put(value)
|
||||||
|
status = DeviceStatus(self)
|
||||||
|
status.set_finished()
|
||||||
|
return status
|
||||||
|
|
||||||
def describe(self):
|
def describe(self):
|
||||||
"""Describe the readback signal.
|
"""Describe the readback signal.
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@ import os
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from bec_lib import messages
|
from bec_lib import messages
|
||||||
@ -11,6 +12,7 @@ from bec_lib.endpoints import MessageEndpoints
|
|||||||
from bec_lib.logger import bec_logger
|
from bec_lib.logger import bec_logger
|
||||||
from ophyd import Component as Cpt
|
from ophyd import Component as Cpt
|
||||||
from ophyd import Device, DeviceStatus, Kind, Staged
|
from ophyd import Device, DeviceStatus, Kind, Staged
|
||||||
|
from typeguard import typechecked
|
||||||
|
|
||||||
from ophyd_devices.sim.sim_data import SimulatedDataWaveform
|
from ophyd_devices.sim.sim_data import SimulatedDataWaveform
|
||||||
from ophyd_devices.sim.sim_signals import ReadOnlySignal, SetableSignal
|
from ophyd_devices.sim.sim_signals import ReadOnlySignal, SetableSignal
|
||||||
@ -20,6 +22,31 @@ from ophyd_devices.utils.errors import DeviceStopError
|
|||||||
logger = bec_logger.logger
|
logger = bec_logger.logger
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncUpdateSignal(SetableSignal):
|
||||||
|
"""Async updated signal, with check for async_update type."""
|
||||||
|
|
||||||
|
def check_value(self, value, **kwargs) -> None:
|
||||||
|
"""Check the value of the async_update signal."""
|
||||||
|
if value not in ["add_slice", "add"]:
|
||||||
|
raise ValueError(f"Invalid async_update type: {value} for signal {self.name}")
|
||||||
|
|
||||||
|
# FIXME: BEC issue #443 remove this method once tests in BEC are updated.
|
||||||
|
def put(self, value: Any) -> None:
|
||||||
|
"""Put the value of the async_update signal."""
|
||||||
|
if value in ["append", "extend"]:
|
||||||
|
if value == "append":
|
||||||
|
logger.warning(
|
||||||
|
f"Deprecated async_update of type {value} for signal {self.name}, falling back to 'add_slice'"
|
||||||
|
)
|
||||||
|
value = "add_slice"
|
||||||
|
elif value == "extend":
|
||||||
|
logger.warning(
|
||||||
|
f"Deprecated async_update of type {value} for signal {self.name}, falling back to 'add'"
|
||||||
|
)
|
||||||
|
value = "add"
|
||||||
|
super().put(value)
|
||||||
|
|
||||||
|
|
||||||
class SimWaveform(Device):
|
class SimWaveform(Device):
|
||||||
"""A simulated device mimic any 1D Waveform detector.
|
"""A simulated device mimic any 1D Waveform detector.
|
||||||
|
|
||||||
@ -39,7 +66,7 @@ class SimWaveform(Device):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
USER_ACCESS = ["sim", "registered_proxies"]
|
USER_ACCESS = ["sim", "registered_proxies", "delay_slice_update"]
|
||||||
|
|
||||||
sim_cls = SimulatedDataWaveform
|
sim_cls = SimulatedDataWaveform
|
||||||
SHAPE = (1000,)
|
SHAPE = (1000,)
|
||||||
@ -60,10 +87,11 @@ class SimWaveform(Device):
|
|||||||
name="waveform",
|
name="waveform",
|
||||||
value=np.empty(SHAPE, dtype=BIT_DEPTH),
|
value=np.empty(SHAPE, dtype=BIT_DEPTH),
|
||||||
compute_readback=True,
|
compute_readback=True,
|
||||||
kind=Kind.normal,
|
kind=Kind.hinted,
|
||||||
)
|
)
|
||||||
# Can be extend or append
|
# Can be extend or append
|
||||||
async_update = Cpt(SetableSignal, value="append", kind=Kind.config)
|
async_update = Cpt(AsyncUpdateSignal, value="add", kind=Kind.config)
|
||||||
|
slice_size = Cpt(SetableSignal, value=100, dtype=np.int32, kind=Kind.config)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -92,8 +120,20 @@ class SimWaveform(Device):
|
|||||||
self._staged = Staged.no
|
self._staged = Staged.no
|
||||||
self._trigger_thread = None
|
self._trigger_thread = None
|
||||||
self.scan_info = scan_info
|
self.scan_info = scan_info
|
||||||
|
self._delay_slice_update = False
|
||||||
if self.sim_init:
|
if self.sim_init:
|
||||||
self.sim.set_init(self.sim_init)
|
self.sim.set_init(self.sim_init)
|
||||||
|
self._slice_index = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def delay_slice_update(self) -> bool:
|
||||||
|
"""Delay updates in-between slices specified by waveform_shape and slice_size."""
|
||||||
|
return self._delay_slice_update
|
||||||
|
|
||||||
|
@typechecked
|
||||||
|
@delay_slice_update.setter
|
||||||
|
def delay_slice_update(self, value: bool) -> None:
|
||||||
|
self._delay_slice_update = value
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def registered_proxies(self) -> None:
|
def registered_proxies(self) -> None:
|
||||||
@ -113,8 +153,33 @@ class SimWaveform(Device):
|
|||||||
def acquire(status: DeviceStatus):
|
def acquire(status: DeviceStatus):
|
||||||
try:
|
try:
|
||||||
for _ in range(self.burst.get()):
|
for _ in range(self.burst.get()):
|
||||||
self._run_subs(sub_type=self.SUB_MONITOR, value=self.waveform.get())
|
# values of the Waveform
|
||||||
self._send_async_update()
|
values = self.waveform.get()
|
||||||
|
# add_slice option
|
||||||
|
if self.async_update.get() == "add_slice":
|
||||||
|
size = self.slice_size.get()
|
||||||
|
mod = len(values) % size
|
||||||
|
num_slices = len(values) // size + int(mod > 0)
|
||||||
|
for i in range(num_slices):
|
||||||
|
value_slice = values[i * size : min((i + 1) * size, len(values))]
|
||||||
|
logger.info(
|
||||||
|
f"Sending slice {i} of {self._slice_index} with length {len(value_slice)}"
|
||||||
|
)
|
||||||
|
self._run_subs(sub_type=self.SUB_MONITOR, value=value_slice)
|
||||||
|
self._send_async_update(index=self._slice_index, value=value_slice)
|
||||||
|
if self.delay_slice_update is True:
|
||||||
|
time.sleep(0.025) # 25ms to be really fast
|
||||||
|
if self.stopped:
|
||||||
|
raise DeviceStopError(f"{self.name} was stopped")
|
||||||
|
self._slice_index += 1
|
||||||
|
# option add
|
||||||
|
elif self.async_update.get() == "add":
|
||||||
|
self._run_subs(sub_type=self.SUB_MONITOR, value=values)
|
||||||
|
self._send_async_update(value=values)
|
||||||
|
else:
|
||||||
|
# This should never happen, but just in case
|
||||||
|
# we raise an exception
|
||||||
|
raise ValueError(f"Invalid async_update type: {self.async_update.get()}")
|
||||||
if self.stopped:
|
if self.stopped:
|
||||||
raise DeviceStopError(f"{self.name} was stopped")
|
raise DeviceStopError(f"{self.name} was stopped")
|
||||||
status.set_finished()
|
status.set_finished()
|
||||||
@ -128,23 +193,41 @@ class SimWaveform(Device):
|
|||||||
self._trigger_thread.start()
|
self._trigger_thread.start()
|
||||||
return status
|
return status
|
||||||
|
|
||||||
def _send_async_update(self):
|
def _send_async_update(self, value: Any, index: int | None = None) -> None:
|
||||||
"""Send the async update to BEC."""
|
"""
|
||||||
async_update_type = self.async_update.get()
|
Send the async update to BEC.
|
||||||
if async_update_type not in ["extend", "append"]:
|
|
||||||
raise ValueError(f"Invalid async_update type: {async_update_type}")
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index (int | None): The index of the slice to be sent. If None, the entire waveform is sent.
|
||||||
|
value (Any): The value to be sent.
|
||||||
|
"""
|
||||||
|
async_update_type = self.async_update.get()
|
||||||
waveform_shape = self.waveform_shape.get()
|
waveform_shape = self.waveform_shape.get()
|
||||||
if async_update_type == "append":
|
if async_update_type == "add_slice":
|
||||||
metadata = {"async_update": {"type": "add", "max_shape": [None, waveform_shape]}}
|
if index is not None:
|
||||||
else:
|
metadata = {
|
||||||
|
"async_update": {
|
||||||
|
"type": "add_slice",
|
||||||
|
"index": index,
|
||||||
|
"max_shape": [None, waveform_shape],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
metadata = {"async_update": {"type": "add", "max_shape": [None, waveform_shape]}}
|
||||||
|
elif async_update_type == "add":
|
||||||
metadata = {"async_update": {"type": "add", "max_shape": [None]}}
|
metadata = {"async_update": {"type": "add", "max_shape": [None]}}
|
||||||
|
else:
|
||||||
|
# Again, this should never happen -> check_value,
|
||||||
|
# but just in case we raise an exception
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid async_update type: {async_update_type} for device {self.name}"
|
||||||
|
)
|
||||||
|
|
||||||
msg = messages.DeviceMessage(
|
msg = messages.DeviceMessage(
|
||||||
signals={self.waveform.name: {"value": self.waveform.get(), "timestamp": time.time()}},
|
signals={self.waveform.name: {"value": value, "timestamp": time.time()}},
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
# logger.warning(f"Adding async update to {self.name} and {self.scan_info.msg.scan_id}")
|
# Send the message to BEC
|
||||||
self.connector.xadd(
|
self.connector.xadd(
|
||||||
MessageEndpoints.device_async_readback(
|
MessageEndpoints.device_async_readback(
|
||||||
scan_id=self.scan_info.msg.scan_id, device=self.name
|
scan_id=self.scan_info.msg.scan_id, device=self.name
|
||||||
@ -177,6 +260,7 @@ class SimWaveform(Device):
|
|||||||
self.exp_time.set(self.scan_info.msg.scan_parameters["exp_time"])
|
self.exp_time.set(self.scan_info.msg.scan_parameters["exp_time"])
|
||||||
self.burst.set(self.scan_info.msg.scan_parameters["frames_per_trigger"])
|
self.burst.set(self.scan_info.msg.scan_parameters["frames_per_trigger"])
|
||||||
self.stopped = False
|
self.stopped = False
|
||||||
|
self._slice_index = 0
|
||||||
logger.warning(f"Staged {self.name}, scan_id : {self.scan_info.msg.scan_id}")
|
logger.warning(f"Staged {self.name}, scan_id : {self.scan_info.msg.scan_id}")
|
||||||
return super().stage()
|
return super().stage()
|
||||||
|
|
||||||
@ -186,6 +270,7 @@ class SimWaveform(Device):
|
|||||||
Send reads from all config signals to redis
|
Send reads from all config signals to redis
|
||||||
"""
|
"""
|
||||||
logger.warning(f"Unstaging {self.name}, {self._staged}")
|
logger.warning(f"Unstaging {self.name}, {self._staged}")
|
||||||
|
self._slice_index = 0
|
||||||
if self.stopped is True or not self._staged:
|
if self.stopped is True or not self._staged:
|
||||||
return super().unstage()
|
return super().unstage()
|
||||||
return super().unstage()
|
return super().unstage()
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
""" This module contains tests for the simulation devices in ophyd_devices """
|
"""This module contains tests for the simulation devices in ophyd_devices"""
|
||||||
|
|
||||||
# pylint: disable: all
|
# pylint: disable: all
|
||||||
import os
|
import os
|
||||||
@ -765,3 +765,70 @@ def test_waveform(waveform):
|
|||||||
assert status.done is True
|
assert status.done is True
|
||||||
assert mock_connector.xadd.call_count == 1
|
assert mock_connector.xadd.call_count == 1
|
||||||
assert mock_run_subs.call_count == 1
|
assert mock_run_subs.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"mode, mock_data, expected_calls",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"add",
|
||||||
|
np.zeros(5),
|
||||||
|
[{"sub_type": "device_monitor_1d", "value": np.zeros(5)}, {"value": np.zeros(5)}],
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_waveform_update_modes(waveform, mode, mock_data, expected_calls):
|
||||||
|
"""Test the add and add_slice update modes of the SimWaveform class"""
|
||||||
|
waveform.sim.select_model("GaussianModel")
|
||||||
|
waveform.sim.params = {"amplitude": 500, "center": 500, "sigma": 10}
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
waveform.async_update.put("invalid_mode")
|
||||||
|
# Use add mode
|
||||||
|
waveform.async_update.put(mode)
|
||||||
|
with (
|
||||||
|
mock.patch.object(waveform, "_run_subs") as mock_run_subs,
|
||||||
|
mock.patch.object(waveform, "_send_async_update") as mock_send_async_update,
|
||||||
|
mock.patch.object(waveform.waveform, "get", return_value=mock_data),
|
||||||
|
):
|
||||||
|
|
||||||
|
status = waveform.trigger()
|
||||||
|
status_wait(status, timeout=10) # Raise if times out
|
||||||
|
assert status.done is True
|
||||||
|
# Run subs
|
||||||
|
assert mock_run_subs.call_args[1]["sub_type"] == expected_calls[0]["sub_type"]
|
||||||
|
assert np.array_equal(mock_run_subs.call_args[1]["value"], expected_calls[0]["value"])
|
||||||
|
# Send async update
|
||||||
|
assert np.array_equal(
|
||||||
|
mock_send_async_update.call_args[1]["value"], expected_calls[1]["value"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"mode, index, expected_md",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"add_slice",
|
||||||
|
0,
|
||||||
|
{"async_update": {"type": "add_slice", "index": 0, "max_shape": [None, 100]}},
|
||||||
|
),
|
||||||
|
("add_slice", None, {"async_update": {"type": "add", "max_shape": [None, 100]}}),
|
||||||
|
("add", 0, {"async_update": {"type": "add", "max_shape": [None]}}),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_waveform_send_async_update(waveform, mode, index, expected_md):
|
||||||
|
"""Test the send_async_update method of SimWaveform."""
|
||||||
|
max_shape = expected_md["async_update"]["max_shape"]
|
||||||
|
if len(max_shape) > 1:
|
||||||
|
wv_shape = max_shape[1]
|
||||||
|
else:
|
||||||
|
wv_shape = 100
|
||||||
|
waveform.waveform_shape.put(wv_shape)
|
||||||
|
waveform.async_update.put(mode)
|
||||||
|
waveform.scan_info = get_mock_scan_info(device=waveform)
|
||||||
|
value = 0
|
||||||
|
with mock.patch.object(waveform.connector, "xadd") as mock_xadd:
|
||||||
|
waveform._send_async_update(index=index, value=value)
|
||||||
|
# Check here that metadata is properly set
|
||||||
|
args, kwargs = mock_xadd.call_args
|
||||||
|
msg = args[1]["data"]
|
||||||
|
assert msg.metadata == expected_md
|
||||||
|
Loading…
x
Reference in New Issue
Block a user