diff --git a/ophyd_devices/interfaces/base_classes/psi_device_base.py b/ophyd_devices/interfaces/base_classes/psi_device_base.py new file mode 100644 index 0000000..ddc528c --- /dev/null +++ b/ophyd_devices/interfaces/base_classes/psi_device_base.py @@ -0,0 +1,176 @@ +""" +Consider using the bec_device_base name for the base class. +I will use this name instead here to simplify comparisons between the two approaches +""" + +from __future__ import annotations + +from ophyd import Device, DeviceStatus, Staged, StatusBase + +from ophyd_devices.tests.utils import get_mock_scan_info +from ophyd_devices.utils.psi_device_base_utils import FileHandler, TaskHandler + + +class PSIDeviceBase(Device): + """ + Base class for all PSI ophyd devices to ensure consistent configuration + and communication with BEC services. + """ + + # These are all possible subscription types that the device_manager supports + # and automatically subscribes to + SUB_READBACK = "readback" + SUB_VALUE = "value" + SUB_DONE_MOVING = "done_moving" + SUB_MOTOR_IS_MOVING = "motor_is_moving" + SUB_PROGRESS = "progress" + SUB_FILE_EVENT = "file_event" + SUB_DEVICE_MONITOR_1D = "device_monitor_1d" + SUB_DEVICE_MONITOR_2D = "device_monitor_2d" + _default_sub = SUB_VALUE + + def __init__(self, name: str, scan_info=None, **kwargs): + """ + Initialize the PSI Device Base class. + + Args: + name (str) : Name of the device + scan_info (ScanInfo): The scan info to use. + """ + super().__init__(name=name, **kwargs) + self._stopped = False + self.task_handler = TaskHandler(parent=self) + if scan_info is None: + scan_info = get_mock_scan_info() + self.scan_info = scan_info + self.file_utils = FileHandler() + self.on_init() + + ######################################## + # Additional Properties and Attributes # + ######################################## + + @property + def destroyed(self) -> bool: + """Check if the device has been destroyed.""" + return self._destroyed + + @property + def staged(self) -> Staged: + """Check if the device has been staged.""" + return self._staged + + @property + def stopped(self) -> bool: + """Check if the device has been stopped.""" + return self._stopped + + @stopped.setter + def stopped(self, value: bool): + self._stopped = value + + ######################################## + # Wrapper around Device class methods # + ######################################## + + def stage(self) -> list[object] | StatusBase: + """Stage the device.""" + if self.staged != Staged.no: + return super().stage() + self.stopped = False + super_staged = super().stage() + status = self.on_stage() # pylint: disable=assignment-from-no-return + if isinstance(status, StatusBase): + return status + return super_staged + + def unstage(self) -> list[object] | StatusBase: + """Unstage the device.""" + super_unstage = super().unstage() + status = self.on_unstage() # pylint: disable=assignment-from-no-return + if isinstance(status, StatusBase): + return status + return super_unstage + + def pre_scan(self) -> StatusBase | None: + """Pre-scan function.""" + status = self.on_pre_scan() # pylint: disable=assignment-from-no-return + return status + + def trigger(self) -> DeviceStatus: + """Trigger the device.""" + super_trigger = super().trigger() + status = self.on_trigger() # pylint: disable=assignment-from-no-return + return status if status else super_trigger + + def complete(self) -> DeviceStatus: + """Complete the device.""" + status = self.on_complete() # pylint: disable=assignment-from-no-return + if isinstance(status, StatusBase): + return status + status = DeviceStatus(self) + status.set_finished() + return status + + def kickoff(self) -> DeviceStatus: + """Kickoff the device.""" + status = self.on_kickoff() # pylint: disable=assignment-from-no-return + if isinstance(status, StatusBase): + return status + status = DeviceStatus(self) + status.set_finished() + return status + + # pylint: disable=arguments-differ + def stop(self, success: bool = False) -> None: + """Stop the device. + + Args: + success (bool): True if the device was stopped successfully. + """ + self.on_stop() + super().stop(success=success) + self.stopped = True + + ######################################## + # Beamline Specific Implementations # + ######################################## + + def on_init(self) -> None: + """ + Called when the device is initialized. + + No siganls are connected at this point, + thus should not be set here but in on_connected instead. + """ + + def on_connected(self) -> None: + """ + Called after the device is connected and its signals are connected. + Default values for signals should be set here. + """ + + def on_stage(self) -> DeviceStatus | None: + """ + Called while staging the device. + + Information about the upcoming scan can be accessed from the scan_info object. + """ + + def on_unstage(self) -> DeviceStatus | None: + """Called while unstaging the device.""" + + def on_pre_scan(self) -> DeviceStatus | None: + """Called right before the scan starts on all devices automatically.""" + + def on_trigger(self) -> DeviceStatus | None: + """Called when the device is triggered.""" + + def on_complete(self) -> DeviceStatus | None: + """Called to inquire if a device has completed a scans.""" + + def on_kickoff(self) -> DeviceStatus | None: + """Called to kickoff a device for a fly scan. Has to be called explicitly.""" + + def on_stop(self) -> None: + """Called when the device is stopped.""" diff --git a/ophyd_devices/interfaces/protocols/bec_protocols.py b/ophyd_devices/interfaces/protocols/bec_protocols.py index b7b3b06..2b24eae 100644 --- a/ophyd_devices/interfaces/protocols/bec_protocols.py +++ b/ophyd_devices/interfaces/protocols/bec_protocols.py @@ -1,7 +1,9 @@ -""" This module provides a range of protocols that describe the expected interface for different types of devices. +""" This module provides a range of protocols that describe the expected +interface for different types of devices. -The protocols below can be used as teamplates for functionality to be implemeted by different type of devices. -They further facilitate runtime checks on devices and provide a minimum set of properties required for a device to be loadable by BEC. +The protocols below can be used as teamplates for functionality to be implemeted +by different type of devices. They further facilitate runtime checks on devices +and provide a minimum set of properties required for a device to be loadable by BEC. The protocols are: - BECBaseProtocol: Protocol for devices in BEC. All devices must at least implement this protocol. @@ -11,17 +13,15 @@ The protocols are: - BECPositionerProtocol: Protocol for positioners. - BECFlyerProtocol: Protocol with for flyers. -Keep in mind, that a device of type flyer should generally also implement the BECDeviceProtocol that provides the required functionality for scans. -Flyers in addition, also implement the BECFlyerProtocol. Similarly, positioners should also implement the BECDeviceProtocol and BECPositionerProtocol. +Keep in mind, that a device of type flyer should generally also implement the BECDeviceProtocol +with the functionality needed for scans. In addition, flyers also implement the BECFlyerProtocol. +Similarly, positioners should also implement the BECDeviceProtocol and BECPositionerProtocol. """ from typing import Protocol, runtime_checkable -from bec_lib.file_utils import FileWriter -from ophyd import Component, DeviceStatus, Kind, Staged - -from ophyd_devices.utils import bec_scaninfo_mixin +from ophyd import DeviceStatus, Kind, Staged @runtime_checkable @@ -349,8 +349,8 @@ class BECPositionerProtocol(BECDeviceProtocol, Protocol): def move(self, position: float) -> DeviceStatus: """Move method for positioners. - The returned DeviceStatus is marked as done once the positioner has reached the target position. - DeviceStatus.wait() can be used to block until the move is completed. + The returned DeviceStatus is marked as done once the positioner has reached the target + position. DeviceStatus.wait() can be used to block until the move is completed. Args: position: position to move to diff --git a/ophyd_devices/sim/sim_camera.py b/ophyd_devices/sim/sim_camera.py index 361ce20..90022be 100644 --- a/ophyd_devices/sim/sim_camera.py +++ b/ophyd_devices/sim/sim_camera.py @@ -1,161 +1,23 @@ -import traceback -from threading import Thread +""" Simulated 2D camera device""" import numpy as np from bec_lib.logger import bec_logger from ophyd import Component as Cpt -from ophyd import DeviceStatus, Kind +from ophyd import Device, Kind, StatusBase -from ophyd_devices.interfaces.base_classes.psi_detector_base import ( - CustomDetectorMixin, - PSIDetectorBase, -) +from ophyd_devices.interfaces.base_classes.psi_device_base import PSIDeviceBase from ophyd_devices.sim.sim_data import SimulatedDataCamera from ophyd_devices.sim.sim_signals import ReadOnlySignal, SetableSignal from ophyd_devices.sim.sim_utils import H5Writer -from ophyd_devices.utils.errors import DeviceStopError logger = bec_logger.logger -class SimCameraSetup(CustomDetectorMixin): - """Mixin class for the SimCamera device.""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._thread_trigger = None - self._thread_complete = None - self.file_path = None - - def on_trigger(self) -> None: - """Trigger the camera to acquire images. - - This method can be called from BEC during a scan. It will acquire images and send them to BEC. - Whether the trigger is send from BEC is determined by the softwareTrigger argument in the device config. - - Here, we also run a callback on SUB_MONITOR to send the image data the device_monitor endpoint in BEC. - """ - status = DeviceStatus(self.parent) - - def on_trigger_call(status: DeviceStatus) -> None: - try: - for _ in range(self.parent.burst.get()): - data = self.parent.image.get() - # pylint: disable=protected-access - self.parent._run_subs(sub_type=self.parent.SUB_MONITOR, value=data) - if self.parent.stopped: - raise DeviceStopError(f"{self.parent.name} was stopped") - if self.parent.write_to_disk.get(): - self.parent.h5_writer.receive_data(data) - status.set_finished() - # pylint: disable=broad-except - except Exception as exc: - content = traceback.format_exc() - logger.warning( - f"Error in on_trigger_call in device {self.parent.name}. Error traceback: {content}" - ) - status.set_exception(exc) - - self._thread_trigger = Thread(target=on_trigger_call, args=(status,)) - self._thread_trigger.start() - return status - - def on_stage(self) -> None: - """Stage the camera for upcoming scan - - This method is called from BEC in preparation of a scan. - It receives metadata about the scan from BEC, - compiles it and prepares the camera for the scan. - - FYI: No data is written to disk in the simulation, but upon each trigger it - is published to the device_monitor endpoint in REDIS. - """ - self.file_path = self.parent.filewriter.compile_full_filename(f"{self.parent.name}") - - self.parent.frames.set( - self.parent.scaninfo.num_points * self.parent.scaninfo.frames_per_trigger - ) - self.parent.exp_time.set(self.parent.scaninfo.exp_time) - self.parent.burst.set(self.parent.scaninfo.frames_per_trigger) - if self.parent.write_to_disk.get(): - self.parent.h5_writer.on_stage(file_path=self.file_path, h5_entry="/entry/data/data") - self.parent._run_subs( - sub_type=self.parent.SUB_FILE_EVENT, - file_path=self.file_path, - done=False, - successful=False, - hinted_location={"data": "/entry/data/data"}, - ) - self.parent.stopped = False - - def on_complete(self) -> None: - """Complete the motion of the simulated device.""" - status = DeviceStatus(self.parent) - - def on_complete_call(status: DeviceStatus) -> None: - try: - if self.parent.write_to_disk.get(): - self.parent.h5_writer.on_complete() - self.parent._run_subs( - sub_type=self.parent.SUB_FILE_EVENT, - file_path=self.file_path, - done=True, - successful=True, - hinted_location={"data": "/entry/data/data"}, - ) - if self.parent.stopped: - raise DeviceStopError(f"{self.parent.name} was stopped") - status.set_finished() - # pylint: disable=broad-except - except Exception as exc: - content = traceback.format_exc() - logger.warning( - f"Error in on_complete call in device {self.parent.name}. Error traceback: {content}" - ) - status.set_exception(exc) - - self._thread_complete = Thread(target=on_complete_call, args=(status,), daemon=True) - self._thread_complete.start() - return status - - def on_unstage(self): - """Unstage the camera device.""" - if self.parent.write_to_disk.get(): - self.parent.h5_writer.on_unstage() - - def on_stop(self) -> None: - """Stop the camera acquisition.""" - if self._thread_trigger: - self._thread_trigger.join() - if self._thread_complete: - self._thread_complete.join() - self.on_unstage() - self._thread_trigger = None - self._thread_complete = None - - -class SimCamera(PSIDetectorBase): - """A simulated device mimic any 2D camera. - - It's image is a computed signal, which is configurable by the user and from the command line. - The corresponding simulation class is sim_cls=SimulatedDataCamera, more details on defaults within the simulation class. - - >>> camera = SimCamera(name="camera") - - Parameters - ---------- - name (string) : Name of the device. This is the only required argmuent, passed on to all signals of the device. - precision (integer) : Precision of the readback in digits, written to .describe(). Default is 3 digits. - sim_init (dict) : Dictionary to initiate parameters of the simulation, check simulation type defaults for more details. - parent : Parent device, optional, is used internally if this signal/device is part of a larger device. - kind : A member the Kind IntEnum (or equivalent integer), optional. Default is Kind.normal. See Kind for options. - device_manager : DeviceManager from BEC, optional . Within startup of simulation, device_manager is passed on automatically. - - """ +class SimCameraControl(Device): + """SimCamera Control layer""" USER_ACCESS = ["sim", "registered_proxies"] - custom_prepare_cls = SimCameraSetup sim_cls = SimulatedDataCamera SHAPE = (100, 100) BIT_DEPTH = np.uint16 @@ -178,16 +40,13 @@ class SimCamera(PSIDetectorBase): ) write_to_disk = Cpt(SetableSignal, name="write_to_disk", value=False, kind=Kind.config) - def __init__( - self, name, *, kind=None, parent=None, sim_init: dict = None, device_manager=None, **kwargs - ): + def __init__(self, name, *, parent=None, sim_init: dict = None, device_manager=None, **kwargs): self.sim_init = sim_init + self.device_manager = device_manager self._registered_proxies = {} self.sim = self.sim_cls(parent=self, **kwargs) self.h5_writer = H5Writer() - super().__init__( - name=name, parent=parent, kind=kind, device_manager=device_manager, **kwargs - ) + super().__init__(name=name, parent=parent, **kwargs) if self.sim_init: self.sim.set_init(self.sim_init) @@ -195,3 +54,104 @@ class SimCamera(PSIDetectorBase): def registered_proxies(self) -> None: """Dictionary of registered signal_names and proxies.""" return self._registered_proxies + + +class SimCamera(PSIDeviceBase, SimCameraControl): + """A simulated device mimic any 2D camera. + + It's image is a computed signal, which is configurable by the user and from the command line. + The corresponding simulation class is sim_cls=SimulatedDataCamera, more details on defaults within the simulation class. + + >>> camera = SimCamera(name="camera") + + Parameters + ---------- + name (string) : Name of the device. This is the only required argmuent, passed on to all signals of the device. + precision (integer) : Precision of the readback in digits, written to .describe(). Default is 3 digits. + sim_init (dict) : Dictionary to initiate parameters of the simulation, check simulation type defaults for more details. + parent : Parent device, optional, is used internally if this signal/device is part of a larger device. + kind : A member the Kind IntEnum (or equivalent integer), optional. Default is Kind.normal. See Kind for options. + + """ + + def __init__(self, name: str, scan_info=None, device_manager=None, **kwargs): + super().__init__(name=name, scan_info=scan_info, device_manager=device_manager, **kwargs) + self.file_path = None + + def on_trigger(self) -> StatusBase: + """Trigger the camera to acquire images. + + This method can be called from BEC during a scan. It will acquire images and send them to BEC. + Whether the trigger is send from BEC is determined by the softwareTrigger argument in the device config. + + Here, we also run a callback on SUB_MONITOR to send the image data the device_monitor endpoint in BEC. + """ + + def trigger_cam() -> None: + """Trigger the camera to acquire images.""" + for _ in range(self.burst.get()): + data = self.image.get() + # pylint: disable=protected-access + self._run_subs(sub_type=self.SUB_MONITOR, value=data) + if self.write_to_disk.get(): + self.h5_writer.receive_data(data) + + status = self.task_handler.submit_task(trigger_cam) + return status + + def on_stage(self) -> None: + """Stage the camera for upcoming scan + + This method is called from BEC in preparation of a scan. + It receives metadata about the scan from BEC, + compiles it and prepares the camera for the scan. + + FYI: No data is written to disk in the simulation, but upon each trigger it + is published to the device_monitor endpoint in REDIS. + """ + self.file_path = self.file_utils.get_file_path( + scan_status_msg=self.scan_info.msg, name=self.name + ) + self.frames.set( + self.scan_info.msg.num_points * self.scan_info.msg.scan_parameters["frames_per_trigger"] + ) + self.exp_time.set(self.scan_info.msg.scan_parameters["exp_time"]) + self.burst.set(self.scan_info.msg.scan_parameters["frames_per_trigger"]) + if self.write_to_disk.get(): + self.h5_writer.on_stage(file_path=self.file_path, h5_entry="/entry/data/data") + # pylint: disable=protected-access + self._run_subs( + sub_type=self.SUB_FILE_EVENT, + file_path=self.file_path, + done=False, + successful=False, + hinted_location={"data": "/entry/data/data"}, + ) + + def on_complete(self) -> StatusBase: + """Complete the motion of the simulated device.""" + + def complete_cam(): + """Complete the camera acquisition.""" + if self.write_to_disk.get(): + self.h5_writer.on_complete() + self._run_subs( + sub_type=self.SUB_FILE_EVENT, + file_path=self.file_path, + done=True, + successful=True, + hinted_location={"data": "/entry/data/data"}, + ) + + status = self.task_handler.submit_task(complete_cam) + return status + + def on_unstage(self) -> None: + """Unstage the camera device.""" + if self.write_to_disk.get(): + self.h5_writer.on_unstage() + + def on_stop(self) -> None: + """Stop the camera acquisition.""" + self.task_handler.shutdown() + self.on_unstage() diff --git a/ophyd_devices/sim/sim_data.py b/ophyd_devices/sim/sim_data.py index de0767d..9f25e99 100644 --- a/ophyd_devices/sim/sim_data.py +++ b/ophyd_devices/sim/sim_data.py @@ -106,7 +106,7 @@ class SimulatedDataBase(ABC): def execute_simulation_method(self, *args, method=None, signal_name: str = "", **kwargs) -> any: """ Execute either the provided method or reroutes the method execution - to a device proxy in case it is registered in self.parentregistered_proxies. + to a device proxy in case it is registered in self.parent.registered_proxies. """ if self.registered_proxies and self.parent.device_manager: for proxy_name, signal in self.registered_proxies.items(): diff --git a/ophyd_devices/sim/sim_monitor.py b/ophyd_devices/sim/sim_monitor.py index 2219cd9..d4bb760 100644 --- a/ophyd_devices/sim/sim_monitor.py +++ b/ophyd_devices/sim/sim_monitor.py @@ -1,22 +1,16 @@ """Module for simulated monitor devices.""" -import traceback -from threading import Thread - import numpy as np from bec_lib import messages from bec_lib.endpoints import MessageEndpoints from bec_lib.logger import bec_logger from ophyd import Component as Cpt -from ophyd import Device, DeviceStatus, Kind +from ophyd import Device, Kind, StatusBase -from ophyd_devices.interfaces.base_classes.psi_detector_base import ( - CustomDetectorMixin, - PSIDetectorBase, -) +from ophyd_devices.interfaces.base_classes.psi_device_base import PSIDeviceBase from ophyd_devices.sim.sim_data import SimulatedDataMonitor from ophyd_devices.sim.sim_signals import ReadOnlySignal, SetableSignal -from ophyd_devices.utils.errors import DeviceStopError +from ophyd_devices.utils import bec_utils logger = bec_logger.logger @@ -26,18 +20,25 @@ class SimMonitor(ReadOnlySignal): A simulated device mimic any 1D Axis (position, temperature, beam). It's readback is a computed signal, which is configurable by the user and from the command line. - The corresponding simulation class is sim_cls=SimulatedDataMonitor, more details on defaults within the simulation class. + The corresponding simulation class is sim_cls=SimulatedDataMonitor, more details on defaults + within the simulation class. >>> monitor = SimMonitor(name="monitor") Parameters ---------- - name (string) : Name of the device. This is the only required argmuent, passed on to all signals of the device. - precision (integer) : Precision of the readback in digits, written to .describe(). Default is 3 digits. - sim_init (dict) : Dictionary to initiate parameters of the simulation, check simulation type defaults for more details. - parent : Parent device, optional, is used internally if this signal/device is part of a larger device. - kind : A member the Kind IntEnum (or equivalent integer), optional. Default is Kind.normal. See Kind for options. - device_manager : DeviceManager from BEC, optional . Within startup of simulation, device_manager is passed on automatically. + name (string) : Name of the device. This is the only required argmuent, + passed on to all signals of the device. + precision (integer) : Precision of the readback in digits, written to .describe(). + Default is 3 digits. + sim_init (dict) : Dictionary to initiate parameters of the simulation, + check simulation type defaults for more details. + parent : Parent device, optional, is used internally if this + signal/device is part of a larger device. + kind : A member the Kind IntEnum (or equivalent integer), optional. + Default is Kind.normal. See Kind for options. + device_manager : DeviceManager from BEC, optional . Within startup of simulation, + device_manager is passed on automatically. """ @@ -81,135 +82,11 @@ class SimMonitor(ReadOnlySignal): return self._registered_proxies -class SimMonitorAsyncPrepare(CustomDetectorMixin): - """Custom prepare for the SimMonitorAsync class.""" - - def __init__(self, *args, parent: Device = None, **kwargs) -> None: - super().__init__(*args, parent=parent, **kwargs) - self._stream_ttl = 1800 - self._random_send_interval = None - self._counter = 0 - self._thread_trigger = None - self._thread_complete = None - self.prep_random_interval() - self.parent.current_trigger.subscribe(self._progress_update, run=False) - - def clear_buffer(self): - """Clear the data buffer.""" - self.parent.data_buffer["value"].clear() - self.parent.data_buffer["timestamp"].clear() - - def prep_random_interval(self): - """Prepare counter and random interval to send data to BEC.""" - self._random_send_interval = np.random.randint(1, 10) - self.parent.current_trigger.set(0).wait() - self._counter = self.parent.current_trigger.get() - - def on_stage(self): - """Prepare the device for staging.""" - self.clear_buffer() - self.prep_random_interval() - - def on_complete(self): - """Prepare the device for completion.""" - status = DeviceStatus(self.parent) - - def on_complete_call(status: DeviceStatus) -> None: - try: - if self.parent.data_buffer["value"]: - self._send_data_to_bec() - if self.parent.stopped: - raise DeviceStopError(f"{self.parent.name} was stopped") - status.set_finished() - # pylint: disable=broad-except - except Exception as exc: - content = traceback.format_exc() - status.set_exception(exc=exc) - logger.warning(f"Error in {self.parent.name} on_complete; Traceback: {content}") - - self._thread_complete = Thread(target=on_complete_call, args=(status,)) - self._thread_complete.start() - return status - - def _send_data_to_bec(self) -> None: - """Sends bundled data to BEC""" - if self.parent.scaninfo.scan_msg is None: - return - metadata = self.parent.scaninfo.scan_msg.metadata - metadata.update({"async_update": self.parent.async_update.get()}) - - msg = messages.DeviceMessage( - signals={self.parent.readback.name: self.parent.data_buffer}, - metadata=self.parent.scaninfo.scan_msg.metadata, - ) - self.parent.connector.xadd( - MessageEndpoints.device_async_readback( - scan_id=self.parent.scaninfo.scan_id, device=self.parent.name - ), - {"data": msg}, - expire=self._stream_ttl, - ) - self.clear_buffer() - - def on_trigger(self): - """Prepare the device for triggering.""" - status = DeviceStatus(self.parent) - - def on_trigger_call(status: DeviceStatus) -> None: - try: - self.parent.data_buffer["value"].append(self.parent.readback.get()) - self.parent.data_buffer["timestamp"].append(self.parent.readback.timestamp) - self._counter += 1 - self.parent.current_trigger.set(self._counter).wait() - if self._counter % self._random_send_interval == 0: - self._send_data_to_bec() - if self.parent.stopped: - raise DeviceStopError(f"{self.parent.name} was stopped") - status.set_finished() - # pylint: disable=broad-except - except Exception as exc: - content = traceback.format_exc() - logger.warning( - f"Error in on_trigger_call in device {self.parent.name}; Traceback: {content}" - ) - status.set_exception(exc=exc) - - self._thread_trigger = Thread(target=on_trigger_call, args=(status,)) - self._thread_trigger.start() - return status - - def _progress_update(self, value: int, **kwargs): - """Update the progress of the device.""" - max_value = self.parent.scaninfo.num_points - # pylint: disable=protected-access - self.parent._run_subs( - sub_type=self.parent.SUB_PROGRESS, - value=value, - max_value=max_value, - done=bool(max_value == value), - ) - - def on_stop(self): - """Stop the device.""" - if self._thread_trigger: - self._thread_trigger.join() - if self._thread_complete: - self._thread_complete.join() - self._thread_trigger = None - self._thread_complete = None - - -class SimMonitorAsync(PSIDetectorBase): - """ - A simulated device to mimic the behaviour of an asynchronous monitor. - - During a scan, this device will send data not in sync with the point ID to BEC, - but buffer data and send it in random intervals.s - """ +class SimMonitorAsyncControl(Device): + """SimMonitor Sync Control Device""" USER_ACCESS = ["sim", "registered_proxies", "async_update"] - custom_prepare_cls = SimMonitorAsyncPrepare sim_cls = SimulatedDataMonitor BIT_DEPTH = np.uint32 @@ -221,17 +98,17 @@ class SimMonitorAsync(PSIDetectorBase): SUB_PROGRESS = "progress" _default_sub = SUB_READBACK - def __init__( - self, name, *, sim_init: dict = None, parent=None, kind=None, device_manager=None, **kwargs - ): + def __init__(self, name, *, sim_init: dict = None, parent=None, device_manager=None, **kwargs): + if device_manager: + self.device_manager = device_manager + else: + self.device_manager = bec_utils.DMMock() + self.connector = self.device_manager.connector self.sim_init = sim_init - self.device_manager = device_manager self.sim = self.sim_cls(parent=self, **kwargs) self._registered_proxies = {} - super().__init__( - name=name, parent=parent, kind=kind, device_manager=device_manager, **kwargs - ) + super().__init__(name=name, parent=parent, **kwargs) self.sim.sim_state[self.name] = self.sim.sim_state.pop(self.readback.name, None) self.readback.name = self.name self._data_buffer = {"value": [], "timestamp": []} @@ -247,3 +124,101 @@ class SimMonitorAsync(PSIDetectorBase): def registered_proxies(self) -> None: """Dictionary of registered signal_names and proxies.""" return self._registered_proxies + + +class SimMonitorAsync(PSIDeviceBase, SimMonitorAsyncControl): + """ + A simulated device to mimic the behaviour of an asynchronous monitor. + + During a scan, this device will send data not in sync with the point ID to BEC, + but buffer data and send it in random intervals.s + """ + + def __init__( + self, name: str, scan_info=None, parent: Device = None, device_manager=None, **kwargs + ) -> None: + super().__init__( + name=name, scan_info=scan_info, parent=parent, device_manager=device_manager, **kwargs + ) + self._stream_ttl = 1800 + self._random_send_interval = None + self._counter = 0 + self.prep_random_interval() + + def on_connected(self): + self.current_trigger.subscribe(self._progress_update, run=False) + + def clear_buffer(self): + """Clear the data buffer.""" + self.data_buffer["value"].clear() + self.data_buffer["timestamp"].clear() + + def prep_random_interval(self): + """Prepare counter and random interval to send data to BEC.""" + self._random_send_interval = np.random.randint(1, 10) + self.current_trigger.set(0).wait() + self._counter = self.current_trigger.get() + + def on_stage(self): + """Prepare the device for staging.""" + self.clear_buffer() + self.prep_random_interval() + + def on_complete(self) -> StatusBase: + """Prepare the device for completion.""" + + def complete_action(): + if self.data_buffer["value"]: + self._send_data_to_bec() + + status = self.task_handler.submit_task(complete_action) + return status + + def _send_data_to_bec(self) -> None: + """Sends bundled data to BEC""" + if self.scan_info.msg is None: + return + metadata = self.scan_info.msg.metadata + metadata.update({"async_update": self.async_update.get()}) + + msg = messages.DeviceMessage( + signals={self.readback.name: self.data_buffer}, metadata=self.scan_info.msg.metadata + ) + self.connector.xadd( + MessageEndpoints.device_async_readback( + scan_id=self.scan_info.msg.scan_id, device=self.name + ), + {"data": msg}, + expire=self._stream_ttl, + ) + self.clear_buffer() + + def on_trigger(self): + """Prepare the device for triggering.""" + + def trigger_action(): + """Trigger actions""" + self.data_buffer["value"].append(self.readback.get()) + self.data_buffer["timestamp"].append(self.readback.timestamp) + self._counter += 1 + self.current_trigger.set(self._counter).wait() + if self._counter % self._random_send_interval == 0: + self._send_data_to_bec() + + status = self.task_handler.submit_task(trigger_action) + return status + + def _progress_update(self, value: int, **kwargs): + """Update the progress of the device.""" + max_value = self.scan_info.msg.num_points + # pylint: disable=protected-access + self._run_subs( + sub_type=self.SUB_PROGRESS, + value=value, + max_value=max_value, + done=bool(max_value == value), + ) + + def on_stop(self): + """Stop the device.""" + self.task_handler.shutdown() diff --git a/ophyd_devices/sim/sim_waveform.py b/ophyd_devices/sim/sim_waveform.py index 2df832a..aa43b30 100644 --- a/ophyd_devices/sim/sim_waveform.py +++ b/ophyd_devices/sim/sim_waveform.py @@ -15,7 +15,6 @@ from ophyd import Device, DeviceStatus, Kind, Staged from ophyd_devices.sim.sim_data import SimulatedDataWaveform from ophyd_devices.sim.sim_signals import ReadOnlySignal, SetableSignal from ophyd_devices.utils import bec_utils -from ophyd_devices.utils.bec_scaninfo_mixin import BecScaninfoMixin from ophyd_devices.utils.errors import DeviceStopError logger = bec_logger.logger @@ -67,7 +66,15 @@ class SimWaveform(Device): async_update = Cpt(SetableSignal, value="append", kind=Kind.config) def __init__( - self, name, *, kind=None, parent=None, sim_init: dict = None, device_manager=None, **kwargs + self, + name, + *, + kind=None, + parent=None, + sim_init: dict = None, + device_manager=None, + scan_info=None, + **kwargs, ): self.sim_init = sim_init self._registered_proxies = {} @@ -83,9 +90,8 @@ class SimWaveform(Device): self._stream_ttl = 1800 # 30 min max self.stopped = False self._staged = Staged.no - self.scaninfo = None self._trigger_thread = None - self._update_scaninfo() + self.scan_info = scan_info if self.sim_init: self.sim.set_init(self.sim_init) @@ -124,7 +130,7 @@ class SimWaveform(Device): def _send_async_update(self): """Send the async update to BEC.""" - metadata = self.scaninfo.scan_msg.metadata + metadata = self.scan_info.msg.metadata async_update_type = self.async_update.get() if async_update_type not in ["extend", "append"]: raise ValueError(f"Invalid async_update type: {async_update_type}") @@ -134,19 +140,15 @@ class SimWaveform(Device): signals={self.waveform.name: {"value": self.waveform.get(), "timestamp": time.time()}}, metadata=metadata, ) - # logger.warning(f"Adding async update to {self.name} and {self.scaninfo.scan_id}") + # logger.warning(f"Adding async update to {self.name} and {self.scan_info.msg.scan_id}") self.connector.xadd( - MessageEndpoints.device_async_readback(scan_id=self.scaninfo.scan_id, device=self.name), + MessageEndpoints.device_async_readback( + scan_id=self.scan_info.msg.scan_id, device=self.name + ), {"data": msg}, expire=self._stream_ttl, ) - def _update_scaninfo(self) -> None: - """Update scaninfo from BecScaninfoMixing - This depends on device manager and operation/sim_mode - """ - self.scaninfo = BecScaninfoMixin(self.device_manager) - def stage(self) -> list[object]: """Stage the camera for upcoming scan @@ -160,17 +162,18 @@ class SimWaveform(Device): if self._staged is Staged.yes: return super().stage() - self.scaninfo.load_scan_metadata() self.file_path.set( os.path.join( - self.file_path.get(), self.file_pattern.get().format(self.scaninfo.scan_number) + self.file_path.get(), self.file_pattern.get().format(self.scan_info.msg.scan_number) ) ) - self.frames.set(self.scaninfo.num_points * self.scaninfo.frames_per_trigger) - self.exp_time.set(self.scaninfo.exp_time) - self.burst.set(self.scaninfo.frames_per_trigger) + self.frames.set( + self.scan_info.msg.num_points * self.scan_info.msg.scan_parameters["frames_per_trigger"] + ) + self.exp_time.set(self.scan_info.msg.scan_parameters["exp_time"]) + self.burst.set(self.scan_info.msg.scan_parameters["frames_per_trigger"]) self.stopped = False - logger.warning(f"Staged {self.name}, scan_id : {self.scaninfo.scan_id}") + logger.warning(f"Staged {self.name}, scan_id : {self.scan_info.msg.scan_id}") return super().stage() def unstage(self) -> list[object]: diff --git a/ophyd_devices/tests/utils.py b/ophyd_devices/tests/utils.py index ff1502c..d7ec12a 100644 --- a/ophyd_devices/tests/utils.py +++ b/ophyd_devices/tests/utils.py @@ -1,5 +1,21 @@ +""" Utilities to mock and test devices.""" + +from dataclasses import dataclass +from typing import TYPE_CHECKING from unittest import mock +from bec_lib.devicemanager import ScanInfo +from bec_lib.logger import bec_logger +from bec_lib.utils.import_utils import lazy_import_from + +if TYPE_CHECKING: + from bec_lib.messages import ScanStatusMessage +else: + # TODO: put back normal import when Pydantic gets faster + ScanStatusMessage = lazy_import_from("bec_lib.messages", ("ScanStatusMessage",)) + +logger = bec_logger.logger + def patch_dual_pvs(device): """Patch dual PVs""" @@ -255,3 +271,128 @@ class MockPV: use_monitor=use_monitor, ) return data["value"] if data is not None else None + + +def get_mock_scan_info(): + """ + Get a mock scan info object. + """ + return ScanInfo(msg=fake_scan_status_msg()) + + +def fake_scan_status_msg(): + """ + Create a fake scan status message. + """ + logger.warning( + ("Device is not connected to a Redis server. Fetching mocked ScanStatusMessage.") + ) + return ScanStatusMessage( + metadata={}, + scan_id="mock_scan_id", + status="closed", + scan_number=0, + session_id=None, + num_points=11, + scan_name="mock_line_scan", + scan_type="step", + dataset_number=0, + scan_report_devices=["samx"], + user_metadata={}, + readout_priority={ + "monitored": ["bpm4a", "samx"], + "baseline": ["eyex"], + "async": ["waveform"], + "continuous": [], + "on_request": ["flyer_sim"], + }, + scan_parameters={ + "exp_time": 0, + "frames_per_trigger": 1, + "settling_time": 0, + "readout_time": 0, + "optim_trajectory": None, + "return_to_start": True, + "relative": True, + "system_config": {"file_suffix": None, "file_directory": None}, + }, + request_inputs={ + "arg_bundle": ["samx", -10, 10], + "inputs": {}, + "kwargs": { + "steps": 11, + "relative": True, + "system_config": {"file_suffix": None, "file_directory": None}, + }, + }, + info={ + "readout_priority": { + "monitored": ["bpm4a", "samx"], + "baseline": ["eyex"], + "async": ["waveform"], + "continuous": [], + "on_request": ["flyer_sim"], + }, + "file_suffix": None, + "file_directory": None, + "user_metadata": {}, + "RID": "a1d86f61-191c-4460-bcd6-f33c61b395ea", + "scan_id": "3edb8219-75a7-4791-8f86-d5ca112b771a", + "queue_id": "0f3639ee-899f-4ad1-9e71-f40514c937ef", + "scan_motors": ["samx"], + "num_points": 11, + "positions": [ + [-10.0], + [-8.0], + [-6.0], + [-4.0], + [-2.0], + [0.0], + [2.0], + [4.0], + [6.0], + [8.0], + [10.0], + ], + "file_path": "./data/test_file", + "scan_name": "mock_line_scan", + "scan_type": "step", + "scan_number": 0, + "dataset_number": 0, + "exp_time": 0, + "frames_per_trigger": 1, + "settling_time": 0, + "readout_time": 0, + "scan_report_devices": ["samx"], + "monitor_sync": "bec", + "scan_parameters": { + "exp_time": 0, + "frames_per_trigger": 1, + "settling_time": 0, + "readout_time": 0, + "optim_trajectory": None, + "return_to_start": True, + "relative": True, + "system_config": {"file_suffix": None, "file_directory": None}, + }, + "request_inputs": { + "arg_bundle": ["samx", -10, 10], + "inputs": {}, + "kwargs": { + "steps": 11, + "relative": True, + "system_config": {"file_suffix": None, "file_directory": None}, + }, + }, + "scan_msgs": [ + "metadata={'file_suffix': None, 'file_directory': None, 'user_metadata': {}, 'RID': 'a1d86f61-191c-4460-bcd6-f33c61b395ea'} scan_type='mock_line_scan' parameter={'args': {'samx': [-10, 10]}, 'kwargs': {'steps': 11, 'relative': True, 'system_config': {'file_suffix': None, 'file_directory': None}}} queue='primary'" + ], + "args": {"samx": [-10, 10]}, + "kwargs": { + "steps": 11, + "relative": True, + "system_config": {"file_suffix": None, "file_directory": None}, + }, + }, + timestamp=1737100681.694211, + ) diff --git a/ophyd_devices/utils/bec_utils.py b/ophyd_devices/utils/bec_utils.py index 825a512..3319b1d 100644 --- a/ophyd_devices/utils/bec_utils.py +++ b/ophyd_devices/utils/bec_utils.py @@ -1,3 +1,5 @@ +""" Utility class linked to BEC""" + import time from bec_lib import bec_logger @@ -11,8 +13,9 @@ logger = bec_logger.logger DEFAULT_EPICSSIGNAL_VALUE = object() -# TODO maybe specify here that this DeviceMock is for usage in the DeviceServer class DeviceMock: + """Mock for Device""" + def __init__(self, name: str, value: float = 0.0): self.name = name self.read_buffer = value @@ -21,13 +24,16 @@ class DeviceMock: self._enabled = True def read(self): + """Return the current value of the device""" return {self.name: {"value": self.read_buffer}} def readback(self): + """Return the current value of the device""" return self.read_buffer @property def read_only(self) -> bool: + """Get the read only status of the device""" return self._read_only @read_only.setter @@ -36,6 +42,7 @@ class DeviceMock: @property def enabled(self) -> bool: + """Get the enabled status of the device""" return self._enabled @enabled.setter @@ -44,10 +51,12 @@ class DeviceMock: @property def user_parameter(self): + """Get the user parameter of the device""" return self._config["userParameter"] @property def obj(self): + """Get the device object""" return self diff --git a/ophyd_devices/utils/psi_device_base_utils.py b/ophyd_devices/utils/psi_device_base_utils.py new file mode 100644 index 0000000..437a7a5 --- /dev/null +++ b/ophyd_devices/utils/psi_device_base_utils.py @@ -0,0 +1,193 @@ +""" Utility handler to run tasks (function, conditions) in an asynchronous fashion.""" + +import ctypes +import threading +import traceback +import uuid +from enum import Enum +from typing import TYPE_CHECKING + +from bec_lib.file_utils import get_full_file_path +from bec_lib.logger import bec_logger +from bec_lib.utils.import_utils import lazy_import_from +from ophyd import Device, DeviceStatus, StatusBase + +if TYPE_CHECKING: + from bec_lib.messages import ScanStatusMessage +else: + # TODO: put back normal import when Pydantic gets faster + ScanStatusMessage = lazy_import_from("bec_lib.messages", ("ScanStatusMessage",)) + + +logger = bec_logger.logger + +set_async_exc = ctypes.pythonapi.PyThreadState_SetAsyncExc + + +class TaskState(str, Enum): + """Possible task states""" + + NOT_STARTED = "not_started" + RUNNING = "running" + TIMEOUT = "timeout" + ERROR = "error" + COMPLETED = "completed" + KILLED = "killed" + + +class TaskKilledError(Exception): + """Exception raised when a task thread is killed""" + + +class TaskStatus(DeviceStatus): + """Thin wrapper around StatusBase to add information about tasks""" + + def __init__(self, device: Device, *, timeout=None, settle_time=0, done=None, success=None): + super().__init__( + device=device, timeout=timeout, settle_time=settle_time, done=done, success=success + ) + self._state = TaskState.NOT_STARTED + self._task_id = str(uuid.uuid4()) + + @property + def exc(self) -> Exception: + """Get the exception of the task""" + return self.exception() + + @property + def state(self) -> str: + """Get the state of the task""" + return self._state.value + + @state.setter + def state(self, value: TaskState): + self._state = value + + @property + def task_id(self) -> bool: + """Get the task ID""" + return self._task_id + + +class TaskHandler: + """Handler to manage asynchronous tasks""" + + def __init__(self, parent: Device): + """Initialize the handler""" + self._tasks = {} + self._parent = parent + + def submit_task(self, task: callable, run: bool = True) -> TaskStatus: + """Submit a task to the task handler. + + Args: + task: The task to run. + run: Whether to run the task immediately. + """ + task_status = TaskStatus(device=self._parent) + thread = threading.Thread( + target=self._wrap_task, + args=(task, task_status), + name=f"task {task_status.task_id}", + daemon=True, + ) + self._tasks.update({task_status.task_id: (task_status, thread)}) + if run is True: + self.start_task(task_status) + return task_status + + def start_task(self, task_status: TaskStatus) -> None: + """Start a task, + + Args: + task_status: The task status object. + """ + thread = self._tasks[task_status.task_id][1] + if thread.is_alive(): + logger.warning(f"Task with ID {task_status.task_id} is already running.") + return + thread.start() + task_status.state = TaskState.RUNNING + + def _wrap_task(self, task: callable, task_status: TaskStatus): + """Wrap the task in a function""" + try: + task() + except TimeoutError as exc: + content = traceback.format_exc() + logger.warning( + ( + f"Timeout Exception in task handler for task {task_status.task_id}," + f" Traceback: {content}" + ) + ) + task_status.set_exception(exc) + task_status.state = TaskState.TIMEOUT + except TaskKilledError as exc: + exc = exc.__class__( + f"Task {task_status.task_id} was killed. ThreadID:" + f" {self._tasks[task_status.task_id][1].ident}" + ) + content = traceback.format_exc() + logger.warning( + ( + f"TaskKilled Exception in task handler for task {task_status.task_id}," + f" Traceback: {content}" + ) + ) + task_status.set_exception(exc) + task_status.state = TaskState.KILLED + except Exception as exc: # pylint: disable=broad-except + content = traceback.format_exc() + logger.warning( + f"Exception in task handler for task {task_status.task_id}, Traceback: {content}" + ) + task_status.set_exception(exc) + task_status.state = TaskState.ERROR + else: + task_status.set_finished() + task_status.state = TaskState.COMPLETED + finally: + self._tasks.pop(task_status.task_id) + + def kill_task(self, task_status: TaskStatus) -> None: + """Kill the thread + + task_status: The task status object. + """ + thread = self._tasks[task_status.task_id][1] + exception_cls = TaskKilledError + + ident = ctypes.c_long(thread.ident) + exc = ctypes.py_object(exception_cls) + try: + res = set_async_exc(ident, exc) + if res == 0: + raise ValueError("Invalid thread ID") + elif res > 1: + set_async_exc(ident, None) + logger.warning(f"Exception raise while kille Thread {ident}; return value: {res}") + except Exception as e: # pylint: disable=broad-except + logger.warning(f"Exception raised while killing thread {ident}: {e}") + + def shutdown(self): + """Shutdown all tasks of task handler""" + for info in self._tasks.values(): + self.kill_task(info[0]) + self._tasks.clear() + + +class FileHandler: + """Utility class for file operations.""" + + def get_file_path( + self, scan_status_msg: ScanStatusMessage, name: str, create_dir: bool = True + ) -> str: + """Get the file path. + + Args: + scan_info_msg: The scan info message. + name: The name of the file. + create_dir: Whether to create the directory. + """ + return get_full_file_path(scan_status_msg=scan_status_msg, name=name, create_dir=create_dir) diff --git a/tests/test_base_classes.py b/tests/test_base_classes.py index 52ac6d3..02726bd 100644 --- a/tests/test_base_classes.py +++ b/tests/test_base_classes.py @@ -6,200 +6,215 @@ import pytest from ophyd import DeviceStatus, Staged from ophyd.utils.errors import RedundantStaging -from ophyd_devices.interfaces.base_classes.bec_device_base import BECDeviceBase, CustomPrepare -from ophyd_devices.utils.bec_scaninfo_mixin import BecScaninfoMixin +from ophyd_devices.interfaces.base_classes.psi_device_base import PSIDeviceBase from ophyd_devices.utils.errors import DeviceStopError, DeviceTimeoutError @pytest.fixture def detector_base(): - yield BECDeviceBase(name="test_detector") + yield PSIDeviceBase(name="test_detector") def test_detector_base_init(detector_base): assert detector_base.stopped is False assert detector_base.name == "test_detector" - assert "base_path" in detector_base.filewriter.service_config - assert isinstance(detector_base.scaninfo, BecScaninfoMixin) - assert issubclass(detector_base.custom_prepare_cls, CustomPrepare) + assert detector_base.staged == Staged.no + assert detector_base.destroyed == False def test_stage(detector_base): - detector_base._staged = Staged.yes - with pytest.raises(RedundantStaging): - detector_base.stage() + assert detector_base._staged == Staged.no assert detector_base.stopped is False detector_base._staged = Staged.no - with ( - mock.patch.object(detector_base.custom_prepare, "on_stage") as mock_on_stage, - mock.patch.object(detector_base.scaninfo, "load_scan_metadata") as mock_load_metadata, - ): + with mock.patch.object(detector_base, "on_stage") as mock_on_stage: rtr = detector_base.stage() assert isinstance(rtr, list) - mock_on_stage.assert_called_once() - mock_load_metadata.assert_called_once() + assert mock_on_stage.called is True + with pytest.raises(RedundantStaging): + detector_base.stage() + detector_base._staged = Staged.no + detector_base.stopped = True + detector_base.stage() assert detector_base.stopped is False + assert mock_on_stage.call_count == 2 -def test_pre_scan(detector_base): - with mock.patch.object(detector_base.custom_prepare, "on_pre_scan") as mock_on_pre_scan: - detector_base.pre_scan() - mock_on_pre_scan.assert_called_once() +# def test_stage(detector_base): +# detector_base._staged = Staged.yes +# with pytest.raises(RedundantStaging): +# detector_base.stage() +# assert detector_base.stopped is False +# detector_base._staged = Staged.no +# with ( +# mock.patch.object(detector_base.custom_prepare, "on_stage") as mock_on_stage, +# mock.patch.object(detector_base.scaninfo, "load_scan_metadata") as mock_load_metadata, +# ): +# rtr = detector_base.stage() +# assert isinstance(rtr, list) +# mock_on_stage.assert_called_once() +# mock_load_metadata.assert_called_once() +# assert detector_base.stopped is False -def test_trigger(detector_base): - status = DeviceStatus(detector_base) - with mock.patch.object( - detector_base.custom_prepare, "on_trigger", side_effect=[None, status] - ) as mock_on_trigger: - st = detector_base.trigger() - assert isinstance(st, DeviceStatus) - time.sleep(0.1) - assert st.done is True - st = detector_base.trigger() - assert st.done is False - assert id(st) == id(status) +# def test_pre_scan(detector_base): +# with mock.patch.object(detector_base.custom_prepare, "on_pre_scan") as mock_on_pre_scan: +# detector_base.pre_scan() +# mock_on_pre_scan.assert_called_once() -def test_unstage(detector_base): - detector_base.stopped = True - with ( - mock.patch.object(detector_base.custom_prepare, "on_unstage") as mock_on_unstage, - mock.patch.object(detector_base, "check_scan_id") as mock_check_scan_id, - ): - rtr = detector_base.unstage() - assert isinstance(rtr, list) - assert mock_check_scan_id.call_count == 1 - assert mock_on_unstage.call_count == 1 - detector_base.stopped = False - rtr = detector_base.unstage() - assert isinstance(rtr, list) - assert mock_check_scan_id.call_count == 2 - assert mock_on_unstage.call_count == 2 +# def test_trigger(detector_base): +# status = DeviceStatus(detector_base) +# with mock.patch.object( +# detector_base.custom_prepare, "on_trigger", side_effect=[None, status] +# ) as mock_on_trigger: +# st = detector_base.trigger() +# assert isinstance(st, DeviceStatus) +# time.sleep(0.1) +# assert st.done is True +# st = detector_base.trigger() +# assert st.done is False +# assert id(st) == id(status) -def test_complete(detector_base): - status = DeviceStatus(detector_base) - with mock.patch.object( - detector_base.custom_prepare, "on_complete", side_effect=[None, status] - ) as mock_on_complete: - st = detector_base.complete() - assert isinstance(st, DeviceStatus) - time.sleep(0.1) - assert st.done is True - st = detector_base.complete() - assert st.done is False - assert id(st) == id(status) +# def test_unstage(detector_base): +# detector_base.stopped = True +# with ( +# mock.patch.object(detector_base.custom_prepare, "on_unstage") as mock_on_unstage, +# mock.patch.object(detector_base, "check_scan_id") as mock_check_scan_id, +# ): +# rtr = detector_base.unstage() +# assert isinstance(rtr, list) +# assert mock_check_scan_id.call_count == 1 +# assert mock_on_unstage.call_count == 1 +# detector_base.stopped = False +# rtr = detector_base.unstage() +# assert isinstance(rtr, list) +# assert mock_check_scan_id.call_count == 2 +# assert mock_on_unstage.call_count == 2 -def test_stop(detector_base): - with mock.patch.object(detector_base.custom_prepare, "on_stop") as mock_on_stop: - detector_base.stop() - mock_on_stop.assert_called_once() - assert detector_base.stopped is True +# def test_complete(detector_base): +# status = DeviceStatus(detector_base) +# with mock.patch.object( +# detector_base.custom_prepare, "on_complete", side_effect=[None, status] +# ) as mock_on_complete: +# st = detector_base.complete() +# assert isinstance(st, DeviceStatus) +# time.sleep(0.1) +# assert st.done is True +# st = detector_base.complete() +# assert st.done is False +# assert id(st) == id(status) -def test_check_scan_id(detector_base): - detector_base.scaninfo.scan_id = "abcde" - detector_base.stopped = False - detector_base.check_scan_id() - assert detector_base.stopped is True - detector_base.stopped = False - detector_base.check_scan_id() - assert detector_base.stopped is False +# def test_stop(detector_base): +# with mock.patch.object(detector_base.custom_prepare, "on_stop") as mock_on_stop: +# detector_base.stop() +# mock_on_stop.assert_called_once() +# assert detector_base.stopped is True -def test_wait_for_signal(detector_base): - my_value = False - - def my_callback(): - return my_value - - detector_base - status = detector_base.custom_prepare.wait_with_status( - [(my_callback, True)], - check_stopped=True, - timeout=5, - interval=0.01, - exception_on_timeout=None, - ) - time.sleep(0.1) - assert status.done is False - # Check first that it is stopped when detector_base.stop() is called - detector_base.stop() - # some delay to allow the stop to take effect - time.sleep(0.15) - assert status.done is True - assert status.exception().args == DeviceStopError(f"{detector_base.name} was stopped").args - detector_base.stopped = False - status = detector_base.custom_prepare.wait_with_status( - [(my_callback, True)], - check_stopped=True, - timeout=5, - interval=0.01, - exception_on_timeout=None, - ) - # Check that thread resolves when expected value is set - my_value = True - # some delay to allow the stop to take effect - time.sleep(0.15) - assert status.done is True - assert status.success is True - assert status.exception() is None - - detector_base.stopped = False - # Check that wait for status runs into timeout with expectd exception - my_value = "random_value" - exception = TimeoutError("Timeout") - status = detector_base.custom_prepare.wait_with_status( - [(my_callback, True)], - check_stopped=True, - timeout=0.01, - interval=0.01, - exception_on_timeout=exception, - ) - time.sleep(0.2) - assert status.done is True - assert id(status.exception()) == id(exception) - assert status.success is False +# def test_check_scan_id(detector_base): +# detector_base.scaninfo.scan_id = "abcde" +# detector_base.stopped = False +# detector_base.check_scan_id() +# assert detector_base.stopped is True +# detector_base.stopped = False +# detector_base.check_scan_id() +# assert detector_base.stopped is False -def test_wait_for_signal_returns_exception(detector_base): - my_value = False +# def test_wait_for_signal(detector_base): +# my_value = False - def my_callback(): - return my_value +# def my_callback(): +# return my_value - # Check that wait for status runs into timeout with expectd exception +# detector_base +# status = detector_base.custom_prepare.wait_with_status( +# [(my_callback, True)], +# check_stopped=True, +# timeout=5, +# interval=0.01, +# exception_on_timeout=None, +# ) +# time.sleep(0.1) +# assert status.done is False +# # Check first that it is stopped when detector_base.stop() is called +# detector_base.stop() +# # some delay to allow the stop to take effect +# time.sleep(0.15) +# assert status.done is True +# assert status.exception().args == DeviceStopError(f"{detector_base.name} was stopped").args +# detector_base.stopped = False +# status = detector_base.custom_prepare.wait_with_status( +# [(my_callback, True)], +# check_stopped=True, +# timeout=5, +# interval=0.01, +# exception_on_timeout=None, +# ) +# # Check that thread resolves when expected value is set +# my_value = True +# # some delay to allow the stop to take effect +# time.sleep(0.15) +# assert status.done is True +# assert status.success is True +# assert status.exception() is None - exception = TimeoutError("Timeout") - status = detector_base.custom_prepare.wait_with_status( - [(my_callback, True)], - check_stopped=True, - timeout=0.01, - interval=0.01, - exception_on_timeout=exception, - ) - time.sleep(0.2) - assert status.done is True - assert id(status.exception()) == id(exception) - assert status.success is False +# detector_base.stopped = False +# # Check that wait for status runs into timeout with expectd exception +# my_value = "random_value" +# exception = TimeoutError("Timeout") +# status = detector_base.custom_prepare.wait_with_status( +# [(my_callback, True)], +# check_stopped=True, +# timeout=0.01, +# interval=0.01, +# exception_on_timeout=exception, +# ) +# time.sleep(0.2) +# assert status.done is True +# assert id(status.exception()) == id(exception) +# assert status.success is False - detector_base.stopped = False - # Check that standard exception is thrown - status = detector_base.custom_prepare.wait_with_status( - [(my_callback, True)], - check_stopped=True, - timeout=0.01, - interval=0.01, - exception_on_timeout=None, - ) - time.sleep(0.2) - assert status.done is True - assert ( - status.exception().args - == DeviceTimeoutError( - f"Timeout error for {detector_base.name} while waiting for signals {[(my_callback, True)]}" - ).args - ) - assert status.success is False + +# def test_wait_for_signal_returns_exception(detector_base): +# my_value = False + +# def my_callback(): +# return my_value + +# # Check that wait for status runs into timeout with expectd exception + +# exception = TimeoutError("Timeout") +# status = detector_base.custom_prepare.wait_with_status( +# [(my_callback, True)], +# check_stopped=True, +# timeout=0.01, +# interval=0.01, +# exception_on_timeout=exception, +# ) +# time.sleep(0.2) +# assert status.done is True +# assert id(status.exception()) == id(exception) +# assert status.success is False + +# detector_base.stopped = False +# # Check that standard exception is thrown +# status = detector_base.custom_prepare.wait_with_status( +# [(my_callback, True)], +# check_stopped=True, +# timeout=0.01, +# interval=0.01, +# exception_on_timeout=None, +# ) +# time.sleep(0.2) +# assert status.done is True +# assert ( +# status.exception().args +# == DeviceTimeoutError( +# f"Timeout error for {detector_base.name} while waiting for signals {[(my_callback, True)]}" +# ).args +# ) +# assert status.success is False diff --git a/tests/test_simulation.py b/tests/test_simulation.py index b8d84e4..589701f 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -17,7 +17,6 @@ from ophyd import Device, Signal from ophyd.status import wait as status_wait from ophyd_devices.interfaces.protocols.bec_protocols import ( - BECBaseProtocol, BECDeviceProtocol, BECFlyerProtocol, BECPositionerProtocol, @@ -31,6 +30,7 @@ from ophyd_devices.sim.sim_positioner import SimLinearTrajectoryPositioner, SimP from ophyd_devices.sim.sim_signals import ReadOnlySignal from ophyd_devices.sim.sim_utils import H5Writer, LinearTrajectory from ophyd_devices.sim.sim_waveform import SimWaveform +from ophyd_devices.tests.utils import get_mock_scan_info from ophyd_devices.utils.bec_device_base import BECDevice, BECDeviceBase @@ -423,7 +423,6 @@ def test_h5proxy(h5proxy_fixture): ) camera._registered_proxies.update({h5proxy.name: camera.image.name}) camera.sim.params = {"noise": "none", "noise_multiplier": 0} - camera.scaninfo.sim_mode = True # pylint: disable=no-member camera.image_shape.set(data.shape[1:]) camera.stage() @@ -544,15 +543,15 @@ def test_cam_stage_h5writer(camera): mock.patch.object(camera, "h5_writer") as mock_h5_writer, mock.patch.object(camera, "_run_subs") as mock_run_subs, ): - camera.scaninfo.num_points = 10 - camera.scaninfo.frames_per_trigger = 1 - camera.scaninfo.exp_time = 1 + camera.scan_info.msg.num_points = 10 + camera.scan_info.msg.scan_parameters["frames_per_trigger"] = 1 + camera.scan_info.msg.scan_parameters["exp_time"] = 1 camera.stage() assert mock_h5_writer.on_stage.call_count == 0 camera.unstage() camera.write_to_disk.put(True) camera.stage() - calls = [mock.call(file_path="", h5_entry="/entry/data/data")] + calls = [mock.call(file_path="./data/test_file_camera.h5", h5_entry="/entry/data/data")] assert mock_h5_writer.on_stage.mock_calls == calls # mock_h5_writer.prepare @@ -622,17 +621,17 @@ def test_async_monitor_stage(async_monitor): def test_async_monitor_prep_random_interval(async_monitor): """Test the stage method of SimMonitorAsync.""" - async_monitor.custom_prepare.prep_random_interval() - assert async_monitor.custom_prepare._counter == 0 + async_monitor.prep_random_interval() + assert async_monitor._counter == 0 assert async_monitor.current_trigger.get() == 0 - assert 0 < async_monitor.custom_prepare._random_send_interval < 10 + assert 0 < async_monitor._random_send_interval < 10 def test_async_monitor_complete(async_monitor): """Test the on_complete method of SimMonitorAsync.""" with ( - mock.patch.object(async_monitor.custom_prepare, "_send_data_to_bec") as mock_send, - mock.patch.object(async_monitor.custom_prepare, "prep_random_interval") as mock_prep, + mock.patch.object(async_monitor, "_send_data_to_bec") as mock_send, + mock.patch.object(async_monitor, "prep_random_interval") as mock_prep, ): status = async_monitor.complete() status_wait(status) @@ -649,11 +648,11 @@ def test_async_monitor_complete(async_monitor): def test_async_mon_on_trigger(async_monitor): """Test the on_trigger method of SimMonitorAsync.""" - with (mock.patch.object(async_monitor.custom_prepare, "_send_data_to_bec") as mock_send,): - async_monitor.custom_prepare.on_stage() - upper_limit = async_monitor.custom_prepare._random_send_interval + with (mock.patch.object(async_monitor, "_send_data_to_bec") as mock_send,): + async_monitor.on_stage() + upper_limit = async_monitor._random_send_interval for ii in range(1, upper_limit + 1): - status = async_monitor.custom_prepare.on_trigger() + status = async_monitor.on_trigger() status_wait(status) assert async_monitor.current_trigger.get() == ii assert mock_send.call_count == 1 @@ -661,10 +660,10 @@ def test_async_mon_on_trigger(async_monitor): def test_async_mon_send_data_to_bec(async_monitor): """Test the _send_data_to_bec method of SimMonitorAsync.""" - async_monitor.scaninfo.scan_msg = SimpleNamespace(metadata={}) + async_monitor.scan_info = get_mock_scan_info() async_monitor.data_buffer.update({"value": [0, 5], "timestamp": [0, 0]}) with mock.patch.object(async_monitor.connector, "xadd") as mock_xadd: - async_monitor.custom_prepare._send_data_to_bec() + async_monitor._send_data_to_bec() dev_msg = messages.DeviceMessage( signals={async_monitor.readback.name: async_monitor.data_buffer}, metadata={"async_update": async_monitor.async_update.get()}, @@ -673,10 +672,10 @@ def test_async_mon_send_data_to_bec(async_monitor): call = [ mock.call( MessageEndpoints.device_async_readback( - scan_id=async_monitor.scaninfo.scan_id, device=async_monitor.name + scan_id=async_monitor.scan_info.msg.scan_id, device=async_monitor.name ), {"data": dev_msg}, - expire=async_monitor.custom_prepare._stream_ttl, + expire=async_monitor._stream_ttl, ) ] assert mock_xadd.mock_calls == call @@ -711,8 +710,8 @@ def test_waveform(waveform): # Now also test the async readback mock_connector = waveform.connector = mock.MagicMock() mock_run_subs = waveform._run_subs = mock.MagicMock() - waveform.scaninfo.scan_msg = SimpleNamespace(metadata={}) - waveform.scaninfo.scan_id = "test" + waveform.scan_info = get_mock_scan_info() + waveform.scan_info.msg.scan_id = "test" status = waveform.trigger() timer = 0 while not status.done: