refactor(mcs-card): cleanup class, add tests

This commit is contained in:
2025-07-21 22:24:02 +02:00
parent c2cba873d4
commit 3fd3d54003
2 changed files with 452 additions and 281 deletions

View File

@@ -5,8 +5,8 @@ from __future__ import annotations
import enum import enum
from threading import RLock from threading import RLock
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import numpy as np
import numpy as np
from bec_lib.logger import bec_logger from bec_lib.logger import bec_logger
from ophyd import Component as Cpt from ophyd import Component as Cpt
from ophyd import Device, EpicsSignalRO, Kind, Signal from ophyd import Device, EpicsSignalRO, Kind, Signal
@@ -120,7 +120,7 @@ class MCSCardCSAXS(PSIDeviceBase, MCSCard):
super().__init__( super().__init__(
name=name, prefix=prefix, scan_info=scan_info, device_manager=device_manager, **kwargs name=name, prefix=prefix, scan_info=scan_info, device_manager=device_manager, **kwargs
) )
self._mcs_clock = 1e-7 # 10MHz clock self._mcs_clock = 1e7 # 10MHz clock -> 1e7 Hz
self._pv_timeout = 2 self._pv_timeout = 2
self._rlock = RLock() self._rlock = RLock()
self.counter_mapping = { self.counter_mapping = {
@@ -139,8 +139,9 @@ class MCSCardCSAXS(PSIDeviceBase, MCSCard):
# Make sure card is not running # Make sure card is not running
self.stop_all.put(1) self.stop_all.put(1)
# TODO Check channel1_source !!
self.channel_advance.set(CHANNELADVANCE.EXTERNAL).wait(timeout=self._pv_timeout) self.channel_advance.set(CHANNELADVANCE.EXTERNAL).wait(timeout=self._pv_timeout)
self.channel_advance.set(CHANNEL1SOURCE.EXTERNAL).wait(timeout=self._pv_timeout) self.channel1_source.set(CHANNEL1SOURCE.EXTERNAL).wait(timeout=self._pv_timeout)
self.prescale.set(1).wait(timeout=self._pv_timeout) self.prescale.set(1).wait(timeout=self._pv_timeout)
# Set the user LED to off # Set the user LED to off
self.user_led.set(0).wait(timeout=self._pv_timeout) self.user_led.set(0).wait(timeout=self._pv_timeout)
@@ -156,15 +157,15 @@ class MCSCardCSAXS(PSIDeviceBase, MCSCard):
# Set appropriate read mode # Set appropriate read mode
self.read_mode.set(READMODE.PASSIVE).wait(timeout=self._pv_timeout) self.read_mode.set(READMODE.PASSIVE).wait(timeout=self._pv_timeout)
# Subscribe the progress signal
self.current_channel.subscribe(self._progress_update, run=False)
# Set the acquire mode # Set the acquire mode
self.acquire_mode.set(ACQUIREMODE.MCS).wait(timeout=self._pv_timeout) self.acquire_mode.set(ACQUIREMODE.MCS).wait(timeout=self._pv_timeout)
# Subscribe the progress signal
self.current_channel.subscribe(self._progress_update, run=False)
# Subscribe to the mca updates # Subscribe to the mca updates
for name in self.counter_mapping.keys(): for name in self.counter_mapping.keys():
sig: EpicsSignalRO = getattr(self.counters, name.split('_')[-1]) sig: EpicsSignalRO = getattr(self.counters, name.split("_")[-1])
sig.subscribe(self._on_counter_update, run=False) sig.subscribe(self._on_counter_update, run=False)
def _on_counter_update(self, value, **kwargs) -> None: def _on_counter_update(self, value, **kwargs) -> None:
@@ -178,18 +179,18 @@ class MCSCardCSAXS(PSIDeviceBase, MCSCard):
return return
mca_raw = getattr(self.mcs, signal.name.split("_")[-1], None) mca_raw = getattr(self.mcs, signal.name.split("_")[-1], None)
if mca_raw is None: if mca_raw is None:
return return
logger.info(f"Received update of type {type(value)} for {signal.name}") logger.info(f"Received update of type {type(value)} for {signal.name}")
if isinstance(value, np.ndarray): if isinstance(value, np.ndarray):
mca_raw.put(value.tolist()) mca_raw.put(value.tolist())
if mapped_signal_name == "count_time": if mapped_signal_name == "count_time":
value = value*self._mcs_clock value = value / self._mcs_clock
value = float(value.mean()) value = float(value.mean())
else: else:
mca_raw.put(value) mca_raw.put(value)
if mapped_signal_name == "count_time": if mapped_signal_name == "count_time":
value = value*self._mcs_clock value = value / self._mcs_clock
# Mean signal for burst acquisition # Mean signal for burst acquisition
sig = getattr(self.bpm, mapped_signal_name) sig = getattr(self.bpm, mapped_signal_name)
sig.put(value) sig.put(value)
@@ -223,6 +224,7 @@ class MCSCardCSAXS(PSIDeviceBase, MCSCard):
""" """
self.stop_all.put(1) self.stop_all.put(1)
self.ready_to_read.put(READYTOREAD.DONE) self.ready_to_read.put(READYTOREAD.DONE)
# TODO why 0?
self.erase_all.set(0).wait(timeout=self._pv_timeout) self.erase_all.set(0).wait(timeout=self._pv_timeout)
def on_trigger(self) -> None: def on_trigger(self) -> None:

View File

@@ -2,311 +2,480 @@
import threading import threading
from unittest import mock from unittest import mock
import numpy as np
import ophyd import ophyd
import pytest import pytest
from bec_lib import messages from bec_lib import messages
from bec_lib.endpoints import MessageEndpoints from bec_lib.endpoints import MessageEndpoints
from bec_server.device_server.tests.utils import DMMock from bec_server.device_server.tests.utils import DMMock
from ophyd_devices.tests.utils import MockPV from ophyd_devices.tests.utils import MockPV, patch_dual_pvs
from csaxs_bec.devices.epics.mcs_csaxs import ( from csaxs_bec.devices.epics.mcs_card.mcs_card import (
MCScSAXS, ACQUIREMODE,
MCSError, ACQUIRING,
MCSTimeoutError, CHANNEL1SOURCE,
ReadoutMode, CHANNELADVANCE,
TriggerSource, INPUTMODE,
OUTPUTMODE,
POLARITY,
READMODE,
MCSCard,
) )
from csaxs_bec.devices.tests_utils.utils import patch_dual_pvs from csaxs_bec.devices.epics.mcs_card.mcs_card_csaxs import READYTOREAD, MCSCardCSAXS
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def mock_det(): def mock_mcs_card():
name = "mcs" """Fixture to mock the MCSCard device."""
name = "mcs_card"
prefix = "X12SA-MCS:" prefix = "X12SA-MCS:"
with mock.patch.object(ophyd, "cl") as mock_cl:
mock_cl.get_pv = MockPV
mock_cl.thread_class = threading
mcs_card = MCSCard(name=name, prefix=prefix)
patch_dual_pvs(mcs_card)
yield mcs_card
def test_mcs_card(mock_mcs_card):
"""Test the MCSCard initialization."""
assert mock_mcs_card.name == "mcs_card"
assert mock_mcs_card.prefix == "X12SA-MCS:"
assert len(mock_mcs_card.counters.component_names) == 32
assert mock_mcs_card.counters.mca1.name == "mcs_card_counters_mca1"
@pytest.fixture(scope="function")
def mock_mcs_csaxs():
"""Fixture to mock the MCSCardCSAXS device."""
name = "mcs_csaxs"
prefix = "X12SA-MCS-CSAXS:"
dm = DMMock() dm = DMMock()
with mock.patch.object(dm, "connector"): with mock.patch.object(ophyd, "cl") as mock_cl:
with ( mock_cl.get_pv = MockPV
mock.patch( mock_cl.thread_class = threading.Thread
"ophyd_devices.interfaces.base_classes.bec_device_base.FileWriter" mcs_card_csaxs = MCSCardCSAXS(name=name, prefix=prefix, device_manager=dm)
) as filemixin, patch_dual_pvs(mcs_card_csaxs)
mock.patch( yield mcs_card_csaxs
"ophyd_devices.interfaces.base_classes.psi_detector_base.PSIDetectorBase._update_service_config"
) as mock_service_config,
):
with mock.patch.object(ophyd, "cl") as mock_cl:
mock_cl.get_pv = MockPV
mock_cl.thread_class = threading.Thread
with mock.patch.object(MCScSAXS, "_init"):
det = MCScSAXS(name=name, prefix=prefix, device_manager=dm)
patch_dual_pvs(det)
det.TIMEOUT_FOR_SIGNALS = 0.1
yield det
def test_init(): def test_mcs_card_csaxs(mock_mcs_csaxs):
"""Test the _init function:""" """Test the MCSCardCSAXS initialization."""
name = "eiger" assert mock_mcs_csaxs.name == "mcs_csaxs"
prefix = "X12SA-ES-EIGER9M:" assert mock_mcs_csaxs.prefix == "X12SA-MCS-CSAXS:"
dm = DMMock() assert mock_mcs_csaxs.counter_mapping == {
with mock.patch.object(dm, "connector"): "mcs_csaxs_counters_mca1": "current1",
with ( "mcs_csaxs_counters_mca2": "current2",
mock.patch("ophyd_devices.interfaces.base_classes.bec_device_base.FileWriter"), "mcs_csaxs_counters_mca3": "current3",
mock.patch( "mcs_csaxs_counters_mca4": "current4",
"ophyd_devices.interfaces.base_classes.psi_detector_base.PSIDetectorBase._update_service_config" "mcs_csaxs_counters_mca5": "count_time",
), }
): assert mock_mcs_csaxs._mcs_clock == 1e7 # 10 MHz
with mock.patch.object(ophyd, "cl") as mock_cl:
mock_cl.get_pv = MockPV
with (
mock.patch(
"csaxs_bec.devices.epics.mcs_csaxs.MCSSetup.initialize_detector"
) as mock_init_det,
mock.patch(
"csaxs_bec.devices.epics.mcs_csaxs.MCSSetup.initialize_detector_backend"
) as mock_init_backend,
):
MCScSAXS(name=name, prefix=prefix, device_manager=dm)
mock_init_det.assert_called_once()
mock_init_backend.assert_called_once()
@pytest.mark.parametrize( def test_mcs_card_csaxs_on_connected(mock_mcs_csaxs):
"trigger_source, channel_advance, channel_source1, pv_channels", """Test the on_connected method of MCSCardCSAXS."""
[ mcs = mock_mcs_csaxs
( mcs.on_connected()
3, # Stop called
1, assert mcs.stop_all.get() == 1
0, # Channel advance settings
{ assert mcs.channel_advance.get() == CHANNELADVANCE.EXTERNAL
"user_led": 0, assert mcs.channel1_source.get() == CHANNEL1SOURCE.EXTERNAL
"mux_output": 5, assert mcs.prescale.get() == 1
"input_pol": 0, #
"output_pol": 1, assert mcs.user_led.get() == 0
"count_on_start": 0, # Only 5 channels are connected
"stop_all": 1, assert mcs.mux_output.get() == 5
}, # input output settings
) assert mcs.input_mode.get() == INPUTMODE.MODE_3
], assert mcs.input_polarity.get() == POLARITY.NORMAL
) assert mcs.output_mode.get() == OUTPUTMODE.MODE_2
def test_initialize_detector( assert mcs.output_polarity.get() == POLARITY.NORMAL
mock_det, trigger_source, channel_advance, channel_source1, pv_channels assert mcs.count_on_start.get() == 0
): assert mcs.read_mode.get() == READMODE.PASSIVE
"""Test the _init function: assert mcs.acquire_mode.get() == ACQUIREMODE.MCS
This includes testing the functions: with mock.patch.object(mcs.current_channel, "subscribe") as mock_cur_ch_subscribe:
- initialize_detector with mock.patch.object(mcs.counters.mca1, "subscribe") as mock_mca_subscribe:
- stop_det mcs.on_connected()
- parent.set_trigger assert mock_cur_ch_subscribe.call_args == mock.call(mcs._progress_update, run=False)
--> Testing the filewriter is done in test_init_filewriter assert mock_mca_subscribe.call_args == mock.call(mcs._on_counter_update, run=False)
Validation upon setting the correct PVs
"""
mock_det.custom_prepare.initialize_detector() # call the method you want to test
assert mock_det.channel_advance.get() == channel_advance
assert mock_det.channel1_source.get() == channel_source1
assert mock_det.user_led.get() == pv_channels["user_led"]
assert mock_det.mux_output.get() == pv_channels["mux_output"]
assert mock_det.input_polarity.get() == pv_channels["input_pol"]
assert mock_det.output_polarity.get() == pv_channels["output_pol"]
assert mock_det.count_on_start.get() == pv_channels["count_on_start"]
assert mock_det.input_mode.get() == trigger_source
def test_trigger(mock_det): def test_mcs_card_csaxs_stage(mock_mcs_csaxs):
"""Test the trigger function: """Test on stage method of MCSCardCSAXS"""
Validate that trigger calls the custom_prepare.on_trigger() function mcs = mock_mcs_csaxs
""" triggers = 5
with mock.patch.object(mock_det.custom_prepare, "on_trigger") as mock_on_trigger: mcs.scan_info.msg.scan_parameters["frames_per_trigger"] = triggers
mock_det.trigger() mcs.erase_all.put(0)
mock_on_trigger.assert_called_once() mcs.stage()
assert mcs._staged == ophyd.Staged.yes
assert mcs.erase_all.get() == 1
assert mcs.preset_real.get() == 0
assert mcs.num_use_all.get() == triggers
@pytest.mark.parametrize( def test_mcs_card_csaxs_unstage(mock_mcs_csaxs):
"value, num_lines, num_points, done", [(100, 5, 500, False), (500, 5, 500, True)] """Test unstage method of MCSCardCSAXS"""
) mcs = mock_mcs_csaxs
def test_progress_update(mock_det, value, num_lines, num_points, done): mcs.stop_all.put(0)
mock_det.num_lines.set(num_lines) mcs.ready_to_read.put(0)
mock_det.scaninfo.num_points = num_points mcs.erase_all.put(1)
calls = mock.call(sub_type="progress", value=value, max_value=num_points, done=done) mcs.unstage()
with mock.patch.object(mock_det, "_run_subs") as mock_run_subs: assert mcs.stop_all.get() == 1
mock_det.custom_prepare._progress_update(value=value) assert mcs.ready_to_read.get() == READYTOREAD.DONE
mock_run_subs.assert_called_once() assert mcs.erase_all.get() == 0
assert mock_run_subs.call_args == calls
@pytest.mark.parametrize( def test_mcs_card_csaxs_complete_and_stop(mock_mcs_csaxs):
"values, expected_nothing", """Test complete method of MCSCarcCSAXS"""
[([[100, 120, 140], [200, 220, 240], [300, 320, 340]], False), ([100, 200, 300], True)], mcs = mock_mcs_csaxs
) mcs.acquiring._read_pv.mock_data = ACQUIRING.ACQUIRING
def test_on_mca_data(mock_det, values, expected_nothing): st = mcs.complete()
"""Test the on_mca_data function: assert st.done is False
Validate that on_mca_data calls the custom_prepare.on_mca_data() function mcs.stop_all.put(0)
""" mcs.ready_to_read.put(READYTOREAD.PROCESSING)
with mock.patch.object(mock_det.custom_prepare, "_send_data_to_bec") as mock_send_data: mcs.stop()
mock_object = mock.MagicMock() with pytest.raises(Exception):
for ii, name in enumerate(mock_det.custom_prepare.mca_names): st.wait(timeout=3)
mock_object.attr_name = name assert st.done is True
mock_det.custom_prepare._on_mca_data(obj=mock_object, value=values[ii]) assert st.success is False
if not expected_nothing and ii < (len(values) - 1): assert mcs.stop_all.get() == 1
assert mock_det.custom_prepare.mca_data[name] == values[ii] assert mcs.ready_to_read.get() == READYTOREAD.DONE
if not expected_nothing:
mock_send_data.assert_called_once()
assert mock_det.custom_prepare.acquisition_done is True
@pytest.mark.parametrize( def test_mcs_card_csaxs_on_counter_updated(mock_mcs_csaxs):
"metadata, mca_data", mcs = mock_mcs_csaxs
[ # Called for mca1
( kwargs = {"obj": mcs.counters.mca1}
{"scan_id": 123}, mcs._on_counter_update(1, **kwargs)
{ assert mcs.mcs.mca1.get() == 1
"mca1": {"value": [100, 120, 140]}, assert mcs.bpm.current1.get() == 1
"mca3": {"value": [200, 220, 240]}, assert mcs.counter_updated == [mcs.counters.mca1.name]
"mca4": {"value": [300, 320, 340]}, # Called for mca2
}, kwargs = {"obj": mcs.counters.mca2}
) mcs._on_counter_update(np.array([2, 4]), **kwargs)
], assert mcs.mcs.mca2.get() == [2, 4]
) assert np.isclose(mcs.bpm.current2.get(), 3)
def test_send_data_to_bec(mock_det, metadata, mca_data): assert mcs.counter_updated == [mcs.counters.mca1.name, mcs.counters.mca2.name]
mock_det.scaninfo.scan_msg = mock.MagicMock() # Called for mca3
mock_det.scaninfo.scan_msg.metadata = metadata kwargs = {"obj": mcs.counters.mca3}
mock_det.scaninfo.scan_id = metadata["scan_id"] mcs._on_counter_update(1000, **kwargs)
mock_det.custom_prepare.mca_data = mca_data assert mcs.mcs.mca3.get() == 1000
mock_det.custom_prepare._send_data_to_bec() assert mcs.bpm.current3.get() == 1000
device_metadata = mock_det.scaninfo.scan_msg.metadata assert mcs.counter_updated == [
metadata.update({"async_update": "append", "num_lines": mock_det.num_lines.get()}) mcs.counters.mca1.name,
data = messages.DeviceMessage(signals=dict(mca_data), metadata=device_metadata) mcs.counters.mca2.name,
calls = mock.call( mcs.counters.mca3.name,
topic=MessageEndpoints.device_async_readback( ]
scan_id=metadata["scan_id"], device=mock_det.name # Called for mca4
), kwargs = {"obj": mcs.counters.mca4}
msg={"data": data}, mcs._on_counter_update(np.array([20, 40]), **kwargs)
expire=1800, assert mcs.mcs.mca4.get() == [20, 40]
) assert np.isclose(mcs.bpm.current4.get(), 30)
assert mcs.counter_updated == [
assert mock_det.connector.xadd.call_args == calls mcs.counters.mca1.name,
mcs.counters.mca2.name,
mcs.counters.mca3.name,
mcs.counters.mca4.name,
]
# Called for mca5
assert mcs.ready_to_read.get() == 0
kwargs = {"obj": mcs.counters.mca5}
mcs._on_counter_update(np.array([10000, 10000]), **kwargs)
assert np.isclose(mcs.bpm.count_time.get(), 10000 / 1e7)
assert mcs.mcs.mca5.get() == [10000, 10000]
@pytest.mark.parametrize( # @pytest.fixture(scope="function")
"scaninfo, triggersource, stopped, expected_exception", # def mock_det():
[ # name = "mcs"
( # prefix = "X12SA-MCS:"
{"num_points": 500, "frames_per_trigger": 1, "scan_type": "step"}, # dm = DMMock()
TriggerSource.MODE3, # with mock.patch.object(dm, "connector"):
False, # with (
False, # mock.patch(
), # "ophyd_devices.interfaces.base_classes.bec_device_base.FileWriter"
( # ) as filemixin,
{"num_points": 500, "frames_per_trigger": 1, "scan_type": "fly"}, # mock.patch(
TriggerSource.MODE3, # "ophyd_devices.interfaces.base_classes.psi_detector_base.PSIDetectorBase._update_service_config"
False, # ) as mock_service_config,
False, # ):
), # with mock.patch.object(ophyd, "cl") as mock_cl:
( # mock_cl.get_pv = MockPV
{"num_points": 5001, "frames_per_trigger": 2, "scan_type": "step"}, # mock_cl.thread_class = threading.Thread
TriggerSource.MODE3, # with mock.patch.object(MCScSAXS, "_init"):
False, # det = MCScSAXS(name=name, prefix=prefix, device_manager=dm)
True, # patch_dual_pvs(det)
), # det.TIMEOUT_FOR_SIGNALS = 0.1
( # yield det
{"num_points": 500, "frames_per_trigger": 2, "scan_type": "random"},
TriggerSource.MODE3,
False,
True,
),
],
)
def test_stage(mock_det, scaninfo, triggersource, stopped, expected_exception):
mock_det.scaninfo.num_points = scaninfo["num_points"]
mock_det.scaninfo.frames_per_trigger = scaninfo["frames_per_trigger"]
mock_det.scaninfo.scan_type = scaninfo["scan_type"]
mock_det.stopped = stopped
with mock.patch.object(mock_det.custom_prepare, "prepare_detector_backend") as mock_prep_fw:
if expected_exception:
with pytest.raises(MCSError):
mock_det.stage()
mock_prep_fw.assert_called_once()
else:
mock_det.stage()
mock_prep_fw.assert_called_once()
# Check set_trigger
mock_det.input_mode.get() == triggersource
if scaninfo["scan_type"] == "step":
assert mock_det.num_use_all.get() == int(scaninfo["frames_per_trigger"]) * int(
scaninfo["num_points"]
)
elif scaninfo["scan_type"] == "fly":
assert mock_det.num_use_all.get() == int(scaninfo["num_points"])
mock_det.preset_real.get() == 0
# # CHeck custom_prepare.arm_acquisition
# assert mock_det.custom_prepare.counter == 0
# assert mock_det.erase_start.get() == 1
# mock_prep_fw.assert_called_once()
# # Check _prep_det
# assert mock_det.cam.num_images.get() == int(
# scaninfo["num_points"] * scaninfo["frames_per_trigger"]
# )
# assert mock_det.cam.num_frames.get() == 1
# mock_publish_file_location.assert_called_with(done=False)
# assert mock_det.cam.acquire.get() == 1
def test_prepare_detector_backend(mock_det): # def test_init():
mock_det.custom_prepare.prepare_detector_backend() # """Test the _init function:"""
assert mock_det.erase_all.get() == 1 # name = "eiger"
assert mock_det.read_mode.get() == ReadoutMode.EVENT # prefix = "X12SA-ES-EIGER9M:"
# dm = DMMock()
# with mock.patch.object(dm, "connector"):
# with (
# mock.patch("ophyd_devices.interfaces.base_classes.bec_device_base.FileWriter"),
# mock.patch(
# "ophyd_devices.interfaces.base_classes.psi_detector_base.PSIDetectorBase._update_service_config"
# ),
# ):
# with mock.patch.object(ophyd, "cl") as mock_cl:
# mock_cl.get_pv = MockPV
# with (
# mock.patch(
# "csaxs_bec.devices.epics.mcs_csaxs.MCSSetup.initialize_detector"
# ) as mock_init_det,
# mock.patch(
# "csaxs_bec.devices.epics.mcs_csaxs.MCSSetup.initialize_detector_backend"
# ) as mock_init_backend,
# ):
# MCScSAXS(name=name, prefix=prefix, device_manager=dm)
# mock_init_det.assert_called_once()
# mock_init_backend.assert_called_once()
def test_complete(mock_det): # @pytest.mark.parametrize(
with (mock.patch.object(mock_det.custom_prepare, "finished") as mock_finished,): # "trigger_source, channel_advance, channel_source1, pv_channels",
mock_det.complete() # [
assert mock_finished.call_count == 1 # (
# 3,
# 1,
# 0,
# {
# "user_led": 0,
# "mux_output": 5,
# "input_pol": 0,
# "output_pol": 1,
# "count_on_start": 0,
# "stop_all": 1,
# },
# )
# ],
# )
# def test_initialize_detector(
# mock_det, trigger_source, channel_advance, channel_source1, pv_channels
# ):
# """Test the _init function:
# This includes testing the functions:
# - initialize_detector
# - stop_det
# - parent.set_trigger
# --> Testing the filewriter is done in test_init_filewriter
# Validation upon setting the correct PVs
# """
# mock_det.custom_prepare.initialize_detector() # call the method you want to test
# assert mock_det.channel_advance.get() == channel_advance
# assert mock_det.channel1_source.get() == channel_source1
# assert mock_det.user_led.get() == pv_channels["user_led"]
# assert mock_det.mux_output.get() == pv_channels["mux_output"]
# assert mock_det.input_polarity.get() == pv_channels["input_pol"]
# assert mock_det.output_polarity.get() == pv_channels["output_pol"]
# assert mock_det.count_on_start.get() == pv_channels["count_on_start"]
# assert mock_det.input_mode.get() == trigger_source
def test_stop_detector_backend(mock_det): # def test_trigger(mock_det):
mock_det.custom_prepare.stop_detector_backend() # """Test the trigger function:
assert mock_det.custom_prepare.acquisition_done is True # Validate that trigger calls the custom_prepare.on_trigger() function
# """
# with mock.patch.object(mock_det.custom_prepare, "on_trigger") as mock_on_trigger:
# mock_det.trigger()
# mock_on_trigger.assert_called_once()
def test_stop(mock_det): # @pytest.mark.parametrize(
with ( # "value, num_lines, num_points, done", [(100, 5, 500, False), (500, 5, 500, True)]
mock.patch.object(mock_det.custom_prepare, "stop_detector") as mock_stop_det, # )
mock.patch.object( # def test_progress_update(mock_det, value, num_lines, num_points, done):
mock_det.custom_prepare, "stop_detector_backend" # mock_det.num_lines.set(num_lines)
) as mock_stop_detector_backend, # mock_det.scaninfo.num_points = num_points
): # calls = mock.call(sub_type="progress", value=value, max_value=num_points, done=done)
mock_det.stop() # with mock.patch.object(mock_det, "_run_subs") as mock_run_subs:
mock_stop_det.assert_called_once() # mock_det.custom_prepare._progress_update(value=value)
mock_stop_detector_backend.assert_called_once() # mock_run_subs.assert_called_once()
assert mock_det.stopped is True # assert mock_run_subs.call_args == calls
@pytest.mark.parametrize( # @pytest.mark.parametrize(
"stopped, acquisition_done, acquiring_state, expected_exception", # "values, expected_nothing",
[ # [([[100, 120, 140], [200, 220, 240], [300, 320, 340]], False), ([100, 200, 300], True)],
(False, True, 0, False), # )
(False, False, 0, True), # def test_on_mca_data(mock_det, values, expected_nothing):
(False, True, 1, True), # """Test the on_mca_data function:
(True, True, 0, True), # Validate that on_mca_data calls the custom_prepare.on_mca_data() function
], # """
) # with mock.patch.object(mock_det.custom_prepare, "_send_data_to_bec") as mock_send_data:
def test_finished(mock_det, stopped, acquisition_done, acquiring_state, expected_exception): # mock_object = mock.MagicMock()
mock_det.custom_prepare.acquisition_done = acquisition_done # for ii, name in enumerate(mock_det.custom_prepare.mca_names):
mock_det.acquiring._read_pv.mock_data = acquiring_state # mock_object.attr_name = name
mock_det.scaninfo.num_points = 500 # mock_det.custom_prepare._on_mca_data(obj=mock_object, value=values[ii])
mock_det.num_lines.put(500) # if not expected_nothing and ii < (len(values) - 1):
mock_det.current_channel._read_pv.mock_data = 1 # assert mock_det.custom_prepare.mca_data[name] == values[ii]
mock_det.stopped = stopped
if expected_exception: # if not expected_nothing:
with pytest.raises(MCSTimeoutError): # mock_send_data.assert_called_once()
mock_det.timeout = 0.1 # assert mock_det.custom_prepare.acquisition_done is True
mock_det.custom_prepare.finished()
else:
mock_det.custom_prepare.finished() # @pytest.mark.parametrize(
if stopped: # "metadata, mca_data",
assert mock_det.stopped is stopped # [
# (
# {"scan_id": 123},
# {
# "mca1": {"value": [100, 120, 140]},
# "mca3": {"value": [200, 220, 240]},
# "mca4": {"value": [300, 320, 340]},
# },
# )
# ],
# )
# def test_send_data_to_bec(mock_det, metadata, mca_data):
# mock_det.scaninfo.scan_msg = mock.MagicMock()
# mock_det.scaninfo.scan_msg.metadata = metadata
# mock_det.scaninfo.scan_id = metadata["scan_id"]
# mock_det.custom_prepare.mca_data = mca_data
# mock_det.custom_prepare._send_data_to_bec()
# device_metadata = mock_det.scaninfo.scan_msg.metadata
# metadata.update({"async_update": "append", "num_lines": mock_det.num_lines.get()})
# data = messages.DeviceMessage(signals=dict(mca_data), metadata=device_metadata)
# calls = mock.call(
# topic=MessageEndpoints.device_async_readback(
# scan_id=metadata["scan_id"], device=mock_det.name
# ),
# msg={"data": data},
# expire=1800,
# )
# assert mock_det.connector.xadd.call_args == calls
# @pytest.mark.parametrize(
# "scaninfo, triggersource, stopped, expected_exception",
# [
# (
# {"num_points": 500, "frames_per_trigger": 1, "scan_type": "step"},
# TriggerSource.MODE3,
# False,
# False,
# ),
# (
# {"num_points": 500, "frames_per_trigger": 1, "scan_type": "fly"},
# TriggerSource.MODE3,
# False,
# False,
# ),
# (
# {"num_points": 5001, "frames_per_trigger": 2, "scan_type": "step"},
# TriggerSource.MODE3,
# False,
# True,
# ),
# (
# {"num_points": 500, "frames_per_trigger": 2, "scan_type": "random"},
# TriggerSource.MODE3,
# False,
# True,
# ),
# ],
# )
# def test_stage(mock_det, scaninfo, triggersource, stopped, expected_exception):
# mock_det.scaninfo.num_points = scaninfo["num_points"]
# mock_det.scaninfo.frames_per_trigger = scaninfo["frames_per_trigger"]
# mock_det.scaninfo.scan_type = scaninfo["scan_type"]
# mock_det.stopped = stopped
# with mock.patch.object(mock_det.custom_prepare, "prepare_detector_backend") as mock_prep_fw:
# if expected_exception:
# with pytest.raises(MCSError):
# mock_det.stage()
# mock_prep_fw.assert_called_once()
# else:
# mock_det.stage()
# mock_prep_fw.assert_called_once()
# # Check set_trigger
# mock_det.input_mode.get() == triggersource
# if scaninfo["scan_type"] == "step":
# assert mock_det.num_use_all.get() == int(scaninfo["frames_per_trigger"]) * int(
# scaninfo["num_points"]
# )
# elif scaninfo["scan_type"] == "fly":
# assert mock_det.num_use_all.get() == int(scaninfo["num_points"])
# mock_det.preset_real.get() == 0
# # # CHeck custom_prepare.arm_acquisition
# # assert mock_det.custom_prepare.counter == 0
# # assert mock_det.erase_start.get() == 1
# # mock_prep_fw.assert_called_once()
# # # Check _prep_det
# # assert mock_det.cam.num_images.get() == int(
# # scaninfo["num_points"] * scaninfo["frames_per_trigger"]
# # )
# # assert mock_det.cam.num_frames.get() == 1
# # mock_publish_file_location.assert_called_with(done=False)
# # assert mock_det.cam.acquire.get() == 1
# def test_prepare_detector_backend(mock_det):
# mock_det.custom_prepare.prepare_detector_backend()
# assert mock_det.erase_all.get() == 1
# assert mock_det.read_mode.get() == ReadoutMode.EVENT
# def test_complete(mock_det):
# with (mock.patch.object(mock_det.custom_prepare, "finished") as mock_finished,):
# mock_det.complete()
# assert mock_finished.call_count == 1
# def test_stop_detector_backend(mock_det):
# mock_det.custom_prepare.stop_detector_backend()
# assert mock_det.custom_prepare.acquisition_done is True
# def test_stop(mock_det):
# with (
# mock.patch.object(mock_det.custom_prepare, "stop_detector") as mock_stop_det,
# mock.patch.object(
# mock_det.custom_prepare, "stop_detector_backend"
# ) as mock_stop_detector_backend,
# ):
# mock_det.stop()
# mock_stop_det.assert_called_once()
# mock_stop_detector_backend.assert_called_once()
# assert mock_det.stopped is True
# @pytest.mark.parametrize(
# "stopped, acquisition_done, acquiring_state, expected_exception",
# [
# (False, True, 0, False),
# (False, False, 0, True),
# (False, True, 1, True),
# (True, True, 0, True),
# ],
# )
# def test_finished(mock_det, stopped, acquisition_done, acquiring_state, expected_exception):
# mock_det.custom_prepare.acquisition_done = acquisition_done
# mock_det.acquiring._read_pv.mock_data = acquiring_state
# mock_det.scaninfo.num_points = 500
# mock_det.num_lines.put(500)
# mock_det.current_channel._read_pv.mock_data = 1
# mock_det.stopped = stopped
# if expected_exception:
# with pytest.raises(MCSTimeoutError):
# mock_det.timeout = 0.1
# mock_det.custom_prepare.finished()
# else:
# mock_det.custom_prepare.finished()
# if stopped:
# assert mock_det.stopped is stopped