diff --git a/ophyd_devices/epics/devices/pilatus_csaxs.py b/ophyd_devices/epics/devices/pilatus_csaxs.py index 2aa2634..fe9a667 100644 --- a/ophyd_devices/epics/devices/pilatus_csaxs.py +++ b/ophyd_devices/epics/devices/pilatus_csaxs.py @@ -190,6 +190,7 @@ class PilatuscSAXS(DetectorBase): def _init_detector(self) -> None: """Initialize the detector""" # TODO add check if detector is running + self._stop_det() self._set_trigger(TriggerSource.EXT_ENABLE) def _init_filewriter(self) -> None: @@ -212,14 +213,14 @@ class PilatuscSAXS(DetectorBase): setpoint = int(self.mokev * factor) threshold = self.cam.threshold_energy.read()[self.cam.threshold_energy.name]["value"] if not np.isclose(setpoint / 2, threshold, rtol=0.05): - self.cam.threshold_energy.set(setpoint / 2) + self.cam.threshold_energy.put(setpoint / 2) def _set_acquisition_params(self) -> None: """set acquisition parameters on the detector""" # self.cam.acquire_time.set(self.exp_time) # self.cam.acquire_period.set(self.exp_time + self.readout) - self.cam.num_images.set(int(self.scaninfo.num_points * self.scaninfo.frames_per_trigger)) - self.cam.num_frames.set(1) + self.cam.num_images.put(int(self.scaninfo.num_points * self.scaninfo.frames_per_trigger)) + self.cam.num_frames.put(1) self._update_readout_time() def _set_trigger(self, trigger_source: int) -> None: diff --git a/tests/test_pilatus_csaxs.py b/tests/test_pilatus_csaxs.py index f20c5fa..dfedb85 100644 --- a/tests/test_pilatus_csaxs.py +++ b/tests/test_pilatus_csaxs.py @@ -2,95 +2,22 @@ import os import pytest from unittest import mock -from ophyd.signal import Signal -from ophyd import Staged +import ophyd from bec_lib.core import BECMessage, MessageEndpoints -from bec_lib.core.devicemanager import DeviceContainer -from bec_lib.core.tests.utils import ProducerMock -import requests +from ophyd_devices.epics.devices.pilatus_csaxs import PilatuscSAXS + +from tests.utils import DMMock, MockPV -class MockSignal(Signal): - def __init__(self, read_pv, *, string=False, name=None, parent=None, **kwargs): - self.read_pv = read_pv - self._string = bool(string) - super().__init__(name=name, parent=parent, **kwargs) - self._waited_for_connection = False - self._subscriptions = [] - - def wait_for_connection(self): - self._waited_for_connection = True - - def subscribe(self, method, event_type, **kw): - self._subscriptions.append((method, event_type, kw)) - - def describe_configuration(self): - return {self.name + "_conf": {"source": "SIM:test"}} - - def read_configuration(self): - return {self.name + "_conf": {"value": 0}} - - -with mock.patch("ophyd.EpicsSignal", new=MockSignal), mock.patch( - "ophyd.EpicsSignalRO", new=MockSignal -), mock.patch("ophyd.EpicsSignalWithRBV", new=MockSignal): - from ophyd_devices.epics.devices.pilatus_csaxs import PilatuscSAXS - - -# TODO maybe specify here that this DeviceMock is for usage in the DeviceServer -class DeviceMock: - def __init__(self, name: str, value: float = 0.0): - self.name = name - self.read_buffer = value - self._config = {"deviceConfig": {"limits": [-50, 50]}, "userParameter": None} - self._enabled_set = True - self._enabled = True - - def read(self): - return {self.name: {"value": self.read_buffer}} - - def readback(self): - return self.read_buffer - - @property - def enabled_set(self) -> bool: - return self._enabled_set - - @enabled_set.setter - def enabled_set(self, val: bool): - self._enabled_set = val - - @property - def enabled(self) -> bool: - return self._enabled - - @enabled.setter - def enabled(self, val: bool): - self._enabled = val - - @property - def user_parameter(self): - return self._config["userParameter"] - - @property - def obj(self): - return self - - -class DMMock: - """Mock for DeviceManager - - The mocked DeviceManager creates a device containert and a producer. - - """ - - def __init__(self): - self.devices = DeviceContainer() - self.producer = ProducerMock() - - def add_device(self, name: str, value: float = 0.0): - self.devices[name] = DeviceMock(name, value) +def patch_dual_pvs(device): + for walk in device.walk_signals(): + if not hasattr(walk.item, "_read_pv"): + continue + if not hasattr(walk.item, "_write_pv"): + continue + if walk.item._read_pv.pvname.endswith("_RBV"): + walk.item._read_pv = walk.item._write_pv @pytest.fixture(scope="function") @@ -99,73 +26,37 @@ def mock_det(): prefix = "X12SA-ES-PILATUS300K:" sim_mode = False dm = DMMock() - # dm.add_device("mokev", value=12.4) with mock.patch.object(dm, "producer"): - with mock.patch.object( - PilatuscSAXS, "_update_service_config" - ) as mock_update_service_config, mock.patch( + with mock.patch( "ophyd_devices.epics.devices.pilatus_csaxs.FileWriterMixin" - ) as filemixin: - with mock.patch.object(PilatuscSAXS, "_init"): - yield PilatuscSAXS(name=name, prefix=prefix, device_manager=dm, sim_mode=sim_mode) + ) as filemixin, mock.patch( + "ophyd_devices.epics.devices.pilatus_csaxs.PilatuscSAXS._update_service_config" + ) as mock_service_config: + with mock.patch.object(ophyd, "cl") as mock_cl: + mock_cl.get_pv = MockPV + with mock.patch.object(PilatuscSAXS, "_init"): + det = PilatuscSAXS( + name=name, prefix=prefix, device_manager=dm, sim_mode=sim_mode + ) + patch_dual_pvs(det) + yield det @pytest.mark.parametrize( - "trigger_source, sim_mode, scan_status_msg, expected_exception", + "trigger_source, detector_state", [ - ( - 1, - True, - BECMessage.ScanStatusMessage( - scanID="1", - status={}, - info={ - "RID": "mockrid1111", - "queueID": "mockqueueID111", - "scan_number": 1, - "exp_time": 0.012, - "num_points": 500, - "readout_time": 0.003, - "scan_type": "fly", - "num_lines": 0.012, - "frames_per_trigger": 1, - }, - ), - True, - ), - ( - 1, - False, - BECMessage.ScanStatusMessage( - scanID="1", - status={}, - info={ - "RID": "mockrid1111", - "queueID": "mockqueueID111", - "scan_number": 1, - "exp_time": 0.012, - "num_points": 500, - "readout_time": 0.003, - "scan_type": "fly", - "num_lines": 0.012, - "frames_per_trigger": 1, - }, - ), - False, - ), + (1, 0), ], ) # TODO rewrite this one, write test for init_detector, init_filewriter is tested -def test_init( +def test_init_detector( + mock_det, trigger_source, - sim_mode, - scan_status_msg, - expected_exception, + detector_state, ): """Test the _init function: This includes testing the functions: - - _set_default_parameter - _init_detector - _stop_det - _set_trigger @@ -174,34 +65,9 @@ def test_init( Validation upon setting the correct PVs """ - name = "pilatus" - prefix = "X12SA-ES-PILATUS300K:" - sim_mode = sim_mode - dm = DMMock() - with mock.patch.object(dm, "producer") as producer, mock.patch.object( - PilatuscSAXS, "_init_filewriter" - ) as mock_init_fw, mock.patch.object( - PilatuscSAXS, "_update_scaninfo" - ) as mock_update_scaninfo, mock.patch.object( - PilatuscSAXS, "_update_filewriter" - ) as mock_update_filewriter, mock.patch.object( - PilatuscSAXS, "_update_service_config" - ) as mock_update_service_config: - mock_det = PilatuscSAXS(name=name, prefix=prefix, device_manager=dm, sim_mode=sim_mode) - if expected_exception: - with pytest.raises(Exception): - mock_det._init() - mock_init_fw.assert_called_once() - else: - mock_det._init() # call the method you want to test - assert mock_det.cam.acquire.get() == 0 - assert mock_det.cam.trigger_mode.get() == trigger_source - mock_init_fw.assert_called() - mock_update_scaninfo.assert_called_once() - mock_update_filewriter.assert_called_once() - mock_update_service_config.assert_called_once() - - assert mock_init_fw.call_count == 2 + mock_det._init_detector() # call the method you want to test + assert mock_det.cam.acquire.get() == detector_state + assert mock_det.cam.trigger_mode.get() == trigger_source @pytest.mark.parametrize( @@ -239,7 +105,7 @@ def test_stage( stopped, expected_exception, ): - with mock.patch.object(PilatuscSAXS, "_publish_file_location") as mock_publish_file_location: + with mock.patch.object(mock_det, "_publish_file_location") as mock_publish_file_location: mock_det.scaninfo.num_points = scaninfo["num_points"] mock_det.scaninfo.frames_per_trigger = scaninfo["frames_per_trigger"] mock_det.filewriter.compile_full_filename.return_value = scaninfo["filepath"] @@ -542,78 +408,6 @@ def test_prep_file_writer(mock_det, scaninfo, data_msgs, urls, requests_state, e assert call == mock_call -# @pytest.mark.parametrize( -# "scaninfo, daq_status, expected_exception", -# [ -# ( -# { -# "eacc": "e12345", -# "num_points": 500, -# "frames_per_trigger": 1, -# "filepath": "test.h5", -# "scanID": "123", -# }, -# {"state": "BUSY", "acquisition": {"state": "WAITING_IMAGES"}}, -# False, -# ), -# ( -# { -# "eacc": "e12345", -# "num_points": 500, -# "frames_per_trigger": 1, -# "filepath": "test.h5", -# "scanID": "123", -# }, -# {"state": "BUSY", "acquisition": {"state": "WAITING_IMAGES"}}, -# False, -# ), -# ( -# { -# "eacc": "e12345", -# "num_points": 500, -# "frames_per_trigger": 1, -# "filepath": "test.h5", -# "scanID": "123", -# }, -# {"state": "BUSY", "acquisition": {"state": "ERROR"}}, -# True, -# ), -# ], -# ) -# def test_prep_file_writer(mock_det, scaninfo, daq_status, expected_exception): -# with mock.patch.object(mock_det, "std_client") as mock_std_daq, mock.patch.object( -# mock_det, "_filepath_exists" -# ) as mock_file_path_exists, mock.patch.object( -# mock_det, "_stop_file_writer" -# ) as mock_stop_file_writer, mock.patch.object( -# mock_det, "scaninfo" -# ) as mock_scaninfo: -# # mock_det = eiger_factory(name, prefix, sim_mode) -# mock_det.std_client = mock_std_daq -# mock_std_daq.start_writer_async.return_value = None -# mock_std_daq.get_status.return_value = daq_status -# mock_det.filewriter.compile_full_filename.return_value = scaninfo["filepath"] -# mock_det.scaninfo.num_points = scaninfo["num_points"] -# mock_det.scaninfo.frames_per_trigger = scaninfo["frames_per_trigger"] - -# if expected_exception: -# with pytest.raises(Exception): -# mock_det._prep_file_writer() -# mock_file_path_exists.assert_called_once() -# assert mock_stop_file_writer.call_count == 2 - -# else: -# mock_det._prep_file_writer() -# mock_file_path_exists.assert_called_once() -# mock_stop_file_writer.assert_called_once() - -# daq_writer_call = { -# "output_file": scaninfo["filepath"], -# "n_images": int(scaninfo["num_points"] * scaninfo["frames_per_trigger"]), -# } -# mock_std_daq.start_writer_async.assert_called_with(daq_writer_call) - - @pytest.mark.parametrize( "stopped, expected_exception", [ @@ -649,14 +443,6 @@ def test_unstage( assert mock_det._stopped == False -# def test_stop_fw(mock_det): -# with mock.patch.object(mock_det, "std_client") as mock_std_daq: -# mock_std_daq.stop_writer.return_value = None -# mock_det.std_client = mock_std_daq -# mock_det._stop_file_writer() -# mock_std_daq.stop_writer.assert_called_once() - - def test_stop(mock_det): with mock.patch.object(mock_det, "_stop_det") as mock_stop_det, mock.patch.object( mock_det, "_stop_file_writer" @@ -675,17 +461,17 @@ def test_stop(mock_det): [ ( False, - Staged.no, + ophyd.Staged.no, False, ), ( True, - Staged.no, + ophyd.Staged.no, False, ), ( False, - Staged.yes, + ophyd.Staged.yes, True, ), ],