diff --git a/superxas_bec/devices/trigger.py b/superxas_bec/devices/trigger.py index 0ed09d2..21e7c06 100644 --- a/superxas_bec/devices/trigger.py +++ b/superxas_bec/devices/trigger.py @@ -1,45 +1,66 @@ -from ophyd import Device, Kind, Component as Cpt -from ophyd import EpicsSignal, EpicsSignalRO, DeviceStatus, StatusBase -from ophyd.status import SubscriptionStatus +"""SuperXAS Trigger Device""" + from bec_lib.logger import bec_logger +from ophyd import Component as Cpt +from ophyd import Device, DeviceStatus, EpicsSignal, EpicsSignalRO, Kind, StatusBase logger = bec_logger.logger -from bec_lib.devicemanager import ScanInfo - import enum +from bec_lib.devicemanager import ScanInfo from ophyd_devices.interfaces.base_classes.psi_device_base import PSIDeviceBase + class ContinuousSamplingMode(int, enum.Enum): - """ Options for start_csmpl signal""" + """Options for start_csmpl signal""" + OFF = 0 ON = 1 + class SamplingDone(int, enum.Enum): - """ Status of sampling """ + """Status of sampling""" + RUNNING = 0 DONE = 1 + class TriggerControl(Device): - """ Trigger Device Control PVs at X10DA, prefix: X10DA-ES1: """ - - total_cycles = Cpt(EpicsSignal, suffix='TOTAL-CYCLES', kind=Kind.config, doc="Number of cycles (multiplies by 0.2s)") - start_csmpl = Cpt(EpicsSignal, suffix='START-CSMPL', kind=Kind.config, doc="Continous sampling mode on/off") - smpl = Cpt(EpicsSignal, suffix='SMPL', kind=Kind.config, doc="Sampling Trigger if cont mode is off") - smpl_done = Cpt(EpicsSignalRO, suffix='SMPL-DONE', kind=Kind.config, doc="Done status of trigger") + """Trigger Device Control PVs at X10DA, prefix: X10DA-ES1:""" + total_cycles = Cpt( + EpicsSignal, + suffix="TOTAL-CYCLES", + kind=Kind.config, + doc="Number of cycles (multiplies by 0.2s)", + ) + start_csmpl = Cpt( + EpicsSignal, suffix="START-CSMPL", kind=Kind.config, doc="Continous sampling mode on/off" + ) + smpl = Cpt( + EpicsSignal, suffix="SMPL", kind=Kind.config, doc="Sampling Trigger if cont mode is off" + ) + smpl_done = Cpt( + EpicsSignalRO, suffix="SMPL-DONE", kind=Kind.config, doc="Done status of trigger" + ) class Trigger(PSIDeviceBase, TriggerControl): - """ Trigger Device of X10DA (SUPERXAS), prefix: X10DA-ES1: """ + """Trigger Device of X10DA (SUPERXAS), prefix: X10DA-ES1:""" - def __init__(self, name: str, prefix:str='',scan_info: ScanInfo | None = None, device_manager=None, **kwargs): + def __init__( + self, + name: str, + prefix: str = "", + scan_info: ScanInfo | None = None, + device_manager=None, + **kwargs, + ): super().__init__(name=name, prefix=prefix, scan_info=scan_info, **kwargs) self.device_manager = device_manager self._pv_timeout = 1 - ######################################## # Beamline Specific Implementations # ######################################## @@ -57,7 +78,6 @@ class Trigger(PSIDeviceBase, TriggerControl): Called after the device is connected and its signals are connected. Default values for signals should be set here. """ - def on_stage(self) -> DeviceStatus | StatusBase | None: """ @@ -65,8 +85,8 @@ class Trigger(PSIDeviceBase, TriggerControl): Information about the upcoming scan can be accessed from the scan_info (self.scan_info.msg) object. """ - self.start_csmpl.set(ContinuousSamplingMode.OFF).wait() - exp_time = self.scan_info.msg.scan_parameters['exp_time'] + self.start_csmpl.set(ContinuousSamplingMode.OFF).wait(timeout=self._pv_timeout) + exp_time = self.scan_info.msg.scan_parameters["exp_time"] if self.scan_info.msg.scan_name != "exafs_scan": self.set_exposure_time(exp_time).wait() @@ -81,8 +101,9 @@ class Trigger(PSIDeviceBase, TriggerControl): def on_trigger(self) -> DeviceStatus | StatusBase | None: """Called when the device is triggered.""" falcon = self.device_manager.devices.get("falcon", None) - + if falcon is not None: + # pylint: disable=protected-access status = falcon._stop_erase_and_wait_for_acquiring() status.wait() @@ -97,11 +118,10 @@ class Trigger(PSIDeviceBase, TriggerControl): return True return self.smpl_done.get() == SamplingDone.DONE - - self.smpl.put(1) - status = self.task_handler.submit_task(_sampling_done,run=True) - return status + self.smpl.put(1) + status = self.task_handler.submit_task(_sampling_done, run=True) + return status def on_complete(self) -> DeviceStatus | StatusBase | None: """Called to inquire if a device has completed a scans.""" @@ -113,10 +133,7 @@ class Trigger(PSIDeviceBase, TriggerControl): """Called when the device is stopped.""" self.task_handler.shutdown() - def set_exposure_time(self, value:float) -> DeviceStatus: - """ Utility method to set exposure time complying to device logic with cycle of min 0.2s.""" - cycles = max(int(value*5),1) + def set_exposure_time(self, value: float) -> DeviceStatus: + """Utility method to set exposure time complying to device logic with cycle of min 0.2s.""" + cycles = max(int(value * 5), 1) return self.total_cycles.set(cycles) - - - diff --git a/tests/tests_devices/test_devices_trigger.py b/tests/tests_devices/test_devices_trigger.py new file mode 100644 index 0000000..a8ab305 --- /dev/null +++ b/tests/tests_devices/test_devices_trigger.py @@ -0,0 +1,105 @@ +"""Tests for Trigger device.""" + +import threading +from unittest import mock + +import ophyd +import pytest +from bec_server.device_server.tests.utils import DMMock +from ophyd import DeviceStatus +from ophyd_devices.tests.utils import MockPV, patch_dual_pvs + +from superxas_bec.devices.trigger import ContinuousSamplingMode, Trigger + +# pylint: disable=protected-access + + +@pytest.fixture(scope="function") +def trigger(): + """Trigger device with mocked EPICS PVs.""" + name = "trigger" + prefix = "X10DA-ES1:" + with mock.patch.object(ophyd, "cl") as mock_cl: + mock_cl.get_pv = MockPV + mock_cl.thread_class = threading.Thread + dev = Trigger(name=name, prefix=prefix, device_manager=DMMock()) + patch_dual_pvs(dev) + yield dev + + +@pytest.mark.parametrize(["exp_time", "cycles"], [(0.1, 1), (0.5, 2), (2, 10)]) +def test_devices_trigger_stage_core_scans(trigger, exp_time, cycles): + """Test on_connected method of Trigger device. + + The pytest.mark.parametrize decorator is used to run the test for each parameter in the list. + """ + assert trigger.prefix == "X10DA-ES1:" + assert trigger.name == "trigger" + assert trigger._pv_timeout == 1 + # + trigger.on_connected() + trigger._pv_timeout = 0.2 + trigger.start_csmpl.put(ContinuousSamplingMode.ON) + assert trigger.start_csmpl.get() == ContinuousSamplingMode.ON + + # Set scan_info information for scan + trigger.scan_info.msg.scan_name = "step_scan" + trigger.scan_info.msg.scan_parameters["exp_time"] = exp_time + + # On stage should set exposure time + status = trigger.stage() + if isinstance(status, DeviceStatus): + status.wait() + assert trigger.start_csmpl.get() == ContinuousSamplingMode.OFF + # cycles will be multiple of exp_time/0.2 as int, minimum 1. + assert trigger.total_cycles.get() == cycles + + +def test_devices_trigger_unstage(trigger): + """ + Test on_unstage method of Trigger device. + + This should put start_csmpl to ON. + """ + trigger.start_csmpl.put(ContinuousSamplingMode.OFF) + assert trigger.start_csmpl.get() == ContinuousSamplingMode.OFF + status = trigger.unstage() + status.wait() + assert trigger.start_csmpl.get() == ContinuousSamplingMode.ON + + +def test_devices_trigger_stop(trigger): + """ + Test on_stop method of Trigger device. + + This should stop the task_handler. + """ + assert trigger.stopped is False + with mock.patch.object(trigger, "task_handler") as mock_handler: + trigger.stop() + mock_handler.shutdown.assert_called_once() + assert trigger.stopped is True + + +def test_devices_trigger_trigger(trigger): + """Test trigger method of Trigger device.""" + # TODO we should use the ScanStatusMessage to update scan_info here + falcon = mock.MagicMock() + falcon.name.return_value = "falcon" + status = DeviceStatus(device=falcon) + status.set_finished() + falcon._stop_erase_and_wait_for_acquiring.return_value = status + trigger.device_manager.devices["falcon"] = falcon + + trigger_status = DeviceStatus(device=trigger) + trigger_status.set_finished() + with mock.patch.object( + trigger.task_handler, "submit_task", return_value=trigger_status + ) as mock_submit: + status = trigger.trigger() + assert falcon._stop_erase_and_wait_for_acquiring.call_count == 1 + assert trigger.smpl.get() == 1 # smpl called with 1 + # TODO check that the task_handler is called with the correct function + # This is currently not easily testable + assert mock_submit.call_count == 1 + assert trigger_status == status