171 lines
6.4 KiB
Python
171 lines
6.4 KiB
Python
from unittest import mock
|
|
|
|
import numpy as np
|
|
import pytest
|
|
from bec_lib import MessageEndpoints, ServiceConfig, messages
|
|
from bec_lib.redis_connector import MessageObject
|
|
|
|
from bec_plugins.services.NIDAQ_writer import NIDAQWriterService
|
|
|
|
|
|
def test_nidaq_starts_consumers():
|
|
with mock.patch.object(NIDAQWriterService, "_start_scan_status_consumer") as mock_scan_start:
|
|
with mock.patch.object(NIDAQWriterService, "_start_ni_data_consumer") as mock_data_start:
|
|
with mock.patch.object(NIDAQWriterService, "_start_ni_writer") as mock_writer_start:
|
|
NIDAQWriterService(
|
|
config=ServiceConfig(redis={"host": "test", "port": 6379}),
|
|
connector_cls=mock.MagicMock(),
|
|
)
|
|
mock_scan_start.assert_called_once()
|
|
mock_data_start.assert_called_once()
|
|
mock_writer_start.assert_called_once()
|
|
|
|
|
|
class NIWriterMock(NIDAQWriterService):
|
|
def _start_ni_data_consumer(self) -> None:
|
|
pass
|
|
|
|
def _start_ni_writer(self) -> None:
|
|
pass
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def nidaq():
|
|
service = NIWriterMock(
|
|
config=ServiceConfig(redis={"host": "test", "port": 6379}),
|
|
connector_cls=mock.MagicMock(),
|
|
)
|
|
yield service
|
|
|
|
|
|
def test_nidaq_scan_status_consumer(nidaq):
|
|
nidaq.connector.consumer.assert_called_once_with(
|
|
MessageEndpoints.scan_status(), cb=nidaq._scan_status_callback, parent=nidaq
|
|
)
|
|
nidaq._scan_status_consumer.start.assert_called_once()
|
|
|
|
|
|
def test_scan_status_callback(nidaq):
|
|
scan_status_msg = messages.ScanStatusMessage(scanID="test", status="open", info={})
|
|
msg_obj = MessageObject(
|
|
topic="test",
|
|
value=scan_status_msg.dumps(),
|
|
)
|
|
with mock.patch.object(nidaq, "handle_scan_status") as mock_handle:
|
|
nidaq._scan_status_callback(msg_obj, nidaq)
|
|
mock_handle.assert_called_once_with(scan_status_msg)
|
|
|
|
|
|
def test_nidaq_doesnt_read_data_when_scan_is_not_running(nidaq):
|
|
nidaq.scan_is_running = False
|
|
with mock.patch.object(nidaq, "writer_mixin") as mock_writer:
|
|
nidaq._read_data()
|
|
mock_writer.compile_full_filename.assert_not_called()
|
|
|
|
|
|
def test_nidaq_reads_data(nidaq):
|
|
nidaq.scan_is_running = True
|
|
nidaq.use_redis_stream = False
|
|
with mock.patch.object(nidaq, "writer_mixin") as mock_writer:
|
|
with mock.patch.object(nidaq, "handle_ni_data") as mock_handle:
|
|
nidaq._read_data()
|
|
mock_writer.compile_full_filename.assert_called_once()
|
|
mock_handle.assert_called_once()
|
|
|
|
|
|
def test_nidaq_reads_data_from_strea(nidaq):
|
|
nidaq.scan_is_running = True
|
|
nidaq.use_redis_stream = True
|
|
with mock.patch.object(nidaq, "writer_mixin") as mock_writer:
|
|
with mock.patch.object(nidaq, "handle_ni_data") as mock_handle:
|
|
nidaq._read_data()
|
|
mock_writer.compile_full_filename.assert_called_once()
|
|
mock_handle.assert_called_once()
|
|
|
|
|
|
@pytest.mark.parametrize("scan_status", ["open", "closed", "aborted", "halted", None])
|
|
def test_nidaq_handle_scan_status(nidaq, scan_status):
|
|
scan_status_msg = messages.ScanStatusMessage(
|
|
scanID="test", status=scan_status, info={"scan_number": 5}
|
|
)
|
|
nidaq.handle_scan_status(scan_status_msg)
|
|
if scan_status == "open":
|
|
assert nidaq.scan_is_running
|
|
assert nidaq.scan_number == 5
|
|
else:
|
|
assert not nidaq.scan_is_running
|
|
assert nidaq.scan_number is None
|
|
|
|
|
|
def test_nidaq_handle_ni_data(nidaq):
|
|
data = [
|
|
messages.DeviceMessage(signals={"signal1": list(range(10)), "signal2": list(range(10))}),
|
|
messages.DeviceMessage(
|
|
signals={"signal1": list(range(10, 20)), "signal2": list(range(10, 20))}
|
|
),
|
|
]
|
|
|
|
nidaq.handle_ni_data(data)
|
|
signal = nidaq.queue.get()
|
|
assert all(signal["signal1"] == np.asarray(range(20)))
|
|
assert all(signal["signal2"] == np.asarray(range(20)))
|
|
|
|
|
|
def test_nidaq_write_data_without_filename(nidaq):
|
|
signal = {"signal1": np.asarray(range(20)), "signal2": np.asarray(range(20))}
|
|
|
|
with mock.patch("bec_plugins.services.NIDAQ_writer.NIDAQ_writer.h5py") as mock_h5py:
|
|
nidaq.write_data(signal)
|
|
mock_h5py.File.assert_not_called()
|
|
|
|
|
|
def test_nidaq_write_data_with_filename(nidaq):
|
|
signal = {"signal1": np.asarray(range(20)), "signal2": np.asarray(range(20))}
|
|
nidaq.filename = "test.h5"
|
|
with mock.patch("bec_plugins.services.NIDAQ_writer.NIDAQ_writer.h5py") as mock_h5py:
|
|
nidaq.write_data(signal)
|
|
mock_h5py.File.assert_called_once_with("test.h5", "a")
|
|
|
|
|
|
def test_nidaq_write_data_reshape(nidaq):
|
|
signal = {"signal1": np.asarray(range(20)), "signal2": np.asarray(range(20))}
|
|
nidaq.filename = "test.h5"
|
|
nidaq.reshape_dataset = True
|
|
with mock.patch("bec_plugins.services.NIDAQ_writer.NIDAQ_writer.h5py") as mock_h5py:
|
|
nidaq.write_data(signal)
|
|
mock_h5py.File.assert_called_once_with("test.h5", "a")
|
|
|
|
|
|
def test_nidaq_write_data_without_reshape(nidaq):
|
|
signal = {"signal1": np.asarray(range(20)), "signal2": np.asarray(range(20))}
|
|
nidaq.filename = "test.h5"
|
|
nidaq.reshape_dataset = False
|
|
with mock.patch("bec_plugins.services.NIDAQ_writer.NIDAQ_writer.h5py") as mock_h5py:
|
|
nidaq.write_data(signal)
|
|
mock_h5py.File.assert_called_once_with("test.h5", "a")
|
|
file_handle = mock_h5py.File().__enter__()
|
|
file_handle.create_group.assert_called_once_with("dataset_0")
|
|
calls = file_handle.create_group().create_dataset.call_args_list
|
|
assert calls[0] == mock.call(
|
|
"signal1", data=signal["signal1"], chunks=True, maxshape=(None,)
|
|
)
|
|
assert calls[1] == mock.call(
|
|
"signal2", data=signal["signal2"], chunks=True, maxshape=(None,)
|
|
)
|
|
file_handle.keys.return_value = ["dataset_0"]
|
|
nidaq.write_data(signal)
|
|
assert mock.call("dataset_1") in file_handle.create_group.call_args_list
|
|
|
|
|
|
def test_nidaq_write_data_reshapes_data(nidaq):
|
|
signal = {"signal1": np.asarray(range(20)), "signal2": np.asarray(range(20))}
|
|
nidaq.filename = "test.h5"
|
|
nidaq.reshape_dataset = True
|
|
with mock.patch("bec_plugins.services.NIDAQ_writer.NIDAQ_writer.h5py") as mock_h5py:
|
|
file_handle = mock_h5py.File().__enter__()
|
|
file_handle.__contains__.side_effect = signal.__contains__
|
|
nidaq.write_data(signal)
|
|
dataset = file_handle["signal1"]
|
|
assert len(dataset.resize.call_args_list) == 2
|
|
assert mock.call("test.h5", "a") in mock_h5py.File.call_args_list
|