diff --git a/ophyd_devices/devices/virtual_slit.py b/ophyd_devices/devices/virtual_slit.py index 857d64d..a19c421 100644 --- a/ophyd_devices/devices/virtual_slit.py +++ b/ophyd_devices/devices/virtual_slit.py @@ -1,9 +1,19 @@ -"""Module for virtual slit center implementation.""" +"""Pseudo-motor implementations for slit center and slit width. + +Both devices map one pseudo axis onto two real motors (left and right slit +edges). They can be instantiated from names resolved through the BEC device +manager and support an optional offset term in their coordinate transforms. +""" from __future__ import annotations from typing import TYPE_CHECKING +from bec_lib.logger import bec_logger +from ophyd import Component as Cpt +from ophyd import Kind +from ophyd.signal import SignalRO + from ophyd_devices.interfaces.base_classes.psi_pseudo_motor_base import PSIPseudoMotorBase if TYPE_CHECKING: # pragma: no cover @@ -11,19 +21,37 @@ if TYPE_CHECKING: # pragma: no cover from ophyd import PositionerBase, Signal +logger = bec_logger.logger + + class VirtualSlitCenter(PSIPseudoMotorBase): - """ - Virtual slit center implementation. It expects the left and right slit positioner names to be passed - as arguments. The named devices must be positioners and available in the device_manager. In addition, - it must have a readback (user_readback), setpoint (user_setpoint) and motor_is_moving signal. + """Pseudo motor controlling slit center from two edge motors. + + Both positioners must be present in the device manager, and the pseudo + motor entry in the current session config must declare them in ``needs``. + + The forward calculation computes the center position based on the positions of the left and right positioners, + while the inverse calculation computes the setpoints for the left and right positioners based on a desired center + position. The motors_are_moving method checks if either of the positioners is currently moving. Args: - name (str): The name of the virtual slit center. - left_slit (str): The name of the left slit positioner in the device manager. - right_slit (str): The name of the right slit positioner in the device manager. - device_manager (DeviceManagerBase): The device manager to use for connecting to the positioners. + name (str): The name of the pseudo motor device. + left_slit (str): The name of the left slit positioner device in the device manager. + right_slit (str): The name of the right slit positioner device in the device manager. + device_manager (DeviceManagerBase): The device manager instance to fetch the positioner devices from. + offset (float, optional): Constant center offset added in forward + calculation and removed in inverse calculation. + egu (str | None, optional): Engineering units. If omitted, units are + taken from the left positioner. """ + offset = Cpt( + SignalRO, + name="offset", + kind=Kind.config, + doc="Offset applied to the position of the slit center when calculating the width.", + ) + def __init__( self, name: str, @@ -31,6 +59,7 @@ class VirtualSlitCenter(PSIPseudoMotorBase): right_slit: str, device_manager: DeviceManagerBase, offset: float = 0, + egu: str | None = None, **kwargs, ): positioners = self.get_positioner_objects( @@ -38,32 +67,44 @@ class VirtualSlitCenter(PSIPseudoMotorBase): positioners={"left": left_slit, "right": right_slit}, device_manager=device_manager, ) + if egu is None: # if not specified, fetch it from the left positioner + egu = positioners["left"].egu + if positioners["right"].egu != egu: + logger.warning( + f"Device {name} found inconsistency for egu for positioner left {left_slit} and right {right_slit}. Using egu {egu}." + ) self._offset = offset super().__init__( - name=name, device_manager=device_manager, positioners=positioners, **kwargs + name=name, device_manager=device_manager, positioners=positioners, egu=egu, **kwargs ) + def wait_for_connection(self, *args, **kwargs): + """Connect and initialize the read-only ``offset`` configuration signal.""" + super().wait_for_connection(*args, **kwargs) + # Set the initial value of the offset signal + # Config values are read by back after wait_for_connection is called. + self.offset._readback = self._offset + def _get_pos_motor(self, motor: PositionerBase) -> float: - """ - Helper method to get the position of a motor. + """Return the current position read from ``motor``. Args: - motor (PositionerBase): The motor to get the position of. + motor (PositionerBase): The positioner motor to read the position from. + Returns: + float: Current motor position. """ return motor.read()[motor.name]["value"] # pylint: disable=arguments-differ def forward_calculation(self, left: Signal, right: Signal) -> float: - """ - Forward calculation to compute the value for the pseudo motor readback - and setpoint based on the position of the left and right slit. + """Compute slit center from left and right positions. Args: - left (Signal): The left slit positioner signal. - right (Signal): The right slit positioner signal. + left (Signal): The signal representing the position of the left slit positioner. + right (Signal): The signal representing the position of the right slit positioner. Returns: - float: The center position of the slit. + float: Center position ``(left + right) / 2 + offset``. """ left_pos = left.get() right_pos = right.get() @@ -71,33 +112,33 @@ class VirtualSlitCenter(PSIPseudoMotorBase): return float(center) def inverse_calculation(self, position: float, left: Signal, right: Signal) -> dict[str, float]: - """ - Inverse calculation to compute the position of the left and right slit based on the center position. + """Compute left/right setpoints for a target center. + + The current slit width is preserved. Args: - center (float): The center position of the slit. - left (Signal): The left slit positioner signal. - right (Signal): The right slit positioner signal. - + position (float): The desired center position of the slit. + left (Signal): The signal representing the position of the left slit positioner. + right (Signal): The signal representing the position of the right slit positioner. Returns: - dict[str, float]: The positions of the left and right slit. + A dictionary with the new setpoints for the left and right positioners, with keys "left" and "right". """ position_with_offset = position - self._offset # To access position, run read on the root (PositionerBase) of the signal - left_pos = self._get_pos_motor(left.root) - right_pos = self._get_pos_motor(right.root) + left_pos = left.get() + right_pos = right.get() width = right_pos - left_pos new_left_pos = position_with_offset - width / 2 new_right_pos = position_with_offset + width / 2 return {"left": new_left_pos, "right": new_right_pos} def motors_are_moving(self, left: Signal, right: Signal) -> int: - """ - Calculate whether the motors are moving based on the motor_is_moving signal of the left and right slit. + """Return 1 when either left or right motor is moving, else 0. Args: - left (Signal): The left slit positioner signal. - right (Signal): The right slit positioner signal. + left (Signal): The signal representing the position of the left slit positioner. + right (Signal): The signal representing the position of the right slit positioner. + Returns: int: 1 if either motor is moving, 0 otherwise. """ @@ -107,6 +148,28 @@ class VirtualSlitCenter(PSIPseudoMotorBase): class VirtualSlitWidth(PSIPseudoMotorBase): + """Pseudo motor controlling slit width from two edge motors. + + Both positioners must be present in the device manager, and the pseudo + motor entry in the current session config must declare them in ``needs``. + + Args: + name (str): The name of the pseudo motor device. + left_slit (str): The name of the left slit positioner device in the device manager. + right_slit (str): The name of the right slit positioner device in the device manager. + device_manager (DeviceManagerBase): The device manager instance to fetch the positioner devices from. + offset (float, optional): Constant width offset added in forward + calculation and removed in inverse calculation. + egu (str | None, optional): Engineering units. If omitted, units are + taken from the left positioner. + """ + + offset = Cpt( + SignalRO, + name="offset", + kind=Kind.config, + doc="Offset applied to the position of the slit center when calculating the width.", + ) def __init__( self, @@ -114,6 +177,8 @@ class VirtualSlitWidth(PSIPseudoMotorBase): left_slit: str, right_slit: str, device_manager: DeviceManagerBase, + offset: float = 0, + egu: str | None = None, **kwargs, ): positioners = self.get_positioner_objects( @@ -121,62 +186,66 @@ class VirtualSlitWidth(PSIPseudoMotorBase): positioners={"left": left_slit, "right": right_slit}, device_manager=device_manager, ) + if egu is None: # if not specified, fetch it from the left positioner + egu = positioners["left"].egu + if positioners["right"].egu != egu: + logger.warning( + f"Device {name} found inconsistency for egu for positioner left {left_slit} and right {right_slit}. Using egu {egu}." + ) + self._offset = offset super().__init__( - name=name, device_manager=device_manager, positioners=positioners, **kwargs + name=name, device_manager=device_manager, positioners=positioners, egu=egu, **kwargs ) - def _get_pos_motor(self, motor: PositionerBase) -> float: - """ - Helper method to get the position of a motor. - - Args: - motor (PositionerBase): The motor to get the position of. - """ - return motor.read()[motor.name]["value"] + def wait_for_connection(self, *args, **kwargs): + """Connect and initialize the read-only ``offset`` configuration signal.""" + super().wait_for_connection(*args, **kwargs) + # Set the initial value of the offset signal + # Config values are read by back after wait_for_connection is called. + self.offset._readback = self._offset # pylint: disable=arguments-differ def forward_calculation(self, left: Signal, right: Signal) -> float: - """ - Forward calculation to compute the value for the pseudo motor readback - and setpoint based on the position of the left and right slit. + """Compute slit width from left and right positions. Args: - left (Signal): The left slit positioner signal. - right (Signal): The right slit positioner signal. + left (Signal): The signal representing the position of the left slit positioner. + right (Signal): The signal representing the position of the right slit positioner. Returns: - float: The center position of the slit. + float: Width ``right - left + offset``. """ left_pos = left.get() right_pos = right.get() - width = right_pos - left_pos + width = right_pos - left_pos + self._offset return float(width) def inverse_calculation(self, position: float, left: Signal, right: Signal) -> dict[str, float]: - """ - Inverse calculation to compute the position of the left and right slit based on the center position. + """Compute left/right setpoints for a target width. + + The current slit center is preserved. Args: - position (float): The center position of the slit. - + position (float): The desired width of the slit. + left (Signal): The signal representing the position of the left slit positioner. + right (Signal): The signal representing the position of the right slit positioner. Returns: - dict[str, float]: The positions of the left and right slit. + A dictionary with the new setpoints for the left and right positioners, with keys "left" and "right". """ - left_pos = self._get_pos_motor(left.root) - right_pos = self._get_pos_motor(right.root) + left_pos = left.get() + right_pos = right.get() center = (left_pos + right_pos) / 2 - width = right_pos - left_pos + width = position - self._offset new_right_pos = center + width / 2 new_left_pos = center - width / 2 return {"left": new_left_pos, "right": new_right_pos} def motors_are_moving(self, left: Signal, right: Signal) -> int: - """ - Calculate whether the motors are moving based on the motor_is_moving signal of the left and right slit. + """Return 1 when either left or right motor is moving, else 0. Args: - left (Signal): The left slit positioner signal. - right (Signal): The right slit positioner signal. + left (Signal): The signal representing the position of the left slit positioner. + right (Signal): The signal representing the position of the right slit positioner. Returns: int: 1 if either motor is moving, 0 otherwise. """ @@ -186,10 +255,14 @@ class VirtualSlitWidth(PSIPseudoMotorBase): if __name__ == "__main__": # pragma: no cover + # pylint: disable=import-outside-toplevel, unused-import, missing-docstring, ungrouped-imports, arguments-differ, protected-access from ophyd import Component as Cpt from ophyd_devices.sim.sim_positioner import SimPositioner + ########### + ## Alternative approach for virtual slit center + ########### class TestPseudoMotor(PSIPseudoMotorBase): motor_a = Cpt(SimPositioner, name="motor_a") diff --git a/ophyd_devices/interfaces/base_classes/psi_pseudo_device_base.py b/ophyd_devices/interfaces/base_classes/psi_pseudo_device_base.py new file mode 100644 index 0000000..4b5409f --- /dev/null +++ b/ophyd_devices/interfaces/base_classes/psi_pseudo_device_base.py @@ -0,0 +1,13 @@ +from ophyd_devices import PSIDeviceBase +from ophyd_devices.utils.bec_processed_signal import BECProcessedSignal + + +class PSIPseudoDeviceBase(PSIDeviceBase): + """Base class for pseudo devices at PSI.""" + + def wait_for_connection(self, *args, **kwargs): + """Wait for connection of the pseudo device has to be called manually on BECProcessedSignals""" + for walk in self.walk_signals(): + if isinstance(walk.item, BECProcessedSignal): + walk.item.wait_for_connection(*args, **kwargs) + super().wait_for_connection(*args, **kwargs) diff --git a/ophyd_devices/interfaces/base_classes/psi_pseudo_motor_base.py b/ophyd_devices/interfaces/base_classes/psi_pseudo_motor_base.py index dc37810..52fc3f7 100644 --- a/ophyd_devices/interfaces/base_classes/psi_pseudo_motor_base.py +++ b/ophyd_devices/interfaces/base_classes/psi_pseudo_motor_base.py @@ -1,4 +1,9 @@ -""" """ +"""Base class for pseudo motors built from real positioner objects. + +The class wires three :class:`BECProcessedSignal` instances (`readback`, +`setpoint`, `motor_is_moving`) to user-defined calculation methods and combines +child-motor move statuses into one pseudo-motor status. +""" from __future__ import annotations @@ -10,7 +15,7 @@ from ophyd import Component as Cpt from ophyd import Kind, PositionerBase from ophyd_devices.interfaces.base_classes.psi_device_base import PSIDeviceBase -from ophyd_devices.utils.bec_processed_signal import BECProcessedSignal, ProcessedSignalModel +from ophyd_devices.utils.bec_processed_signal import BECProcessedSignal from ophyd_devices.utils.psi_device_base_utils import AndStatus, StatusBase if TYPE_CHECKING: # pragma: no cover @@ -18,21 +23,33 @@ if TYPE_CHECKING: # pragma: no cover class PSIPseudoMotorBase(ABC, PSIDeviceBase, PositionerBase): - """ - Base class for PSI pseudo motors. + """Abstract base class for pseudo-positioners. + + Subclasses implement coordinate transforms via: + + - ``forward_calculation`` for readback/setpoint projection + - ``inverse_calculation`` for pseudo-to-real target mapping + - ``motors_are_moving`` for movement aggregation + + Please note that forward_calculation, inverse_calculation and motors_are_moving methods must be implemented with + the same signature as the keys for the associated positioner objects stored in the positioner_objects attribute + which is either passed to __init__ or set using the set_positioner_objects method. The positioner objects are expected + to be ophyd PositionerBase object like devices or at least implement a 'move' method and have attributes + 'readback' or 'user_readback', 'setpoint' or 'user_setpoint', and 'motor_is_moving'. Args: - name (str): The name of the pseudo motor. - device_manager (DeviceManagerDS): The device manager to use for connecting to the positioners. - positioners (dict[str, str]): A dictionary mapping positioner names to device names in the device manager. - Keys of this dictionary must match the arguments of the forward_calculation, - inverse_calculation and motors_are_moving methods. The values must be the names - of the devices in the device manager that correspond to the positioners. + name (str): The name of the pseudo motor device. + device_manager (DeviceManagerDS): The device manager instance to fetch the positioner objects from based on the configuration. + positioners (dict[str, PositionerBase] | None): A dictionary of positioner objects that this pseudo motor depends on. The keys should match the input parameters of the forward_calculation, inverse_calculation and motors_are_moving methods. If not provided during initialization, it can be set later using the set_positioner_objects method. + egu (str): Engineering units for the pseudo motor. + **kwargs: Additional keyword arguments to pass to the parent classes. """ - readback = Cpt(BECProcessedSignal, name="readback", model={}, kind=Kind.hinted) - setpoint = Cpt(BECProcessedSignal, name="setpoint", model={}, kind=Kind.normal) - motor_is_moving = Cpt(BECProcessedSignal, name="motor_is_moving", model={}, kind=Kind.omitted) + readback = Cpt(BECProcessedSignal, name="readback", model_config=None, kind=Kind.hinted) + setpoint = Cpt(BECProcessedSignal, name="setpoint", model_config=None, kind=Kind.normal) + motor_is_moving = Cpt( + BECProcessedSignal, name="motor_is_moving", model_config=None, kind=Kind.omitted + ) def __init__( self, @@ -46,23 +63,23 @@ class PSIPseudoMotorBase(ABC, PSIDeviceBase, PositionerBase): self._positioner_move_kwargs: dict[str, dict[str, Any]] = {} self._egu = egu super().__init__(name=name, device_manager=device_manager, **kwargs) + self.readback.name = self.name @property def egu(self) -> str: - """Engineering units for the pseudo motor. This can be set during initialization or by setting the egu attribute.""" + """Engineering units for the pseudo motor.""" return self._egu def set_positioner_objects(self, positioners: dict[str, PositionerBase]) -> None: - """ - Method to set the positioner objects after initialization. This can be used if the positioner objects are not available at the time of initialization. + """Set the positioner objects for the pseudo motor. Args: - positioners (dict[str, PositionerBase]): A dictionary mapping positioner names to positioner objects. + positioners (dict[str, PositionerBase]): A dictionary of positioner objects that this pseudo motor depends on. """ self.positioner_objects = positioners def wait_for_connection(self, *args, **kwargs) -> None: - """Connect to relevant positioners, setup processed signals.""" + """Validate signatures, wire processed signals, and connect dependencies.""" if not self.positioner_objects: raise ConnectionError( f"No positioners specified for pseudo motor {self.name}. Please use 'set_positioner_objects' or pass positioner objects during initialization." @@ -70,21 +87,21 @@ class PSIPseudoMotorBase(ABC, PSIDeviceBase, PositionerBase): # Check if all methods have the required signature that matches the positioner_objects keys self._check_method_signatures() self._setup_pseudo_signal( - "readback", ["readback", "user_readback"], self.forward_calculation, return_type=float + "readback", ["readback", "user_readback"], self.forward_calculation ) self._setup_pseudo_signal( - "setpoint", ["setpoint", "user_setpoint"], self.forward_calculation, return_type=float - ) - self._setup_pseudo_signal( - "motor_is_moving", ["motor_is_moving"], self.motors_are_moving, return_type=int + "setpoint", ["setpoint", "user_setpoint"], self.forward_calculation ) + self._setup_pseudo_signal("motor_is_moving", ["motor_is_moving"], self.motors_are_moving) + # Prepare move kwargs for each positioner based on their move method signature + for name, positioner in self.positioner_objects.items(): + move_signature = inspect.signature(positioner.move) + if "wait" in move_signature.parameters: + self._positioner_move_kwargs[name] = {"wait": False} return super().wait_for_connection(*args, **kwargs) def _check_method_signatures(self) -> None: - """ - Method to check that the forward_calculation, inverse_calculation and motors_are_moving methods - have the required signature that matches the positioner_objects keys. - """ + """Ensure calculation method parameters match configured positioner keys.""" input_names = set(self.positioner_objects.keys()) for method in [self.forward_calculation, self.inverse_calculation, self.motors_are_moving]: signature = inspect.signature(method) @@ -96,22 +113,15 @@ class PSIPseudoMotorBase(ABC, PSIDeviceBase, PositionerBase): ) def _setup_pseudo_signal( - self, - pseudo_attr: str, - allowed_attributes: list[str], - compute_method: Callable[..., float], - return_type: type, + self, pseudo_attr: str, allowed_attributes: list[str], compute_method: Callable[..., float] ): - """ - Setup a pseudo signal with the given compute method and return type. The compute method will be called with - the values of the positioners as arguments. The allowed_attributes are used to determine which signal of the - positioner to use as input for the compute method. The first attribute that is found in the positioner will be used. + """Configure one pseudo signal from selected positioner attributes. Args: - pseudo_attr (str): The attribute of the pseudo motor to set the model for. - allowed_attributes (list[str]): The attributes of the positioner to look for as input for the compute method. - compute_method (Callable[..., float]): The method to compute the value of the pseudo signal. - return_type (type): The return type of the compute method. + pseudo_attr (str): The name of the pseudo attribute to set up. + allowed_attributes (list[str]): A list of allowed attributes for the positioner objects. + compute_method (Callable[..., float]): Function used to compute the + pseudo signal value. """ device_objects = {} dotted_names = {} @@ -130,26 +140,23 @@ class PSIPseudoMotorBase(ABC, PSIDeviceBase, PositionerBase): dotted_names[name] = f"{device_name}.{obj.name}" device_objects[name] = obj - model = ProcessedSignalModel( - devices=dotted_names, compute_method=compute_method, return_type=return_type + pseudo_attr_obj.set_compute_method( + compute_method, **{name: obj for name, obj in device_objects.items()} ) - pseudo_attr_obj.set_device_object(device_objects) - pseudo_attr_obj.set_model(model) pseudo_attr_obj.wait_for_connection() def get_positioner_objects( self, name: str, positioners: dict[str, str], device_manager: DeviceManagerDS ) -> dict[str, PositionerBase]: - """ - Helper method to get the positioner objects from the device manager based on a positioner dictionary. + """Resolve and validate positioner objects from device-manager names. Args: - name (str): The name of the pseudo motor device to look for in the device manager config. - positioners (dict[str, str]): A dictionary mapping positioner names to device names in the device manager. - device_manager (DeviceManagerDS): The device manager to use for connecting to the positioners. + name (str): The name of the pseudo motor device. + positioners (dict[str, str]): A dictionary mapping positioner names to device names. + device_manager (DeviceManagerDS): The device manager instance to fetch the positioner objects from. Returns: - dict[str, PositionerBase]: A dictionary mapping positioner names to positioner objects. + dict[str, PositionerBase]: A dictionary of positioner objects. """ positioner_objs = {} # First we check that the device config of this device specifies @@ -180,16 +187,22 @@ class PSIPseudoMotorBase(ABC, PSIDeviceBase, PositionerBase): f"Device '{device_name}' must have at least one argument for each tuple in the following list of tuples: {required_attrs}." ) positioner_objs[name] = device - move_signature = inspect.signature(device.move) - if "wait" in move_signature.parameters: - self._positioner_move_kwargs[name] = {"wait": False} return positioner_objs def _find_device_config_in_session( self, device_name: str, device_manager: DeviceManagerDS ) -> dict[str, Any]: - """ - Helper method to find the device config for a given device name in the current session config of the device manager. + """Find the session configuration entry for ``device_name``. + + Args: + device_name (str): The name of the device to find the configuration for. + device_manager (DeviceManagerDS): The device manager instance to fetch the configuration from. + + Returns: + dict[str, Any]: The configuration dictionary for the device. + + Raises: + ConnectionError: If the device configuration is not found in the current session. """ configs = device_manager.current_session["devices"] config = None @@ -203,30 +216,44 @@ class PSIPseudoMotorBase(ABC, PSIDeviceBase, PositionerBase): @abstractmethod def forward_calculation(self, *args) -> float: - """Calculate the pseudo motor value based on the positioner values.""" + """Compute pseudo value from positioner signals. + + Method parameters must include all keys defined in + ``self.positioner_objects``. + """ @abstractmethod def inverse_calculation(self, position: float, **positioner_objects) -> dict[str, float]: - """Calculate the positioner values based on the pseudo motor value.""" + """Map a pseudo target position to child-motor setpoints. + + The first argument is always the desired pseudo position. + """ @abstractmethod def motors_are_moving(self, *args) -> int: - """Calculate whether the motors are moving based on the positioner values. Should be 0 or 1.""" + """Return a movement flag derived from child-motor motion signals.""" # pylint: disable=arguments-differ def move(self, position: float, **kwargs) -> StatusBase: - """ - Move method for the pseudo motor. This currently only supports moving all positioners - at once. If children want to implement more complex move logic, they can override this method - based on this implementation and the inverse_calculation method. + """Move child motors to realize a pseudo target position. + + The method calls :meth:`inverse_calculation` with the current method + inputs of the ``readback`` processed signal, then moves each configured + child positioner and combines all returned statuses with + :class:`AndStatus`. Args: - position (float): The position to move the pseudo motor to. - kwargs: The kwargs to pass to the move method of the positioners. + position (float): The desired position to move the pseudo motor to. + **kwargs: Additional keyword arguments to pass to the move method of the positioner objects. + Returns: + StatusBase: A combined status object that represents the status of all the move operations on the + positioner objects. """ self.check_value(position) status = None - motor_positions = self.inverse_calculation(position, **self.positioner_objects) + motor_positions = self.inverse_calculation( + position, **self.readback.compute_model.method_inputs + ) for name, pos in motor_positions.items(): positioner = self.positioner_objects[name] move_kwargs = self._positioner_move_kwargs.get(name, {}) diff --git a/ophyd_devices/sim/sim_positioner.py b/ophyd_devices/sim/sim_positioner.py index aaece4a..99142ef 100644 --- a/ophyd_devices/sim/sim_positioner.py +++ b/ophyd_devices/sim/sim_positioner.py @@ -155,6 +155,13 @@ class SimPositioner(Device, PositionerBase): value=self.sim.sim_state[self.readback.name]["value"], timestamp=self.sim.sim_state[self.readback.name]["timestamp"], ) + # Run subscription on "value" + self.readback._run_subs( + sub_type=self.readback.SUB_VALUE, + old_value=old_readback, + value=self.sim.sim_state[self.readback.name]["value"], + timestamp=self.sim.sim_state[self.readback.name]["timestamp"], + ) def _move_to_setpoint(self) -> None: """Move the simulated device to the setpoint.""" diff --git a/ophyd_devices/sim/sim_test_devices.py b/ophyd_devices/sim/sim_test_devices.py index d9b0b41..b8abbb0 100644 --- a/ophyd_devices/sim/sim_test_devices.py +++ b/ophyd_devices/sim/sim_test_devices.py @@ -7,11 +7,13 @@ from bec_lib import messages from bec_lib.endpoints import MessageEndpoints from bec_lib.logger import bec_logger from ophyd import Component as Cpt -from ophyd import Device, DeviceStatus, Kind, OphydObject, PositionerBase, Staged +from ophyd import Device, DeviceStatus, Kind, OphydObject, PositionerBase, Signal, Staged +from ophyd_devices.interfaces.base_classes.psi_pseudo_motor_base import PSIPseudoMotorBase from ophyd_devices.sim.sim_camera import SimCamera from ophyd_devices.sim.sim_positioner import SimPositioner from ophyd_devices.sim.sim_signals import SetableSignal +from ophyd_devices.utils.bec_processed_signal import BECProcessedSignal, ProcessedSignalModel from ophyd_devices.utils.bec_signals import ( AsyncSignal, DynamicSignal, @@ -424,6 +426,97 @@ class SimCameraWithPSIComponents(SimCamera): return status +class VirtualSlitCenter(PSIPseudoMotorBase): + """ + Alternative implementation of a Virtual Slit Center based on two sub-positioners. + The positioners are sub-devices called left_edge and right_edge, and are used + to calculate the center position of the slit. + + Args: + name (str): The name of the pseudo motor device. + device_manager (DeviceManagerBase): The device manager instance to fetch the positioner devices from + """ + + left_edge = Cpt(SimPositioner, name="left_edge", kind=Kind.normal) + right_edge = Cpt(SimPositioner, name="right_edge", kind=Kind.normal) + + def __init__(self, name, device_manager, **kwargs): + super().__init__(name, device_manager, **kwargs) + positioners = {"left_edge": self.left_edge, "right_edge": self.right_edge} + self.set_positioner_objects(positioners) + + def forward_calculation(self, left_edge: Signal, right_edge: Signal) -> float: + """ + Forward calculation to compute the center position of the slit based on the positions of the left and right edges. + + Args: + left_edge (Signal): The signal representing the position of the left edge positioner. + right_edge (Signal): The signal representing the position of the right edge positioner. + """ + return float((left_edge.get() + right_edge.get()) / 2) + + def inverse_calculation( + self, position, left_edge: Signal, right_edge: Signal + ) -> dict[str, float]: + """ + Inverse calculation to compute the setpoints for the left and right edge positioners based on the desired center position. + + Args: + position (float): The desired center position of the slit. + left_edge (Signal): The signal representing the position of the left edge positioner. + right_edge (Signal): The signal representing the position of the right edge positioner. + Returns: + dict[str, float]: A dictionary containing the setpoints for the left and right edge positioners. + """ + left_pos = left_edge.root.readback.get() + right_pos = right_edge.root.readback.get() + width = right_pos - left_pos + new_right_pos = position + width / 2 + new_left_pos = position - width / 2 + return {"left_edge": new_left_pos, "right_edge": new_right_pos} + + def motors_are_moving(self, left_edge: Signal, right_edge: Signal) -> int: + """ + Check if either the left or right edge positioners are currently moving. + + Args: + left_edge (Signal): The signal representing the position of the left edge positioner. + right_edge (Signal): The signal representing the position of the right edge positioner. + Returns: + int: 1 if either motor is moving, 0 otherwise. + """ + left_moving = left_edge.get() + right_moving = right_edge.get() + return int(left_moving or right_moving) + + +class BECPseudoSignal(BECProcessedSignal): + """ + Example of a pseudo signal implementation based on two signals available in the device manager. + The value of the signal is calculated based on the compute method. + + Args: + name (str): The name of the pseudo signal. + signal_1 (str): The name of the first signal in the device manager. + signal_2 (str): The name of the second signal in the device manager. + device_manager (DeviceManagerBase): The device manager instance to fetch the signals from. + """ + + def __init__(self, name, signal_1: str, signal_2: str, device_manager=None, **kwargs): + super().__init__(name, model_config=None, device_manager=device_manager, **kwargs) + signal_1 = self.get_device_object_from_bec( + object_name=signal_1, signal_name=self.name, device_manager=device_manager + ) + signal_2 = self.get_device_object_from_bec( + object_name=signal_2, signal_name=self.name, device_manager=device_manager + ) + self.set_compute_method(self.compute, signal_1=signal_1, signal_2=signal_2) + + def compute(self, signal_1: Signal, signal_2: Signal) -> float: + """Compute the value of the pseudo signal based on the values of signal_1 and signal_2.""" + return float(signal_1.get() + signal_2.get()) + + if __name__ == "__main__": cam = SimCameraWithPSIComponents(name="cam") cam.read() diff --git a/ophyd_devices/utils/bec_processed_signal.py b/ophyd_devices/utils/bec_processed_signal.py index 671cd40..7bda9fc 100644 --- a/ophyd_devices/utils/bec_processed_signal.py +++ b/ophyd_devices/utils/bec_processed_signal.py @@ -1,222 +1,248 @@ +"""Utilities for building computed read-only signals. + +This module provides: + +- :class:`ProcessedSignalModel`, which validates that ``method_inputs`` can be + passed to ``compute_method`` as keyword arguments. +- :class:`BECProcessedSignal`, a ``SignalRO`` subclass that subscribes to input + ophyd objects and recomputes its readback whenever an input updates. + +Re-entrant callback execution is guarded to avoid recursive subscription loops. +""" + from __future__ import annotations import inspect -from typing import TYPE_CHECKING, Any, Callable, Self, TypeAlias +import time +from typing import TYPE_CHECKING, Any, Callable, Literal, Self -import numpy as np -from ophyd import Device, Signal, SignalRO +from ophyd import Component, Device, Signal, SignalRO from pydantic import BaseModel, model_validator if TYPE_CHECKING: # pragma: no cover - from bec_server.device_server.devices.devicemanager import DeviceManagerDS + from bec_server.device_server.devices.devicemanager import DeviceManagerDS, DSDevice -ALLOWED_RETURN_TYPES = (float, int, str, np.ndarray) -ALLOWED_TYPE_ALIAS: TypeAlias = type[float] | type[int] | type[str] | type[np.ndarray] -ComputeMethod: TypeAlias = Callable[..., ALLOWED_TYPE_ALIAS] +def find_device_config_in_session(name: str, device_manager: DeviceManagerDS) -> dict[str, Any]: + """Return the configuration entry of ``name`` from ``device_manager.current_session``. + + The helper is used by lookups that resolve objects through the device + manager and need to validate the ``needs`` dependency list. + + Args: + name: The name of the signal/device for which the config is being fetched. + device_manager: The device manager instance to fetch the current session config from. + """ + configs = device_manager.current_session["devices"] + config = None + for conf in configs: + if conf["name"] == name: + config = conf + break + if config is None: + raise ConnectionError(f"Device '{name}' not found in current session config.") + return config class ProcessedSignalModel(BaseModel): - """ - Model for the BECProcessedSignal, which defines the devices/signals to subscribe to, and a - compute method that will be called with the device/signal objects as arguments. The method - `get_device_objects_from_device_manager` can be used to get the device/signal objects from the device manager - based on the devices dictionary. + """Configuration model for :class:`BECProcessedSignal`. + + The model stores arbitrary keyword inputs and a callable. Validation enforces + that ``compute_method(**method_inputs)`` is a compatible call. Args: - signals (dict[str, DeviceSignalMapping]): A dictionary mapping signal names to their corresponding device - and attribute names. The keys of this dictionary must match the parameters of the compute_method. - compute_method (Callable): A callable that will be called with the device/signal objects as arguments. - The signature of this method must match the keys of the signals dictionary. The return value of - this method will be used as the value of the BECProcessedSignal. The compute method can also be a - coroutine function, in which case it will be awaited. - return_type (type): The expected return type of the compute method. This is used for validation and to ensure - that the value of the BECProcessedSignal is of the correct type. Must be one of the types specified - in ALLOWED_RETURN_TYPES. + method_inputs (dict[str, Any]): A dictionary mapping input names of the compute method to the + corresponding objects. They can be ophyd Signals, Devices, Components or any other additional argument that should be passed + to the compute method when called. The keys of this dictionary must match the signature of the compute method. + compute_method (Callable[..., Any]): A user-defined function that computes the value of the processed signal based on the input devices/signals. + + Note: + Validation only checks call compatibility (missing/extra kwargs and + unsupported signature kinds). It does not enforce the runtime return + type of ``compute_method``. """ model_config = {"arbitrary_types_allowed": True} - devices: dict[str, str] - compute_method: ComputeMethod - return_type: ALLOWED_TYPE_ALIAS + method_inputs: dict[str, Any] + compute_method: Callable[..., Any] @model_validator(mode="after") def validate_signals_in_compute_method(self) -> Self: - """Validator to check that the compute method signature contains all keys of the devices field.""" - # Check that the compute method's signature contains all the keys in the devices dictionary. - input_names = set(self.devices.keys()) - + """Validate compatibility of ``compute_method`` with ``method_inputs``.""" signature = inspect.signature(self.compute_method) - parameters = signature.parameters + input_names = set(self.method_inputs) - method_param_names = set(parameters.keys()) + accepted_names = set() + required_names = set() + has_var_keyword = False - if method_param_names != input_names: - raise ValueError( - "The compute_method signature does not match with the devices keys provided. " - f"Expected parameters: {sorted(input_names)}, " - f"got: {sorted(method_param_names)}" - ) - return self - - def get_device_objects_from_device_manager( - self, device_manager: DeviceManagerDS, signal_name: str - ) -> dict[str, Device | Signal]: - """ - Helper method to get the device/signal objects from the device manager based on the devices dictionary. - - Args: - device_manager (DeviceManagerDS): The device manager to get the devices/signals from. - signal_name (str): The name of the ProcessedSignal - Returns: - dict[str, Device | Signal]: A dictionary mapping the device/signal names to their corresponding - ophyd Device or Signal objects. - """ - device_objs = {} - signal_config = self._find_device_config_in_session(signal_name, device_manager) - needs = signal_config.get("needs", []) - for name, dotted_name in self.devices.items(): - dev_name = dotted_name.split(".")[0] # First part of the dotted name is device - if dev_name not in needs: - raise ConnectionError( - f"Device {dev_name} needs to be specified in the 'needs' field of the config for the current session" - f"for signal '{signal_name}' in order to fetch the device object with name {dotted_name}from the device manager." + for name, param in parameters.items(): + if param.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.VAR_POSITIONAL): + raise ValueError( + "Compute method mus be compatible with compute_method(**method_inputs): " + f"unsupported parameter {name!r} ({param.kind.description})." + f"for compute method {self.compute_method.__name__!r} with signature {signature}." ) - # Attribute access resolves dotted name to fetch the correct signal/device object from the device manager - device = device_manager.devices[dotted_name] - # TODO: How to safeguard against using variables and not signals? - # if not isinstance(device, (Device, Signal)): - # raise ConnectionError( - # f"Device '{dotted_name}' does not point to a valid signal or device but to type {type(device)}." - # ) - device_objs[name] = device - return device_objs + if param.kind is inspect.Parameter.VAR_KEYWORD: + has_var_keyword = True + continue - def _find_device_config_in_session( - self, name: str, device_manager: DeviceManagerDS - ) -> dict[str, Any]: - """ - Helper method to find the device config for a given device name in the current session config of the device manager. - """ - configs = device_manager.current_session["devices"] - config = None - for conf in configs: - if conf["name"] == name: - config = conf - break - if config is None: - raise ConnectionError(f"Device '{name}' not found in current session config.") - return config + accepted_names.add(name) + + if param.default is inspect.Parameter.empty: + required_names.add(name) + + missing = required_names - input_names + extra = set() if has_var_keyword else input_names - accepted_names + + if missing or extra: + problems = [] + if missing: + problems.append(f"missing required inputs: {sorted(missing)}") + if extra: + problems.append(f"unexpected inputs: {sorted(extra)}") + raise ValueError("; ".join(problems)) + + return self class BECProcessedSignal(SignalRO): + """Read-only signal whose value is computed from other inputs. + + A compute model can be provided at construction time via ``model_config`` or + later through :meth:`set_compute_method`. During + :meth:`wait_for_connection`, input ophyd objects are subscribed and the + current value is computed immediately. + + Args: + name (str): The name of the signal. + model_config (dict[Literal["devices", "compute_method"], Any] | None): + Optional initialization payload for :meth:`set_compute_method`. + ``devices`` is passed as keyword arguments to the compute method. + device_manager (DeviceManagerDS | None): Device manager used by helpers + that resolve objects from BEC names. + **kwargs: Additional keyword arguments passed to the SignalRO initializer. + """ def __init__( self, name: str, - model: ProcessedSignalModel | dict[str, Any] | None = None, + model_config: dict[Literal["method_inputs", "compute_method"], Any] | None = None, device_manager: DeviceManagerDS | None = None, **kwargs, ): super().__init__(name=name, **kwargs) - self._model = model or {} + self.compute_model: ProcessedSignalModel | None = None self._device_manager: DeviceManagerDS = self._get_device_manager(device_manager) - self.device_objects: dict[str, Device | Signal] = {} self._metadata["connected"] = False self._callback_is_running = False self._active_callbacks: set[str] = set() + if model_config: + self.set_compute_method( + compute_method=model_config["compute_method"], **model_config["devices"] + ) - def set_model(self, model: ProcessedSignalModel | dict[str, Any]) -> None: + @staticmethod + def get_device_object_from_bec( + object_name: str, signal_name: str, device_manager: DeviceManagerDS + ) -> Device | Signal: + """Resolve one device/signal object from a BEC object name. + + The method verifies that the resolved device is listed in the ``needs`` + section of ``signal_name`` in the active session configuration. """ - Method to set the model for the BECProcessedSignal. This method will be used to set the devices/signals to subscribe to, - and the compute method to use for processing the values from those devices/signals. + signal_config = find_device_config_in_session(signal_name, device_manager) + needs = signal_config.get("needs", []) + dev_name = object_name.split(".")[0] # First part of the dotted name is device + if dev_name not in needs: + raise ConnectionError( + f"Device {dev_name} needs to be specified in the 'needs' field of the config for the current session" + f"for signal '{signal_name}' in order to fetch the device object with name {object_name} from the device manager." + ) + # Attribute access resolves dotted name to fetch the correct signal/device object from the device manager + # If this line crashes, there is likely an issue with the implementation of 'needs' in the device manager. + device = device_manager.devices[object_name] + return device + + def set_compute_method(self, compute_method: Callable[..., Any], **kwargs) -> None: + """Set or replace the compute method and its keyword inputs. + + ``kwargs`` may contain ophyd objects and/or plain values. All entries are + forwarded to the compute method as keyword arguments. Args: - model (ProcessedSignalModel | dict): The model for the BECProcessedSignal. Can be either a ProcessedSignalModel instance or a dictionary that can be used to create a ProcessedSignalModel instance. + compute_method: Callable used to compute the readback value. + **kwargs: Keyword arguments forwarded to ``compute_method``. """ - if isinstance(model, dict): - model = ProcessedSignalModel(**model) - self._model = model + # Lazy import DSDevice to avoid circular import issues + from bec_server.device_server.devices.devicemanager import DSDevice - def set_device_object(self, device_objects: dict[str, Device | Signal]) -> None: - """ - Method to set the device/signal objects for the BECProcessedSignal. This method will be used to provide the actual ophyd Device or Signal objects that correspond to the devices/signals specified in the model. - - Args: - device_objects (dict[str, Device | Signal]): A dictionary mapping the device/signal names to their corresponding ophyd Device or Signal objects. - """ - self.device_objects = device_objects + method_inputs = {} + found_opd_objects = False + for kw, value in kwargs.items(): + if isinstance(value, (Component, Device, Signal, DSDevice)): + found_opd_objects = True + method_inputs[kw] = value + if not found_opd_objects: + raise ValueError( + "At least one ophyd object (Component, Device, Signal, or DSDevice) must be provided as a keyword argument to set_compute_method." + ) + self.compute_model = ProcessedSignalModel.model_validate( + {"method_inputs": method_inputs, "compute_method": compute_method} + ) def wait_for_connection(self, *args, **kwargs) -> None: + """Connect to inputs and initialize computed readback. + + Subscriptions are attached to every method input that exposes both + ``wait_for_connection`` and ``subscribe``. Inputs without these methods + are treated as static keyword arguments. """ - Wait for connection will try to connect to the device/signal objects specified in the - model. It will check the model for validity, and attempt to fetch the device_objects if - they are not already set. Once everything is set up, it subscribe to the default subscriptions - and compute an initial value while asserting that its return value is of the correct type. - """ + # Already connected, no need to do anything + if self._metadata.get("connected", False): + return - # I. Check that model is set, if not raise an error. - if self._model is None: + # Check that model is set, if not raise an error. + if self.compute_model is None: raise ValueError( - f"No model provided for signal {self.name}. Please either provided model in init or use `set_model` before `wait_for_connection` is called." + f"No compute model provided for signal {self.name}. Please either provide a model_config in init or use `set_compute_method` before `wait_for_connection` is called." ) - # II. Check that device objects are set, if not raise an error. - if not self.device_objects: - # The method will also check that the device object names are specified in this signals config needs field. - device_objects = self._model.get_device_objects_from_device_manager( - self._device_manager, self.name - ) - self.set_device_object(device_objects) + # Setup subscriptions to input devices/signals based on the model's configuration. + for input in self.compute_model.method_inputs.values(): + # Lazy import DSDevice to avoid circular import issues + from bec_server.device_server.devices.devicemanager import DSDevice - # III. Check that device_objects dict contains all keys specified in the model's devices field. - model_keys = set(self._model.devices.keys()) - device_object_keys = set(self.device_objects.keys()) - if model_keys != device_object_keys: - raise ValueError( - f"The device_objects provided do not match with the devices specified in the model for signal {self.name}. " - f"Expected keys: {sorted(model_keys)}, got: {sorted(device_object_keys)}" - ) + if not isinstance(input, (Component, Device, Signal, DSDevice)): + continue # Skip non-ophyd objects, they are additional arguments for the compute method + input.wait_for_connection(*args, **kwargs) # Ensure connected - # Now we connect to the default subscriptions of the provided signal/device objects. - for device_obj in self.device_objects.values(): - device_obj.wait_for_connection(*args, **kwargs) # Ensure connected - - device_obj.subscribe( - self._subscription_callback, event_type=device_obj._default_sub, run=False - ) + input.subscribe(self._subscription_callback, event_type=input._default_sub, run=False) # Run computation of the processed signal, this stores the value in _readback - self._subscription_callback(**self.device_objects) - if not isinstance(self._readback, self._model.return_type): - raise TypeError( - f"Computed value has type {type(self._readback)}, but expected {self._model.return_type}" - ) + self._subscription_callback() # Signal is connected self._metadata["connected"] = True def _subscription_callback(self, *args, **kwargs): - """Callback method that is executed whenever any of the subscribed signal/device objects have an update.""" + """Recompute readback from ``compute_model`` and emit value subscriptions.""" if self._callback_is_running: return # Callback is already running, skip to avoid multiple executions at the same time try: self._callback_is_running = True old_value = self._readback - self._readback = self._model.compute_method(**self.device_objects) + timestamp = time.time() + self._metadata["timestamp"] = timestamp + self._readback = self.compute_model.compute_method(**self.compute_model.method_inputs) self._run_subs(sub_type=self._default_sub, old_value=old_value, value=self._readback) finally: self._callback_is_running = False def _run_subs(self, *args, sub_type, **kwargs): - """ - Run subs should not allow for recursive calls of the same subscription type to avoid recursion loops. - - Args: - sub_type (str): The subscription type for which to run the callbacks. - """ + """Prevent concurrent callbacks for the same subscription type.""" if sub_type in self._active_callbacks: return try: @@ -227,14 +253,16 @@ class BECProcessedSignal(SignalRO): self._active_callbacks.remove(sub_type) def _get_device_manager(self, device_manager: DeviceManagerDS | None = None) -> DeviceManagerDS: - """ - Helper method to get the device manager instance. + """Return the active device manager for this signal. + + If ``device_manager`` is not provided, the method tries to read it from + ``self.root.device_manager``. Args: - device_manager (DeviceManagerDS | None): Optional device manager instance. If not provided, it will attempt to fetch from the root device. - + device_manager (DeviceManagerDS | None): An optional device manager instance. If not provided, it will attempt to fetch + the device manager from the root device's `device_manager` attribute. Returns: - DeviceManagerDS: The device manager instance. + DeviceManagerDS: The resolved device manager. """ if device_manager is None: # PSIDeviceBase will have a reference to the device manager on device_manager attribute. @@ -250,13 +278,33 @@ class BECProcessedSignal(SignalRO): return device_manager def describe(self): + """Return ``describe`` metadata including compute model information.""" ret = super().describe() - ret[self.name]["device_objects"] = ", ".join(self._model.devices.values()) + if self.compute_model is None: + ret[self.name]["method_inputs"] = "" + ret[self.name]["compute_method"] = "" + ret[self.name]["extra_kwargs"] = {} + return ret + ret[self.name]["method_inputs"] = ", ".join( + [ + f"{obj.root.name}.{obj.dotted_name}" + for obj in self.compute_model.method_inputs.values() + if hasattr(obj, "dotted_name") and hasattr(obj, "root") # Ophyd obj + ] + ) + ret[self.name]["compute_method"] = self.compute_model.compute_method.__name__ + ret[self.name]["extra_kwargs"] = { + kw: value + for kw, value in self.compute_model.method_inputs.items() + if not (hasattr(value, "dotted_name") and hasattr(value, "root")) # Ophyd obj + } return ret if __name__ == "__main__": # pragma: no cover + # pylint: disable=import-outside-toplevel, unused-import, missing-docstring, protected-access + from bec_server.device_server.tests.utils import DMMock from ophyd_devices.sim.sim_positioner import SimPositioner @@ -271,9 +319,8 @@ if __name__ == "__main__": # pragma: no cover dm.devices._add_device("samx", samx) dm.devices._add_device("samy", samy) - def compute_method(signal_1: Signal, signal_2: Signal) -> float: - """Example compute method that adds two signals.""" - return float(signal_1.get() + signal_2.get()) + def compute_method(signal_1: Signal, signal_2: Signal, tmp: float = 2) -> float: + return float(signal_1.get() + signal_2.get()) + tmp def _callback_print(value, **kwargs): obj = kwargs.get("obj") @@ -282,13 +329,12 @@ if __name__ == "__main__": # pragma: no cover samx.readback.subscribe(_callback_print, run=False, event_type=samx.readback.SUB_VALUE) samy.readback.subscribe(_callback_print, run=False, event_type=samy.readback.SUB_VALUE) - _model = ProcessedSignalModel( - devices={"signal_1": "samx.readback", "signal_2": "samy.readback"}, - compute_method=compute_method, - return_type=float, + processed_signal = BECProcessedSignal( + name="processed_signal", model_config={}, device_manager=dm + ) + processed_signal.set_compute_method( + compute_method, signal_1=samx.readback, signal_2=samy.readback, tmp=0 ) - - processed_signal = BECProcessedSignal(name="processed_signal", model=_model, device_manager=dm) processed_signal.subscribe(_callback_print, run=False, event_type=processed_signal.SUB_VALUE) dm.current_session = {} dm.current_session["devices"] = [{"name": processed_signal.name, "needs": ["samx", "samy"]}] diff --git a/tests/test_base_classes.py b/tests/test_base_classes.py index d04783f..bf74dcc 100644 --- a/tests/test_base_classes.py +++ b/tests/test_base_classes.py @@ -3,11 +3,14 @@ import time from unittest import mock import pytest -from ophyd import DeviceStatus, Staged +from bec_server.device_server.tests.utils import DMMock +from ophyd import Component as Cpt +from ophyd import DeviceStatus, Signal, Staged from ophyd.utils.errors import RedundantStaging from ophyd_devices.interfaces.base_classes.psi_device_base import PSIDeviceBase -from ophyd_devices.utils.errors import DeviceStopError, DeviceTimeoutError +from ophyd_devices.interfaces.base_classes.psi_pseudo_device_base import PSIPseudoDeviceBase +from ophyd_devices.utils.bec_processed_signal import BECProcessedSignal @pytest.fixture @@ -15,6 +18,34 @@ def detector_base(): yield PSIDeviceBase(name="test_detector") +class TestPSIPseudoDeviceBase(PSIPseudoDeviceBase): + """Test class for PSIPseudoDeviceBase to test wait_for_connection method.""" + + signal = Cpt(BECProcessedSignal, name="signal", model_config=None) + tmp = Cpt(Signal, name="tmp") + + def __init__(self, name, device_manager=None, **kwargs): + super().__init__(name=name, device_manager=device_manager, **kwargs) + self.signal.set_compute_method(self.compute, signal=self.tmp) + + def compute(self, signal): + """Compute method for the processed signal.""" + return signal.get() * 2 + + +@pytest.fixture +def pseudo_device_base(): + dm = DMMock() + yield TestPSIPseudoDeviceBase(name="test_pseudo_device", device_manager=dm) + + +def test_psuedo_device_wait_for_connection(pseudo_device_base): + # The wait_for_connection method should connect to the BECProcessedSignal after connecting to the normal signals. + assert pseudo_device_base.signal._metadata["connected"] is False + pseudo_device_base.wait_for_connection() + assert pseudo_device_base.signal._metadata["connected"] is True + + def test_detector_base_init(detector_base): assert detector_base.stopped is False assert detector_base.name == "test_detector" diff --git a/tests/test_processed_signal.py b/tests/test_processed_signal.py new file mode 100644 index 0000000..a42b577 --- /dev/null +++ b/tests/test_processed_signal.py @@ -0,0 +1,152 @@ +"""Tests for processed-signal behavior and integration patterns.""" + +# pylint: disable=redefined-outer-name + +import pytest +from bec_server.device_server.tests.utils import DMMock +from ophyd import Component as Cpt +from ophyd import Device + +from ophyd_devices.sim.sim_positioner import SimPositioner +from ophyd_devices.utils.bec_processed_signal import BECProcessedSignal, ProcessedSignalModel + + +class TestProcessedSignalDevice(Device): + """Fixture device with two sub-positioners and one processed signal.""" + + motor_a = Cpt(SimPositioner, name="motor_a", delay=0) + motor_b = Cpt(SimPositioner, name="motor_b", delay=0) + processed = Cpt(BECProcessedSignal, name="processed", model_config=None) + + def __init__(self, name, device_manager, **kwargs): + self.device_manager = device_manager + super().__init__(name=name, **kwargs) + self.processed.set_compute_method( + self.compute, motor_a=self.motor_a.readback, motor_b=self.motor_b.readback, offset=0.5 + ) + + @staticmethod + def compute(motor_a, motor_b, offset=0): + """Compute processed value from two motor readbacks.""" + return float(motor_a.get() + motor_b.get() + offset) + + +@pytest.fixture(name="device_manager") +def fixture_device_manager(): + """Mock device manager fixture.""" + return DMMock() + + +@pytest.fixture(name="processed_device") +def fixture_processed_device(device_manager): + """Fixture for TestProcessedSignalDevice.""" + dev = TestProcessedSignalDevice(name="processed_dev", device_manager=device_manager) + dev.motor_a.wait_for_connection() + dev.motor_b.wait_for_connection() + dev.processed.wait_for_connection() + return dev + + +@pytest.fixture(name="samx") +def fixture_samx(): + """Standalone left motor fixture.""" + return SimPositioner(name="samx", delay=0) + + +@pytest.fixture(name="samy") +def fixture_samy(): + """Standalone right motor fixture.""" + return SimPositioner(name="samy", delay=0) + + +@pytest.fixture(name="device_manager_with_signals") +def fixture_device_manager_with_signals(samx, samy): + """Device manager fixture with motor mapping and session needs.""" + dm = DMMock() + dm.devices["samx"] = samx + dm.devices["samy"] = samy + dm.current_session = {"devices": [{"name": "processed_signal", "needs": ["samx", "samy"]}]} + return dm + + +@pytest.fixture(name="processed_signal_from_device_manager") +def fixture_processed_signal_from_device_manager(device_manager_with_signals): + """Processed signal fixture using dotted-name resolution through the device manager.""" + + signal = BECProcessedSignal(name="processed_signal", device_manager=device_manager_with_signals) + signal_1 = BECProcessedSignal.get_device_object_from_bec( + "samx.readback", "processed_signal", device_manager_with_signals + ) + signal_2 = BECProcessedSignal.get_device_object_from_bec( + "samy.readback", "processed_signal", device_manager_with_signals + ) + signal.set_compute_method( + lambda signal_1, signal_2, offset=1.0: float(signal_1.get() + signal_2.get() + offset), + signal_1=signal_1, + signal_2=signal_2, + offset=1.0, + ) + signal.wait_for_connection() + return signal + + +def test_processed_signal_with_sub_components(processed_device): + """Test processed signal updates from sub-component motor readbacks.""" + processed_device.motor_a.move(2).wait(timeout=2) + processed_device.motor_b.move(3).wait(timeout=2) + assert processed_device.processed.get() == pytest.approx(5.5) + + processed_device.motor_a.move(-1).wait(timeout=2) + assert processed_device.processed.get() == pytest.approx(2.5) + + +def test_processed_signal_device_manager_resolution( + processed_signal_from_device_manager, samx, samy +): + """Test processed signal using device-manager string key resolution.""" + samx.move(1).wait(timeout=2) + samy.move(2).wait(timeout=2) + assert processed_signal_from_device_manager.get() == pytest.approx(4) + + samy.move(5).wait(timeout=2) + assert processed_signal_from_device_manager.get() == pytest.approx(7) + + +def test_processed_signal_describe_metadata(processed_signal_from_device_manager): + """Test describe contains compute method metadata and extra kwargs.""" + info = processed_signal_from_device_manager.describe()["processed_signal"] + assert info["compute_method"] == "" + assert info["extra_kwargs"] == {"offset": 1.0} + assert "samx.readback" in info["method_inputs"] + assert "samy.readback" in info["method_inputs"] + + +def test_processed_signal_model_rejects_missing_required_inputs(): + """Test compute model validation when required kwargs are missing.""" + + def compute(signal_1, signal_2): + return signal_1 + signal_2 + + with pytest.raises(ValueError, match="missing required inputs"): + ProcessedSignalModel.model_validate( + {"method_inputs": {"signal_1": 1}, "compute_method": compute} + ) + + +def test_processed_signal_model_rejects_unexpected_inputs(): + """Test compute model validation when unknown kwargs are provided.""" + + def compute(signal_1): + return signal_1 + + with pytest.raises(ValueError, match="unexpected inputs"): + ProcessedSignalModel.model_validate( + {"method_inputs": {"signal_1": 1, "extra": 2}, "compute_method": compute} + ) + + +def test_processed_signal_requires_compute_method(device_manager): + """Test wait_for_connection fails when no compute model is configured.""" + signal = BECProcessedSignal(name="no_model", device_manager=device_manager) + with pytest.raises(ValueError, match="No compute model provided"): + signal.wait_for_connection() diff --git a/tests/test_virtual_slits.py b/tests/test_virtual_slits.py new file mode 100644 index 0000000..7a61439 --- /dev/null +++ b/tests/test_virtual_slits.py @@ -0,0 +1,261 @@ +"""Module to test the virtual slit center and width classes.""" + +import pytest +from bec_server.device_server.tests.utils import DMMock +from ophyd import Component as Cpt + +from ophyd_devices import TransitionStatus +from ophyd_devices.devices.virtual_slit import VirtualSlitCenter, VirtualSlitWidth +from ophyd_devices.interfaces.base_classes.psi_pseudo_motor_base import PSIPseudoMotorBase +from ophyd_devices.sim.sim_positioner import SimPositioner + + +class TestPseudoMotor(PSIPseudoMotorBase): + """Pseudo motor fixture class with two sub-positioner components.""" + + motor_a = Cpt(SimPositioner, name="motor_a", delay=0) + motor_b = Cpt(SimPositioner, name="motor_b", delay=0) + + def __init__(self, name, device_manager, **kwargs): + super().__init__(name=name, device_manager=device_manager, **kwargs) + self.set_positioner_objects({"a": self.motor_a, "b": self.motor_b}) + + def forward_calculation(self, a, b): + """Forward calculation.""" + return float(a.get() + b.get()) + + def inverse_calculation(self, position, a, b): + """Inverse calculation.""" + a_val = a.get() + b_val = position - a_val + return {"a": a_val, "b": b_val} + + def motors_are_moving(self, a, b): + """Check if the sub-positioners are moving.""" + return int(a.get() or b.get()) + + +@pytest.fixture +def pseudo_motor_fixture(): + """Fixture for the TestPseudoMotor with a mock device manager.""" + dm = DMMock() + pseudo = TestPseudoMotor(name="test_pseudo", device_manager=dm) + pseudo.wait_for_connection() + return pseudo + + +@pytest.fixture +def samx(): + """Positioner fixture for the left slit motor.""" + return SimPositioner(name="samx", delay=0) + + +@pytest.fixture +def samy(): + """Positioner fixture for the right slit motor.""" + return SimPositioner(name="samy", delay=0) + + +@pytest.fixture +def device_manager_with_slit_motors(samx, samy): + """Device manager fixture with slit motors and session configuration for virtual slits.""" + dm = DMMock() + dm.devices["samx"] = samx + dm.devices["samy"] = samy + dm.current_session = { + "devices": [ + {"name": "slit_center", "needs": ["samx", "samy"]}, + {"name": "slit_width", "needs": ["samx", "samy"]}, + ] + } + return dm + + +@pytest.fixture +def slit_center(device_manager_with_slit_motors): + """Slit Center fixture""" + center = VirtualSlitCenter( + name="slit_center", + left_slit="samx", + right_slit="samy", + device_manager=device_manager_with_slit_motors, + ) + center.wait_for_connection() + return center + + +@pytest.fixture +def slit_width(device_manager_with_slit_motors): + """Slit Width fixture""" + width = VirtualSlitWidth( + name="slit_width", + left_slit="samx", + right_slit="samy", + device_manager=device_manager_with_slit_motors, + ) + width.wait_for_connection() + return width + + +def test_subcomponent_pseudo_motor_move(pseudo_motor_fixture): + """ + Test that moving the pseudo motor correctly updates the positions of the sub-positioners + and that the readback of the pseudo motor reflects the new position. + """ + pseudo_motor_fixture.motor_a.move(5).wait(timeout=2) + pseudo_motor_fixture.motor_b.move(-5).wait(timeout=2) + + status = pseudo_motor_fixture.move(2) + status.wait(timeout=2) + + assert pseudo_motor_fixture.motor_a.readback.get() == pytest.approx(5) + assert pseudo_motor_fixture.motor_b.readback.get() == pytest.approx(-3) + assert pseudo_motor_fixture.readback.get() == pytest.approx(2) + + +def test_virtual_slit_center_forward_calculation(slit_center, samx, samy): + """ + Test that the forward calculation for the virtual slit center correctly computes + the center position based on the current positions of the left and right slit motors. + """ + samx.move(1).wait(timeout=2) + samy.move(3).wait(timeout=2) + + center = slit_center.forward_calculation(samx.readback, samy.readback) + assert center == pytest.approx(2) + + +def test_virtual_slit_center_inverse_calculation(slit_center, samx, samy): + """ + Test that the inverse calculation for the virtual slit center correctly + computes the positions of the left and right slit motors based on the + desired center position and current positions of the slit motors. + """ + status = TransitionStatus(slit_center.motor_is_moving, transitions=[0, 1, 0]) + st_samx = samx.move(1) + st_samy = samy.move(3) + + status.wait(timeout=2) + assert status.done is True, "Expected the slit center to start moving the slit motors." + assert status.success is True, "Expected the slit center to move the slit motors." + assert st_samy.done is True, "Expected the slit center to finish moving the right slit motor." + assert st_samy.success is True, "Expected the slit center to move the right slit motor." + assert st_samx.done is True, "Expected the slit center to finish moving the left slit motor." + assert st_samx.success is True, "Expected the slit center to move the left slit motor." + + pos = slit_center.inverse_calculation(4, samx.readback, samy.readback) + assert pos["left"] == pytest.approx(3) + assert pos["right"] == pytest.approx(5) + + +def test_virtual_slit_width_forward_calculation(slit_width, samx, samy): + """ + Test that the forward calculation for the virtual slit width correctly computes + the width based on the current positions of the left and right slit motors. + """ + samx.move(1).wait(timeout=2) + samy.move(3).wait(timeout=2) + + width = slit_width.forward_calculation(samx.readback, samy.readback) + assert width == pytest.approx(2) + + +def test_virtual_slit_width_inverse_calculation(slit_width, samx, samy): + """ + Test that the inverse calculation for the virtual slit width correctly + computes the positions of the left and right slit motors based on the + desired width and current positions of the slit motors. + """ + samx.move(1).wait(timeout=2) + samy.move(3).wait(timeout=2) + + pos = slit_width.inverse_calculation(6, samx.readback, samy.readback) + assert pos["left"] == pytest.approx(-1) + assert pos["right"] == pytest.approx(5) + + +def test_virtual_slit_center_move(slit_center, samx, samy): + """ + Test that moving the virtual slit center correctly updates the + positions of the left and right slit motors. + """ + samx.move(1).wait(timeout=2) + samy.move(3).wait(timeout=2) + + assert slit_center.readback.get() == pytest.approx(2) + + status = slit_center.move(5) + status.wait(timeout=2) + + assert samx.readback.get() == pytest.approx(4) + assert samy.readback.get() == pytest.approx(6) + assert slit_center.readback.get() == pytest.approx(5) + + +def test_virtual_slit_width_move(slit_width, samx, samy): + """ + Test that moving the virtual slit width correctly updates the + positions of the left and right slit motors. + """ + samx.move(1).wait(timeout=2) + samy.move(3).wait(timeout=2) + + # EGU should be taken from the left slit motor. + assert slit_width.egu == samx.egu + assert slit_width.egu == samx.egu + + status = slit_width.move(6) + status.wait(timeout=2) + + assert samx.readback.get() == pytest.approx(-1) + assert samy.readback.get() == pytest.approx(5) + assert slit_width.readback.get() == pytest.approx(6) + + +def test_virtual_slit_offset_applied_in_forward_and_inverse( + device_manager_with_slit_motors, samx, samy +): + """ + Test that the offset is correctly applied in both forward and + inverse calculations for the virtual slit center and width. + """ + samx.move(1).wait(timeout=2) + samy.move(3).wait(timeout=2) + + slit_center_offset = VirtualSlitCenter( + name="slit_center", + left_slit="samx", + right_slit="samy", + device_manager=device_manager_with_slit_motors, + offset=0.5, + egu="new_egu", + ) + slit_width_offset = VirtualSlitWidth( + name="slit_width", + left_slit="samx", + right_slit="samy", + device_manager=device_manager_with_slit_motors, + offset=0.5, + egu="new_egu", + ) + assert slit_width_offset.egu == "new_egu" + assert slit_center_offset.egu == "new_egu" + slit_center_offset.wait_for_connection() + slit_width_offset.wait_for_connection() + + assert slit_center_offset.forward_calculation(samx.readback, samy.readback) == pytest.approx( + 2.5 + ) + assert slit_width_offset.forward_calculation(samx.readback, samy.readback) == pytest.approx(2.5) + + center_pos = slit_center_offset.inverse_calculation(5.5, samx.readback, samy.readback) + assert center_pos["left"] == pytest.approx(4) + assert center_pos["right"] == pytest.approx(6) + + width_pos = slit_width_offset.inverse_calculation(6.5, samx.readback, samy.readback) + assert width_pos["left"] == pytest.approx(-1) + assert width_pos["right"] == pytest.approx(5) + + # Check offset signal + assert slit_center_offset.offset.get() == pytest.approx(0.5) + assert slit_width_offset.offset.get() == pytest.approx(0.5)