272 lines
10 KiB
Python
272 lines
10 KiB
Python
# pylint: skip-file
|
|
import os
|
|
import threading
|
|
from typing import TYPE_CHECKING, Generator
|
|
from unittest import mock
|
|
|
|
import numpy as np
|
|
import ophyd
|
|
import pytest
|
|
from bec_lib.messages import ScanStatusMessage
|
|
from bec_server.scan_server.scan_worker import ScanWorker
|
|
from ophyd_devices import CompareStatus, DeviceStatus
|
|
from ophyd_devices.interfaces.base_classes.psi_device_base import DeviceStoppedError
|
|
from ophyd_devices.tests.utils import MockPV, patch_dual_pvs
|
|
from ophyd_devices.utils.psi_device_base_utils import TaskStatus
|
|
|
|
from debye_bec.devices.pilatus.pilatus import (
|
|
ACQUIREMODE,
|
|
COMPRESSIONALGORITHM,
|
|
DETECTORSTATE,
|
|
FILEWRITEMODE,
|
|
TRIGGERMODE,
|
|
Pilatus,
|
|
)
|
|
|
|
if TYPE_CHECKING: # pragma no cover
|
|
from bec_lib.messages import FileMessage
|
|
|
|
# @pytest.fixture(scope="function")
|
|
# def scan_worker_mock(scan_server_mock):
|
|
# scan_server_mock.device_manager.connector = mock.MagicMock()
|
|
# scan_worker = ScanWorker(parent=scan_server_mock)
|
|
# yield scan_worker
|
|
|
|
|
|
@pytest.fixture(
|
|
scope="function",
|
|
params=[(0.1, 1, 1, "line_scan"), (0.2, 2, 2, "time_scan"), (0.5, 5, 5, "xas_advanced_scan")],
|
|
)
|
|
def mock_scan_info(request, tmpdir):
|
|
exp_time, frames_per_trigger, num_points, scan_name = request.param
|
|
scan_info = ScanStatusMessage(
|
|
scan_id="test_id",
|
|
status="open",
|
|
scan_number=1,
|
|
scan_parameters={
|
|
"exp_time": exp_time,
|
|
"frames_per_trigger": frames_per_trigger,
|
|
"system_config": {},
|
|
},
|
|
info={"file_components": (f"{tmpdir}/data/S00000/S000001", "h5")},
|
|
num_points=num_points,
|
|
scan_name=scan_name,
|
|
)
|
|
yield scan_info
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def pilatus(mock_scan_info) -> Generator[Pilatus, None, None]:
|
|
name = "pilatus"
|
|
prefix = "X01DA-OP-MO1:PILATUS:"
|
|
with mock.patch.object(ophyd, "cl") as mock_cl:
|
|
mock_cl.get_pv = MockPV
|
|
mock_cl.thread_class = threading.Thread
|
|
dev = Pilatus(name=name, prefix=prefix)
|
|
patch_dual_pvs(dev)
|
|
# dev.image1 = mock.MagicMock()
|
|
# with mock.patch.object(dev, "image1"):
|
|
with mock.patch.object(dev, "task_handler"):
|
|
dev.scan_info.msg = mock_scan_info
|
|
try:
|
|
yield dev
|
|
finally:
|
|
dev.destroy()
|
|
|
|
|
|
# TODO figure out how to test as set calls on the PV below seem to break it..
|
|
# def test_pilatus_on_connected(pilatus):
|
|
# """Test the on_connected logic of the Pilatus detector."""
|
|
# pilatus.cam.acquire._read_pv.mock_data = ACQUIREMODE.DONE.value
|
|
# pilatus.hdf.capture._read_pv.mock_data = ACQUIREMODE.DONE.value
|
|
# pilatus.on_connected()
|
|
# assert pilatus.cam.trigger_mode.get() == TRIGGERMODE.MULT_TRIGGER
|
|
# assert pilatus.hdf.file_write_mode.get() == FILEWRITEMODE.STREAM
|
|
# assert pilatus.hdf.file_template.get() == "%s%s"
|
|
# assert pilatus.hdf.auto_save.get() == 1
|
|
# assert pilatus.hdf.lazy_open.get() == 1
|
|
# assert pilatus.hdf.compression.get() == COMPRESSIONALGORITHM.NONE
|
|
|
|
|
|
def test_pilatus_on_stop(pilatus):
|
|
"""Test the on_stop logic of the Pilatus detector."""
|
|
pilatus.cam.acquire._read_pv.mock_data = ACQUIREMODE.ACQUIRING.value
|
|
pilatus.hdf.capture._read_pv.mock_data = ACQUIREMODE.ACQUIRING.value
|
|
pilatus.on_stop()
|
|
assert pilatus.cam.acquire.get() == ACQUIREMODE.DONE
|
|
assert pilatus.hdf.capture.get() == ACQUIREMODE.DONE
|
|
|
|
|
|
def test_pilatus_on_destroy(pilatus):
|
|
"""Test the on_destroy logic of the Pilatus detector."""
|
|
with mock.patch.object(pilatus, "on_stop") as mock_on_stop:
|
|
pilatus.destroy()
|
|
assert mock_on_stop.call_count == 1
|
|
assert pilatus._poll_thread_stop_event.is_set()
|
|
|
|
|
|
def test_pilatus_on_failure_callback(pilatus):
|
|
"""Test the on_failure_callback logic of the Pilatus detector."""
|
|
|
|
with mock.patch.object(pilatus, "on_stop") as mock_on_stop:
|
|
status = DeviceStatus(pilatus)
|
|
status.set_finished() # Does not trigger 'stop'
|
|
assert mock_on_stop.call_count == 0
|
|
status = DeviceStatus(pilatus)
|
|
status.set_exception(RuntimeError("Test error")) # triggers 'stop'
|
|
assert mock_on_stop.call_count == 1
|
|
|
|
|
|
def test_pilatus_on_pre_scan(pilatus):
|
|
"""Test the on_pre_scan logic of the Pilatus detector."""
|
|
if pilatus.scan_info.msg.scan_name.startswith("xas"):
|
|
assert pilatus.on_pre_scan() is None
|
|
return
|
|
pilatus.cam.acquire._read_pv.mock_data = ACQUIREMODE.DONE.value
|
|
pilatus.hdf.capture._read_pv.mock_data = ACQUIREMODE.DONE.value
|
|
pilatus.cam.armed._read_pv.mock_data = DETECTORSTATE.UNARMED.value
|
|
status = pilatus.on_pre_scan()
|
|
assert status.done is False
|
|
pilatus.cam.armed.put(DETECTORSTATE.ARMED.value)
|
|
status.wait(timeout=5)
|
|
assert status.done is True
|
|
assert status.success is True
|
|
|
|
|
|
def test_pilatus_on_trigger(pilatus):
|
|
"""test on trigger logic of the Pilatus detector."""
|
|
if pilatus.scan_info.msg.scan_name.startswith("xas"):
|
|
status = pilatus.trigger()
|
|
assert status.done is True
|
|
assert status.success is True
|
|
return None
|
|
pilatus.hdf.num_captured._read_pv.mock_data = 0
|
|
pilatus.trigger_shot.put(0)
|
|
status = pilatus.trigger()
|
|
assert status.done is False
|
|
assert pilatus.trigger_shot.get() == 1
|
|
pilatus.hdf.num_captured._read_pv.mock_data = 1
|
|
status.wait(timeout=5)
|
|
assert status.done is True
|
|
assert status.success is True
|
|
|
|
|
|
def test_pilatus_on_trigger_cancel_on_stop(pilatus):
|
|
"""Test that the status of the trigger is cancelled if stop is called"""
|
|
if pilatus.scan_info.msg.scan_name.startswith("xas"):
|
|
status = pilatus.trigger()
|
|
assert status.done is True
|
|
assert status.success is True
|
|
return
|
|
pilatus.hdf.num_captured._read_pv.mock_data = 0
|
|
pilatus.trigger_shot.put(0)
|
|
status = pilatus.trigger()
|
|
assert status.done is False
|
|
with pytest.raises(DeviceStoppedError):
|
|
pilatus.stop()
|
|
status.wait(timeout=5)
|
|
|
|
|
|
def test_pilatus_on_complete(pilatus):
|
|
"""Test the on_complete logic of the Pilatus detector."""
|
|
if pilatus.scan_info.msg.scan_name.startswith("xas"):
|
|
status = pilatus.complete()
|
|
assert status.done is True
|
|
assert status.success is True
|
|
return
|
|
# Check in addition that the file event is set properly, once with if it works, and once if not (i.e. when cancelled)
|
|
for success in [True, False]:
|
|
if success is True:
|
|
pilatus.file_event.put(file_path="", done=False, successful=False)
|
|
pilatus._full_path = "file-path-for-success"
|
|
else:
|
|
pilatus.file_event.put(file_path="", done=False, successful=True)
|
|
pilatus._full_path = "file-path-for-failure"
|
|
# Set values for relevant PVs
|
|
pilatus.cam.acquire._read_pv.mock_data = ACQUIREMODE.ACQUIRING.value
|
|
pilatus.hdf.capture._read_pv.mock_data = ACQUIREMODE.ACQUIRING.value
|
|
pilatus.cam.armed._read_pv.mock_data = DETECTORSTATE.ARMED.value
|
|
num_images = pilatus.scan_info.msg.num_points * pilatus.scan_info.msg.scan_parameters.get(
|
|
"frames_per_trigger", 1
|
|
)
|
|
pilatus.hdf.num_captured._read_pv.mock_data = num_images - 1
|
|
# Call on complete
|
|
status = pilatus.complete()
|
|
# Should not be finished
|
|
assert status.done is False
|
|
pilatus.cam.acquire.put(ACQUIREMODE.DONE.value)
|
|
pilatus.hdf.capture.put(ACQUIREMODE.DONE.value)
|
|
pilatus.cam.armed.put(DETECTORSTATE.UNARMED.value)
|
|
assert status.done is False
|
|
if success is True:
|
|
pilatus.hdf.num_captured._read_pv.mock_data = num_images
|
|
# Now it should resolve
|
|
status.wait(timeout=5)
|
|
assert status.done is True
|
|
assert status.success is True
|
|
file_msg: FileMessage = pilatus.file_event.get()
|
|
assert file_msg.file_path == "file-path-for-success"
|
|
assert file_msg.done is True
|
|
assert file_msg.successful is True
|
|
else:
|
|
with pytest.raises(DeviceStoppedError):
|
|
pilatus.stop()
|
|
status.wait(timeout=5)
|
|
assert status.done is True
|
|
assert status.success is False
|
|
file_msg: FileMessage = pilatus.file_event.get()
|
|
assert file_msg.file_path == "file-path-for-failure"
|
|
assert file_msg.done is True
|
|
assert file_msg.successful is False
|
|
|
|
|
|
# TODO, figure out how to properly test this..
|
|
# def test_pilatus_on_stage(pilatus):
|
|
# """Test the on_stage logic of the Pilatus detector."""
|
|
# # Make sure that no additional logic from stage signals of underlying components is triggered
|
|
# pilatus.stage_sigs = {}
|
|
# pilatus.cam.stage_sigs = {}
|
|
# pilatus.hdf.stage_sigs = {}
|
|
# if pilatus.scan_info.msg.scan_name.startswith("xas"):
|
|
# pilatus.on_stage()
|
|
# return
|
|
# exp_time = pilatus.scan_info.msg.scan_parameters.get("exp_time", 0.1)
|
|
# n_images = pilatus.scan_info.msg.num_points * pilatus.scan_info.msg.scan_parameters.get(
|
|
# "frames_per_trigger", 1
|
|
# )
|
|
# if exp_time <= 0.1:
|
|
# with pytest.raises(ValueError):
|
|
# pilatus.on_stage()
|
|
# return
|
|
# pilatus.filter_number.put(10)
|
|
# pilatus.cam.array_counter.put(1)
|
|
# file_components = pilatus.scan_info.msg.info.get("file_components", ("", ""))[0]
|
|
# base_path = file_components[0].rsplit("/", 1)[0]
|
|
# file_name = file_components[0].rsplit("/", 1)[1] + "_pilatus.h5"
|
|
# file_path = os.path.join(base_path, file_name)
|
|
# pilatus.on_stage()
|
|
# assert pilatus.cam.array_callbacks.get() == 0
|
|
# assert pilatus.hdf.enable.get() == 1
|
|
# assert pilatus.cam.num_exposures.get() == 1
|
|
# assert pilatus.cam.num_images.get() == n_images
|
|
# assert pilatus.cam.acquire_time.get() == exp_time - pilatus._readout_time
|
|
# assert pilatus.cam.acquire_period.get() == exp_time
|
|
# assert pilatus.filter_number.get() == 0
|
|
# assert pilatus.hdf.file_path.get() == base_path
|
|
# assert pilatus.hdf.file_name.get() == file_name
|
|
# assert pilatus.hdf.num_capture.get() == n_images
|
|
# assert pilatus.cam.array_counter.get() == 0
|
|
# file_msg: FileMessage = pilatus.file_event.get()
|
|
# assert file_msg.file_path == file_path
|
|
# assert file_msg.done is False
|
|
# assert file_msg.successful is False
|
|
|
|
|
|
def test_pilatus_on_stage_raises_low_exp_time(pilatus):
|
|
"""Test that on_stage raises a ValueError if the exposure time is too low."""
|
|
pilatus.scan_info.msg.scan_parameters["exp_time"] = 0.09
|
|
if pilatus.scan_info.msg.scan_name.startswith("xas"):
|
|
return
|
|
with pytest.raises(ValueError):
|
|
pilatus.on_stage()
|