refactor: Cleanup, bugfixes, and docs

This commit is contained in:
2026-03-17 15:09:56 +01:00
parent 877156d200
commit 3ccd51ecd0
9 changed files with 991 additions and 288 deletions
+133 -60
View File
@@ -1,9 +1,19 @@
"""Module for virtual slit center implementation."""
"""Pseudo-motor implementations for slit center and slit width.
Both devices map one pseudo axis onto two real motors (left and right slit
edges). They can be instantiated from names resolved through the BEC device
manager and support an optional offset term in their coordinate transforms.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
from bec_lib.logger import bec_logger
from ophyd import Component as Cpt
from ophyd import Kind
from ophyd.signal import SignalRO
from ophyd_devices.interfaces.base_classes.psi_pseudo_motor_base import PSIPseudoMotorBase
if TYPE_CHECKING: # pragma: no cover
@@ -11,19 +21,37 @@ if TYPE_CHECKING: # pragma: no cover
from ophyd import PositionerBase, Signal
logger = bec_logger.logger
class VirtualSlitCenter(PSIPseudoMotorBase):
"""
Virtual slit center implementation. It expects the left and right slit positioner names to be passed
as arguments. The named devices must be positioners and available in the device_manager. In addition,
it must have a readback (user_readback), setpoint (user_setpoint) and motor_is_moving signal.
"""Pseudo motor controlling slit center from two edge motors.
Both positioners must be present in the device manager, and the pseudo
motor entry in the current session config must declare them in ``needs``.
The forward calculation computes the center position based on the positions of the left and right positioners,
while the inverse calculation computes the setpoints for the left and right positioners based on a desired center
position. The motors_are_moving method checks if either of the positioners is currently moving.
Args:
name (str): The name of the virtual slit center.
left_slit (str): The name of the left slit positioner in the device manager.
right_slit (str): The name of the right slit positioner in the device manager.
device_manager (DeviceManagerBase): The device manager to use for connecting to the positioners.
name (str): The name of the pseudo motor device.
left_slit (str): The name of the left slit positioner device in the device manager.
right_slit (str): The name of the right slit positioner device in the device manager.
device_manager (DeviceManagerBase): The device manager instance to fetch the positioner devices from.
offset (float, optional): Constant center offset added in forward
calculation and removed in inverse calculation.
egu (str | None, optional): Engineering units. If omitted, units are
taken from the left positioner.
"""
offset = Cpt(
SignalRO,
name="offset",
kind=Kind.config,
doc="Offset applied to the position of the slit center when calculating the width.",
)
def __init__(
self,
name: str,
@@ -31,6 +59,7 @@ class VirtualSlitCenter(PSIPseudoMotorBase):
right_slit: str,
device_manager: DeviceManagerBase,
offset: float = 0,
egu: str | None = None,
**kwargs,
):
positioners = self.get_positioner_objects(
@@ -38,32 +67,44 @@ class VirtualSlitCenter(PSIPseudoMotorBase):
positioners={"left": left_slit, "right": right_slit},
device_manager=device_manager,
)
if egu is None: # if not specified, fetch it from the left positioner
egu = positioners["left"].egu
if positioners["right"].egu != egu:
logger.warning(
f"Device {name} found inconsistency for egu for positioner left {left_slit} and right {right_slit}. Using egu {egu}."
)
self._offset = offset
super().__init__(
name=name, device_manager=device_manager, positioners=positioners, **kwargs
name=name, device_manager=device_manager, positioners=positioners, egu=egu, **kwargs
)
def wait_for_connection(self, *args, **kwargs):
"""Connect and initialize the read-only ``offset`` configuration signal."""
super().wait_for_connection(*args, **kwargs)
# Set the initial value of the offset signal
# Config values are read by back after wait_for_connection is called.
self.offset._readback = self._offset
def _get_pos_motor(self, motor: PositionerBase) -> float:
"""
Helper method to get the position of a motor.
"""Return the current position read from ``motor``.
Args:
motor (PositionerBase): The motor to get the position of.
motor (PositionerBase): The positioner motor to read the position from.
Returns:
float: Current motor position.
"""
return motor.read()[motor.name]["value"]
# pylint: disable=arguments-differ
def forward_calculation(self, left: Signal, right: Signal) -> float:
"""
Forward calculation to compute the value for the pseudo motor readback
and setpoint based on the position of the left and right slit.
"""Compute slit center from left and right positions.
Args:
left (Signal): The left slit positioner signal.
right (Signal): The right slit positioner signal.
left (Signal): The signal representing the position of the left slit positioner.
right (Signal): The signal representing the position of the right slit positioner.
Returns:
float: The center position of the slit.
float: Center position ``(left + right) / 2 + offset``.
"""
left_pos = left.get()
right_pos = right.get()
@@ -71,33 +112,33 @@ class VirtualSlitCenter(PSIPseudoMotorBase):
return float(center)
def inverse_calculation(self, position: float, left: Signal, right: Signal) -> dict[str, float]:
"""
Inverse calculation to compute the position of the left and right slit based on the center position.
"""Compute left/right setpoints for a target center.
The current slit width is preserved.
Args:
center (float): The center position of the slit.
left (Signal): The left slit positioner signal.
right (Signal): The right slit positioner signal.
position (float): The desired center position of the slit.
left (Signal): The signal representing the position of the left slit positioner.
right (Signal): The signal representing the position of the right slit positioner.
Returns:
dict[str, float]: The positions of the left and right slit.
A dictionary with the new setpoints for the left and right positioners, with keys "left" and "right".
"""
position_with_offset = position - self._offset
# To access position, run read on the root (PositionerBase) of the signal
left_pos = self._get_pos_motor(left.root)
right_pos = self._get_pos_motor(right.root)
left_pos = left.get()
right_pos = right.get()
width = right_pos - left_pos
new_left_pos = position_with_offset - width / 2
new_right_pos = position_with_offset + width / 2
return {"left": new_left_pos, "right": new_right_pos}
def motors_are_moving(self, left: Signal, right: Signal) -> int:
"""
Calculate whether the motors are moving based on the motor_is_moving signal of the left and right slit.
"""Return 1 when either left or right motor is moving, else 0.
Args:
left (Signal): The left slit positioner signal.
right (Signal): The right slit positioner signal.
left (Signal): The signal representing the position of the left slit positioner.
right (Signal): The signal representing the position of the right slit positioner.
Returns:
int: 1 if either motor is moving, 0 otherwise.
"""
@@ -107,6 +148,28 @@ class VirtualSlitCenter(PSIPseudoMotorBase):
class VirtualSlitWidth(PSIPseudoMotorBase):
"""Pseudo motor controlling slit width from two edge motors.
Both positioners must be present in the device manager, and the pseudo
motor entry in the current session config must declare them in ``needs``.
Args:
name (str): The name of the pseudo motor device.
left_slit (str): The name of the left slit positioner device in the device manager.
right_slit (str): The name of the right slit positioner device in the device manager.
device_manager (DeviceManagerBase): The device manager instance to fetch the positioner devices from.
offset (float, optional): Constant width offset added in forward
calculation and removed in inverse calculation.
egu (str | None, optional): Engineering units. If omitted, units are
taken from the left positioner.
"""
offset = Cpt(
SignalRO,
name="offset",
kind=Kind.config,
doc="Offset applied to the position of the slit center when calculating the width.",
)
def __init__(
self,
@@ -114,6 +177,8 @@ class VirtualSlitWidth(PSIPseudoMotorBase):
left_slit: str,
right_slit: str,
device_manager: DeviceManagerBase,
offset: float = 0,
egu: str | None = None,
**kwargs,
):
positioners = self.get_positioner_objects(
@@ -121,62 +186,66 @@ class VirtualSlitWidth(PSIPseudoMotorBase):
positioners={"left": left_slit, "right": right_slit},
device_manager=device_manager,
)
if egu is None: # if not specified, fetch it from the left positioner
egu = positioners["left"].egu
if positioners["right"].egu != egu:
logger.warning(
f"Device {name} found inconsistency for egu for positioner left {left_slit} and right {right_slit}. Using egu {egu}."
)
self._offset = offset
super().__init__(
name=name, device_manager=device_manager, positioners=positioners, **kwargs
name=name, device_manager=device_manager, positioners=positioners, egu=egu, **kwargs
)
def _get_pos_motor(self, motor: PositionerBase) -> float:
"""
Helper method to get the position of a motor.
Args:
motor (PositionerBase): The motor to get the position of.
"""
return motor.read()[motor.name]["value"]
def wait_for_connection(self, *args, **kwargs):
"""Connect and initialize the read-only ``offset`` configuration signal."""
super().wait_for_connection(*args, **kwargs)
# Set the initial value of the offset signal
# Config values are read by back after wait_for_connection is called.
self.offset._readback = self._offset
# pylint: disable=arguments-differ
def forward_calculation(self, left: Signal, right: Signal) -> float:
"""
Forward calculation to compute the value for the pseudo motor readback
and setpoint based on the position of the left and right slit.
"""Compute slit width from left and right positions.
Args:
left (Signal): The left slit positioner signal.
right (Signal): The right slit positioner signal.
left (Signal): The signal representing the position of the left slit positioner.
right (Signal): The signal representing the position of the right slit positioner.
Returns:
float: The center position of the slit.
float: Width ``right - left + offset``.
"""
left_pos = left.get()
right_pos = right.get()
width = right_pos - left_pos
width = right_pos - left_pos + self._offset
return float(width)
def inverse_calculation(self, position: float, left: Signal, right: Signal) -> dict[str, float]:
"""
Inverse calculation to compute the position of the left and right slit based on the center position.
"""Compute left/right setpoints for a target width.
The current slit center is preserved.
Args:
position (float): The center position of the slit.
position (float): The desired width of the slit.
left (Signal): The signal representing the position of the left slit positioner.
right (Signal): The signal representing the position of the right slit positioner.
Returns:
dict[str, float]: The positions of the left and right slit.
A dictionary with the new setpoints for the left and right positioners, with keys "left" and "right".
"""
left_pos = self._get_pos_motor(left.root)
right_pos = self._get_pos_motor(right.root)
left_pos = left.get()
right_pos = right.get()
center = (left_pos + right_pos) / 2
width = right_pos - left_pos
width = position - self._offset
new_right_pos = center + width / 2
new_left_pos = center - width / 2
return {"left": new_left_pos, "right": new_right_pos}
def motors_are_moving(self, left: Signal, right: Signal) -> int:
"""
Calculate whether the motors are moving based on the motor_is_moving signal of the left and right slit.
"""Return 1 when either left or right motor is moving, else 0.
Args:
left (Signal): The left slit positioner signal.
right (Signal): The right slit positioner signal.
left (Signal): The signal representing the position of the left slit positioner.
right (Signal): The signal representing the position of the right slit positioner.
Returns:
int: 1 if either motor is moving, 0 otherwise.
"""
@@ -186,10 +255,14 @@ class VirtualSlitWidth(PSIPseudoMotorBase):
if __name__ == "__main__": # pragma: no cover
# pylint: disable=import-outside-toplevel, unused-import, missing-docstring, ungrouped-imports, arguments-differ, protected-access
from ophyd import Component as Cpt
from ophyd_devices.sim.sim_positioner import SimPositioner
###########
## Alternative approach for virtual slit center
###########
class TestPseudoMotor(PSIPseudoMotorBase):
motor_a = Cpt(SimPositioner, name="motor_a")
@@ -0,0 +1,13 @@
from ophyd_devices import PSIDeviceBase
from ophyd_devices.utils.bec_processed_signal import BECProcessedSignal
class PSIPseudoDeviceBase(PSIDeviceBase):
"""Base class for pseudo devices at PSI."""
def wait_for_connection(self, *args, **kwargs):
"""Wait for connection of the pseudo device has to be called manually on BECProcessedSignals"""
for walk in self.walk_signals():
if isinstance(walk.item, BECProcessedSignal):
walk.item.wait_for_connection(*args, **kwargs)
super().wait_for_connection(*args, **kwargs)
@@ -1,4 +1,9 @@
""" """
"""Base class for pseudo motors built from real positioner objects.
The class wires three :class:`BECProcessedSignal` instances (`readback`,
`setpoint`, `motor_is_moving`) to user-defined calculation methods and combines
child-motor move statuses into one pseudo-motor status.
"""
from __future__ import annotations
@@ -10,7 +15,7 @@ from ophyd import Component as Cpt
from ophyd import Kind, PositionerBase
from ophyd_devices.interfaces.base_classes.psi_device_base import PSIDeviceBase
from ophyd_devices.utils.bec_processed_signal import BECProcessedSignal, ProcessedSignalModel
from ophyd_devices.utils.bec_processed_signal import BECProcessedSignal
from ophyd_devices.utils.psi_device_base_utils import AndStatus, StatusBase
if TYPE_CHECKING: # pragma: no cover
@@ -18,21 +23,33 @@ if TYPE_CHECKING: # pragma: no cover
class PSIPseudoMotorBase(ABC, PSIDeviceBase, PositionerBase):
"""
Base class for PSI pseudo motors.
"""Abstract base class for pseudo-positioners.
Subclasses implement coordinate transforms via:
- ``forward_calculation`` for readback/setpoint projection
- ``inverse_calculation`` for pseudo-to-real target mapping
- ``motors_are_moving`` for movement aggregation
Please note that forward_calculation, inverse_calculation and motors_are_moving methods must be implemented with
the same signature as the keys for the associated positioner objects stored in the positioner_objects attribute
which is either passed to __init__ or set using the set_positioner_objects method. The positioner objects are expected
to be ophyd PositionerBase object like devices or at least implement a 'move' method and have attributes
'readback' or 'user_readback', 'setpoint' or 'user_setpoint', and 'motor_is_moving'.
Args:
name (str): The name of the pseudo motor.
device_manager (DeviceManagerDS): The device manager to use for connecting to the positioners.
positioners (dict[str, str]): A dictionary mapping positioner names to device names in the device manager.
Keys of this dictionary must match the arguments of the forward_calculation,
inverse_calculation and motors_are_moving methods. The values must be the names
of the devices in the device manager that correspond to the positioners.
name (str): The name of the pseudo motor device.
device_manager (DeviceManagerDS): The device manager instance to fetch the positioner objects from based on the configuration.
positioners (dict[str, PositionerBase] | None): A dictionary of positioner objects that this pseudo motor depends on. The keys should match the input parameters of the forward_calculation, inverse_calculation and motors_are_moving methods. If not provided during initialization, it can be set later using the set_positioner_objects method.
egu (str): Engineering units for the pseudo motor.
**kwargs: Additional keyword arguments to pass to the parent classes.
"""
readback = Cpt(BECProcessedSignal, name="readback", model={}, kind=Kind.hinted)
setpoint = Cpt(BECProcessedSignal, name="setpoint", model={}, kind=Kind.normal)
motor_is_moving = Cpt(BECProcessedSignal, name="motor_is_moving", model={}, kind=Kind.omitted)
readback = Cpt(BECProcessedSignal, name="readback", model_config=None, kind=Kind.hinted)
setpoint = Cpt(BECProcessedSignal, name="setpoint", model_config=None, kind=Kind.normal)
motor_is_moving = Cpt(
BECProcessedSignal, name="motor_is_moving", model_config=None, kind=Kind.omitted
)
def __init__(
self,
@@ -46,23 +63,23 @@ class PSIPseudoMotorBase(ABC, PSIDeviceBase, PositionerBase):
self._positioner_move_kwargs: dict[str, dict[str, Any]] = {}
self._egu = egu
super().__init__(name=name, device_manager=device_manager, **kwargs)
self.readback.name = self.name
@property
def egu(self) -> str:
"""Engineering units for the pseudo motor. This can be set during initialization or by setting the egu attribute."""
"""Engineering units for the pseudo motor."""
return self._egu
def set_positioner_objects(self, positioners: dict[str, PositionerBase]) -> None:
"""
Method to set the positioner objects after initialization. This can be used if the positioner objects are not available at the time of initialization.
"""Set the positioner objects for the pseudo motor.
Args:
positioners (dict[str, PositionerBase]): A dictionary mapping positioner names to positioner objects.
positioners (dict[str, PositionerBase]): A dictionary of positioner objects that this pseudo motor depends on.
"""
self.positioner_objects = positioners
def wait_for_connection(self, *args, **kwargs) -> None:
"""Connect to relevant positioners, setup processed signals."""
"""Validate signatures, wire processed signals, and connect dependencies."""
if not self.positioner_objects:
raise ConnectionError(
f"No positioners specified for pseudo motor {self.name}. Please use 'set_positioner_objects' or pass positioner objects during initialization."
@@ -70,21 +87,21 @@ class PSIPseudoMotorBase(ABC, PSIDeviceBase, PositionerBase):
# Check if all methods have the required signature that matches the positioner_objects keys
self._check_method_signatures()
self._setup_pseudo_signal(
"readback", ["readback", "user_readback"], self.forward_calculation, return_type=float
"readback", ["readback", "user_readback"], self.forward_calculation
)
self._setup_pseudo_signal(
"setpoint", ["setpoint", "user_setpoint"], self.forward_calculation, return_type=float
)
self._setup_pseudo_signal(
"motor_is_moving", ["motor_is_moving"], self.motors_are_moving, return_type=int
"setpoint", ["setpoint", "user_setpoint"], self.forward_calculation
)
self._setup_pseudo_signal("motor_is_moving", ["motor_is_moving"], self.motors_are_moving)
# Prepare move kwargs for each positioner based on their move method signature
for name, positioner in self.positioner_objects.items():
move_signature = inspect.signature(positioner.move)
if "wait" in move_signature.parameters:
self._positioner_move_kwargs[name] = {"wait": False}
return super().wait_for_connection(*args, **kwargs)
def _check_method_signatures(self) -> None:
"""
Method to check that the forward_calculation, inverse_calculation and motors_are_moving methods
have the required signature that matches the positioner_objects keys.
"""
"""Ensure calculation method parameters match configured positioner keys."""
input_names = set(self.positioner_objects.keys())
for method in [self.forward_calculation, self.inverse_calculation, self.motors_are_moving]:
signature = inspect.signature(method)
@@ -96,22 +113,15 @@ class PSIPseudoMotorBase(ABC, PSIDeviceBase, PositionerBase):
)
def _setup_pseudo_signal(
self,
pseudo_attr: str,
allowed_attributes: list[str],
compute_method: Callable[..., float],
return_type: type,
self, pseudo_attr: str, allowed_attributes: list[str], compute_method: Callable[..., float]
):
"""
Setup a pseudo signal with the given compute method and return type. The compute method will be called with
the values of the positioners as arguments. The allowed_attributes are used to determine which signal of the
positioner to use as input for the compute method. The first attribute that is found in the positioner will be used.
"""Configure one pseudo signal from selected positioner attributes.
Args:
pseudo_attr (str): The attribute of the pseudo motor to set the model for.
allowed_attributes (list[str]): The attributes of the positioner to look for as input for the compute method.
compute_method (Callable[..., float]): The method to compute the value of the pseudo signal.
return_type (type): The return type of the compute method.
pseudo_attr (str): The name of the pseudo attribute to set up.
allowed_attributes (list[str]): A list of allowed attributes for the positioner objects.
compute_method (Callable[..., float]): Function used to compute the
pseudo signal value.
"""
device_objects = {}
dotted_names = {}
@@ -130,26 +140,23 @@ class PSIPseudoMotorBase(ABC, PSIDeviceBase, PositionerBase):
dotted_names[name] = f"{device_name}.{obj.name}"
device_objects[name] = obj
model = ProcessedSignalModel(
devices=dotted_names, compute_method=compute_method, return_type=return_type
pseudo_attr_obj.set_compute_method(
compute_method, **{name: obj for name, obj in device_objects.items()}
)
pseudo_attr_obj.set_device_object(device_objects)
pseudo_attr_obj.set_model(model)
pseudo_attr_obj.wait_for_connection()
def get_positioner_objects(
self, name: str, positioners: dict[str, str], device_manager: DeviceManagerDS
) -> dict[str, PositionerBase]:
"""
Helper method to get the positioner objects from the device manager based on a positioner dictionary.
"""Resolve and validate positioner objects from device-manager names.
Args:
name (str): The name of the pseudo motor device to look for in the device manager config.
positioners (dict[str, str]): A dictionary mapping positioner names to device names in the device manager.
device_manager (DeviceManagerDS): The device manager to use for connecting to the positioners.
name (str): The name of the pseudo motor device.
positioners (dict[str, str]): A dictionary mapping positioner names to device names.
device_manager (DeviceManagerDS): The device manager instance to fetch the positioner objects from.
Returns:
dict[str, PositionerBase]: A dictionary mapping positioner names to positioner objects.
dict[str, PositionerBase]: A dictionary of positioner objects.
"""
positioner_objs = {}
# First we check that the device config of this device specifies
@@ -180,16 +187,22 @@ class PSIPseudoMotorBase(ABC, PSIDeviceBase, PositionerBase):
f"Device '{device_name}' must have at least one argument for each tuple in the following list of tuples: {required_attrs}."
)
positioner_objs[name] = device
move_signature = inspect.signature(device.move)
if "wait" in move_signature.parameters:
self._positioner_move_kwargs[name] = {"wait": False}
return positioner_objs
def _find_device_config_in_session(
self, device_name: str, device_manager: DeviceManagerDS
) -> dict[str, Any]:
"""
Helper method to find the device config for a given device name in the current session config of the device manager.
"""Find the session configuration entry for ``device_name``.
Args:
device_name (str): The name of the device to find the configuration for.
device_manager (DeviceManagerDS): The device manager instance to fetch the configuration from.
Returns:
dict[str, Any]: The configuration dictionary for the device.
Raises:
ConnectionError: If the device configuration is not found in the current session.
"""
configs = device_manager.current_session["devices"]
config = None
@@ -203,30 +216,44 @@ class PSIPseudoMotorBase(ABC, PSIDeviceBase, PositionerBase):
@abstractmethod
def forward_calculation(self, *args) -> float:
"""Calculate the pseudo motor value based on the positioner values."""
"""Compute pseudo value from positioner signals.
Method parameters must include all keys defined in
``self.positioner_objects``.
"""
@abstractmethod
def inverse_calculation(self, position: float, **positioner_objects) -> dict[str, float]:
"""Calculate the positioner values based on the pseudo motor value."""
"""Map a pseudo target position to child-motor setpoints.
The first argument is always the desired pseudo position.
"""
@abstractmethod
def motors_are_moving(self, *args) -> int:
"""Calculate whether the motors are moving based on the positioner values. Should be 0 or 1."""
"""Return a movement flag derived from child-motor motion signals."""
# pylint: disable=arguments-differ
def move(self, position: float, **kwargs) -> StatusBase:
"""
Move method for the pseudo motor. This currently only supports moving all positioners
at once. If children want to implement more complex move logic, they can override this method
based on this implementation and the inverse_calculation method.
"""Move child motors to realize a pseudo target position.
The method calls :meth:`inverse_calculation` with the current method
inputs of the ``readback`` processed signal, then moves each configured
child positioner and combines all returned statuses with
:class:`AndStatus`.
Args:
position (float): The position to move the pseudo motor to.
kwargs: The kwargs to pass to the move method of the positioners.
position (float): The desired position to move the pseudo motor to.
**kwargs: Additional keyword arguments to pass to the move method of the positioner objects.
Returns:
StatusBase: A combined status object that represents the status of all the move operations on the
positioner objects.
"""
self.check_value(position)
status = None
motor_positions = self.inverse_calculation(position, **self.positioner_objects)
motor_positions = self.inverse_calculation(
position, **self.readback.compute_model.method_inputs
)
for name, pos in motor_positions.items():
positioner = self.positioner_objects[name]
move_kwargs = self._positioner_move_kwargs.get(name, {})
+7
View File
@@ -155,6 +155,13 @@ class SimPositioner(Device, PositionerBase):
value=self.sim.sim_state[self.readback.name]["value"],
timestamp=self.sim.sim_state[self.readback.name]["timestamp"],
)
# Run subscription on "value"
self.readback._run_subs(
sub_type=self.readback.SUB_VALUE,
old_value=old_readback,
value=self.sim.sim_state[self.readback.name]["value"],
timestamp=self.sim.sim_state[self.readback.name]["timestamp"],
)
def _move_to_setpoint(self) -> None:
"""Move the simulated device to the setpoint."""
+94 -1
View File
@@ -7,11 +7,13 @@ from bec_lib import messages
from bec_lib.endpoints import MessageEndpoints
from bec_lib.logger import bec_logger
from ophyd import Component as Cpt
from ophyd import Device, DeviceStatus, Kind, OphydObject, PositionerBase, Staged
from ophyd import Device, DeviceStatus, Kind, OphydObject, PositionerBase, Signal, Staged
from ophyd_devices.interfaces.base_classes.psi_pseudo_motor_base import PSIPseudoMotorBase
from ophyd_devices.sim.sim_camera import SimCamera
from ophyd_devices.sim.sim_positioner import SimPositioner
from ophyd_devices.sim.sim_signals import SetableSignal
from ophyd_devices.utils.bec_processed_signal import BECProcessedSignal, ProcessedSignalModel
from ophyd_devices.utils.bec_signals import (
AsyncSignal,
DynamicSignal,
@@ -424,6 +426,97 @@ class SimCameraWithPSIComponents(SimCamera):
return status
class VirtualSlitCenter(PSIPseudoMotorBase):
"""
Alternative implementation of a Virtual Slit Center based on two sub-positioners.
The positioners are sub-devices called left_edge and right_edge, and are used
to calculate the center position of the slit.
Args:
name (str): The name of the pseudo motor device.
device_manager (DeviceManagerBase): The device manager instance to fetch the positioner devices from
"""
left_edge = Cpt(SimPositioner, name="left_edge", kind=Kind.normal)
right_edge = Cpt(SimPositioner, name="right_edge", kind=Kind.normal)
def __init__(self, name, device_manager, **kwargs):
super().__init__(name, device_manager, **kwargs)
positioners = {"left_edge": self.left_edge, "right_edge": self.right_edge}
self.set_positioner_objects(positioners)
def forward_calculation(self, left_edge: Signal, right_edge: Signal) -> float:
"""
Forward calculation to compute the center position of the slit based on the positions of the left and right edges.
Args:
left_edge (Signal): The signal representing the position of the left edge positioner.
right_edge (Signal): The signal representing the position of the right edge positioner.
"""
return float((left_edge.get() + right_edge.get()) / 2)
def inverse_calculation(
self, position, left_edge: Signal, right_edge: Signal
) -> dict[str, float]:
"""
Inverse calculation to compute the setpoints for the left and right edge positioners based on the desired center position.
Args:
position (float): The desired center position of the slit.
left_edge (Signal): The signal representing the position of the left edge positioner.
right_edge (Signal): The signal representing the position of the right edge positioner.
Returns:
dict[str, float]: A dictionary containing the setpoints for the left and right edge positioners.
"""
left_pos = left_edge.root.readback.get()
right_pos = right_edge.root.readback.get()
width = right_pos - left_pos
new_right_pos = position + width / 2
new_left_pos = position - width / 2
return {"left_edge": new_left_pos, "right_edge": new_right_pos}
def motors_are_moving(self, left_edge: Signal, right_edge: Signal) -> int:
"""
Check if either the left or right edge positioners are currently moving.
Args:
left_edge (Signal): The signal representing the position of the left edge positioner.
right_edge (Signal): The signal representing the position of the right edge positioner.
Returns:
int: 1 if either motor is moving, 0 otherwise.
"""
left_moving = left_edge.get()
right_moving = right_edge.get()
return int(left_moving or right_moving)
class BECPseudoSignal(BECProcessedSignal):
"""
Example of a pseudo signal implementation based on two signals available in the device manager.
The value of the signal is calculated based on the compute method.
Args:
name (str): The name of the pseudo signal.
signal_1 (str): The name of the first signal in the device manager.
signal_2 (str): The name of the second signal in the device manager.
device_manager (DeviceManagerBase): The device manager instance to fetch the signals from.
"""
def __init__(self, name, signal_1: str, signal_2: str, device_manager=None, **kwargs):
super().__init__(name, model_config=None, device_manager=device_manager, **kwargs)
signal_1 = self.get_device_object_from_bec(
object_name=signal_1, signal_name=self.name, device_manager=device_manager
)
signal_2 = self.get_device_object_from_bec(
object_name=signal_2, signal_name=self.name, device_manager=device_manager
)
self.set_compute_method(self.compute, signal_1=signal_1, signal_2=signal_2)
def compute(self, signal_1: Signal, signal_2: Signal) -> float:
"""Compute the value of the pseudo signal based on the values of signal_1 and signal_2."""
return float(signal_1.get() + signal_2.get())
if __name__ == "__main__":
cam = SimCameraWithPSIComponents(name="cam")
cam.read()
+206 -160
View File
@@ -1,222 +1,248 @@
"""Utilities for building computed read-only signals.
This module provides:
- :class:`ProcessedSignalModel`, which validates that ``method_inputs`` can be
passed to ``compute_method`` as keyword arguments.
- :class:`BECProcessedSignal`, a ``SignalRO`` subclass that subscribes to input
ophyd objects and recomputes its readback whenever an input updates.
Re-entrant callback execution is guarded to avoid recursive subscription loops.
"""
from __future__ import annotations
import inspect
from typing import TYPE_CHECKING, Any, Callable, Self, TypeAlias
import time
from typing import TYPE_CHECKING, Any, Callable, Literal, Self
import numpy as np
from ophyd import Device, Signal, SignalRO
from ophyd import Component, Device, Signal, SignalRO
from pydantic import BaseModel, model_validator
if TYPE_CHECKING: # pragma: no cover
from bec_server.device_server.devices.devicemanager import DeviceManagerDS
from bec_server.device_server.devices.devicemanager import DeviceManagerDS, DSDevice
ALLOWED_RETURN_TYPES = (float, int, str, np.ndarray)
ALLOWED_TYPE_ALIAS: TypeAlias = type[float] | type[int] | type[str] | type[np.ndarray]
ComputeMethod: TypeAlias = Callable[..., ALLOWED_TYPE_ALIAS]
def find_device_config_in_session(name: str, device_manager: DeviceManagerDS) -> dict[str, Any]:
"""Return the configuration entry of ``name`` from ``device_manager.current_session``.
The helper is used by lookups that resolve objects through the device
manager and need to validate the ``needs`` dependency list.
Args:
name: The name of the signal/device for which the config is being fetched.
device_manager: The device manager instance to fetch the current session config from.
"""
configs = device_manager.current_session["devices"]
config = None
for conf in configs:
if conf["name"] == name:
config = conf
break
if config is None:
raise ConnectionError(f"Device '{name}' not found in current session config.")
return config
class ProcessedSignalModel(BaseModel):
"""
Model for the BECProcessedSignal, which defines the devices/signals to subscribe to, and a
compute method that will be called with the device/signal objects as arguments. The method
`get_device_objects_from_device_manager` can be used to get the device/signal objects from the device manager
based on the devices dictionary.
"""Configuration model for :class:`BECProcessedSignal`.
The model stores arbitrary keyword inputs and a callable. Validation enforces
that ``compute_method(**method_inputs)`` is a compatible call.
Args:
signals (dict[str, DeviceSignalMapping]): A dictionary mapping signal names to their corresponding device
and attribute names. The keys of this dictionary must match the parameters of the compute_method.
compute_method (Callable): A callable that will be called with the device/signal objects as arguments.
The signature of this method must match the keys of the signals dictionary. The return value of
this method will be used as the value of the BECProcessedSignal. The compute method can also be a
coroutine function, in which case it will be awaited.
return_type (type): The expected return type of the compute method. This is used for validation and to ensure
that the value of the BECProcessedSignal is of the correct type. Must be one of the types specified
in ALLOWED_RETURN_TYPES.
method_inputs (dict[str, Any]): A dictionary mapping input names of the compute method to the
corresponding objects. They can be ophyd Signals, Devices, Components or any other additional argument that should be passed
to the compute method when called. The keys of this dictionary must match the signature of the compute method.
compute_method (Callable[..., Any]): A user-defined function that computes the value of the processed signal based on the input devices/signals.
Note:
Validation only checks call compatibility (missing/extra kwargs and
unsupported signature kinds). It does not enforce the runtime return
type of ``compute_method``.
"""
model_config = {"arbitrary_types_allowed": True}
devices: dict[str, str]
compute_method: ComputeMethod
return_type: ALLOWED_TYPE_ALIAS
method_inputs: dict[str, Any]
compute_method: Callable[..., Any]
@model_validator(mode="after")
def validate_signals_in_compute_method(self) -> Self:
"""Validator to check that the compute method signature contains all keys of the devices field."""
# Check that the compute method's signature contains all the keys in the devices dictionary.
input_names = set(self.devices.keys())
"""Validate compatibility of ``compute_method`` with ``method_inputs``."""
signature = inspect.signature(self.compute_method)
parameters = signature.parameters
input_names = set(self.method_inputs)
method_param_names = set(parameters.keys())
accepted_names = set()
required_names = set()
has_var_keyword = False
if method_param_names != input_names:
raise ValueError(
"The compute_method signature does not match with the devices keys provided. "
f"Expected parameters: {sorted(input_names)}, "
f"got: {sorted(method_param_names)}"
)
return self
def get_device_objects_from_device_manager(
self, device_manager: DeviceManagerDS, signal_name: str
) -> dict[str, Device | Signal]:
"""
Helper method to get the device/signal objects from the device manager based on the devices dictionary.
Args:
device_manager (DeviceManagerDS): The device manager to get the devices/signals from.
signal_name (str): The name of the ProcessedSignal
Returns:
dict[str, Device | Signal]: A dictionary mapping the device/signal names to their corresponding
ophyd Device or Signal objects.
"""
device_objs = {}
signal_config = self._find_device_config_in_session(signal_name, device_manager)
needs = signal_config.get("needs", [])
for name, dotted_name in self.devices.items():
dev_name = dotted_name.split(".")[0] # First part of the dotted name is device
if dev_name not in needs:
raise ConnectionError(
f"Device {dev_name} needs to be specified in the 'needs' field of the config for the current session"
f"for signal '{signal_name}' in order to fetch the device object with name {dotted_name}from the device manager."
for name, param in parameters.items():
if param.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.VAR_POSITIONAL):
raise ValueError(
"Compute method mus be compatible with compute_method(**method_inputs): "
f"unsupported parameter {name!r} ({param.kind.description})."
f"for compute method {self.compute_method.__name__!r} with signature {signature}."
)
# Attribute access resolves dotted name to fetch the correct signal/device object from the device manager
device = device_manager.devices[dotted_name]
# TODO: How to safeguard against using variables and not signals?
# if not isinstance(device, (Device, Signal)):
# raise ConnectionError(
# f"Device '{dotted_name}' does not point to a valid signal or device but to type {type(device)}."
# )
device_objs[name] = device
return device_objs
if param.kind is inspect.Parameter.VAR_KEYWORD:
has_var_keyword = True
continue
def _find_device_config_in_session(
self, name: str, device_manager: DeviceManagerDS
) -> dict[str, Any]:
"""
Helper method to find the device config for a given device name in the current session config of the device manager.
"""
configs = device_manager.current_session["devices"]
config = None
for conf in configs:
if conf["name"] == name:
config = conf
break
if config is None:
raise ConnectionError(f"Device '{name}' not found in current session config.")
return config
accepted_names.add(name)
if param.default is inspect.Parameter.empty:
required_names.add(name)
missing = required_names - input_names
extra = set() if has_var_keyword else input_names - accepted_names
if missing or extra:
problems = []
if missing:
problems.append(f"missing required inputs: {sorted(missing)}")
if extra:
problems.append(f"unexpected inputs: {sorted(extra)}")
raise ValueError("; ".join(problems))
return self
class BECProcessedSignal(SignalRO):
"""Read-only signal whose value is computed from other inputs.
A compute model can be provided at construction time via ``model_config`` or
later through :meth:`set_compute_method`. During
:meth:`wait_for_connection`, input ophyd objects are subscribed and the
current value is computed immediately.
Args:
name (str): The name of the signal.
model_config (dict[Literal["devices", "compute_method"], Any] | None):
Optional initialization payload for :meth:`set_compute_method`.
``devices`` is passed as keyword arguments to the compute method.
device_manager (DeviceManagerDS | None): Device manager used by helpers
that resolve objects from BEC names.
**kwargs: Additional keyword arguments passed to the SignalRO initializer.
"""
def __init__(
self,
name: str,
model: ProcessedSignalModel | dict[str, Any] | None = None,
model_config: dict[Literal["method_inputs", "compute_method"], Any] | None = None,
device_manager: DeviceManagerDS | None = None,
**kwargs,
):
super().__init__(name=name, **kwargs)
self._model = model or {}
self.compute_model: ProcessedSignalModel | None = None
self._device_manager: DeviceManagerDS = self._get_device_manager(device_manager)
self.device_objects: dict[str, Device | Signal] = {}
self._metadata["connected"] = False
self._callback_is_running = False
self._active_callbacks: set[str] = set()
if model_config:
self.set_compute_method(
compute_method=model_config["compute_method"], **model_config["devices"]
)
def set_model(self, model: ProcessedSignalModel | dict[str, Any]) -> None:
@staticmethod
def get_device_object_from_bec(
object_name: str, signal_name: str, device_manager: DeviceManagerDS
) -> Device | Signal:
"""Resolve one device/signal object from a BEC object name.
The method verifies that the resolved device is listed in the ``needs``
section of ``signal_name`` in the active session configuration.
"""
Method to set the model for the BECProcessedSignal. This method will be used to set the devices/signals to subscribe to,
and the compute method to use for processing the values from those devices/signals.
signal_config = find_device_config_in_session(signal_name, device_manager)
needs = signal_config.get("needs", [])
dev_name = object_name.split(".")[0] # First part of the dotted name is device
if dev_name not in needs:
raise ConnectionError(
f"Device {dev_name} needs to be specified in the 'needs' field of the config for the current session"
f"for signal '{signal_name}' in order to fetch the device object with name {object_name} from the device manager."
)
# Attribute access resolves dotted name to fetch the correct signal/device object from the device manager
# If this line crashes, there is likely an issue with the implementation of 'needs' in the device manager.
device = device_manager.devices[object_name]
return device
def set_compute_method(self, compute_method: Callable[..., Any], **kwargs) -> None:
"""Set or replace the compute method and its keyword inputs.
``kwargs`` may contain ophyd objects and/or plain values. All entries are
forwarded to the compute method as keyword arguments.
Args:
model (ProcessedSignalModel | dict): The model for the BECProcessedSignal. Can be either a ProcessedSignalModel instance or a dictionary that can be used to create a ProcessedSignalModel instance.
compute_method: Callable used to compute the readback value.
**kwargs: Keyword arguments forwarded to ``compute_method``.
"""
if isinstance(model, dict):
model = ProcessedSignalModel(**model)
self._model = model
# Lazy import DSDevice to avoid circular import issues
from bec_server.device_server.devices.devicemanager import DSDevice
def set_device_object(self, device_objects: dict[str, Device | Signal]) -> None:
"""
Method to set the device/signal objects for the BECProcessedSignal. This method will be used to provide the actual ophyd Device or Signal objects that correspond to the devices/signals specified in the model.
Args:
device_objects (dict[str, Device | Signal]): A dictionary mapping the device/signal names to their corresponding ophyd Device or Signal objects.
"""
self.device_objects = device_objects
method_inputs = {}
found_opd_objects = False
for kw, value in kwargs.items():
if isinstance(value, (Component, Device, Signal, DSDevice)):
found_opd_objects = True
method_inputs[kw] = value
if not found_opd_objects:
raise ValueError(
"At least one ophyd object (Component, Device, Signal, or DSDevice) must be provided as a keyword argument to set_compute_method."
)
self.compute_model = ProcessedSignalModel.model_validate(
{"method_inputs": method_inputs, "compute_method": compute_method}
)
def wait_for_connection(self, *args, **kwargs) -> None:
"""Connect to inputs and initialize computed readback.
Subscriptions are attached to every method input that exposes both
``wait_for_connection`` and ``subscribe``. Inputs without these methods
are treated as static keyword arguments.
"""
Wait for connection will try to connect to the device/signal objects specified in the
model. It will check the model for validity, and attempt to fetch the device_objects if
they are not already set. Once everything is set up, it subscribe to the default subscriptions
and compute an initial value while asserting that its return value is of the correct type.
"""
# Already connected, no need to do anything
if self._metadata.get("connected", False):
return
# I. Check that model is set, if not raise an error.
if self._model is None:
# Check that model is set, if not raise an error.
if self.compute_model is None:
raise ValueError(
f"No model provided for signal {self.name}. Please either provided model in init or use `set_model` before `wait_for_connection` is called."
f"No compute model provided for signal {self.name}. Please either provide a model_config in init or use `set_compute_method` before `wait_for_connection` is called."
)
# II. Check that device objects are set, if not raise an error.
if not self.device_objects:
# The method will also check that the device object names are specified in this signals config needs field.
device_objects = self._model.get_device_objects_from_device_manager(
self._device_manager, self.name
)
self.set_device_object(device_objects)
# Setup subscriptions to input devices/signals based on the model's configuration.
for input in self.compute_model.method_inputs.values():
# Lazy import DSDevice to avoid circular import issues
from bec_server.device_server.devices.devicemanager import DSDevice
# III. Check that device_objects dict contains all keys specified in the model's devices field.
model_keys = set(self._model.devices.keys())
device_object_keys = set(self.device_objects.keys())
if model_keys != device_object_keys:
raise ValueError(
f"The device_objects provided do not match with the devices specified in the model for signal {self.name}. "
f"Expected keys: {sorted(model_keys)}, got: {sorted(device_object_keys)}"
)
if not isinstance(input, (Component, Device, Signal, DSDevice)):
continue # Skip non-ophyd objects, they are additional arguments for the compute method
input.wait_for_connection(*args, **kwargs) # Ensure connected
# Now we connect to the default subscriptions of the provided signal/device objects.
for device_obj in self.device_objects.values():
device_obj.wait_for_connection(*args, **kwargs) # Ensure connected
device_obj.subscribe(
self._subscription_callback, event_type=device_obj._default_sub, run=False
)
input.subscribe(self._subscription_callback, event_type=input._default_sub, run=False)
# Run computation of the processed signal, this stores the value in _readback
self._subscription_callback(**self.device_objects)
if not isinstance(self._readback, self._model.return_type):
raise TypeError(
f"Computed value has type {type(self._readback)}, but expected {self._model.return_type}"
)
self._subscription_callback()
# Signal is connected
self._metadata["connected"] = True
def _subscription_callback(self, *args, **kwargs):
"""Callback method that is executed whenever any of the subscribed signal/device objects have an update."""
"""Recompute readback from ``compute_model`` and emit value subscriptions."""
if self._callback_is_running:
return # Callback is already running, skip to avoid multiple executions at the same time
try:
self._callback_is_running = True
old_value = self._readback
self._readback = self._model.compute_method(**self.device_objects)
timestamp = time.time()
self._metadata["timestamp"] = timestamp
self._readback = self.compute_model.compute_method(**self.compute_model.method_inputs)
self._run_subs(sub_type=self._default_sub, old_value=old_value, value=self._readback)
finally:
self._callback_is_running = False
def _run_subs(self, *args, sub_type, **kwargs):
"""
Run subs should not allow for recursive calls of the same subscription type to avoid recursion loops.
Args:
sub_type (str): The subscription type for which to run the callbacks.
"""
"""Prevent concurrent callbacks for the same subscription type."""
if sub_type in self._active_callbacks:
return
try:
@@ -227,14 +253,16 @@ class BECProcessedSignal(SignalRO):
self._active_callbacks.remove(sub_type)
def _get_device_manager(self, device_manager: DeviceManagerDS | None = None) -> DeviceManagerDS:
"""
Helper method to get the device manager instance.
"""Return the active device manager for this signal.
If ``device_manager`` is not provided, the method tries to read it from
``self.root.device_manager``.
Args:
device_manager (DeviceManagerDS | None): Optional device manager instance. If not provided, it will attempt to fetch from the root device.
device_manager (DeviceManagerDS | None): An optional device manager instance. If not provided, it will attempt to fetch
the device manager from the root device's `device_manager` attribute.
Returns:
DeviceManagerDS: The device manager instance.
DeviceManagerDS: The resolved device manager.
"""
if device_manager is None:
# PSIDeviceBase will have a reference to the device manager on device_manager attribute.
@@ -250,13 +278,33 @@ class BECProcessedSignal(SignalRO):
return device_manager
def describe(self):
"""Return ``describe`` metadata including compute model information."""
ret = super().describe()
ret[self.name]["device_objects"] = ", ".join(self._model.devices.values())
if self.compute_model is None:
ret[self.name]["method_inputs"] = ""
ret[self.name]["compute_method"] = ""
ret[self.name]["extra_kwargs"] = {}
return ret
ret[self.name]["method_inputs"] = ", ".join(
[
f"{obj.root.name}.{obj.dotted_name}"
for obj in self.compute_model.method_inputs.values()
if hasattr(obj, "dotted_name") and hasattr(obj, "root") # Ophyd obj
]
)
ret[self.name]["compute_method"] = self.compute_model.compute_method.__name__
ret[self.name]["extra_kwargs"] = {
kw: value
for kw, value in self.compute_model.method_inputs.items()
if not (hasattr(value, "dotted_name") and hasattr(value, "root")) # Ophyd obj
}
return ret
if __name__ == "__main__": # pragma: no cover
# pylint: disable=import-outside-toplevel, unused-import, missing-docstring, protected-access
from bec_server.device_server.tests.utils import DMMock
from ophyd_devices.sim.sim_positioner import SimPositioner
@@ -271,9 +319,8 @@ if __name__ == "__main__": # pragma: no cover
dm.devices._add_device("samx", samx)
dm.devices._add_device("samy", samy)
def compute_method(signal_1: Signal, signal_2: Signal) -> float:
"""Example compute method that adds two signals."""
return float(signal_1.get() + signal_2.get())
def compute_method(signal_1: Signal, signal_2: Signal, tmp: float = 2) -> float:
return float(signal_1.get() + signal_2.get()) + tmp
def _callback_print(value, **kwargs):
obj = kwargs.get("obj")
@@ -282,13 +329,12 @@ if __name__ == "__main__": # pragma: no cover
samx.readback.subscribe(_callback_print, run=False, event_type=samx.readback.SUB_VALUE)
samy.readback.subscribe(_callback_print, run=False, event_type=samy.readback.SUB_VALUE)
_model = ProcessedSignalModel(
devices={"signal_1": "samx.readback", "signal_2": "samy.readback"},
compute_method=compute_method,
return_type=float,
processed_signal = BECProcessedSignal(
name="processed_signal", model_config={}, device_manager=dm
)
processed_signal.set_compute_method(
compute_method, signal_1=samx.readback, signal_2=samy.readback, tmp=0
)
processed_signal = BECProcessedSignal(name="processed_signal", model=_model, device_manager=dm)
processed_signal.subscribe(_callback_print, run=False, event_type=processed_signal.SUB_VALUE)
dm.current_session = {}
dm.current_session["devices"] = [{"name": processed_signal.name, "needs": ["samx", "samy"]}]
+33 -2
View File
@@ -3,11 +3,14 @@ import time
from unittest import mock
import pytest
from ophyd import DeviceStatus, Staged
from bec_server.device_server.tests.utils import DMMock
from ophyd import Component as Cpt
from ophyd import DeviceStatus, Signal, Staged
from ophyd.utils.errors import RedundantStaging
from ophyd_devices.interfaces.base_classes.psi_device_base import PSIDeviceBase
from ophyd_devices.utils.errors import DeviceStopError, DeviceTimeoutError
from ophyd_devices.interfaces.base_classes.psi_pseudo_device_base import PSIPseudoDeviceBase
from ophyd_devices.utils.bec_processed_signal import BECProcessedSignal
@pytest.fixture
@@ -15,6 +18,34 @@ def detector_base():
yield PSIDeviceBase(name="test_detector")
class TestPSIPseudoDeviceBase(PSIPseudoDeviceBase):
"""Test class for PSIPseudoDeviceBase to test wait_for_connection method."""
signal = Cpt(BECProcessedSignal, name="signal", model_config=None)
tmp = Cpt(Signal, name="tmp")
def __init__(self, name, device_manager=None, **kwargs):
super().__init__(name=name, device_manager=device_manager, **kwargs)
self.signal.set_compute_method(self.compute, signal=self.tmp)
def compute(self, signal):
"""Compute method for the processed signal."""
return signal.get() * 2
@pytest.fixture
def pseudo_device_base():
dm = DMMock()
yield TestPSIPseudoDeviceBase(name="test_pseudo_device", device_manager=dm)
def test_psuedo_device_wait_for_connection(pseudo_device_base):
# The wait_for_connection method should connect to the BECProcessedSignal after connecting to the normal signals.
assert pseudo_device_base.signal._metadata["connected"] is False
pseudo_device_base.wait_for_connection()
assert pseudo_device_base.signal._metadata["connected"] is True
def test_detector_base_init(detector_base):
assert detector_base.stopped is False
assert detector_base.name == "test_detector"
+152
View File
@@ -0,0 +1,152 @@
"""Tests for processed-signal behavior and integration patterns."""
# pylint: disable=redefined-outer-name
import pytest
from bec_server.device_server.tests.utils import DMMock
from ophyd import Component as Cpt
from ophyd import Device
from ophyd_devices.sim.sim_positioner import SimPositioner
from ophyd_devices.utils.bec_processed_signal import BECProcessedSignal, ProcessedSignalModel
class TestProcessedSignalDevice(Device):
"""Fixture device with two sub-positioners and one processed signal."""
motor_a = Cpt(SimPositioner, name="motor_a", delay=0)
motor_b = Cpt(SimPositioner, name="motor_b", delay=0)
processed = Cpt(BECProcessedSignal, name="processed", model_config=None)
def __init__(self, name, device_manager, **kwargs):
self.device_manager = device_manager
super().__init__(name=name, **kwargs)
self.processed.set_compute_method(
self.compute, motor_a=self.motor_a.readback, motor_b=self.motor_b.readback, offset=0.5
)
@staticmethod
def compute(motor_a, motor_b, offset=0):
"""Compute processed value from two motor readbacks."""
return float(motor_a.get() + motor_b.get() + offset)
@pytest.fixture(name="device_manager")
def fixture_device_manager():
"""Mock device manager fixture."""
return DMMock()
@pytest.fixture(name="processed_device")
def fixture_processed_device(device_manager):
"""Fixture for TestProcessedSignalDevice."""
dev = TestProcessedSignalDevice(name="processed_dev", device_manager=device_manager)
dev.motor_a.wait_for_connection()
dev.motor_b.wait_for_connection()
dev.processed.wait_for_connection()
return dev
@pytest.fixture(name="samx")
def fixture_samx():
"""Standalone left motor fixture."""
return SimPositioner(name="samx", delay=0)
@pytest.fixture(name="samy")
def fixture_samy():
"""Standalone right motor fixture."""
return SimPositioner(name="samy", delay=0)
@pytest.fixture(name="device_manager_with_signals")
def fixture_device_manager_with_signals(samx, samy):
"""Device manager fixture with motor mapping and session needs."""
dm = DMMock()
dm.devices["samx"] = samx
dm.devices["samy"] = samy
dm.current_session = {"devices": [{"name": "processed_signal", "needs": ["samx", "samy"]}]}
return dm
@pytest.fixture(name="processed_signal_from_device_manager")
def fixture_processed_signal_from_device_manager(device_manager_with_signals):
"""Processed signal fixture using dotted-name resolution through the device manager."""
signal = BECProcessedSignal(name="processed_signal", device_manager=device_manager_with_signals)
signal_1 = BECProcessedSignal.get_device_object_from_bec(
"samx.readback", "processed_signal", device_manager_with_signals
)
signal_2 = BECProcessedSignal.get_device_object_from_bec(
"samy.readback", "processed_signal", device_manager_with_signals
)
signal.set_compute_method(
lambda signal_1, signal_2, offset=1.0: float(signal_1.get() + signal_2.get() + offset),
signal_1=signal_1,
signal_2=signal_2,
offset=1.0,
)
signal.wait_for_connection()
return signal
def test_processed_signal_with_sub_components(processed_device):
"""Test processed signal updates from sub-component motor readbacks."""
processed_device.motor_a.move(2).wait(timeout=2)
processed_device.motor_b.move(3).wait(timeout=2)
assert processed_device.processed.get() == pytest.approx(5.5)
processed_device.motor_a.move(-1).wait(timeout=2)
assert processed_device.processed.get() == pytest.approx(2.5)
def test_processed_signal_device_manager_resolution(
processed_signal_from_device_manager, samx, samy
):
"""Test processed signal using device-manager string key resolution."""
samx.move(1).wait(timeout=2)
samy.move(2).wait(timeout=2)
assert processed_signal_from_device_manager.get() == pytest.approx(4)
samy.move(5).wait(timeout=2)
assert processed_signal_from_device_manager.get() == pytest.approx(7)
def test_processed_signal_describe_metadata(processed_signal_from_device_manager):
"""Test describe contains compute method metadata and extra kwargs."""
info = processed_signal_from_device_manager.describe()["processed_signal"]
assert info["compute_method"] == "<lambda>"
assert info["extra_kwargs"] == {"offset": 1.0}
assert "samx.readback" in info["method_inputs"]
assert "samy.readback" in info["method_inputs"]
def test_processed_signal_model_rejects_missing_required_inputs():
"""Test compute model validation when required kwargs are missing."""
def compute(signal_1, signal_2):
return signal_1 + signal_2
with pytest.raises(ValueError, match="missing required inputs"):
ProcessedSignalModel.model_validate(
{"method_inputs": {"signal_1": 1}, "compute_method": compute}
)
def test_processed_signal_model_rejects_unexpected_inputs():
"""Test compute model validation when unknown kwargs are provided."""
def compute(signal_1):
return signal_1
with pytest.raises(ValueError, match="unexpected inputs"):
ProcessedSignalModel.model_validate(
{"method_inputs": {"signal_1": 1, "extra": 2}, "compute_method": compute}
)
def test_processed_signal_requires_compute_method(device_manager):
"""Test wait_for_connection fails when no compute model is configured."""
signal = BECProcessedSignal(name="no_model", device_manager=device_manager)
with pytest.raises(ValueError, match="No compute model provided"):
signal.wait_for_connection()
+261
View File
@@ -0,0 +1,261 @@
"""Module to test the virtual slit center and width classes."""
import pytest
from bec_server.device_server.tests.utils import DMMock
from ophyd import Component as Cpt
from ophyd_devices import TransitionStatus
from ophyd_devices.devices.virtual_slit import VirtualSlitCenter, VirtualSlitWidth
from ophyd_devices.interfaces.base_classes.psi_pseudo_motor_base import PSIPseudoMotorBase
from ophyd_devices.sim.sim_positioner import SimPositioner
class TestPseudoMotor(PSIPseudoMotorBase):
"""Pseudo motor fixture class with two sub-positioner components."""
motor_a = Cpt(SimPositioner, name="motor_a", delay=0)
motor_b = Cpt(SimPositioner, name="motor_b", delay=0)
def __init__(self, name, device_manager, **kwargs):
super().__init__(name=name, device_manager=device_manager, **kwargs)
self.set_positioner_objects({"a": self.motor_a, "b": self.motor_b})
def forward_calculation(self, a, b):
"""Forward calculation."""
return float(a.get() + b.get())
def inverse_calculation(self, position, a, b):
"""Inverse calculation."""
a_val = a.get()
b_val = position - a_val
return {"a": a_val, "b": b_val}
def motors_are_moving(self, a, b):
"""Check if the sub-positioners are moving."""
return int(a.get() or b.get())
@pytest.fixture
def pseudo_motor_fixture():
"""Fixture for the TestPseudoMotor with a mock device manager."""
dm = DMMock()
pseudo = TestPseudoMotor(name="test_pseudo", device_manager=dm)
pseudo.wait_for_connection()
return pseudo
@pytest.fixture
def samx():
"""Positioner fixture for the left slit motor."""
return SimPositioner(name="samx", delay=0)
@pytest.fixture
def samy():
"""Positioner fixture for the right slit motor."""
return SimPositioner(name="samy", delay=0)
@pytest.fixture
def device_manager_with_slit_motors(samx, samy):
"""Device manager fixture with slit motors and session configuration for virtual slits."""
dm = DMMock()
dm.devices["samx"] = samx
dm.devices["samy"] = samy
dm.current_session = {
"devices": [
{"name": "slit_center", "needs": ["samx", "samy"]},
{"name": "slit_width", "needs": ["samx", "samy"]},
]
}
return dm
@pytest.fixture
def slit_center(device_manager_with_slit_motors):
"""Slit Center fixture"""
center = VirtualSlitCenter(
name="slit_center",
left_slit="samx",
right_slit="samy",
device_manager=device_manager_with_slit_motors,
)
center.wait_for_connection()
return center
@pytest.fixture
def slit_width(device_manager_with_slit_motors):
"""Slit Width fixture"""
width = VirtualSlitWidth(
name="slit_width",
left_slit="samx",
right_slit="samy",
device_manager=device_manager_with_slit_motors,
)
width.wait_for_connection()
return width
def test_subcomponent_pseudo_motor_move(pseudo_motor_fixture):
"""
Test that moving the pseudo motor correctly updates the positions of the sub-positioners
and that the readback of the pseudo motor reflects the new position.
"""
pseudo_motor_fixture.motor_a.move(5).wait(timeout=2)
pseudo_motor_fixture.motor_b.move(-5).wait(timeout=2)
status = pseudo_motor_fixture.move(2)
status.wait(timeout=2)
assert pseudo_motor_fixture.motor_a.readback.get() == pytest.approx(5)
assert pseudo_motor_fixture.motor_b.readback.get() == pytest.approx(-3)
assert pseudo_motor_fixture.readback.get() == pytest.approx(2)
def test_virtual_slit_center_forward_calculation(slit_center, samx, samy):
"""
Test that the forward calculation for the virtual slit center correctly computes
the center position based on the current positions of the left and right slit motors.
"""
samx.move(1).wait(timeout=2)
samy.move(3).wait(timeout=2)
center = slit_center.forward_calculation(samx.readback, samy.readback)
assert center == pytest.approx(2)
def test_virtual_slit_center_inverse_calculation(slit_center, samx, samy):
"""
Test that the inverse calculation for the virtual slit center correctly
computes the positions of the left and right slit motors based on the
desired center position and current positions of the slit motors.
"""
status = TransitionStatus(slit_center.motor_is_moving, transitions=[0, 1, 0])
st_samx = samx.move(1)
st_samy = samy.move(3)
status.wait(timeout=2)
assert status.done is True, "Expected the slit center to start moving the slit motors."
assert status.success is True, "Expected the slit center to move the slit motors."
assert st_samy.done is True, "Expected the slit center to finish moving the right slit motor."
assert st_samy.success is True, "Expected the slit center to move the right slit motor."
assert st_samx.done is True, "Expected the slit center to finish moving the left slit motor."
assert st_samx.success is True, "Expected the slit center to move the left slit motor."
pos = slit_center.inverse_calculation(4, samx.readback, samy.readback)
assert pos["left"] == pytest.approx(3)
assert pos["right"] == pytest.approx(5)
def test_virtual_slit_width_forward_calculation(slit_width, samx, samy):
"""
Test that the forward calculation for the virtual slit width correctly computes
the width based on the current positions of the left and right slit motors.
"""
samx.move(1).wait(timeout=2)
samy.move(3).wait(timeout=2)
width = slit_width.forward_calculation(samx.readback, samy.readback)
assert width == pytest.approx(2)
def test_virtual_slit_width_inverse_calculation(slit_width, samx, samy):
"""
Test that the inverse calculation for the virtual slit width correctly
computes the positions of the left and right slit motors based on the
desired width and current positions of the slit motors.
"""
samx.move(1).wait(timeout=2)
samy.move(3).wait(timeout=2)
pos = slit_width.inverse_calculation(6, samx.readback, samy.readback)
assert pos["left"] == pytest.approx(-1)
assert pos["right"] == pytest.approx(5)
def test_virtual_slit_center_move(slit_center, samx, samy):
"""
Test that moving the virtual slit center correctly updates the
positions of the left and right slit motors.
"""
samx.move(1).wait(timeout=2)
samy.move(3).wait(timeout=2)
assert slit_center.readback.get() == pytest.approx(2)
status = slit_center.move(5)
status.wait(timeout=2)
assert samx.readback.get() == pytest.approx(4)
assert samy.readback.get() == pytest.approx(6)
assert slit_center.readback.get() == pytest.approx(5)
def test_virtual_slit_width_move(slit_width, samx, samy):
"""
Test that moving the virtual slit width correctly updates the
positions of the left and right slit motors.
"""
samx.move(1).wait(timeout=2)
samy.move(3).wait(timeout=2)
# EGU should be taken from the left slit motor.
assert slit_width.egu == samx.egu
assert slit_width.egu == samx.egu
status = slit_width.move(6)
status.wait(timeout=2)
assert samx.readback.get() == pytest.approx(-1)
assert samy.readback.get() == pytest.approx(5)
assert slit_width.readback.get() == pytest.approx(6)
def test_virtual_slit_offset_applied_in_forward_and_inverse(
device_manager_with_slit_motors, samx, samy
):
"""
Test that the offset is correctly applied in both forward and
inverse calculations for the virtual slit center and width.
"""
samx.move(1).wait(timeout=2)
samy.move(3).wait(timeout=2)
slit_center_offset = VirtualSlitCenter(
name="slit_center",
left_slit="samx",
right_slit="samy",
device_manager=device_manager_with_slit_motors,
offset=0.5,
egu="new_egu",
)
slit_width_offset = VirtualSlitWidth(
name="slit_width",
left_slit="samx",
right_slit="samy",
device_manager=device_manager_with_slit_motors,
offset=0.5,
egu="new_egu",
)
assert slit_width_offset.egu == "new_egu"
assert slit_center_offset.egu == "new_egu"
slit_center_offset.wait_for_connection()
slit_width_offset.wait_for_connection()
assert slit_center_offset.forward_calculation(samx.readback, samy.readback) == pytest.approx(
2.5
)
assert slit_width_offset.forward_calculation(samx.readback, samy.readback) == pytest.approx(2.5)
center_pos = slit_center_offset.inverse_calculation(5.5, samx.readback, samy.readback)
assert center_pos["left"] == pytest.approx(4)
assert center_pos["right"] == pytest.approx(6)
width_pos = slit_width_offset.inverse_calculation(6.5, samx.readback, samy.readback)
assert width_pos["left"] == pytest.approx(-1)
assert width_pos["right"] == pytest.approx(5)
# Check offset signal
assert slit_center_offset.offset.get() == pytest.approx(0.5)
assert slit_width_offset.offset.get() == pytest.approx(0.5)