diff --git a/ophyd_devices/utils/psi_device_base_utils.py b/ophyd_devices/utils/psi_device_base_utils.py index 38f88a1..10d4296 100644 --- a/ophyd_devices/utils/psi_device_base_utils.py +++ b/ophyd_devices/utils/psi_device_base_utils.py @@ -46,6 +46,64 @@ OP_MAP = { } +class AndStatusWithList(DeviceStatus): + """ + Custom implementation of the AndStatus that combines the + option to add multiple statuses as a list, and in addition + allows for adding the Device as an object to access its + methods. + + Args""" + + def __init__( + self, + device: Device, + status: StatusBase | DeviceStatus | list[StatusBase | DeviceStatus], + **kwargs, + ): + self.all_statuses = status if isinstance(status, list) else [status] + super().__init__(device=device, **kwargs) + self._trace_attributes["all"] = [st._trace_attributes for st in self.all_statuses] + + def inner(status): + with self._lock: + if self._externally_initiated_completion: + return + if self.done: # Return if status is already done.. It must be resolved already + return + + for st in self.all_statuses: + with st._lock: + if st.done and not st.success: + self.set_exception(st.exception()) # st._exception + return + + if all(st.done for st in self.all_statuses) and all( + st.success for st in self.all_statuses + ): + self.set_finished() + + for st in self.all_statuses: + with st._lock: + st.add_callback(inner) + + def __repr__(self): + return "".format(self=self) + + def __str__(self): + return "".format(self=self) + + def __contains__(self, status: StatusBase | DeviceStatus) -> bool: + for child in self.all_statuses: + if child == status: + return True + if isinstance(child, AndStatusWithList): + if status in child: + return True + + return False + + class CompareStatus(SubscriptionStatus): """ Status class to compare a signal value against a given value.