diff --git a/csaxs_bec/devices/epics/mcs_card/mcs_card_csaxs.py b/csaxs_bec/devices/epics/mcs_card/mcs_card_csaxs.py index e2423f2..0e32fd1 100644 --- a/csaxs_bec/devices/epics/mcs_card/mcs_card_csaxs.py +++ b/csaxs_bec/devices/epics/mcs_card/mcs_card_csaxs.py @@ -5,8 +5,8 @@ from __future__ import annotations import enum from threading import RLock from typing import TYPE_CHECKING -import numpy as np +import numpy as np from bec_lib.logger import bec_logger from ophyd import Component as Cpt from ophyd import Device, EpicsSignalRO, Kind, Signal @@ -120,7 +120,7 @@ class MCSCardCSAXS(PSIDeviceBase, MCSCard): super().__init__( name=name, prefix=prefix, scan_info=scan_info, device_manager=device_manager, **kwargs ) - self._mcs_clock = 1e-7 # 10MHz clock + self._mcs_clock = 1e7 # 10MHz clock -> 1e7 Hz self._pv_timeout = 2 self._rlock = RLock() self.counter_mapping = { @@ -139,8 +139,9 @@ class MCSCardCSAXS(PSIDeviceBase, MCSCard): # Make sure card is not running self.stop_all.put(1) + # TODO Check channel1_source !! self.channel_advance.set(CHANNELADVANCE.EXTERNAL).wait(timeout=self._pv_timeout) - self.channel_advance.set(CHANNEL1SOURCE.EXTERNAL).wait(timeout=self._pv_timeout) + self.channel1_source.set(CHANNEL1SOURCE.EXTERNAL).wait(timeout=self._pv_timeout) self.prescale.set(1).wait(timeout=self._pv_timeout) # Set the user LED to off self.user_led.set(0).wait(timeout=self._pv_timeout) @@ -156,15 +157,15 @@ class MCSCardCSAXS(PSIDeviceBase, MCSCard): # Set appropriate read mode self.read_mode.set(READMODE.PASSIVE).wait(timeout=self._pv_timeout) - # Subscribe the progress signal - self.current_channel.subscribe(self._progress_update, run=False) - # Set the acquire mode self.acquire_mode.set(ACQUIREMODE.MCS).wait(timeout=self._pv_timeout) + # Subscribe the progress signal + self.current_channel.subscribe(self._progress_update, run=False) + # Subscribe to the mca updates for name in self.counter_mapping.keys(): - sig: EpicsSignalRO = getattr(self.counters, name.split('_')[-1]) + sig: EpicsSignalRO = getattr(self.counters, name.split("_")[-1]) sig.subscribe(self._on_counter_update, run=False) def _on_counter_update(self, value, **kwargs) -> None: @@ -178,18 +179,18 @@ class MCSCardCSAXS(PSIDeviceBase, MCSCard): return mca_raw = getattr(self.mcs, signal.name.split("_")[-1], None) if mca_raw is None: - return + return logger.info(f"Received update of type {type(value)} for {signal.name}") if isinstance(value, np.ndarray): mca_raw.put(value.tolist()) if mapped_signal_name == "count_time": - value = value*self._mcs_clock + value = value / self._mcs_clock value = float(value.mean()) else: mca_raw.put(value) if mapped_signal_name == "count_time": - value = value*self._mcs_clock - + value = value / self._mcs_clock + # Mean signal for burst acquisition sig = getattr(self.bpm, mapped_signal_name) sig.put(value) @@ -223,6 +224,7 @@ class MCSCardCSAXS(PSIDeviceBase, MCSCard): """ self.stop_all.put(1) self.ready_to_read.put(READYTOREAD.DONE) + # TODO why 0? self.erase_all.set(0).wait(timeout=self._pv_timeout) def on_trigger(self) -> None: diff --git a/tests/tests_devices/test_mcs_card.py b/tests/tests_devices/test_mcs_card.py index f9163d4..19447c3 100644 --- a/tests/tests_devices/test_mcs_card.py +++ b/tests/tests_devices/test_mcs_card.py @@ -2,311 +2,480 @@ import threading from unittest import mock +import numpy as np import ophyd import pytest from bec_lib import messages from bec_lib.endpoints import MessageEndpoints from bec_server.device_server.tests.utils import DMMock -from ophyd_devices.tests.utils import MockPV +from ophyd_devices.tests.utils import MockPV, patch_dual_pvs -from csaxs_bec.devices.epics.mcs_csaxs import ( - MCScSAXS, - MCSError, - MCSTimeoutError, - ReadoutMode, - TriggerSource, +from csaxs_bec.devices.epics.mcs_card.mcs_card import ( + ACQUIREMODE, + ACQUIRING, + CHANNEL1SOURCE, + CHANNELADVANCE, + INPUTMODE, + OUTPUTMODE, + POLARITY, + READMODE, + MCSCard, ) -from csaxs_bec.devices.tests_utils.utils import patch_dual_pvs +from csaxs_bec.devices.epics.mcs_card.mcs_card_csaxs import READYTOREAD, MCSCardCSAXS @pytest.fixture(scope="function") -def mock_det(): - name = "mcs" +def mock_mcs_card(): + """Fixture to mock the MCSCard device.""" + name = "mcs_card" prefix = "X12SA-MCS:" + with mock.patch.object(ophyd, "cl") as mock_cl: + mock_cl.get_pv = MockPV + mock_cl.thread_class = threading + mcs_card = MCSCard(name=name, prefix=prefix) + patch_dual_pvs(mcs_card) + yield mcs_card + + +def test_mcs_card(mock_mcs_card): + """Test the MCSCard initialization.""" + assert mock_mcs_card.name == "mcs_card" + assert mock_mcs_card.prefix == "X12SA-MCS:" + assert len(mock_mcs_card.counters.component_names) == 32 + assert mock_mcs_card.counters.mca1.name == "mcs_card_counters_mca1" + + +@pytest.fixture(scope="function") +def mock_mcs_csaxs(): + """Fixture to mock the MCSCardCSAXS device.""" + name = "mcs_csaxs" + prefix = "X12SA-MCS-CSAXS:" dm = DMMock() - with mock.patch.object(dm, "connector"): - with ( - mock.patch( - "ophyd_devices.interfaces.base_classes.bec_device_base.FileWriter" - ) as filemixin, - mock.patch( - "ophyd_devices.interfaces.base_classes.psi_detector_base.PSIDetectorBase._update_service_config" - ) as mock_service_config, - ): - with mock.patch.object(ophyd, "cl") as mock_cl: - mock_cl.get_pv = MockPV - mock_cl.thread_class = threading.Thread - with mock.patch.object(MCScSAXS, "_init"): - det = MCScSAXS(name=name, prefix=prefix, device_manager=dm) - patch_dual_pvs(det) - det.TIMEOUT_FOR_SIGNALS = 0.1 - yield det + with mock.patch.object(ophyd, "cl") as mock_cl: + mock_cl.get_pv = MockPV + mock_cl.thread_class = threading.Thread + mcs_card_csaxs = MCSCardCSAXS(name=name, prefix=prefix, device_manager=dm) + patch_dual_pvs(mcs_card_csaxs) + yield mcs_card_csaxs -def test_init(): - """Test the _init function:""" - name = "eiger" - prefix = "X12SA-ES-EIGER9M:" - dm = DMMock() - with mock.patch.object(dm, "connector"): - with ( - mock.patch("ophyd_devices.interfaces.base_classes.bec_device_base.FileWriter"), - mock.patch( - "ophyd_devices.interfaces.base_classes.psi_detector_base.PSIDetectorBase._update_service_config" - ), - ): - with mock.patch.object(ophyd, "cl") as mock_cl: - mock_cl.get_pv = MockPV - with ( - mock.patch( - "csaxs_bec.devices.epics.mcs_csaxs.MCSSetup.initialize_detector" - ) as mock_init_det, - mock.patch( - "csaxs_bec.devices.epics.mcs_csaxs.MCSSetup.initialize_detector_backend" - ) as mock_init_backend, - ): - MCScSAXS(name=name, prefix=prefix, device_manager=dm) - mock_init_det.assert_called_once() - mock_init_backend.assert_called_once() +def test_mcs_card_csaxs(mock_mcs_csaxs): + """Test the MCSCardCSAXS initialization.""" + assert mock_mcs_csaxs.name == "mcs_csaxs" + assert mock_mcs_csaxs.prefix == "X12SA-MCS-CSAXS:" + assert mock_mcs_csaxs.counter_mapping == { + "mcs_csaxs_counters_mca1": "current1", + "mcs_csaxs_counters_mca2": "current2", + "mcs_csaxs_counters_mca3": "current3", + "mcs_csaxs_counters_mca4": "current4", + "mcs_csaxs_counters_mca5": "count_time", + } + assert mock_mcs_csaxs._mcs_clock == 1e7 # 10 MHz -@pytest.mark.parametrize( - "trigger_source, channel_advance, channel_source1, pv_channels", - [ - ( - 3, - 1, - 0, - { - "user_led": 0, - "mux_output": 5, - "input_pol": 0, - "output_pol": 1, - "count_on_start": 0, - "stop_all": 1, - }, - ) - ], -) -def test_initialize_detector( - mock_det, trigger_source, channel_advance, channel_source1, pv_channels -): - """Test the _init function: +def test_mcs_card_csaxs_on_connected(mock_mcs_csaxs): + """Test the on_connected method of MCSCardCSAXS.""" + mcs = mock_mcs_csaxs + mcs.on_connected() + # Stop called + assert mcs.stop_all.get() == 1 + # Channel advance settings + assert mcs.channel_advance.get() == CHANNELADVANCE.EXTERNAL + assert mcs.channel1_source.get() == CHANNEL1SOURCE.EXTERNAL + assert mcs.prescale.get() == 1 + # + assert mcs.user_led.get() == 0 + # Only 5 channels are connected + assert mcs.mux_output.get() == 5 + # input output settings + assert mcs.input_mode.get() == INPUTMODE.MODE_3 + assert mcs.input_polarity.get() == POLARITY.NORMAL + assert mcs.output_mode.get() == OUTPUTMODE.MODE_2 + assert mcs.output_polarity.get() == POLARITY.NORMAL + assert mcs.count_on_start.get() == 0 + assert mcs.read_mode.get() == READMODE.PASSIVE + assert mcs.acquire_mode.get() == ACQUIREMODE.MCS - This includes testing the functions: - - initialize_detector - - stop_det - - parent.set_trigger - --> Testing the filewriter is done in test_init_filewriter - - Validation upon setting the correct PVs - - """ - mock_det.custom_prepare.initialize_detector() # call the method you want to test - assert mock_det.channel_advance.get() == channel_advance - assert mock_det.channel1_source.get() == channel_source1 - assert mock_det.user_led.get() == pv_channels["user_led"] - assert mock_det.mux_output.get() == pv_channels["mux_output"] - assert mock_det.input_polarity.get() == pv_channels["input_pol"] - assert mock_det.output_polarity.get() == pv_channels["output_pol"] - assert mock_det.count_on_start.get() == pv_channels["count_on_start"] - assert mock_det.input_mode.get() == trigger_source + with mock.patch.object(mcs.current_channel, "subscribe") as mock_cur_ch_subscribe: + with mock.patch.object(mcs.counters.mca1, "subscribe") as mock_mca_subscribe: + mcs.on_connected() + assert mock_cur_ch_subscribe.call_args == mock.call(mcs._progress_update, run=False) + assert mock_mca_subscribe.call_args == mock.call(mcs._on_counter_update, run=False) -def test_trigger(mock_det): - """Test the trigger function: - Validate that trigger calls the custom_prepare.on_trigger() function - """ - with mock.patch.object(mock_det.custom_prepare, "on_trigger") as mock_on_trigger: - mock_det.trigger() - mock_on_trigger.assert_called_once() +def test_mcs_card_csaxs_stage(mock_mcs_csaxs): + """Test on stage method of MCSCardCSAXS""" + mcs = mock_mcs_csaxs + triggers = 5 + mcs.scan_info.msg.scan_parameters["frames_per_trigger"] = triggers + mcs.erase_all.put(0) + mcs.stage() + assert mcs._staged == ophyd.Staged.yes + assert mcs.erase_all.get() == 1 + assert mcs.preset_real.get() == 0 + assert mcs.num_use_all.get() == triggers -@pytest.mark.parametrize( - "value, num_lines, num_points, done", [(100, 5, 500, False), (500, 5, 500, True)] -) -def test_progress_update(mock_det, value, num_lines, num_points, done): - mock_det.num_lines.set(num_lines) - mock_det.scaninfo.num_points = num_points - calls = mock.call(sub_type="progress", value=value, max_value=num_points, done=done) - with mock.patch.object(mock_det, "_run_subs") as mock_run_subs: - mock_det.custom_prepare._progress_update(value=value) - mock_run_subs.assert_called_once() - assert mock_run_subs.call_args == calls +def test_mcs_card_csaxs_unstage(mock_mcs_csaxs): + """Test unstage method of MCSCardCSAXS""" + mcs = mock_mcs_csaxs + mcs.stop_all.put(0) + mcs.ready_to_read.put(0) + mcs.erase_all.put(1) + mcs.unstage() + assert mcs.stop_all.get() == 1 + assert mcs.ready_to_read.get() == READYTOREAD.DONE + assert mcs.erase_all.get() == 0 -@pytest.mark.parametrize( - "values, expected_nothing", - [([[100, 120, 140], [200, 220, 240], [300, 320, 340]], False), ([100, 200, 300], True)], -) -def test_on_mca_data(mock_det, values, expected_nothing): - """Test the on_mca_data function: - Validate that on_mca_data calls the custom_prepare.on_mca_data() function - """ - with mock.patch.object(mock_det.custom_prepare, "_send_data_to_bec") as mock_send_data: - mock_object = mock.MagicMock() - for ii, name in enumerate(mock_det.custom_prepare.mca_names): - mock_object.attr_name = name - mock_det.custom_prepare._on_mca_data(obj=mock_object, value=values[ii]) - if not expected_nothing and ii < (len(values) - 1): - assert mock_det.custom_prepare.mca_data[name] == values[ii] - - if not expected_nothing: - mock_send_data.assert_called_once() - assert mock_det.custom_prepare.acquisition_done is True +def test_mcs_card_csaxs_complete_and_stop(mock_mcs_csaxs): + """Test complete method of MCSCarcCSAXS""" + mcs = mock_mcs_csaxs + mcs.acquiring._read_pv.mock_data = ACQUIRING.ACQUIRING + st = mcs.complete() + assert st.done is False + mcs.stop_all.put(0) + mcs.ready_to_read.put(READYTOREAD.PROCESSING) + mcs.stop() + with pytest.raises(Exception): + st.wait(timeout=3) + assert st.done is True + assert st.success is False + assert mcs.stop_all.get() == 1 + assert mcs.ready_to_read.get() == READYTOREAD.DONE -@pytest.mark.parametrize( - "metadata, mca_data", - [ - ( - {"scan_id": 123}, - { - "mca1": {"value": [100, 120, 140]}, - "mca3": {"value": [200, 220, 240]}, - "mca4": {"value": [300, 320, 340]}, - }, - ) - ], -) -def test_send_data_to_bec(mock_det, metadata, mca_data): - mock_det.scaninfo.scan_msg = mock.MagicMock() - mock_det.scaninfo.scan_msg.metadata = metadata - mock_det.scaninfo.scan_id = metadata["scan_id"] - mock_det.custom_prepare.mca_data = mca_data - mock_det.custom_prepare._send_data_to_bec() - device_metadata = mock_det.scaninfo.scan_msg.metadata - metadata.update({"async_update": "append", "num_lines": mock_det.num_lines.get()}) - data = messages.DeviceMessage(signals=dict(mca_data), metadata=device_metadata) - calls = mock.call( - topic=MessageEndpoints.device_async_readback( - scan_id=metadata["scan_id"], device=mock_det.name - ), - msg={"data": data}, - expire=1800, - ) - - assert mock_det.connector.xadd.call_args == calls +def test_mcs_card_csaxs_on_counter_updated(mock_mcs_csaxs): + mcs = mock_mcs_csaxs + # Called for mca1 + kwargs = {"obj": mcs.counters.mca1} + mcs._on_counter_update(1, **kwargs) + assert mcs.mcs.mca1.get() == 1 + assert mcs.bpm.current1.get() == 1 + assert mcs.counter_updated == [mcs.counters.mca1.name] + # Called for mca2 + kwargs = {"obj": mcs.counters.mca2} + mcs._on_counter_update(np.array([2, 4]), **kwargs) + assert mcs.mcs.mca2.get() == [2, 4] + assert np.isclose(mcs.bpm.current2.get(), 3) + assert mcs.counter_updated == [mcs.counters.mca1.name, mcs.counters.mca2.name] + # Called for mca3 + kwargs = {"obj": mcs.counters.mca3} + mcs._on_counter_update(1000, **kwargs) + assert mcs.mcs.mca3.get() == 1000 + assert mcs.bpm.current3.get() == 1000 + assert mcs.counter_updated == [ + mcs.counters.mca1.name, + mcs.counters.mca2.name, + mcs.counters.mca3.name, + ] + # Called for mca4 + kwargs = {"obj": mcs.counters.mca4} + mcs._on_counter_update(np.array([20, 40]), **kwargs) + assert mcs.mcs.mca4.get() == [20, 40] + assert np.isclose(mcs.bpm.current4.get(), 30) + assert mcs.counter_updated == [ + mcs.counters.mca1.name, + mcs.counters.mca2.name, + mcs.counters.mca3.name, + mcs.counters.mca4.name, + ] + # Called for mca5 + assert mcs.ready_to_read.get() == 0 + kwargs = {"obj": mcs.counters.mca5} + mcs._on_counter_update(np.array([10000, 10000]), **kwargs) + assert np.isclose(mcs.bpm.count_time.get(), 10000 / 1e7) + assert mcs.mcs.mca5.get() == [10000, 10000] -@pytest.mark.parametrize( - "scaninfo, triggersource, stopped, expected_exception", - [ - ( - {"num_points": 500, "frames_per_trigger": 1, "scan_type": "step"}, - TriggerSource.MODE3, - False, - False, - ), - ( - {"num_points": 500, "frames_per_trigger": 1, "scan_type": "fly"}, - TriggerSource.MODE3, - False, - False, - ), - ( - {"num_points": 5001, "frames_per_trigger": 2, "scan_type": "step"}, - TriggerSource.MODE3, - False, - True, - ), - ( - {"num_points": 500, "frames_per_trigger": 2, "scan_type": "random"}, - TriggerSource.MODE3, - False, - True, - ), - ], -) -def test_stage(mock_det, scaninfo, triggersource, stopped, expected_exception): - mock_det.scaninfo.num_points = scaninfo["num_points"] - mock_det.scaninfo.frames_per_trigger = scaninfo["frames_per_trigger"] - mock_det.scaninfo.scan_type = scaninfo["scan_type"] - mock_det.stopped = stopped - with mock.patch.object(mock_det.custom_prepare, "prepare_detector_backend") as mock_prep_fw: - if expected_exception: - with pytest.raises(MCSError): - mock_det.stage() - mock_prep_fw.assert_called_once() - else: - mock_det.stage() - mock_prep_fw.assert_called_once() - # Check set_trigger - mock_det.input_mode.get() == triggersource - if scaninfo["scan_type"] == "step": - assert mock_det.num_use_all.get() == int(scaninfo["frames_per_trigger"]) * int( - scaninfo["num_points"] - ) - elif scaninfo["scan_type"] == "fly": - assert mock_det.num_use_all.get() == int(scaninfo["num_points"]) - mock_det.preset_real.get() == 0 - - # # CHeck custom_prepare.arm_acquisition - # assert mock_det.custom_prepare.counter == 0 - # assert mock_det.erase_start.get() == 1 - # mock_prep_fw.assert_called_once() - # # Check _prep_det - # assert mock_det.cam.num_images.get() == int( - # scaninfo["num_points"] * scaninfo["frames_per_trigger"] - # ) - # assert mock_det.cam.num_frames.get() == 1 - - # mock_publish_file_location.assert_called_with(done=False) - # assert mock_det.cam.acquire.get() == 1 +# @pytest.fixture(scope="function") +# def mock_det(): +# name = "mcs" +# prefix = "X12SA-MCS:" +# dm = DMMock() +# with mock.patch.object(dm, "connector"): +# with ( +# mock.patch( +# "ophyd_devices.interfaces.base_classes.bec_device_base.FileWriter" +# ) as filemixin, +# mock.patch( +# "ophyd_devices.interfaces.base_classes.psi_detector_base.PSIDetectorBase._update_service_config" +# ) as mock_service_config, +# ): +# with mock.patch.object(ophyd, "cl") as mock_cl: +# mock_cl.get_pv = MockPV +# mock_cl.thread_class = threading.Thread +# with mock.patch.object(MCScSAXS, "_init"): +# det = MCScSAXS(name=name, prefix=prefix, device_manager=dm) +# patch_dual_pvs(det) +# det.TIMEOUT_FOR_SIGNALS = 0.1 +# yield det -def test_prepare_detector_backend(mock_det): - mock_det.custom_prepare.prepare_detector_backend() - assert mock_det.erase_all.get() == 1 - assert mock_det.read_mode.get() == ReadoutMode.EVENT +# def test_init(): +# """Test the _init function:""" +# name = "eiger" +# prefix = "X12SA-ES-EIGER9M:" +# dm = DMMock() +# with mock.patch.object(dm, "connector"): +# with ( +# mock.patch("ophyd_devices.interfaces.base_classes.bec_device_base.FileWriter"), +# mock.patch( +# "ophyd_devices.interfaces.base_classes.psi_detector_base.PSIDetectorBase._update_service_config" +# ), +# ): +# with mock.patch.object(ophyd, "cl") as mock_cl: +# mock_cl.get_pv = MockPV +# with ( +# mock.patch( +# "csaxs_bec.devices.epics.mcs_csaxs.MCSSetup.initialize_detector" +# ) as mock_init_det, +# mock.patch( +# "csaxs_bec.devices.epics.mcs_csaxs.MCSSetup.initialize_detector_backend" +# ) as mock_init_backend, +# ): +# MCScSAXS(name=name, prefix=prefix, device_manager=dm) +# mock_init_det.assert_called_once() +# mock_init_backend.assert_called_once() -def test_complete(mock_det): - with (mock.patch.object(mock_det.custom_prepare, "finished") as mock_finished,): - mock_det.complete() - assert mock_finished.call_count == 1 +# @pytest.mark.parametrize( +# "trigger_source, channel_advance, channel_source1, pv_channels", +# [ +# ( +# 3, +# 1, +# 0, +# { +# "user_led": 0, +# "mux_output": 5, +# "input_pol": 0, +# "output_pol": 1, +# "count_on_start": 0, +# "stop_all": 1, +# }, +# ) +# ], +# ) +# def test_initialize_detector( +# mock_det, trigger_source, channel_advance, channel_source1, pv_channels +# ): +# """Test the _init function: + +# This includes testing the functions: +# - initialize_detector +# - stop_det +# - parent.set_trigger +# --> Testing the filewriter is done in test_init_filewriter + +# Validation upon setting the correct PVs + +# """ +# mock_det.custom_prepare.initialize_detector() # call the method you want to test +# assert mock_det.channel_advance.get() == channel_advance +# assert mock_det.channel1_source.get() == channel_source1 +# assert mock_det.user_led.get() == pv_channels["user_led"] +# assert mock_det.mux_output.get() == pv_channels["mux_output"] +# assert mock_det.input_polarity.get() == pv_channels["input_pol"] +# assert mock_det.output_polarity.get() == pv_channels["output_pol"] +# assert mock_det.count_on_start.get() == pv_channels["count_on_start"] +# assert mock_det.input_mode.get() == trigger_source -def test_stop_detector_backend(mock_det): - mock_det.custom_prepare.stop_detector_backend() - assert mock_det.custom_prepare.acquisition_done is True +# def test_trigger(mock_det): +# """Test the trigger function: +# Validate that trigger calls the custom_prepare.on_trigger() function +# """ +# with mock.patch.object(mock_det.custom_prepare, "on_trigger") as mock_on_trigger: +# mock_det.trigger() +# mock_on_trigger.assert_called_once() -def test_stop(mock_det): - with ( - mock.patch.object(mock_det.custom_prepare, "stop_detector") as mock_stop_det, - mock.patch.object( - mock_det.custom_prepare, "stop_detector_backend" - ) as mock_stop_detector_backend, - ): - mock_det.stop() - mock_stop_det.assert_called_once() - mock_stop_detector_backend.assert_called_once() - assert mock_det.stopped is True +# @pytest.mark.parametrize( +# "value, num_lines, num_points, done", [(100, 5, 500, False), (500, 5, 500, True)] +# ) +# def test_progress_update(mock_det, value, num_lines, num_points, done): +# mock_det.num_lines.set(num_lines) +# mock_det.scaninfo.num_points = num_points +# calls = mock.call(sub_type="progress", value=value, max_value=num_points, done=done) +# with mock.patch.object(mock_det, "_run_subs") as mock_run_subs: +# mock_det.custom_prepare._progress_update(value=value) +# mock_run_subs.assert_called_once() +# assert mock_run_subs.call_args == calls -@pytest.mark.parametrize( - "stopped, acquisition_done, acquiring_state, expected_exception", - [ - (False, True, 0, False), - (False, False, 0, True), - (False, True, 1, True), - (True, True, 0, True), - ], -) -def test_finished(mock_det, stopped, acquisition_done, acquiring_state, expected_exception): - mock_det.custom_prepare.acquisition_done = acquisition_done - mock_det.acquiring._read_pv.mock_data = acquiring_state - mock_det.scaninfo.num_points = 500 - mock_det.num_lines.put(500) - mock_det.current_channel._read_pv.mock_data = 1 - mock_det.stopped = stopped +# @pytest.mark.parametrize( +# "values, expected_nothing", +# [([[100, 120, 140], [200, 220, 240], [300, 320, 340]], False), ([100, 200, 300], True)], +# ) +# def test_on_mca_data(mock_det, values, expected_nothing): +# """Test the on_mca_data function: +# Validate that on_mca_data calls the custom_prepare.on_mca_data() function +# """ +# with mock.patch.object(mock_det.custom_prepare, "_send_data_to_bec") as mock_send_data: +# mock_object = mock.MagicMock() +# for ii, name in enumerate(mock_det.custom_prepare.mca_names): +# mock_object.attr_name = name +# mock_det.custom_prepare._on_mca_data(obj=mock_object, value=values[ii]) +# if not expected_nothing and ii < (len(values) - 1): +# assert mock_det.custom_prepare.mca_data[name] == values[ii] - if expected_exception: - with pytest.raises(MCSTimeoutError): - mock_det.timeout = 0.1 - mock_det.custom_prepare.finished() - else: - mock_det.custom_prepare.finished() - if stopped: - assert mock_det.stopped is stopped +# if not expected_nothing: +# mock_send_data.assert_called_once() +# assert mock_det.custom_prepare.acquisition_done is True + + +# @pytest.mark.parametrize( +# "metadata, mca_data", +# [ +# ( +# {"scan_id": 123}, +# { +# "mca1": {"value": [100, 120, 140]}, +# "mca3": {"value": [200, 220, 240]}, +# "mca4": {"value": [300, 320, 340]}, +# }, +# ) +# ], +# ) +# def test_send_data_to_bec(mock_det, metadata, mca_data): +# mock_det.scaninfo.scan_msg = mock.MagicMock() +# mock_det.scaninfo.scan_msg.metadata = metadata +# mock_det.scaninfo.scan_id = metadata["scan_id"] +# mock_det.custom_prepare.mca_data = mca_data +# mock_det.custom_prepare._send_data_to_bec() +# device_metadata = mock_det.scaninfo.scan_msg.metadata +# metadata.update({"async_update": "append", "num_lines": mock_det.num_lines.get()}) +# data = messages.DeviceMessage(signals=dict(mca_data), metadata=device_metadata) +# calls = mock.call( +# topic=MessageEndpoints.device_async_readback( +# scan_id=metadata["scan_id"], device=mock_det.name +# ), +# msg={"data": data}, +# expire=1800, +# ) + +# assert mock_det.connector.xadd.call_args == calls + + +# @pytest.mark.parametrize( +# "scaninfo, triggersource, stopped, expected_exception", +# [ +# ( +# {"num_points": 500, "frames_per_trigger": 1, "scan_type": "step"}, +# TriggerSource.MODE3, +# False, +# False, +# ), +# ( +# {"num_points": 500, "frames_per_trigger": 1, "scan_type": "fly"}, +# TriggerSource.MODE3, +# False, +# False, +# ), +# ( +# {"num_points": 5001, "frames_per_trigger": 2, "scan_type": "step"}, +# TriggerSource.MODE3, +# False, +# True, +# ), +# ( +# {"num_points": 500, "frames_per_trigger": 2, "scan_type": "random"}, +# TriggerSource.MODE3, +# False, +# True, +# ), +# ], +# ) +# def test_stage(mock_det, scaninfo, triggersource, stopped, expected_exception): +# mock_det.scaninfo.num_points = scaninfo["num_points"] +# mock_det.scaninfo.frames_per_trigger = scaninfo["frames_per_trigger"] +# mock_det.scaninfo.scan_type = scaninfo["scan_type"] +# mock_det.stopped = stopped +# with mock.patch.object(mock_det.custom_prepare, "prepare_detector_backend") as mock_prep_fw: +# if expected_exception: +# with pytest.raises(MCSError): +# mock_det.stage() +# mock_prep_fw.assert_called_once() +# else: +# mock_det.stage() +# mock_prep_fw.assert_called_once() +# # Check set_trigger +# mock_det.input_mode.get() == triggersource +# if scaninfo["scan_type"] == "step": +# assert mock_det.num_use_all.get() == int(scaninfo["frames_per_trigger"]) * int( +# scaninfo["num_points"] +# ) +# elif scaninfo["scan_type"] == "fly": +# assert mock_det.num_use_all.get() == int(scaninfo["num_points"]) +# mock_det.preset_real.get() == 0 + +# # # CHeck custom_prepare.arm_acquisition +# # assert mock_det.custom_prepare.counter == 0 +# # assert mock_det.erase_start.get() == 1 +# # mock_prep_fw.assert_called_once() +# # # Check _prep_det +# # assert mock_det.cam.num_images.get() == int( +# # scaninfo["num_points"] * scaninfo["frames_per_trigger"] +# # ) +# # assert mock_det.cam.num_frames.get() == 1 + +# # mock_publish_file_location.assert_called_with(done=False) +# # assert mock_det.cam.acquire.get() == 1 + + +# def test_prepare_detector_backend(mock_det): +# mock_det.custom_prepare.prepare_detector_backend() +# assert mock_det.erase_all.get() == 1 +# assert mock_det.read_mode.get() == ReadoutMode.EVENT + + +# def test_complete(mock_det): +# with (mock.patch.object(mock_det.custom_prepare, "finished") as mock_finished,): +# mock_det.complete() +# assert mock_finished.call_count == 1 + + +# def test_stop_detector_backend(mock_det): +# mock_det.custom_prepare.stop_detector_backend() +# assert mock_det.custom_prepare.acquisition_done is True + + +# def test_stop(mock_det): +# with ( +# mock.patch.object(mock_det.custom_prepare, "stop_detector") as mock_stop_det, +# mock.patch.object( +# mock_det.custom_prepare, "stop_detector_backend" +# ) as mock_stop_detector_backend, +# ): +# mock_det.stop() +# mock_stop_det.assert_called_once() +# mock_stop_detector_backend.assert_called_once() +# assert mock_det.stopped is True + + +# @pytest.mark.parametrize( +# "stopped, acquisition_done, acquiring_state, expected_exception", +# [ +# (False, True, 0, False), +# (False, False, 0, True), +# (False, True, 1, True), +# (True, True, 0, True), +# ], +# ) +# def test_finished(mock_det, stopped, acquisition_done, acquiring_state, expected_exception): +# mock_det.custom_prepare.acquisition_done = acquisition_done +# mock_det.acquiring._read_pv.mock_data = acquiring_state +# mock_det.scaninfo.num_points = 500 +# mock_det.num_lines.put(500) +# mock_det.current_channel._read_pv.mock_data = 1 +# mock_det.stopped = stopped + +# if expected_exception: +# with pytest.raises(MCSTimeoutError): +# mock_det.timeout = 0.1 +# mock_det.custom_prepare.finished() +# else: +# mock_det.custom_prepare.finished() +# if stopped: +# assert mock_det.stopped is stopped