diff --git a/ophyd_devices/interfaces/base_classes/psi_device_base.py b/ophyd_devices/interfaces/base_classes/psi_device_base.py index dcb7d71..def17ea 100644 --- a/ophyd_devices/interfaces/base_classes/psi_device_base.py +++ b/ophyd_devices/interfaces/base_classes/psi_device_base.py @@ -6,14 +6,16 @@ from __future__ import annotations import inspect import time -from typing import Any, Callable +from typing import TYPE_CHECKING, Callable -from bec_lib.devicemanager import ScanInfo from ophyd import Device, DeviceStatus, Staged, StatusBase from ophyd_devices.tests.utils import get_mock_scan_info from ophyd_devices.utils.psi_device_base_utils import FileHandler, TaskHandler +if TYPE_CHECKING: # pragma: no cover + from bec_lib.devicemanager import DeviceManagerBase, ScanInfo + class DeviceStoppedError(Exception): """Exception raised when a device is stopped""" @@ -37,7 +39,15 @@ class PSIDeviceBase(Device): SUB_DEVICE_MONITOR_2D = "device_monitor_2d" _default_sub = SUB_VALUE - def __init__(self, *, name: str, prefix: str = "", scan_info: ScanInfo | None = None, **kwargs): # type: ignore + def __init__( + self, + *, + name: str, + prefix: str = "", + scan_info: ScanInfo | None = None, + device_manager: DeviceManagerBase | None = None, + **kwargs, + ): """ Initialize the PSI Device Base class. @@ -49,9 +59,10 @@ class PSIDeviceBase(Device): # This is to avoid issues with ophyd.OphydObject.__init__ when the parent is ophyd.Device # and the device_manager is passed to it. This will cause a TypeError. sig = inspect.signature(super().__init__) - if "device_manager" not in sig.parameters: - kwargs.pop("device_manager", None) - super().__init__(prefix=prefix, name=name, **kwargs) + if "device_manager" in sig.parameters: + super().__init__(device_manager=device_manager, prefix=prefix, name=name, **kwargs) + else: + super().__init__(prefix=prefix, name=name, **kwargs) self._stopped = False self.task_handler = TaskHandler(parent=self) self.file_utils = FileHandler() diff --git a/tests/test_psi_device_base.py b/tests/test_psi_device_base.py index ff1fefb..b40bddf 100644 --- a/tests/test_psi_device_base.py +++ b/tests/test_psi_device_base.py @@ -7,6 +7,7 @@ from ophyd import Device from ophyd.status import StatusBase from ophyd_devices.interfaces.base_classes.psi_device_base import DeviceStoppedError, PSIDeviceBase +from ophyd_devices.sim.sim_camera import SimCamera from ophyd_devices.sim.sim_positioner import SimPositioner # pylint: disable=redefined-outer-name @@ -18,17 +19,24 @@ class SimPositionerDevice(PSIDeviceBase, SimPositioner): class SimDevice(PSIDeviceBase, Device): - """Test Device with ohyd.Device as base class""" + """Simulated Device with PSI Device Base""" + + +@pytest.fixture +def device_positioner(): + """Fixture for Device""" + yield SimPositionerDevice(name="device") @pytest.fixture def device(): """Fixture for Device""" - yield SimPositionerDevice(name="device") + yield SimDevice(name="device", prefix="test:") -def test_psi_device_base_wait_for_signals(device): +def test_psi_device_base_wait_for_signals(device_positioner): """Test wait_for_signals method""" + device: SimPositionerDevice = device_positioner device.motor_is_moving.set(1).wait() def check_motor_is_moving(): @@ -67,8 +75,9 @@ def test_psi_device_base_init_with_device_manager(): dm = mock.MagicMock() device = SimPositionerDevice(name="device", device_manager=dm) assert device.device_manager is dm - device2 = SimDevice(name="device2", device_manager=dm) - assert getattr(device2, "device_manager", None) is None + # device_manager should b passed to SimCamera through PSIDeviceBase + device_2 = SimCamera(name="device", device_manager=dm) + assert device_2.device_manager is dm def test_on_stage_hook(device):