test: add first tests for pilatus

This commit is contained in:
appel_c 2023-11-03 17:36:53 +01:00
parent a80d13ae66
commit a02e0f09b0

474
tests/test_pilatus_csaxs.py Normal file
View File

@ -0,0 +1,474 @@
import pytest
from unittest import mock
from ophyd.signal import Signal
from ophyd import Staged
from bec_lib.core import BECMessage, MessageEndpoints
from bec_lib.core.devicemanager import DeviceContainer
from bec_lib.core.tests.utils import ProducerMock
class MockSignal(Signal):
def __init__(self, read_pv, *, string=False, name=None, parent=None, **kwargs):
self.read_pv = read_pv
self._string = bool(string)
super().__init__(name=name, parent=parent, **kwargs)
self._waited_for_connection = False
self._subscriptions = []
def wait_for_connection(self):
self._waited_for_connection = True
def subscribe(self, method, event_type, **kw):
self._subscriptions.append((method, event_type, kw))
def describe_configuration(self):
return {self.name + "_conf": {"source": "SIM:test"}}
def read_configuration(self):
return {self.name + "_conf": {"value": 0}}
with mock.patch("ophyd.EpicsSignal", new=MockSignal), mock.patch(
"ophyd.EpicsSignalRO", new=MockSignal
), mock.patch("ophyd.EpicsSignalWithRBV", new=MockSignal):
from ophyd_devices.epics.devices.pilatus_csaxs import PilatuscSAXS
# TODO maybe specify here that this DeviceMock is for usage in the DeviceServer
class DeviceMock:
def __init__(self, name: str, value: float = 0.0):
self.name = name
self.read_buffer = value
self._config = {"deviceConfig": {"limits": [-50, 50]}, "userParameter": None}
self._enabled_set = True
self._enabled = True
def read(self):
return {self.name: {"value": self.read_buffer}}
def readback(self):
return self.read_buffer
@property
def enabled_set(self) -> bool:
return self._enabled_set
@enabled_set.setter
def enabled_set(self, val: bool):
self._enabled_set = val
@property
def enabled(self) -> bool:
return self._enabled
@enabled.setter
def enabled(self, val: bool):
self._enabled = val
@property
def user_parameter(self):
return self._config["userParameter"]
@property
def obj(self):
return self
class DMMock:
"""Mock for DeviceManager
The mocked DeviceManager creates a device containert and a producer.
"""
def __init__(self):
self.devices = DeviceContainer()
self.producer = ProducerMock()
def add_device(self, name: str, value: float = 0.0):
self.devices[name] = DeviceMock(name, value)
@pytest.fixture(scope="function")
def mock_det():
name = "pilatus"
prefix = "X12SA-ES-PILATUS300K:"
sim_mode = False
dm = DMMock()
# dm.add_device("mokev", value=12.4)
with mock.patch.object(dm, "producer"):
with mock.patch.object(
PilatuscSAXS, "_update_service_config"
) as mock_update_service_config, mock.patch(
"ophyd_devices.epics.devices.pilatus_csaxs.FileWriterMixin"
) as filemixin:
with mock.patch.object(PilatuscSAXS, "_init"):
yield PilatuscSAXS(name=name, prefix=prefix, device_manager=dm, sim_mode=sim_mode)
@pytest.mark.parametrize(
"trigger_source, sim_mode, scan_status_msg, expected_exception",
[
(
1,
True,
BECMessage.ScanStatusMessage(
scanID="1",
status={},
info={
"RID": "mockrid1111",
"queueID": "mockqueueID111",
"scan_number": 1,
"exp_time": 0.012,
"num_points": 500,
"readout_time": 0.003,
"scan_type": "fly",
"num_lines": 0.012,
"frames_per_trigger": 1,
},
),
True,
),
(
1,
False,
BECMessage.ScanStatusMessage(
scanID="1",
status={},
info={
"RID": "mockrid1111",
"queueID": "mockqueueID111",
"scan_number": 1,
"exp_time": 0.012,
"num_points": 500,
"readout_time": 0.003,
"scan_type": "fly",
"num_lines": 0.012,
"frames_per_trigger": 1,
},
),
False,
),
],
)
# TODO rewrite this one, write test for init_detector, init_filewriter is tested
def test_init(
trigger_source,
sim_mode,
scan_status_msg,
expected_exception,
):
"""Test the _init function:
This includes testing the functions:
- _set_default_parameter
- _init_detector
- _stop_det
- _set_trigger
--> Testing the filewriter is done in test_init_filewriter
Validation upon setting the correct PVs
"""
name = "pilatus"
prefix = "X12SA-ES-PILATUS300K:"
sim_mode = sim_mode
dm = DMMock()
with mock.patch.object(dm, "producer") as producer, mock.patch.object(
PilatuscSAXS, "_init_filewriter"
) as mock_init_fw, mock.patch.object(
PilatuscSAXS, "_update_scaninfo"
) as mock_update_scaninfo, mock.patch.object(
PilatuscSAXS, "_update_filewriter"
) as mock_update_filewriter, mock.patch.object(
PilatuscSAXS, "_update_service_config"
) as mock_update_service_config:
mock_det = PilatuscSAXS(name=name, prefix=prefix, device_manager=dm, sim_mode=sim_mode)
if expected_exception:
with pytest.raises(Exception):
mock_det._init()
mock_init_fw.assert_called_once()
else:
mock_det._init() # call the method you want to test
assert mock_det.cam.acquire.get() == 0
assert mock_det.cam.trigger_mode.get() == trigger_source
mock_init_fw.assert_called()
mock_update_scaninfo.assert_called_once()
mock_update_filewriter.assert_called_once()
mock_update_service_config.assert_called_once()
assert mock_init_fw.call_count == 2
@pytest.mark.parametrize(
"scaninfo, stopped, expected_exception",
[
(
{
"eacc": "e12345",
"num_points": 500,
"frames_per_trigger": 1,
"filepath": "test.h5",
"scanID": "123",
"mokev": 12.4,
},
False,
False,
),
(
{
"eacc": "e12345",
"num_points": 500,
"frames_per_trigger": 1,
"filepath": "test.h5",
"scanID": "123",
"mokev": 12.4,
},
True,
False,
),
],
)
def test_stage(
mock_det,
scaninfo,
stopped,
expected_exception,
):
with mock.patch.object(PilatuscSAXS, "_publish_file_location") as mock_publish_file_location:
mock_det.scaninfo.num_points = scaninfo["num_points"]
mock_det.scaninfo.frames_per_trigger = scaninfo["frames_per_trigger"]
mock_det.filewriter.compile_full_filename.return_value = scaninfo["filepath"]
# TODO consider putting energy as variable in scaninfo
mock_det.device_manager.add_device("mokev", value=12.4)
mock_det._stopped = stopped
with mock.patch.object(mock_det, "_prep_file_writer") as mock_prep_fw:
mock_det.filepath = scaninfo["filepath"]
if expected_exception:
with pytest.raises(Exception):
mock_det.stage()
else:
mock_det.stage()
mock_prep_fw.assert_called_once()
# Check _prep_det
assert mock_det.cam.num_images.get() == int(
scaninfo["num_points"] * scaninfo["frames_per_trigger"]
)
assert mock_det.cam.num_frames.get() == 1
mock_publish_file_location.assert_called_with(done=False)
def test_pre_scan(mock_det):
mock_det.pre_scan()
assert mock_det.cam.acquire.get() == 1
@pytest.mark.parametrize(
"scaninfo",
[
({"filepath": "test.h5", "successful": True, "done": False, "scanID": "123"}),
({"filepath": "test.h5", "successful": False, "done": True, "scanID": "123"}),
({"filepath": "test.h5", "successful": None, "done": True, "scanID": "123"}),
],
)
def test_publish_file_location(mock_det, scaninfo):
mock_det.scaninfo.scanID = scaninfo["scanID"]
mock_det.filepath = scaninfo["filepath"]
mock_det._publish_file_location(done=scaninfo["done"], successful=scaninfo["successful"])
if scaninfo["successful"] is None:
msg = BECMessage.FileMessage(file_path=scaninfo["filepath"], done=scaninfo["done"]).dumps()
else:
msg = BECMessage.FileMessage(
file_path=scaninfo["filepath"], done=scaninfo["done"], successful=scaninfo["successful"]
).dumps()
expected_calls = [
mock.call(
MessageEndpoints.public_file(scaninfo["scanID"], mock_det.name),
msg,
pipe=mock_det._producer.pipeline.return_value,
),
mock.call(
MessageEndpoints.file_event(mock_det.name),
msg,
pipe=mock_det._producer.pipeline.return_value,
),
]
assert mock_det._producer.set_and_publish.call_args_list == expected_calls
# @pytest.mark.parametrize(
# "scaninfo, daq_status, expected_exception",
# [
# (
# {
# "eacc": "e12345",
# "num_points": 500,
# "frames_per_trigger": 1,
# "filepath": "test.h5",
# "scanID": "123",
# },
# {"state": "BUSY", "acquisition": {"state": "WAITING_IMAGES"}},
# False,
# ),
# (
# {
# "eacc": "e12345",
# "num_points": 500,
# "frames_per_trigger": 1,
# "filepath": "test.h5",
# "scanID": "123",
# },
# {"state": "BUSY", "acquisition": {"state": "WAITING_IMAGES"}},
# False,
# ),
# (
# {
# "eacc": "e12345",
# "num_points": 500,
# "frames_per_trigger": 1,
# "filepath": "test.h5",
# "scanID": "123",
# },
# {"state": "BUSY", "acquisition": {"state": "ERROR"}},
# True,
# ),
# ],
# )
# def test_prep_file_writer(mock_det, scaninfo, daq_status, expected_exception):
# with mock.patch.object(mock_det, "std_client") as mock_std_daq, mock.patch.object(
# mock_det, "_filepath_exists"
# ) as mock_file_path_exists, mock.patch.object(
# mock_det, "_stop_file_writer"
# ) as mock_stop_file_writer, mock.patch.object(
# mock_det, "scaninfo"
# ) as mock_scaninfo:
# # mock_det = eiger_factory(name, prefix, sim_mode)
# mock_det.std_client = mock_std_daq
# mock_std_daq.start_writer_async.return_value = None
# mock_std_daq.get_status.return_value = daq_status
# mock_det.filewriter.compile_full_filename.return_value = scaninfo["filepath"]
# mock_det.scaninfo.num_points = scaninfo["num_points"]
# mock_det.scaninfo.frames_per_trigger = scaninfo["frames_per_trigger"]
# if expected_exception:
# with pytest.raises(Exception):
# mock_det._prep_file_writer()
# mock_file_path_exists.assert_called_once()
# assert mock_stop_file_writer.call_count == 2
# else:
# mock_det._prep_file_writer()
# mock_file_path_exists.assert_called_once()
# mock_stop_file_writer.assert_called_once()
# daq_writer_call = {
# "output_file": scaninfo["filepath"],
# "n_images": int(scaninfo["num_points"] * scaninfo["frames_per_trigger"]),
# }
# mock_std_daq.start_writer_async.assert_called_with(daq_writer_call)
@pytest.mark.parametrize(
"stopped, expected_exception",
[
(
False,
False,
),
(
True,
True,
),
],
)
def test_unstage(
mock_det,
stopped,
expected_exception,
):
with mock.patch.object(mock_det, "_finished") as mock_finished, mock.patch.object(
mock_det, "_publish_file_location"
) as mock_publish_file_location, mock.patch.object(
mock_det, "_start_h5converter"
) as mock_start_h5converter:
mock_det._stopped = stopped
if expected_exception:
mock_det.unstage()
assert mock_det._stopped == True
else:
mock_det.unstage()
mock_finished.assert_called_once()
mock_publish_file_location.assert_called_with(done=True, successful=True)
mock_start_h5converter.assert_called_once()
assert mock_det._stopped == False
# def test_stop_fw(mock_det):
# with mock.patch.object(mock_det, "std_client") as mock_std_daq:
# mock_std_daq.stop_writer.return_value = None
# mock_det.std_client = mock_std_daq
# mock_det._stop_file_writer()
# mock_std_daq.stop_writer.assert_called_once()
def test_stop(mock_det):
with mock.patch.object(mock_det, "_stop_det") as mock_stop_det, mock.patch.object(
mock_det, "_stop_file_writer"
) as mock_stop_file_writer, mock.patch.object(
mock_det, "_close_file_writer"
) as mock_close_file_writer:
mock_det.stop()
mock_stop_det.assert_called_once()
mock_stop_file_writer.assert_called_once()
mock_close_file_writer.assert_called_once()
assert mock_det._stopped == True
@pytest.mark.parametrize(
"stopped, mcs_stage_state, expected_exception",
[
(
False,
Staged.no,
False,
),
(
True,
Staged.no,
False,
),
(
False,
Staged.yes,
True,
),
],
)
def test_finished(mock_det, stopped, mcs_stage_state, expected_exception):
with mock.patch.object(mock_det, "device_manager") as mock_dm, mock.patch.object(
mock_det, "_stop_file_writer"
) as mock_stop_file_friter, mock.patch.object(
mock_det, "_stop_det"
) as mock_stop_det, mock.patch.object(
mock_det, "_close_file_writer"
) as mock_close_file_writer:
mock_dm.devices.mcs.obj._staged = mcs_stage_state
mock_det._stopped = stopped
if expected_exception:
with pytest.raises(Exception):
mock_det._finished()
assert mock_det._stopped == stopped
mock_stop_file_friter.assert_called()
mock_stop_det.assert_called_once()
mock_close_file_writer.assert_called_once()
else:
mock_det._finished()
if stopped:
assert mock_det._stopped == stopped
mock_stop_file_friter.assert_called()
mock_stop_det.assert_called_once()
mock_close_file_writer.assert_called_once()