refactor(trigger): refactor trigger device, add tests
This commit is contained in:
@@ -1,45 +1,66 @@
|
||||
from ophyd import Device, Kind, Component as Cpt
|
||||
from ophyd import EpicsSignal, EpicsSignalRO, DeviceStatus, StatusBase
|
||||
from ophyd.status import SubscriptionStatus
|
||||
"""SuperXAS Trigger Device"""
|
||||
|
||||
from bec_lib.logger import bec_logger
|
||||
from ophyd import Component as Cpt
|
||||
from ophyd import Device, DeviceStatus, EpicsSignal, EpicsSignalRO, Kind, StatusBase
|
||||
|
||||
logger = bec_logger.logger
|
||||
|
||||
from bec_lib.devicemanager import ScanInfo
|
||||
|
||||
import enum
|
||||
|
||||
from bec_lib.devicemanager import ScanInfo
|
||||
from ophyd_devices.interfaces.base_classes.psi_device_base import PSIDeviceBase
|
||||
|
||||
|
||||
class ContinuousSamplingMode(int, enum.Enum):
|
||||
""" Options for start_csmpl signal"""
|
||||
"""Options for start_csmpl signal"""
|
||||
|
||||
OFF = 0
|
||||
ON = 1
|
||||
|
||||
|
||||
class SamplingDone(int, enum.Enum):
|
||||
""" Status of sampling """
|
||||
"""Status of sampling"""
|
||||
|
||||
RUNNING = 0
|
||||
DONE = 1
|
||||
|
||||
|
||||
class TriggerControl(Device):
|
||||
""" Trigger Device Control PVs at X10DA, prefix: X10DA-ES1: """
|
||||
|
||||
total_cycles = Cpt(EpicsSignal, suffix='TOTAL-CYCLES', kind=Kind.config, doc="Number of cycles (multiplies by 0.2s)")
|
||||
start_csmpl = Cpt(EpicsSignal, suffix='START-CSMPL', kind=Kind.config, doc="Continous sampling mode on/off")
|
||||
smpl = Cpt(EpicsSignal, suffix='SMPL', kind=Kind.config, doc="Sampling Trigger if cont mode is off")
|
||||
smpl_done = Cpt(EpicsSignalRO, suffix='SMPL-DONE', kind=Kind.config, doc="Done status of trigger")
|
||||
"""Trigger Device Control PVs at X10DA, prefix: X10DA-ES1:"""
|
||||
|
||||
total_cycles = Cpt(
|
||||
EpicsSignal,
|
||||
suffix="TOTAL-CYCLES",
|
||||
kind=Kind.config,
|
||||
doc="Number of cycles (multiplies by 0.2s)",
|
||||
)
|
||||
start_csmpl = Cpt(
|
||||
EpicsSignal, suffix="START-CSMPL", kind=Kind.config, doc="Continous sampling mode on/off"
|
||||
)
|
||||
smpl = Cpt(
|
||||
EpicsSignal, suffix="SMPL", kind=Kind.config, doc="Sampling Trigger if cont mode is off"
|
||||
)
|
||||
smpl_done = Cpt(
|
||||
EpicsSignalRO, suffix="SMPL-DONE", kind=Kind.config, doc="Done status of trigger"
|
||||
)
|
||||
|
||||
|
||||
class Trigger(PSIDeviceBase, TriggerControl):
|
||||
""" Trigger Device of X10DA (SUPERXAS), prefix: X10DA-ES1: """
|
||||
"""Trigger Device of X10DA (SUPERXAS), prefix: X10DA-ES1:"""
|
||||
|
||||
def __init__(self, name: str, prefix:str='',scan_info: ScanInfo | None = None, device_manager=None, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
prefix: str = "",
|
||||
scan_info: ScanInfo | None = None,
|
||||
device_manager=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(name=name, prefix=prefix, scan_info=scan_info, **kwargs)
|
||||
self.device_manager = device_manager
|
||||
self._pv_timeout = 1
|
||||
|
||||
|
||||
########################################
|
||||
# Beamline Specific Implementations #
|
||||
########################################
|
||||
@@ -57,7 +78,6 @@ class Trigger(PSIDeviceBase, TriggerControl):
|
||||
Called after the device is connected and its signals are connected.
|
||||
Default values for signals should be set here.
|
||||
"""
|
||||
|
||||
|
||||
def on_stage(self) -> DeviceStatus | StatusBase | None:
|
||||
"""
|
||||
@@ -65,8 +85,8 @@ class Trigger(PSIDeviceBase, TriggerControl):
|
||||
|
||||
Information about the upcoming scan can be accessed from the scan_info (self.scan_info.msg) object.
|
||||
"""
|
||||
self.start_csmpl.set(ContinuousSamplingMode.OFF).wait()
|
||||
exp_time = self.scan_info.msg.scan_parameters['exp_time']
|
||||
self.start_csmpl.set(ContinuousSamplingMode.OFF).wait(timeout=self._pv_timeout)
|
||||
exp_time = self.scan_info.msg.scan_parameters["exp_time"]
|
||||
if self.scan_info.msg.scan_name != "exafs_scan":
|
||||
self.set_exposure_time(exp_time).wait()
|
||||
|
||||
@@ -81,8 +101,9 @@ class Trigger(PSIDeviceBase, TriggerControl):
|
||||
def on_trigger(self) -> DeviceStatus | StatusBase | None:
|
||||
"""Called when the device is triggered."""
|
||||
falcon = self.device_manager.devices.get("falcon", None)
|
||||
|
||||
|
||||
if falcon is not None:
|
||||
# pylint: disable=protected-access
|
||||
status = falcon._stop_erase_and_wait_for_acquiring()
|
||||
status.wait()
|
||||
|
||||
@@ -97,11 +118,10 @@ class Trigger(PSIDeviceBase, TriggerControl):
|
||||
return True
|
||||
|
||||
return self.smpl_done.get() == SamplingDone.DONE
|
||||
|
||||
self.smpl.put(1)
|
||||
status = self.task_handler.submit_task(_sampling_done,run=True)
|
||||
return status
|
||||
|
||||
self.smpl.put(1)
|
||||
status = self.task_handler.submit_task(_sampling_done, run=True)
|
||||
return status
|
||||
|
||||
def on_complete(self) -> DeviceStatus | StatusBase | None:
|
||||
"""Called to inquire if a device has completed a scans."""
|
||||
@@ -113,10 +133,7 @@ class Trigger(PSIDeviceBase, TriggerControl):
|
||||
"""Called when the device is stopped."""
|
||||
self.task_handler.shutdown()
|
||||
|
||||
def set_exposure_time(self, value:float) -> DeviceStatus:
|
||||
""" Utility method to set exposure time complying to device logic with cycle of min 0.2s."""
|
||||
cycles = max(int(value*5),1)
|
||||
def set_exposure_time(self, value: float) -> DeviceStatus:
|
||||
"""Utility method to set exposure time complying to device logic with cycle of min 0.2s."""
|
||||
cycles = max(int(value * 5), 1)
|
||||
return self.total_cycles.set(cycles)
|
||||
|
||||
|
||||
|
||||
|
||||
105
tests/tests_devices/test_devices_trigger.py
Normal file
105
tests/tests_devices/test_devices_trigger.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""Tests for Trigger device."""
|
||||
|
||||
import threading
|
||||
from unittest import mock
|
||||
|
||||
import ophyd
|
||||
import pytest
|
||||
from bec_server.device_server.tests.utils import DMMock
|
||||
from ophyd import DeviceStatus
|
||||
from ophyd_devices.tests.utils import MockPV, patch_dual_pvs
|
||||
|
||||
from superxas_bec.devices.trigger import ContinuousSamplingMode, Trigger
|
||||
|
||||
# pylint: disable=protected-access
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def trigger():
|
||||
"""Trigger device with mocked EPICS PVs."""
|
||||
name = "trigger"
|
||||
prefix = "X10DA-ES1:"
|
||||
with mock.patch.object(ophyd, "cl") as mock_cl:
|
||||
mock_cl.get_pv = MockPV
|
||||
mock_cl.thread_class = threading.Thread
|
||||
dev = Trigger(name=name, prefix=prefix, device_manager=DMMock())
|
||||
patch_dual_pvs(dev)
|
||||
yield dev
|
||||
|
||||
|
||||
@pytest.mark.parametrize(["exp_time", "cycles"], [(0.1, 1), (0.5, 2), (2, 10)])
|
||||
def test_devices_trigger_stage_core_scans(trigger, exp_time, cycles):
|
||||
"""Test on_connected method of Trigger device.
|
||||
|
||||
The pytest.mark.parametrize decorator is used to run the test for each parameter in the list.
|
||||
"""
|
||||
assert trigger.prefix == "X10DA-ES1:"
|
||||
assert trigger.name == "trigger"
|
||||
assert trigger._pv_timeout == 1
|
||||
#
|
||||
trigger.on_connected()
|
||||
trigger._pv_timeout = 0.2
|
||||
trigger.start_csmpl.put(ContinuousSamplingMode.ON)
|
||||
assert trigger.start_csmpl.get() == ContinuousSamplingMode.ON
|
||||
|
||||
# Set scan_info information for scan
|
||||
trigger.scan_info.msg.scan_name = "step_scan"
|
||||
trigger.scan_info.msg.scan_parameters["exp_time"] = exp_time
|
||||
|
||||
# On stage should set exposure time
|
||||
status = trigger.stage()
|
||||
if isinstance(status, DeviceStatus):
|
||||
status.wait()
|
||||
assert trigger.start_csmpl.get() == ContinuousSamplingMode.OFF
|
||||
# cycles will be multiple of exp_time/0.2 as int, minimum 1.
|
||||
assert trigger.total_cycles.get() == cycles
|
||||
|
||||
|
||||
def test_devices_trigger_unstage(trigger):
|
||||
"""
|
||||
Test on_unstage method of Trigger device.
|
||||
|
||||
This should put start_csmpl to ON.
|
||||
"""
|
||||
trigger.start_csmpl.put(ContinuousSamplingMode.OFF)
|
||||
assert trigger.start_csmpl.get() == ContinuousSamplingMode.OFF
|
||||
status = trigger.unstage()
|
||||
status.wait()
|
||||
assert trigger.start_csmpl.get() == ContinuousSamplingMode.ON
|
||||
|
||||
|
||||
def test_devices_trigger_stop(trigger):
|
||||
"""
|
||||
Test on_stop method of Trigger device.
|
||||
|
||||
This should stop the task_handler.
|
||||
"""
|
||||
assert trigger.stopped is False
|
||||
with mock.patch.object(trigger, "task_handler") as mock_handler:
|
||||
trigger.stop()
|
||||
mock_handler.shutdown.assert_called_once()
|
||||
assert trigger.stopped is True
|
||||
|
||||
|
||||
def test_devices_trigger_trigger(trigger):
|
||||
"""Test trigger method of Trigger device."""
|
||||
# TODO we should use the ScanStatusMessage to update scan_info here
|
||||
falcon = mock.MagicMock()
|
||||
falcon.name.return_value = "falcon"
|
||||
status = DeviceStatus(device=falcon)
|
||||
status.set_finished()
|
||||
falcon._stop_erase_and_wait_for_acquiring.return_value = status
|
||||
trigger.device_manager.devices["falcon"] = falcon
|
||||
|
||||
trigger_status = DeviceStatus(device=trigger)
|
||||
trigger_status.set_finished()
|
||||
with mock.patch.object(
|
||||
trigger.task_handler, "submit_task", return_value=trigger_status
|
||||
) as mock_submit:
|
||||
status = trigger.trigger()
|
||||
assert falcon._stop_erase_and_wait_for_acquiring.call_count == 1
|
||||
assert trigger.smpl.get() == 1 # smpl called with 1
|
||||
# TODO check that the task_handler is called with the correct function
|
||||
# This is currently not easily testable
|
||||
assert mock_submit.call_count == 1
|
||||
assert trigger_status == status
|
||||
Reference in New Issue
Block a user