fix(transition-status): improve transition status called with no transitions

This commit is contained in:
2025-11-28 17:11:32 +01:00
committed by Christian Appel
parent 13d658241a
commit 57ff40566b
2 changed files with 27 additions and 8 deletions

View File

@@ -6,12 +6,11 @@ import threading
import traceback import traceback
import uuid import uuid
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, Callable, Literal from typing import TYPE_CHECKING, Callable, Literal, Union
from bec_lib.file_utils import get_full_path from bec_lib.file_utils import get_full_path
from bec_lib.logger import bec_logger from bec_lib.logger import bec_logger
from bec_lib.utils.import_utils import lazy_import_from from bec_lib.utils.import_utils import lazy_import_from
from ophyd import Device, Signal
from ophyd.status import DeviceStatus as _DeviceStatus from ophyd.status import DeviceStatus as _DeviceStatus
from ophyd.status import MoveStatus as _MoveStatus from ophyd.status import MoveStatus as _MoveStatus
from ophyd.status import Status as _Status from ophyd.status import Status as _Status
@@ -19,6 +18,7 @@ from ophyd.status import StatusBase as _StatusBase
if TYPE_CHECKING: # pragma: no cover if TYPE_CHECKING: # pragma: no cover
from bec_lib.messages import ScanStatusMessage from bec_lib.messages import ScanStatusMessage
from ophyd import Device, Signal
else: else:
# TODO: put back normal import when Pydantic gets faster # TODO: put back normal import when Pydantic gets faster
ScanStatusMessage = lazy_import_from("bec_lib.messages", ("ScanStatusMessage",)) ScanStatusMessage = lazy_import_from("bec_lib.messages", ("ScanStatusMessage",))
@@ -53,7 +53,13 @@ class StatusBase(_StatusBase):
"""Base class for all status objects.""" """Base class for all status objects."""
def __init__( def __init__(
self, obj: Device | None = None, *, timeout=None, settle_time=0, done=None, success=None self,
obj: Union["Device", None] = None,
*,
timeout=None,
settle_time=0,
done=None,
success=None,
): ):
self.obj = obj self.obj = obj
super().__init__(timeout=timeout, settle_time=settle_time, done=done, success=success) super().__init__(timeout=timeout, settle_time=settle_time, done=done, success=success)
@@ -129,7 +135,7 @@ class Status(_Status):
class DeviceStatus(_DeviceStatus): class DeviceStatus(_DeviceStatus):
"""Thin wrapper around DeviceStatus to add __and__ operator and add stop on failure option, defaults to False""" """Thin wrapper around DeviceStatus to add __and__ operator."""
def __and__(self, other): def __and__(self, other):
"""Returns a new 'composite' status object, AndStatus""" """Returns a new 'composite' status object, AndStatus"""
@@ -149,7 +155,7 @@ class SubscriptionStatus(StatusBase):
def __init__( def __init__(
self, self,
obj: Device | Signal, obj: Union["Device", "Signal"],
callback: Callable, callback: Callable,
event_type=None, event_type=None,
timeout=None, timeout=None,
@@ -328,6 +334,8 @@ class TransitionStatus(SubscriptionStatus):
): ):
self._signal = signal self._signal = signal
self._transitions = tuple(transitions) self._transitions = tuple(transitions)
if not transitions:
raise ValueError("Transitions {transitions}must contain at least one value")
self._index = 0 self._index = 0
self._strict = strict self._strict = strict
self._failure_states = failure_states if failure_states else [] self._failure_states = failure_states if failure_states else []
@@ -397,7 +405,13 @@ class TaskStatus(StatusBase):
"""Thin wrapper around StatusBase to add information about tasks""" """Thin wrapper around StatusBase to add information about tasks"""
def __init__( def __init__(
self, obj: Device | Signal, *, timeout=None, settle_time=0, done=None, success=None self,
obj: Union["Device", "Signal"],
*,
timeout=None,
settle_time=0,
done=None,
success=None,
): ):
super().__init__( super().__init__(
obj=obj, timeout=timeout, settle_time=settle_time, done=done, success=success obj=obj, timeout=timeout, settle_time=settle_time, done=done, success=success
@@ -423,7 +437,7 @@ class TaskStatus(StatusBase):
class TaskHandler: class TaskHandler:
"""Handler to manage asynchronous tasks""" """Handler to manage asynchronous tasks"""
def __init__(self, parent: Device): def __init__(self, parent: "Device"):
"""Initialize the handler""" """Initialize the handler"""
self._tasks = {} self._tasks = {}
self._parent = parent self._parent = parent

View File

@@ -965,11 +965,16 @@ def test_patched_status_objects():
with pytest.raises(ValueError): with pytest.raises(ValueError):
and_st.wait(timeout=10) and_st.wait(timeout=10)
# DeviceStatus & StatusBase # DeviceStatus & Status
dev = Device(name="device") dev = Device(name="device")
dev_status = DeviceStatus(device=dev) dev_status = DeviceStatus(device=dev)
st = Status()
and_st = st and dev_status
assert dev_status.device == dev assert dev_status.device == dev
dev_status.set_exception(RuntimeError("device error")) dev_status.set_exception(RuntimeError("device error"))
with pytest.raises(RuntimeError):
and_st.wait(timeout=10)
# Combine DeviceStatus with StatusBase and form AndStatus # Combine DeviceStatus with StatusBase and form AndStatus
st = StatusBase(obj=dev) st = StatusBase(obj=dev)