Files
debye_bec/tests/tests_services/test_file_writer_service.py

174 lines
6.5 KiB
Python

from unittest import mock
import numpy as np
import pytest
from bec_lib import messages
from bec_lib.endpoints import MessageEndpoints
from bec_lib.redis_connector import MessageObject
from bec_lib.service_config import ServiceConfig
from debye_bec.services.NIDAQ_writer import NIDAQWriterService
def test_nidaq_starts_consumers():
with mock.patch.object(NIDAQWriterService, "_start_scan_status_consumer") as mock_scan_start:
with mock.patch.object(NIDAQWriterService, "_start_ni_data_consumer") as mock_data_start:
with mock.patch.object(NIDAQWriterService, "_start_ni_writer") as mock_writer_start:
NIDAQWriterService(
config=ServiceConfig(redis={"host": "test", "port": 6379}),
connector_cls=mock.MagicMock(),
)
mock_scan_start.assert_called_once()
mock_data_start.assert_called_once()
mock_writer_start.assert_called_once()
class NIWriterMock(NIDAQWriterService):
def _start_ni_data_consumer(self) -> None:
pass
def _start_ni_writer(self) -> None:
pass
@pytest.fixture(scope="function")
def nidaq():
service = NIWriterMock(
config=ServiceConfig(redis={"host": "test", "port": 6379}), connector_cls=mock.MagicMock()
)
yield service
def test_nidaq_scan_status_consumer(nidaq):
nidaq.connector.consumer.assert_called_once_with(
MessageEndpoints.scan_status(), cb=nidaq._scan_status_callback, parent=nidaq
)
nidaq._scan_status_consumer.start.assert_called_once()
def test_scan_status_callback(nidaq):
scan_status_msg = messages.ScanStatusMessage(scan_id="test", status="open", info={})
msg_obj = MessageObject(topic="test", value=scan_status_msg.dumps())
with mock.patch.object(nidaq, "handle_scan_status") as mock_handle:
nidaq._scan_status_callback(msg_obj, nidaq)
mock_handle.assert_called_once_with(scan_status_msg)
def test_nidaq_doesnt_read_data_when_scan_is_not_running(nidaq):
nidaq.scan_is_running = False
with mock.patch.object(nidaq, "writer_mixin") as mock_writer:
nidaq._read_data()
mock_writer.compile_full_filename.assert_not_called()
def test_nidaq_reads_data(nidaq):
nidaq.scan_is_running = True
nidaq.use_redis_stream = False
with mock.patch.object(nidaq, "writer_mixin") as mock_writer:
with mock.patch.object(nidaq, "handle_ni_data") as mock_handle:
nidaq._read_data()
mock_writer.compile_full_filename.assert_called_once()
mock_handle.assert_called_once()
def test_nidaq_reads_data_from_strea(nidaq):
nidaq.scan_is_running = True
nidaq.use_redis_stream = True
with mock.patch.object(nidaq, "writer_mixin") as mock_writer:
with mock.patch.object(nidaq, "handle_ni_data") as mock_handle:
nidaq._read_data()
mock_writer.compile_full_filename.assert_called_once()
mock_handle.assert_called_once()
@pytest.mark.parametrize("scan_status", ["open", "closed", "aborted", "halted"])
def test_nidaq_handle_scan_status(nidaq, scan_status):
scan_status_msg = messages.ScanStatusMessage(
scan_id="test", status=scan_status, info={"scan_number": 5}
)
nidaq.handle_scan_status(scan_status_msg)
if scan_status == "open":
assert nidaq.scan_is_running
assert nidaq.scan_number == 5
else:
assert not nidaq.scan_is_running
assert nidaq.scan_number is None
def test_nidaq_handle_ni_data(nidaq):
data = [
messages.DeviceMessage(
signals={"signal1": {"value": list(range(10))}, "signal2": {"value": list(range(10))}}
),
messages.DeviceMessage(
signals={
"signal1": {"value": list(range(10, 20))},
"signal2": {"value": list(range(10, 20))},
}
),
]
nidaq.handle_ni_data(data)
signal = nidaq.queue.get()
assert all(signal["signal1"] == np.asarray(range(20)))
assert all(signal["signal2"] == np.asarray(range(20)))
def test_nidaq_write_data_without_filename(nidaq):
signal = {"signal1": np.asarray(range(20)), "signal2": np.asarray(range(20))}
with mock.patch("debye_bec.services.NIDAQ_writer.NIDAQ_writer.h5py") as mock_h5py:
nidaq.write_data(signal)
mock_h5py.File.assert_not_called()
def test_nidaq_write_data_with_filename(nidaq):
signal = {"signal1": np.asarray(range(20)), "signal2": np.asarray(range(20))}
nidaq.filename = "test.h5"
with mock.patch("debye_bec.services.NIDAQ_writer.NIDAQ_writer.h5py") as mock_h5py:
nidaq.write_data(signal)
mock_h5py.File.assert_called_once_with("test.h5", "a")
def test_nidaq_write_data_reshape(nidaq):
signal = {"signal1": np.asarray(range(20)), "signal2": np.asarray(range(20))}
nidaq.filename = "test.h5"
nidaq.reshape_dataset = True
with mock.patch("debye_bec.services.NIDAQ_writer.NIDAQ_writer.h5py") as mock_h5py:
nidaq.write_data(signal)
mock_h5py.File.assert_called_once_with("test.h5", "a")
def test_nidaq_write_data_without_reshape(nidaq):
signal = {"signal1": np.asarray(range(20)), "signal2": np.asarray(range(20))}
nidaq.filename = "test.h5"
nidaq.reshape_dataset = False
with mock.patch("debye_bec.services.NIDAQ_writer.NIDAQ_writer.h5py") as mock_h5py:
nidaq.write_data(signal)
mock_h5py.File.assert_called_once_with("test.h5", "a")
file_handle = mock_h5py.File().__enter__()
file_handle.create_group.assert_called_once_with("dataset_0")
calls = file_handle.create_group().create_dataset.call_args_list
assert calls[0] == mock.call(
"signal1", data=signal["signal1"], chunks=True, maxshape=(None,)
)
assert calls[1] == mock.call(
"signal2", data=signal["signal2"], chunks=True, maxshape=(None,)
)
file_handle.keys.return_value = ["dataset_0"]
nidaq.write_data(signal)
assert mock.call("dataset_1") in file_handle.create_group.call_args_list
def test_nidaq_write_data_reshapes_data(nidaq):
signal = {"signal1": np.asarray(range(20)), "signal2": np.asarray(range(20))}
nidaq.filename = "test.h5"
nidaq.reshape_dataset = True
with mock.patch("debye_bec.services.NIDAQ_writer.NIDAQ_writer.h5py") as mock_h5py:
file_handle = mock_h5py.File().__enter__()
file_handle.__contains__.side_effect = signal.__contains__
nidaq.write_data(signal)
dataset = file_handle["signal1"]
assert len(dataset.resize.call_args_list) == 2
assert mock.call("test.h5", "a") in mock_h5py.File.call_args_list