diff --git a/csaxs_bec/devices/epics/mcs_card/mcs_card_csaxs.py b/csaxs_bec/devices/epics/mcs_card/mcs_card_csaxs.py index 035c670..8e806d1 100644 --- a/csaxs_bec/devices/epics/mcs_card/mcs_card_csaxs.py +++ b/csaxs_bec/devices/epics/mcs_card/mcs_card_csaxs.py @@ -2,6 +2,7 @@ from __future__ import annotations +import threading import time from threading import RLock from typing import TYPE_CHECKING, Literal @@ -9,8 +10,8 @@ from typing import TYPE_CHECKING, Literal import numpy as np from bec_lib.logger import bec_logger from ophyd import Component as Cpt -from ophyd import EpicsSignalRO, Kind, SignalRO -from ophyd_devices import AsyncSignal, CompareStatus, ProgressSignal, TransitionStatus +from ophyd import EpicsSignalRO, Kind +from ophyd_devices import AsyncSignal, CompareStatus, ProgressSignal, StatusBase from ophyd_devices.interfaces.base_classes.psi_device_base import PSIDeviceBase from ophyd_devices.utils.bec_signals import AsyncMultiSignal @@ -367,6 +368,8 @@ class MCSCardCSAXS(PSIDeviceBase, MCSCard): self._acquisition_group: str = "monitored" self._num_total_triggers: int = 0 # Add logic that data was sent ( threading event) + self._last_data_sent_event: threading.Event = threading.Event() + self._scan_done_event: threading.Event = threading.Event() def on_connected(self): """ @@ -440,6 +443,11 @@ class MCSCardCSAXS(PSIDeviceBase, MCSCard): f"Received update for {signal.name}. Setting data for {mca_channel.name}: to {data}" ) mca_channel.put(data) + if self._scan_done_event.is_set(): + # Last data sent after scan is done TODO, improve logic as this may fail due to timing.. + # better to count the number of data that was sent... + self._last_data_sent_event.set() + # check # self._received_updates.update(data) # if len(self._received_updates) == self.num_connected_channels: # # Send out data on multi async signal @@ -448,11 +456,10 @@ class MCSCardCSAXS(PSIDeviceBase, MCSCard): def _progress_update(self, value, **kwargs) -> None: """Callback for progress updates from ophyd subscription on current_channel.""" - self.progress.put( - value=value, - max_value=self._num_total_triggers, - done=bool(value == self._num_total_triggers), - ) + scan_done = bool(value == self._num_total_triggers) + self.progress.put(value=value, max_value=self._num_total_triggers, done=scan_done) + if scan_done: + self._scan_done_event.set() def on_stage(self) -> None: """ @@ -467,6 +474,9 @@ class MCSCardCSAXS(PSIDeviceBase, MCSCard): self.num_use_all.set(triggers).wait(timeout=self._pv_timeout) # Reset data self._received_updates.clear() + # Reset last data sent event + self._scan_done_event.clear() + self._last_data_sent_event.clear() def on_unstage(self) -> None: """ @@ -475,16 +485,19 @@ class MCSCardCSAXS(PSIDeviceBase, MCSCard): self.stop_all.put(1) self.erase_all.set(0).wait(timeout=self._pv_timeout) - def on_pre_scan(self) -> None: - """ - Called before the scan starts. - """ + def _check_data_sent(self, timeout: float = 5.0) -> None: + """Check if data was sent within the timeout period.""" + self._last_data_sent_event.wait(timeout=timeout) + if not self._last_data_sent_event.is_set(): + raise TimeoutError(f"Data was not sent within {timeout} seconds after acquisition.") def on_complete(self) -> CompareStatus: """On scan completion.""" - # Check if we should get a signal based on updates from the MCA channels + # Check Acquiring is DONE + status_data_sent = self.task_handler.submit_task(self._check_data_sent, task_args=(5.0)) status = CompareStatus(self.acquiring, ACQUIRING.DONE) - self.cancel_on_stop(status) + ret_status = status & status_data_sent + self.cancel_on_stop(ret_status) return status def on_stop(self) -> None: