286 lines
10 KiB
Python
286 lines
10 KiB
Python
# pylint: skip-file
|
|
import threading
|
|
from copy import deepcopy
|
|
from typing import Generator
|
|
from unittest import mock
|
|
|
|
import numpy as np
|
|
import ophyd
|
|
import pytest
|
|
from bec_lib import messages
|
|
from bec_lib.endpoints import MessageEndpoints
|
|
from bec_server.device_server.tests.utils import DMMock
|
|
from ophyd_devices.interfaces.base_classes.psi_device_base import DeviceStoppedError
|
|
from ophyd_devices.tests.utils import MockPV, patch_dual_pvs
|
|
|
|
from csaxs_bec.devices.epics.mcs_card.mcs_card import (
|
|
ACQUIREMODE,
|
|
ACQUIRING,
|
|
CHANNEL1SOURCE,
|
|
CHANNELADVANCE,
|
|
INPUTMODE,
|
|
OUTPUTMODE,
|
|
POLARITY,
|
|
READMODE,
|
|
MCSCard,
|
|
)
|
|
from csaxs_bec.devices.epics.mcs_card.mcs_card_csaxs import MCSCardCSAXS
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def mock_mcs_card():
|
|
"""Fixture to mock the MCSCard device."""
|
|
name = "mcs_card"
|
|
prefix = "X12SA-MCS:"
|
|
with mock.patch.object(ophyd, "cl") as mock_cl:
|
|
mock_cl.get_pv = MockPV
|
|
mock_cl.thread_class = threading
|
|
mcs_card = MCSCard(name=name, prefix=prefix)
|
|
patch_dual_pvs(mcs_card)
|
|
yield mcs_card
|
|
|
|
|
|
def test_mcs_card(mock_mcs_card):
|
|
"""Test the MCSCard initialization."""
|
|
assert mock_mcs_card.name == "mcs_card"
|
|
assert mock_mcs_card.prefix == "X12SA-MCS:"
|
|
assert len(mock_mcs_card.counters.component_names) == 32
|
|
assert mock_mcs_card.counters.mca1.name == "mcs_card_counters_mca1"
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def mock_mcs_csaxs() -> Generator[MCSCardCSAXS, None, None]:
|
|
"""Fixture to mock the MCSCardCSAXS device."""
|
|
name = "mcs_csaxs"
|
|
prefix = "X12SA-MCS-CSAXS:"
|
|
dm = DMMock()
|
|
try:
|
|
with mock.patch.object(ophyd, "cl") as mock_cl:
|
|
mock_cl.get_pv = MockPV
|
|
mock_cl.thread_class = threading.Thread
|
|
mcs_card_csaxs = MCSCardCSAXS(name=name, prefix=prefix, device_manager=dm)
|
|
patch_dual_pvs(mcs_card_csaxs)
|
|
yield mcs_card_csaxs
|
|
finally:
|
|
mcs_card_csaxs.on_destroy()
|
|
|
|
|
|
def test_mcs_card_csaxs(mock_mcs_csaxs: MCSCardCSAXS):
|
|
"""Test the MCSCardCSAXS initialization."""
|
|
assert mock_mcs_csaxs.name == "mcs_csaxs"
|
|
assert mock_mcs_csaxs.prefix == "X12SA-MCS-CSAXS:"
|
|
assert mock_mcs_csaxs._acquisition_group == "monitored"
|
|
assert mock_mcs_csaxs._num_total_triggers == 0
|
|
assert mock_mcs_csaxs._mcs_clock == 1e7
|
|
assert mock_mcs_csaxs._pv_timeout == 2.0
|
|
assert mock_mcs_csaxs._mca_counter_index == 0
|
|
assert mock_mcs_csaxs._current_data_index == 0
|
|
assert mock_mcs_csaxs._current_data == {}
|
|
assert mock_mcs_csaxs.NUM_MCA_CHANNELS == 32
|
|
|
|
|
|
def test_mcs_card_csaxs_on_connected(mock_mcs_csaxs: MCSCardCSAXS):
|
|
"""Test the on_connected method of MCSCardCSAXS."""
|
|
mcs = mock_mcs_csaxs
|
|
with (
|
|
mock.patch.object(mcs.counters.mca1, "subscribe") as mock_mca_subscribe,
|
|
mock.patch.object(mcs, "mcs_recovery") as mock_mcs_recovery,
|
|
mock.patch.object(mcs._scan_done_thread, "start") as mock_scan_done_thread_start,
|
|
):
|
|
mcs.on_connected()
|
|
# Stop called
|
|
assert mcs.stop_all.get() == 1
|
|
# Channel advance settings
|
|
assert mcs.channel_advance.get() == CHANNELADVANCE.EXTERNAL
|
|
assert mcs.channel1_source.get() == CHANNEL1SOURCE.EXTERNAL
|
|
assert mcs.prescale.get() == 1
|
|
assert mcs.user_led.get() == 0
|
|
|
|
# Mux output
|
|
assert mcs.mux_output.get() == mcs.NUM_MCA_CHANNELS
|
|
|
|
# input output settings
|
|
assert mcs.input_mode.get() == INPUTMODE.MODE_3
|
|
assert mcs.input_polarity.get() == POLARITY.NORMAL
|
|
assert mcs.output_mode.get() == OUTPUTMODE.MODE_2
|
|
assert mcs.output_polarity.get() == POLARITY.NORMAL
|
|
assert mcs.count_on_start.get() == 0
|
|
assert mcs.read_mode.get() == READMODE.PASSIVE
|
|
assert mcs.acquire_mode.get() == ACQUIREMODE.MCS
|
|
|
|
# Check if subscriptions are setup correctly
|
|
assert mock_mca_subscribe.call_args == mock.call(mcs._on_counter_update, run=False)
|
|
# Check if recovery is called
|
|
mock_mcs_recovery.assert_called_once_with(timeout=1)
|
|
# Check if scan done thread is started
|
|
mock_scan_done_thread_start.assert_called_once()
|
|
|
|
|
|
def test_mcs_card_csaxs_stage(mock_mcs_csaxs: MCSCardCSAXS):
|
|
"""Test on stage method of MCSCardCSAXS"""
|
|
mcs = mock_mcs_csaxs
|
|
triggers = 5
|
|
num_points = 10
|
|
mcs.scan_info.msg.scan_parameters["frames_per_trigger"] = triggers
|
|
mcs.scan_info.msg.num_points = num_points
|
|
|
|
# Simulate that the MCS card is still acquiring, and that current channel is !=0
|
|
mcs.current_channel._read_pv.mock_data = 2 # Simulate that current channel is not zero
|
|
mcs.erase_all.put(0) # Set erase_all to 0
|
|
mcs._current_data = {"mca1": [1, 2, 3]} # Simulate existing data
|
|
mcs._scan_done_callbacks = [lambda: None] # Simulate existing callbacks
|
|
mcs._start_monitor_async_data_emission.set() # Simulate that monitoring is started
|
|
mcs._omit_mca_callbacks.set() # Simulate that mca callbacks are omitted
|
|
|
|
mcs.stage()
|
|
# Check that card is staged
|
|
assert mcs._staged == ophyd.Staged.yes
|
|
|
|
# Check that erase_all, stop_all, preset_real, num_use_all are set correctly
|
|
assert mcs.erase_all.get() == 1 # Should be set to 1 as current_channel !=0
|
|
assert mcs.preset_real.get() == 0
|
|
assert mcs.num_use_all.get() == triggers
|
|
|
|
# Check that internal variables are reset
|
|
assert mcs._num_total_triggers == triggers * num_points
|
|
assert mcs._current_data == {}
|
|
assert mcs._scan_done_callbacks == []
|
|
assert mcs._current_data_index == 0
|
|
|
|
# Check that thread events are cleared properly
|
|
assert not mcs._start_monitor_async_data_emission.is_set()
|
|
assert not mcs._omit_mca_callbacks.is_set()
|
|
|
|
|
|
def test_mcs_card_csaxs_unstage(mock_mcs_csaxs):
|
|
"""Test unstage method of MCSCardCSAXS"""
|
|
mcs = mock_mcs_csaxs
|
|
mcs.stop_all.put(0)
|
|
mcs.erase_all.put(0)
|
|
mcs.unstage()
|
|
assert mcs.stop_all.get() == 1
|
|
assert mcs.erase_all.get() == 1
|
|
|
|
|
|
def test_mcs_card_csaxs_complete_and_stop(mock_mcs_csaxs: MCSCardCSAXS):
|
|
"""
|
|
Test complete method of MCSCarcCSAXS.
|
|
|
|
Two use cases:
|
|
I. Acquisition is stopped externally
|
|
II. Acquisition completes normally
|
|
"""
|
|
mcs = mock_mcs_csaxs
|
|
mcs.acquiring._read_pv.mock_data = ACQUIRING.ACQUIRING
|
|
# Make sure that device on_connected has been called which starts the monitoring thread
|
|
mcs.on_connected()
|
|
|
|
#######################
|
|
# I. Use case where acquisition is stopped
|
|
#######################
|
|
|
|
st = mcs.complete()
|
|
assert st.done is False
|
|
assert mcs._start_monitor_async_data_emission.is_set()
|
|
|
|
# Status should be cancelled by stop
|
|
mcs.stop()
|
|
with pytest.raises(DeviceStoppedError):
|
|
st.wait(timeout=3)
|
|
|
|
# Callback on status failure should stop monitoring
|
|
mcs._start_monitor_async_data_emission.wait(2)
|
|
assert not mcs._start_monitor_async_data_emission.is_set()
|
|
|
|
#######################
|
|
# II. Use case where acquisition completes normally
|
|
#######################
|
|
|
|
mcs._current_data_index = 0
|
|
mcs.scan_info.msg.num_points = 10
|
|
mcs.acquiring._read_pv.mock_data = ACQUIRING.ACQUIRING
|
|
|
|
st = mcs.complete()
|
|
assert st.done is False
|
|
assert mcs._start_monitor_async_data_emission.is_set()
|
|
|
|
mcs.acquiring._read_pv.mock_data = ACQUIRING.DONE
|
|
|
|
# This should now automatically complete the status
|
|
mcs._current_data_index = 10
|
|
st.wait(timeout=3)
|
|
assert st.done is True
|
|
assert st.success is True
|
|
|
|
# Clean up procedure should stop the async_data monitoring
|
|
mcs._start_monitor_async_data_emission.wait(2)
|
|
assert not mcs._start_monitor_async_data_emission.is_set()
|
|
|
|
|
|
def test_mcs_recovery(mock_mcs_csaxs: MCSCardCSAXS):
|
|
mcs = mock_mcs_csaxs
|
|
# Simulate ongoing acquisition
|
|
mcs.erase_all._read_pv.mock_data = 0
|
|
mcs.stop_all._read_pv.mock_data = 0
|
|
mcs.erase_start.put(0)
|
|
mcs.mcs_recovery(timeout=0.1)
|
|
assert mcs.erase_all.get() == 1
|
|
assert mcs.stop_all.get() == 1
|
|
assert mcs.erase_start.get() == 1
|
|
assert not mcs._omit_mca_callbacks.is_set()
|
|
|
|
|
|
def test_mcs_card_csaxs_on_counter_updated(mock_mcs_csaxs: MCSCardCSAXS):
|
|
"""
|
|
Test the on_counter_update method of MCSCardCSAXS.
|
|
We will test 2 use cases:
|
|
I. Suppressed callbacks
|
|
II. Callback from 32 mca counters, should result in data being sent to BEC
|
|
"""
|
|
mcs = mock_mcs_csaxs
|
|
|
|
# I. Suppressed callbacks
|
|
mcs._omit_mca_callbacks.set()
|
|
kwargs = {"obj": mcs.counters.mca1}
|
|
mcs._on_counter_update(1, **kwargs)
|
|
assert mcs._current_data == {}
|
|
|
|
# II. Callback from 32 mca counters
|
|
mcs._omit_mca_callbacks.clear()
|
|
mcs._mca_counter_index = 0
|
|
mcs._current_data_index = 0
|
|
val = mcs.mca.get()
|
|
|
|
for ii in range(mcs.NUM_MCA_CHANNELS):
|
|
counter = getattr(mcs.counters, f"mca{ii+1}")
|
|
kwargs = {"obj": counter, "timestamp": 1.0}
|
|
if ii % 2 == 1:
|
|
value = np.array([ii, (ii + 1) * 2])
|
|
else:
|
|
value = ii
|
|
mcs._on_counter_update(value, **kwargs)
|
|
if ii < (mcs.NUM_MCA_CHANNELS - 1):
|
|
assert mcs._current_data_index == 0
|
|
assert mcs._mca_counter_index == ii + 1
|
|
assert counter.attr_name in mcs._current_data
|
|
assert (
|
|
mcs._current_data[counter.attr_name]["value"] == value.tolist()
|
|
if isinstance(value, np.ndarray)
|
|
else [value]
|
|
)
|
|
buffer = deepcopy(mcs._current_data)
|
|
assert mcs.mca.get() == val # Async mca signal should not change
|
|
else:
|
|
# On last counter, data should be sent to BEC, and internal variables reset
|
|
buffer[counter.attr_name] = {
|
|
"value": value.tolist() if isinstance(value, np.ndarray) else [value],
|
|
"timestamp": 1.0,
|
|
}
|
|
assert mcs._mca_counter_index == 0
|
|
assert mcs._current_data_index == 1
|
|
assert mcs._current_data == {}
|
|
|
|
# Check that the async mca signal is properly set
|
|
assert isinstance(mcs.mca.get(), messages.DeviceMessage)
|
|
assert len(mcs.mca.get().signals) == mcs.NUM_MCA_CHANNELS
|