refactor: add compare and transition status for nidaq and mo1_bragg

This commit is contained in:
gac-x01da
2025-06-20 09:16:27 +02:00
parent fa434794c3
commit b3672cf5f5
2 changed files with 80 additions and 124 deletions

View File

@@ -15,13 +15,15 @@ from bec_lib.devicemanager import ScanInfo
from bec_lib.logger import bec_logger
from ophyd import Component as Cpt
from ophyd import DeviceStatus, Signal, StatusBase
from ophyd.status import SubscriptionStatus
from ophyd.status import SubscriptionStatus, WaitTimeoutError
from ophyd_devices import ProgressSignal
from ophyd_devices.interfaces.base_classes.psi_device_base import PSIDeviceBase
from ophyd_devices.utils.errors import DeviceStopError
from pydantic import BaseModel, Field
from typeguard import typechecked
from ophyd_devices import TransitionStatus, CompareStatus
from debye_bec.devices.mo1_bragg.mo1_bragg_devices import Mo1BraggPositioner
# pylint: disable=unused-import
@@ -96,7 +98,7 @@ class Mo1Bragg(PSIDeviceBase, Mo1BraggPositioner):
"""
super().__init__(name=name, scan_info=scan_info, prefix=prefix, **kwargs)
self.scan_parameter = ScanParameter()
self.timeout_for_pvwait = 2.5
self.timeout_for_pvwait = 7.5
########################################
# Beamline Specific Implementations #
@@ -123,7 +125,10 @@ class Mo1Bragg(PSIDeviceBase, Mo1BraggPositioner):
Information about the upcoming scan can be accessed from the scan_info (self.scan_info.msg) object.
"""
self._check_scan_msg(ScanControlLoadMessage.PENDING)
# self._check_scan_msg(ScanControlLoadMessage.PENDING)
status = CompareStatus(self.scan_control.scan_msg, ScanControlLoadMessage.PENDING)
self.cancel_on_stop(status)
status.wait(timeout=self.timeout_for_pvwait)
scan_name = self.scan_info.msg.scan_name
self._update_scan_parameter()
@@ -204,11 +209,9 @@ class Mo1Bragg(PSIDeviceBase, Mo1BraggPositioner):
# Load the scan parameters to the controller
self.scan_control.scan_load.put(1)
# Wait for params to be checked from controller
self.wait_for_signal(
self.scan_control.scan_msg,
ScanControlLoadMessage.SUCCESS,
timeout=2 * self.timeout_for_pvwait,
)
status = CompareStatus(self.scan_control.scan_msg, ScanControlLoadMessage.SUCCESS)
self.cancel_on_stop(status)
status.wait(self.timeout_for_pvwait)
return None
def on_unstage(self) -> DeviceStatus | StatusBase | None:
@@ -216,32 +219,21 @@ class Mo1Bragg(PSIDeviceBase, Mo1BraggPositioner):
if self.stopped is True:
logger.warning(f"Resetting stopped in unstage for device {self.name}.")
self._stopped = False
current_state = self.scan_control.scan_msg.get()
# Case 1, message is already ScanControlLoadMessage.PENDING
if current_state == ScanControlLoadMessage.PENDING:
return None
# Case 2, probably called after scan, backend should resolve on its own. Timeout to wait
if current_state in [ScanControlLoadMessage.STARTED, ScanControlLoadMessage.SUCCESS]:
if self.scan_control.scan_msg.get() in [ScanControlLoadMessage.STARTED, ScanControlLoadMessage.SUCCESS]:
status = CompareStatus(self.scan_control.scan_msg, ScanControlLoadMessage.PENDING)
self.cancel_on_stop(status)
try:
self.wait_for_signal(
self.scan_control.scan_msg,
ScanControlLoadMessage.PENDING,
timeout=self.timeout_for_pvwait,
)
return
except TimeoutError:
status.wait(2)
return None
except WaitTimeoutError:
logger.warning(
f"Timeout in on_unstage of {self.name} after {self.timeout_for_pvwait}s, current scan_control_message : {self.scan_control.scan_msg.get()}"
)
def callback(*, old_value, value, **kwargs):
if value == ScanControlLoadMessage.PENDING:
return True
return False
status = SubscriptionStatus(self.scan_control.scan_msg, callback=callback)
self.scan_control.scan_val_reset.put(1)
status.wait(timeout=self.timeout_for_pvwait)
f"Timeout in on_unstage of {self.name} after {self.timeout_for_pvwait}s, current scan_control_message : {self.scan_control.scan_msg.get()}"
)
else:
status = CompareStatus(self.scan_control.scan_msg, ScanControlLoadMessage.PENDING)
self.cancel_on_stop(status)
self.scan_control.scan_val_reset.put(1)
status.wait(timeout=self.timeout_for_pvwait)
return None
def on_pre_scan(self) -> DeviceStatus | StatusBase | None:
@@ -252,20 +244,8 @@ class Mo1Bragg(PSIDeviceBase, Mo1BraggPositioner):
def on_complete(self) -> DeviceStatus | StatusBase | None:
"""Called to inquire if a device has completed a scans."""
def wait_for_complete():
"""Wait for the scan to complete. No timeout is set."""
start_time = time.time()
while True:
if self.stopped is True:
raise DeviceStopError(
f"Device {self.name} was stopped while waiting for scan to complete"
)
if self.scan_control.scan_done.get() == 1:
return
time.sleep(0.1)
status = self.task_handler.submit_task(wait_for_complete)
status = CompareStatus(self.scan_control.scan_done, 1)
self.cancel_on_stop(status)
return status
def on_kickoff(self) -> DeviceStatus | StatusBase | None:
@@ -277,13 +257,13 @@ class Mo1Bragg(PSIDeviceBase, Mo1BraggPositioner):
if scan_duration < 0.1
else self.scan_control.scan_start_timer.put
)
def callback(*, old_value, value, **kwargs):
if old_value == ScanControlScanStatus.READY and value == ScanControlScanStatus.RUNNING:
return True
return False
status = SubscriptionStatus(self.scan_control.scan_status, callback=callback)
status = TransitionStatus(
self.scan_control.scan_status,
transitions=[ScanControlScanStatus.READY, ScanControlScanStatus.RUNNING],
strict=True,
raise_states=[ScanControlScanStatus.PARAMETER_WRONG],
)
self.cancel_on_stop(status)
start_func(1)
return status

View File

@@ -11,8 +11,10 @@ from ophyd_devices import ProgressSignal
from ophyd_devices.interfaces.base_classes.psi_device_base import PSIDeviceBase
from ophyd_devices.sim.sim_signals import SetableSignal
from ophyd_devices import CompareStatus, TransitionStatus
from debye_bec.devices.nidaq.nidaq_enums import (
EncoderTypes,
EncoderFactors,
NIDAQCompression,
NidaqState,
ReadoutRange,
@@ -310,6 +312,7 @@ class NidaqControl(Device):
enc = Cpt(SetableSignal, value=0, kind=Kind.normal)
energy = Cpt(SetableSignal, value=0, kind=Kind.normal)
rle = Cpt(SetableSignal, value=0, kind=Kind.normal)
### Control PVs ###
@@ -323,14 +326,14 @@ class NidaqControl(Device):
sampling_rate = Cpt(EpicsSignal, suffix="NIDAQ-SamplingRateRequested", kind=Kind.config)
scan_duration = Cpt(EpicsSignal, suffix="NIDAQ-SamplingDuration", kind=Kind.config)
readout_range = Cpt(EpicsSignal, suffix="NIDAQ-ReadoutRange", kind=Kind.config)
encoder_type = Cpt(EpicsSignal, suffix="NIDAQ-EncoderType", kind=Kind.config)
encoder_factor = Cpt(EpicsSignal, suffix="NIDAQ-EncoderFactor", kind=Kind.config)
stop_call = Cpt(EpicsSignal, suffix="NIDAQ-Stop", kind=Kind.config)
power = Cpt(EpicsSignal, suffix="NIDAQ-Power", kind=Kind.config)
heartbeat = Cpt(EpicsSignal, suffix="NIDAQ-Heartbeat", kind=Kind.config, auto_monitor=True)
time_left = Cpt(EpicsSignalRO, suffix="NIDAQ-TimeLeft", kind=Kind.config, auto_monitor=True)
ai_chans = Cpt(EpicsSignal, suffix="NIDAQ-AIChans", kind=Kind.config)
ci_chans = Cpt(EpicsSignal, suffix="NIDAQ-CIChans6614", kind=Kind.config)
ci_chans = Cpt(EpicsSignal, suffix="NIDAQ-CIChans", kind=Kind.config)
di_chans = Cpt(EpicsSignal, suffix="NIDAQ-DIChans", kind=Kind.config)
@@ -462,12 +465,20 @@ class Nidaq(PSIDeviceBase, NidaqControl):
elif readout_range == 10:
self.readout_range.put(ReadoutRange.TEN_V)
if encoder_type in "X_1":
self.encoder_type.put(EncoderTypes.X_1)
elif encoder_type in "X_2":
self.encoder_type.put(EncoderTypes.X_2)
elif encoder_type in "X_4":
self.encoder_type.put(EncoderTypes.X_4)
if encoder_type in "1/16":
self.encoder_factor.put(EncoderFactors.X1_16)
elif encoder_type in "1/8":
self.encoder_factor.put(EncoderFactors.X1_8)
elif encoder_type in "1/4":
self.encoder_factor.put(EncoderFactors.X1_4)
elif encoder_type in "1/2":
self.encoder_factor.put(EncoderFactors.X1_2)
elif encoder_type in "1":
self.encoder_factor.put(EncoderFactors.X1)
elif encoder_type in "2":
self.encoder_factor.put(EncoderFactors.X2)
elif encoder_type in "4":
self.encoder_factor.put(EncoderFactors.X4)
if enable_compression is True:
self.enable_compression.put(NIDAQCompression.ON)
@@ -491,26 +502,21 @@ class Nidaq(PSIDeviceBase, NidaqControl):
Called after the device is connected and its signals are connected.
Default values for signals should be set here.
"""
def heartbeat_callback(*, old_value, value, **kwargs):
return ((old_value) == 0 and (value == 1)) or ((old_value) == 1 and (value == 0))
status = SubscriptionStatus(self.heartbeat, callback=heartbeat_callback)
status = TransitionStatus(self.heartbeat, transitions=[0,1], strict=False)
self.cancel_on_stop(status)
try:
status.wait(timeout=self.timeout_wait_for_signal) # Raises if timeout is reached
except WaitTimeoutError:
logger.warning(f"Device {self.name} was not alive, trying to put power on")
status = TransitionStatus(self.heartbeat, transitions=[0,1], strict=False)
self.cancel_on_stop(status)
self.power.put(1)
status.wait(timeout=self.timeout_wait_for_signal)
if not self.wait_for_condition(
condition=lambda: self.state.get() == NidaqState.STANDBY,
timeout=self.timeout_wait_for_signal,
check_stopped=True,
):
raise NidaqError(
f"Device {self.name} has not been reached in state STANDBY, current state {NidaqState(self.state.get())}"
)
status = CompareStatus(self.state, NidaqState.STANDBY)
self.cancel_on_stop(status)
status.wait(timeout=self.timeout_wait_for_signal)
self.scan_duration.set(0).wait(timeout=self._timeout_wait_for_pv)
self.time_left.subscribe(self._progress_update, run=False)
@@ -524,14 +530,9 @@ class Nidaq(PSIDeviceBase, NidaqControl):
if not self._check_if_scan_name_is_valid():
return None
if not self.wait_for_condition(
condition=lambda: self.state.get() == NidaqState.STANDBY,
timeout=self.timeout_wait_for_signal,
check_stopped=True,
):
raise NidaqError(
f"Device {self.name} has not been reached in state STANDBY, current state {NidaqState(self.state.get())}"
)
status = CompareStatus(self.state, NidaqState.STANDBY)
self.cancel_on_stop(status)
status.wait(timeout=self.timeout_wait_for_signal)
# If scan is not part of the valid_scan_names,
if self.scan_info.msg.scan_name != "nidaq_continuous_scan":
self.scan_type.set(ScanType.TRIGGERED).wait(timeout=self._timeout_wait_for_pv)
@@ -546,18 +547,14 @@ class Nidaq(PSIDeviceBase, NidaqControl):
timeout=self._timeout_wait_for_pv
)
# Stage call to IOC
status = CompareStatus(self.state, NidaqState.STAGE)
self.cancel_on_stop(status)
self.stage_call.set(1).wait(timeout=self._timeout_wait_for_pv)
if not self.wait_for_condition(
condition=lambda: self.state.get() == NidaqState.STAGE,
timeout=self.timeout_wait_for_signal,
check_stopped=True,
):
raise NidaqError(
f"Device {self.name} has not been reached in state STAGE, current state {NidaqState(self.state.get())}"
)
status.wait(timeout=self.timeout_wait_for_signal)
if self.scan_info.msg.scan_name != "nidaq_continuous_scan":
status = self.on_kickoff()
self.cancel_on_stop(status)
status.wait(timeout=self._timeout_wait_for_pv)
logger.info(f"Device {self.name} was staged: {NidaqState(self.state.get())}")
@@ -569,17 +566,12 @@ class Nidaq(PSIDeviceBase, NidaqControl):
def on_unstage(self) -> DeviceStatus | StatusBase | None:
"""Called while unstaging the device. Check that the Nidaq goes into Standby"""
def _get_state():
return self.state.get() == NidaqState.STANDBY
# TODO We need to wait longer if rle is disabled
if not self.wait_for_condition(
condition=_get_state, timeout=self.timeout_wait_for_signal, check_stopped=False
):
raise NidaqError(
f"Device {self.name} has not been reached in state STANDBY, current state {NidaqState(self.state.get())}"
)
self.enable_compression.set(1).wait(self._timeout_wait_for_pv)
status = CompareStatus(self.state, NidaqState.STANDBY)
self.cancel_on_stop(status)
status.wait(timeout=self.timeout_wait_for_signal)
status = self.enable_compression.set(1)
self.cancel_on_stop(status)
status.wait(self._timeout_wait_for_pv)
logger.info(f"Device {self.name} was unstaged: {NidaqState(self.state.get())}")
def on_pre_scan(self) -> DeviceStatus | StatusBase | None:
@@ -597,15 +589,10 @@ class Nidaq(PSIDeviceBase, NidaqControl):
logger.info(f"Device {self.name} ready to be kicked off for nidaq_continuous_scan")
return None
def _wait_for_state():
return self.state.get() == NidaqState.KICKOFF
if not self.wait_for_condition(
_wait_for_state, timeout=self.timeout_wait_for_signal, check_stopped=True
):
raise NidaqError(
f"Device {self.name} failed to reach state KICKOFF during pre scan, current state {NidaqState(self.state.get())}"
)
status = CompareStatus(self.state, NidaqState.KICKOFF)
self.cancel_on_stop(status)
status.wait(timeout=self._timeout_wait_for_pv)
logger.info(
f"Device {self.name} ready to take data after pre_scan: {NidaqState(self.state.get())}"
)
@@ -623,21 +610,10 @@ class Nidaq(PSIDeviceBase, NidaqControl):
if not self._check_if_scan_name_is_valid():
return None
def _check_state(self) -> bool:
while True:
if self.stopped is True:
raise NidaqError(f"Device {self.name} was stopped")
if self.state.get() == NidaqState.STANDBY:
return
# if time.time() > timeout_time:
# raise TimeoutError(f"Device {self.name} ran into timeout")
time.sleep(0.1)
status = CompareStatus(self.state, NidaqState.STANDBY)
self.cancel_on_stop(status)
if self.scan_info.msg.scan_name != "nidaq_continuous_scan":
self.on_stop()
status = self.task_handler.submit_task(task=_check_state, task_args=(self,))
else:
status = self.task_handler.submit_task(task=_check_state, task_args=(self,))
return status
def _progress_update(self, value, **kwargs) -> None: