diff --git a/ophyd_devices/sim/sim_data.py b/ophyd_devices/sim/sim_data.py index f56ec5e..fa7f81a 100644 --- a/ophyd_devices/sim/sim_data.py +++ b/ophyd_devices/sim/sim_data.py @@ -164,6 +164,10 @@ class SimulatedDataBase(ABC): """ Method to set the parameters for the active simulation model. """ + self._set_params(params) + + def _set_params(self, params: dict) -> None: + """Utility method to set parameters for active model.""" for k, v in params.items(): if k in self.params: if k == "noise": @@ -327,6 +331,32 @@ class SimulatedDataMonitor(SimulatedDataBase): self.bit_depth = self.parent.BIT_DEPTH self._init_default() + @SimulatedDataBase.params.setter + def params(self, params: dict) -> None: + SimulatedDataBase.params.fset(self, params) + self._add_callback_to_motor() + + def _add_callback_to_motor(self) -> None: + # Setup subscription to the reference motor if available + mot_name = self.params.get("ref_motor", "") + if not hasattr(self.parent, "device_manager"): + return + if mot_name in self.parent.device_manager.devices: + if hasattr(self.parent, "setup_readback_monitor"): + self.parent.setup_readback_monitor(mot_name) + + def select_model(self, model: str) -> None: + """ + Method to select the active simulation model. + It will initiate the model_cls and parameters for the model. + + Args: + model (str): Name of the simulation model to select. + + """ + super().select_model(model) + self._add_callback_to_motor() + def _get_additional_params(self) -> None: params = deepcopy(DEFAULT_PARAMS_NOISE) params.update(deepcopy(DEFAULT_PARAMS_MOTOR)) @@ -432,7 +462,7 @@ class SimulatedDataMonitor(SimulatedDataBase): Returns: float: Value computed by the active model. """ - mot_name = self.params["ref_motor"] + mot_name = self.params.get("ref_motor", "") if self.parent.device_manager and mot_name in self.parent.device_manager.devices: motor_pos = self.parent.device_manager.devices[mot_name].obj.read()[mot_name]["value"] else: diff --git a/ophyd_devices/sim/sim_monitor.py b/ophyd_devices/sim/sim_monitor.py index 5ed1acc..0eb3fa5 100644 --- a/ophyd_devices/sim/sim_monitor.py +++ b/ophyd_devices/sim/sim_monitor.py @@ -1,5 +1,7 @@ """Module for simulated monitor devices.""" +from dataclasses import dataclass + import numpy as np from bec_lib import messages from bec_lib.endpoints import MessageEndpoints @@ -15,6 +17,13 @@ from ophyd_devices.utils import bec_utils logger = bec_logger.logger +@dataclass +class RegisteredCallback: + + motor: str + callback_id: int + + class SimMonitor(ReadOnlySignal): """ A simulated device mimic any 1D Axis (position, temperature, beam). @@ -61,8 +70,9 @@ class SimMonitor(ReadOnlySignal): self.precision = precision self.sim_init = sim_init self.device_manager = device_manager - self.sim = self.sim_cls(parent=self, **kwargs) self._registered_proxies = {} + self._registered_callback: RegisteredCallback | None = None + self.sim = self.sim_cls(parent=self, **kwargs) super().__init__( name=name, @@ -77,10 +87,37 @@ class SimMonitor(ReadOnlySignal): self.sim.set_init(self.sim_init) @property - def registered_proxies(self) -> None: + def registered_proxies(self) -> dict: """Dictionary of registered signal_names and proxies.""" return self._registered_proxies + def setup_readback_monitor(self, motor_name: str) -> None: + """ + Set up monitoring of the readback signal of a motor. + + Args: + motor_name (str): The name of the motor to monitor. + """ + + if self._registered_callback: + if self._registered_callback.motor == motor_name: + # Already registered callback + return + else: # Unregister callback from previous motor if necessary + motor = self.device_manager.devices.get(self._registered_callback.motor, None) + if motor: + motor.unsubscribe(self._registered_callback.callback_id) + + # Register new callback + motor = self.device_manager.devices.get(motor_name, None) + if motor: + cb_id = motor.subscribe(self._update_readback, run=True) + self._registered_callback = RegisteredCallback(motor=motor_name, callback_id=cb_id) + + def _update_readback(self, value, **kwargs): + """Callback function to update the readback value.""" + self.get() # Trigger a read to update the readback value + class SimMonitorAsyncControl(Device): """SimMonitor Sync Control Device""" diff --git a/tests/test_simulation.py b/tests/test_simulation.py index 28af8d1..9f08cdf 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -221,10 +221,12 @@ def test_init_async_monitor(async_monitor): @pytest.mark.parametrize("center", [-10, 0, 10]) -def test_monitor_readback(monitor, center): +def test_monitor_readback(monitor, center, positioner): """Test the readback method of SimMonitor.""" motor_pos = 0 - monitor.device_manager.add_device(name="samx", value=motor_pos) + samx = SimPositioner(name="samx", device_manager=monitor.device_manager) + setattr(samx, "obj", samx) # Set obj attribute to itself for proxy lookup + monitor.device_manager.devices["samx"] = samx for model_name in monitor.sim.get_models(): monitor.sim.select_model(model_name) monitor.sim.params["noise_multipler"] = 10 @@ -234,17 +236,29 @@ def test_monitor_readback(monitor, center): elif "center" in monitor.sim.params: monitor.sim.params["center"] = center assert isinstance(monitor.read()[monitor.name]["value"], monitor.BIT_DEPTH) - expected_value = _safeint(monitor.sim._model.eval(monitor.sim._model_params, x=motor_pos)) - print(expected_value, monitor.read()[monitor.name]["value"]) + expected_value = _safeint( + monitor.sim._model.eval(monitor.sim._model_params, x=samx.read()["samx"]["value"]) + ) tolerance = ( monitor.sim.params["noise_multipler"] + 1 ) # due to ceiling in calculation, but maximum +1int assert np.isclose( monitor.read()[monitor.name]["value"], expected_value, - atol=monitor.sim.params["noise_multipler"] + 1, + atol=monitor.sim.params["noise_multipler"] + 2, # allow extra tolerance for ceiling ) + # Test callback on motor motion + _callback_bucket = [] + + def _callback(value, **kwargs): + _callback_bucket.append(value) + + monitor.subscribe(_callback, run=False) + assert not _callback_bucket + monitor.device_manager.devices["samx"].move(10).wait() + assert len(_callback_bucket) > 0 + @pytest.mark.parametrize("amplitude, noise_multiplier", [(0, 1), (100, 10), (1000, 50)]) def test_camera_readback(camera, amplitude, noise_multiplier):