From 52c0e90f178b0512dec778c6650fc8038bc67cfc Mon Sep 17 00:00:00 2001 From: appel_c Date: Thu, 4 Sep 2025 17:54:34 +0200 Subject: [PATCH] test(pilatus): add tests for the pilatus. on_stage & on_connected tests fail due to AD baseclass callbacks --- debye_bec/devices/pilatus/pilatus.py | 14 +- tests/tests_devices/test_pilatus.py | 271 +++++++++++++++++++++++++++ 2 files changed, 278 insertions(+), 7 deletions(-) create mode 100644 tests/tests_devices/test_pilatus.py diff --git a/debye_bec/devices/pilatus/pilatus.py b/debye_bec/devices/pilatus/pilatus.py index 8a570db..c2fbad9 100644 --- a/debye_bec/devices/pilatus/pilatus.py +++ b/debye_bec/devices/pilatus/pilatus.py @@ -33,9 +33,9 @@ if TYPE_CHECKING: # pragma: no cover from bec_server.device_server.device_server import DeviceManagerDS PILATUS_READOUT_TIME = 0.1 # in s -PILATUS_ACQUIRE_TIME = ( - 999999 # This time is the timeout of the detector in operation mode, so it needs to be large. -) +# PILATUS_ACQUIRE_TIME = ( +# 999999 # This time is the timeout of the detector in operation mode, so it needs to be large. +# ) # pylint: disable=redefined-outer-name @@ -256,7 +256,7 @@ class Pilatus(PSIDeviceBase, ADBase): return None # TODO implement logic for 'xas' scans else: - exp_time = scan_msg.scan_parameters.get("exposure_time", 0.0) + exp_time = scan_msg.scan_parameters.get("exp_time", 0.0) if exp_time - self._readout_time <= 0: raise ValueError( f"Exposure time {exp_time} is too short for Pilatus with readout_time {self._readout_time}." @@ -273,8 +273,8 @@ class Pilatus(PSIDeviceBase, ADBase): # Camera settings self.cam.num_exposures.set(1).wait(5) self.cam.num_images.set(n_images).wait(5) - self.cam.acquire_time.set(exp_time).wait(5) # let's try this - self.cam.acquire_period.set(PILATUS_ACQUIRE_TIME).wait(5) + self.cam.acquire_time.set(detector_exp_time).wait(5) # let's try this + self.cam.acquire_period.set(exp_time).wait(5) self.filter_number.set(0).wait(5) # HDF5 settings logger.debug(f"Setting HDF5 file path to {file_path} and file name to {file_name}") @@ -379,7 +379,7 @@ if __name__ == "__main__": logger.info(f"Sleeping for 5s") time.sleep(5) pilatus.scan_info.msg.num_points = n_pnts - pilatus.scan_info.msg.scan_parameters["exposure_time"] = exp_time + pilatus.scan_info.msg.scan_parameters["exp_time"] = exp_time pilatus.scan_info.msg.scan_parameters["frames_per_trigger"] = 1 pilatus.scan_info.msg.info["file_components"] = ( f"/sls/x01da/data/p22481/raw/data/S00000-00999/S{scan_number:05d}/S{scan_number:05d}", diff --git a/tests/tests_devices/test_pilatus.py b/tests/tests_devices/test_pilatus.py new file mode 100644 index 0000000..75c7c69 --- /dev/null +++ b/tests/tests_devices/test_pilatus.py @@ -0,0 +1,271 @@ +# pylint: skip-file +import os +import threading +from typing import TYPE_CHECKING, Generator +from unittest import mock + +import numpy as np +import ophyd +import pytest +from bec_lib.messages import ScanStatusMessage +from bec_server.scan_server.scan_worker import ScanWorker +from ophyd_devices import CompareStatus, DeviceStatus +from ophyd_devices.interfaces.base_classes.psi_device_base import DeviceStoppedError +from ophyd_devices.tests.utils import MockPV, patch_dual_pvs +from ophyd_devices.utils.psi_device_base_utils import TaskStatus + +from debye_bec.devices.pilatus.pilatus import ( + ACQUIREMODE, + COMPRESSIONALGORITHM, + DETECTORSTATE, + FILEWRITEMODE, + TRIGGERMODE, + Pilatus, +) + +if TYPE_CHECKING: # pragma no cover + from bec_lib.messages import FileMessage + +# @pytest.fixture(scope="function") +# def scan_worker_mock(scan_server_mock): +# scan_server_mock.device_manager.connector = mock.MagicMock() +# scan_worker = ScanWorker(parent=scan_server_mock) +# yield scan_worker + + +@pytest.fixture( + scope="function", + params=[(0.1, 1, 1, "line_scan"), (0.2, 2, 2, "time_scan"), (0.5, 5, 5, "xas_advanced_scan")], +) +def mock_scan_info(request, tmpdir): + exp_time, frames_per_trigger, num_points, scan_name = request.param + scan_info = ScanStatusMessage( + scan_id="test_id", + status="open", + scan_number=1, + scan_parameters={ + "exp_time": exp_time, + "frames_per_trigger": frames_per_trigger, + "system_config": {}, + }, + info={"file_components": (f"{tmpdir}/data/S00000/S000001", "h5")}, + num_points=num_points, + scan_name=scan_name, + ) + yield scan_info + + +@pytest.fixture(scope="function") +def pilatus(mock_scan_info) -> Generator[Pilatus, None, None]: + name = "pilatus" + prefix = "X01DA-OP-MO1:PILATUS:" + with mock.patch.object(ophyd, "cl") as mock_cl: + mock_cl.get_pv = MockPV + mock_cl.thread_class = threading.Thread + dev = Pilatus(name=name, prefix=prefix) + patch_dual_pvs(dev) + # dev.image1 = mock.MagicMock() + # with mock.patch.object(dev, "image1"): + with mock.patch.object(dev, "task_handler"): + dev.scan_info.msg = mock_scan_info + try: + yield dev + finally: + dev.destroy() + + +# TODO figure out how to test as set calls on the PV below seem to break it.. +# def test_pilatus_on_connected(pilatus): +# """Test the on_connected logic of the Pilatus detector.""" +# pilatus.cam.acquire._read_pv.mock_data = ACQUIREMODE.DONE.value +# pilatus.hdf.capture._read_pv.mock_data = ACQUIREMODE.DONE.value +# pilatus.on_connected() +# assert pilatus.cam.trigger_mode.get() == TRIGGERMODE.MULT_TRIGGER +# assert pilatus.hdf.file_write_mode.get() == FILEWRITEMODE.STREAM +# assert pilatus.hdf.file_template.get() == "%s%s" +# assert pilatus.hdf.auto_save.get() == 1 +# assert pilatus.hdf.lazy_open.get() == 1 +# assert pilatus.hdf.compression.get() == COMPRESSIONALGORITHM.NONE + + +def test_pilatus_on_stop(pilatus): + """Test the on_stop logic of the Pilatus detector.""" + pilatus.cam.acquire._read_pv.mock_data = ACQUIREMODE.ACQUIRING.value + pilatus.hdf.capture._read_pv.mock_data = ACQUIREMODE.ACQUIRING.value + pilatus.on_stop() + assert pilatus.cam.acquire.get() == ACQUIREMODE.DONE + assert pilatus.hdf.capture.get() == ACQUIREMODE.DONE + + +def test_pilatus_on_destroy(pilatus): + """Test the on_destroy logic of the Pilatus detector.""" + with mock.patch.object(pilatus, "on_stop") as mock_on_stop: + pilatus.destroy() + assert mock_on_stop.call_count == 1 + assert pilatus._poll_thread_stop_event.is_set() + + +def test_pilatus_on_failure_callback(pilatus): + """Test the on_failure_callback logic of the Pilatus detector.""" + + with mock.patch.object(pilatus, "on_stop") as mock_on_stop: + status = DeviceStatus(pilatus) + status.set_finished() # Does not trigger 'stop' + assert mock_on_stop.call_count == 0 + status = DeviceStatus(pilatus) + status.set_exception(RuntimeError("Test error")) # triggers 'stop' + assert mock_on_stop.call_count == 1 + + +def test_pilatus_on_pre_scan(pilatus): + """Test the on_pre_scan logic of the Pilatus detector.""" + if pilatus.scan_info.msg.scan_name.startswith("xas"): + assert pilatus.on_pre_scan() is None + return + pilatus.cam.acquire._read_pv.mock_data = ACQUIREMODE.DONE.value + pilatus.hdf.capture._read_pv.mock_data = ACQUIREMODE.DONE.value + pilatus.cam.armed._read_pv.mock_data = DETECTORSTATE.UNARMED.value + status = pilatus.on_pre_scan() + assert status.done is False + pilatus.cam.armed.put(DETECTORSTATE.ARMED.value) + status.wait(timeout=5) + assert status.done is True + assert status.success is True + + +def test_pilatus_on_trigger(pilatus): + """test on trigger logic of the Pilatus detector.""" + if pilatus.scan_info.msg.scan_name.startswith("xas"): + status = pilatus.trigger() + assert status.done is True + assert status.success is True + return None + pilatus.hdf.num_captured._read_pv.mock_data = 0 + pilatus.trigger_shot.put(0) + status = pilatus.trigger() + assert status.done is False + assert pilatus.trigger_shot.get() == 1 + pilatus.hdf.num_captured._read_pv.mock_data = 1 + status.wait(timeout=5) + assert status.done is True + assert status.success is True + + +def test_pilatus_on_trigger_cancel_on_stop(pilatus): + """Test that the status of the trigger is cancelled if stop is called""" + if pilatus.scan_info.msg.scan_name.startswith("xas"): + status = pilatus.trigger() + assert status.done is True + assert status.success is True + return + pilatus.hdf.num_captured._read_pv.mock_data = 0 + pilatus.trigger_shot.put(0) + status = pilatus.trigger() + assert status.done is False + with pytest.raises(DeviceStoppedError): + pilatus.stop() + status.wait(timeout=5) + + +def test_pilatus_on_complete(pilatus): + """Test the on_complete logic of the Pilatus detector.""" + if pilatus.scan_info.msg.scan_name.startswith("xas"): + status = pilatus.complete() + assert status.done is True + assert status.success is True + return + # Check in addition that the file event is set properly, once with if it works, and once if not (i.e. when cancelled) + for success in [True, False]: + if success is True: + pilatus.file_event.put(file_path="", done=False, successful=False) + pilatus._full_path = "file-path-for-success" + else: + pilatus.file_event.put(file_path="", done=False, successful=True) + pilatus._full_path = "file-path-for-failure" + # Set values for relevant PVs + pilatus.cam.acquire._read_pv.mock_data = ACQUIREMODE.ACQUIRING.value + pilatus.hdf.capture._read_pv.mock_data = ACQUIREMODE.ACQUIRING.value + pilatus.cam.armed._read_pv.mock_data = DETECTORSTATE.ARMED.value + num_images = pilatus.scan_info.msg.num_points * pilatus.scan_info.msg.scan_parameters.get( + "frames_per_trigger", 1 + ) + pilatus.hdf.num_captured._read_pv.mock_data = num_images - 1 + # Call on complete + status = pilatus.complete() + # Should not be finished + assert status.done is False + pilatus.cam.acquire.put(ACQUIREMODE.DONE.value) + pilatus.hdf.capture.put(ACQUIREMODE.DONE.value) + pilatus.cam.armed.put(DETECTORSTATE.UNARMED.value) + assert status.done is False + if success is True: + pilatus.hdf.num_captured._read_pv.mock_data = num_images + # Now it should resolve + status.wait(timeout=5) + assert status.done is True + assert status.success is True + file_msg: FileMessage = pilatus.file_event.get() + assert file_msg.file_path == "file-path-for-success" + assert file_msg.done is True + assert file_msg.successful is True + else: + with pytest.raises(DeviceStoppedError): + pilatus.stop() + status.wait(timeout=5) + assert status.done is True + assert status.success is False + file_msg: FileMessage = pilatus.file_event.get() + assert file_msg.file_path == "file-path-for-failure" + assert file_msg.done is True + assert file_msg.successful is False + + +# TODO, figure out how to properly test this.. +# def test_pilatus_on_stage(pilatus): +# """Test the on_stage logic of the Pilatus detector.""" +# # Make sure that no additional logic from stage signals of underlying components is triggered +# pilatus.stage_sigs = {} +# pilatus.cam.stage_sigs = {} +# pilatus.hdf.stage_sigs = {} +# if pilatus.scan_info.msg.scan_name.startswith("xas"): +# pilatus.on_stage() +# return +# exp_time = pilatus.scan_info.msg.scan_parameters.get("exp_time", 0.1) +# n_images = pilatus.scan_info.msg.num_points * pilatus.scan_info.msg.scan_parameters.get( +# "frames_per_trigger", 1 +# ) +# if exp_time <= 0.1: +# with pytest.raises(ValueError): +# pilatus.on_stage() +# return +# pilatus.filter_number.put(10) +# pilatus.cam.array_counter.put(1) +# file_components = pilatus.scan_info.msg.info.get("file_components", ("", ""))[0] +# base_path = file_components[0].rsplit("/", 1)[0] +# file_name = file_components[0].rsplit("/", 1)[1] + "_pilatus.h5" +# file_path = os.path.join(base_path, file_name) +# pilatus.on_stage() +# assert pilatus.cam.array_callbacks.get() == 0 +# assert pilatus.hdf.enable.get() == 1 +# assert pilatus.cam.num_exposures.get() == 1 +# assert pilatus.cam.num_images.get() == n_images +# assert pilatus.cam.acquire_time.get() == exp_time - pilatus._readout_time +# assert pilatus.cam.acquire_period.get() == exp_time +# assert pilatus.filter_number.get() == 0 +# assert pilatus.hdf.file_path.get() == base_path +# assert pilatus.hdf.file_name.get() == file_name +# assert pilatus.hdf.num_capture.get() == n_images +# assert pilatus.cam.array_counter.get() == 0 +# file_msg: FileMessage = pilatus.file_event.get() +# assert file_msg.file_path == file_path +# assert file_msg.done is False +# assert file_msg.successful is False + + +def test_pilatus_on_stage_raises_low_exp_time(pilatus): + """Test that on_stage raises a ValueError if the exposure time is too low.""" + pilatus.scan_info.msg.scan_parameters["exp_time"] = 0.09 + if pilatus.scan_info.msg.scan_name.startswith("xas"): + return + with pytest.raises(ValueError): + pilatus.on_stage()