diff --git a/ophyd_devices/interfaces/base_classes/psi_device_base.py b/ophyd_devices/interfaces/base_classes/psi_device_base.py index c34e67a..2729723 100644 --- a/ophyd_devices/interfaces/base_classes/psi_device_base.py +++ b/ophyd_devices/interfaces/base_classes/psi_device_base.py @@ -42,7 +42,7 @@ class PSIDeviceBase(Device): self.task_handler = TaskHandler(parent=self) self.file_utils = FileHandler() if scan_info is None: - scan_info = get_mock_scan_info() + scan_info = get_mock_scan_info(device=self) self.scan_info = scan_info self.on_init() diff --git a/ophyd_devices/tests/utils.py b/ophyd_devices/tests/utils.py index d7ec12a..231c3e0 100644 --- a/ophyd_devices/tests/utils.py +++ b/ophyd_devices/tests/utils.py @@ -1,12 +1,12 @@ """ Utilities to mock and test devices.""" -from dataclasses import dataclass from typing import TYPE_CHECKING from unittest import mock from bec_lib.devicemanager import ScanInfo from bec_lib.logger import bec_logger from bec_lib.utils.import_utils import lazy_import_from +from ophyd import Device if TYPE_CHECKING: from bec_lib.messages import ScanStatusMessage @@ -273,19 +273,27 @@ class MockPV: return data["value"] if data is not None else None -def get_mock_scan_info(): +def get_mock_scan_info(device: Device | None) -> ScanInfo: """ Get a mock scan info object. """ - return ScanInfo(msg=fake_scan_status_msg()) + return ScanInfo(msg=fake_scan_status_msg(device=device)) -def fake_scan_status_msg(): +def fake_scan_status_msg(device: Device | None = None) -> ScanStatusMessage: """ Create a fake scan status message. + + Args: + device: The device creating the fake scan status message. + """ + if device is None: + device = Device(name="mock_device") logger.warning( - ("Device is not connected to a Redis server. Fetching mocked ScanStatusMessage.") + ( + f"Device {device.name} is not connected to a Redis server. Fetching mocked ScanStatusMessage." + ) ) return ScanStatusMessage( metadata={}, diff --git a/tests/test_simulation.py b/tests/test_simulation.py index cf61855..53c1535 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -697,7 +697,7 @@ def test_async_mon_on_trigger(async_monitor): def test_async_mon_send_data_to_bec(async_monitor): """Test the _send_data_to_bec method of SimMonitorAsync.""" - async_monitor.scan_info = get_mock_scan_info() + async_monitor.scan_info = get_mock_scan_info(device=async_monitor) async_monitor.data_buffer.update({"value": [0, 5], "timestamp": [0, 0]}) with mock.patch.object(async_monitor.connector, "xadd") as mock_xadd: async_monitor._send_data_to_bec() @@ -747,7 +747,7 @@ def test_waveform(waveform): # Now also test the async readback mock_connector = waveform.connector = mock.MagicMock() mock_run_subs = waveform._run_subs = mock.MagicMock() - waveform.scan_info = get_mock_scan_info() + waveform.scan_info = get_mock_scan_info(device=waveform) waveform.scan_info.msg.scan_id = "test" status = waveform.trigger() timer = 0