diff --git a/device_server/device_server/device_server.py b/device_server/device_server/device_server.py index 11c3bd2b..57006ed1 100644 --- a/device_server/device_server/device_server.py +++ b/device_server/device_server/device_server.py @@ -8,13 +8,13 @@ from io import StringIO from typing import Any import ophyd -from ophyd import Staged -from ophyd.utils import errors as ophyd_errors - from bec_lib.core import Alarms, BECMessage, BECService, MessageEndpoints, bec_logger from bec_lib.core.BECMessage import BECStatus from bec_lib.core.connector import ConnectorBase from bec_lib.core.devicemanager import OnFailure +from ophyd import Staged +from ophyd.utils import errors as ophyd_errors + from device_server.devices import is_serializable, rgetattr from device_server.devices.devicemanager import DeviceManagerDS @@ -333,10 +333,24 @@ class DeviceServer(BECService): status.add_callback(self._status_callback) def _complete_device(self, instr: BECMessage.DeviceInstructionMessage) -> None: - obj = self.device_manager.devices.get(instr.content["device"]).obj - status = obj.complete() - status.__dict__["instruction"] = instr - status.add_callback(self._status_callback) + if instr.content["device"] is None: + devices = self.device_manager.devices.enabled_devices + else: + devices = instr.content["device"] + if not isinstance(devices, list): + devices = [devices] + for dev in devices: + obj = self.device_manager.devices.get(dev).obj + if not hasattr(obj, "complete"): + # if the device does not have a complete method, we assume that it is done + status = ophyd.StatusBase() + status.obj = obj + status.set_finished() + else: + logger.info(f"Completing device: {dev}") + status = obj.complete() + status.__dict__["instruction"] = instr + status.add_callback(self._status_callback) def _set_device(self, instr: BECMessage.DeviceInstructionMessage) -> None: device_obj = self.device_manager.devices.get(instr.content["device"]) diff --git a/scan_server/scan_server/scan_worker.py b/scan_server/scan_server/scan_worker.py index 2730b826..f9af5620 100644 --- a/scan_server/scan_server/scan_worker.py +++ b/scan_server/scan_server/scan_worker.py @@ -444,15 +444,47 @@ class ScanWorker(threading.Thread): ) def _complete_devices(self, instr: DeviceMsg) -> None: + if instr.content.get("device") is None: + devices = [dev.name for dev in self.device_manager.devices.enabled_devices] + else: + devices = instr.content.get("device") self.device_manager.producer.send( MessageEndpoints.device_instructions(), DeviceMsg( - device=instr.content.get("device"), + device=devices, action="complete", parameter=instr.content["parameter"], metadata=instr.metadata, ).dumps(), ) + self._wait_for_complete(instr) + + def _wait_for_complete(self, instr: DeviceMsg) -> None: + if instr.content.get("device") is None: + devices = [dev.name for dev in self.device_manager.devices.enabled_devices] + else: + devices = instr.content.get("device") + metadata = instr.metadata + while True: + complete_status = self._get_device_status(MessageEndpoints.device_req_status, devices) + self._check_for_interruption() + device_status = [ + BECMessage.DeviceReqStatusMessage.loads(dev) for dev in complete_status + ] + + if None in device_status: + continue + devices_are_ready = all( + bool(dev.content.get("success")) is True for dev in device_status + ) + matching_scanID = all( + dev.metadata.get("scanID") == metadata["scanID"] for dev in device_status + ) + matching_DIID = all( + dev.metadata.get("DIID") == metadata["DIID"] for dev in device_status + ) + if devices_are_ready and matching_scanID and matching_DIID: + break def _baseline_reading(self, instr: DeviceMsg) -> None: baseline_devices = [ diff --git a/scan_server/scan_server/scans.py b/scan_server/scan_server/scans.py index b969d79a..90da9041 100644 --- a/scan_server/scan_server/scans.py +++ b/scan_server/scan_server/scans.py @@ -453,6 +453,7 @@ class ScanBase(RequestBase, PathOptimizerMixin): """finalize the scan""" yield from self.return_to_start() yield from self.stubs.wait(wait_type="read", group="primary", wait_group="readout_primary") + yield from self.stubs.complete(device=None) def unstage(self): """call the unstage procedure""" diff --git a/scan_server/tests/test_scans.py b/scan_server/tests/test_scans.py index ef6194e8..8d16afbf 100644 --- a/scan_server/tests/test_scans.py +++ b/scan_server/tests/test_scans.py @@ -699,6 +699,10 @@ def test_scan_updated_move(mv_msg, reference_msg_list): }, metadata={"readout_priority": "monitored", "DIID": 23}, ), + BMessage.DeviceInstructionMessage( + **{"device": None, "action": "complete", "parameter": {}}, + metadata={"readout_priority": "monitored", "DIID": 31}, + ), BMessage.DeviceInstructionMessage( device=None, action="unstage", @@ -988,6 +992,10 @@ def test_fermat_scan(scan_msg, reference_scan_list): }, metadata={"readout_priority": "monitored", "DIID": 16}, ), + BMessage.DeviceInstructionMessage( + **{"device": None, "action": "complete", "parameter": {}}, + metadata={"readout_priority": "monitored", "DIID": 31}, + ), BMessage.DeviceInstructionMessage( device=None, action="unstage", @@ -1129,6 +1137,10 @@ def test_device_rpc(): parameter={"type": "read", "group": "primary", "wait_group": "readout_primary"}, metadata={"readout_priority": "monitored", "DIID": 6}, ), + BMessage.DeviceInstructionMessage( + **{"device": None, "action": "complete", "parameter": {}}, + metadata={"readout_priority": "monitored", "DIID": 31}, + ), BMessage.DeviceInstructionMessage( device=None, action="unstage", @@ -1815,6 +1827,10 @@ def test_scan_base_set_position_offset(): parameter={"type": "read", "group": "primary", "wait_group": "readout_primary"}, metadata={"readout_priority": "monitored", "DIID": 16}, ), + BMessage.DeviceInstructionMessage( + **{"device": None, "action": "complete", "parameter": {}}, + metadata={"readout_priority": "monitored", "DIID": 31}, + ), BMessage.DeviceInstructionMessage( device=None, action="unstage", @@ -2244,17 +2260,21 @@ def test_list_scan_raises_for_different_lengths(): parameter={"type": "read", "group": "primary", "wait_group": "readout_primary"}, metadata={"readout_priority": "monitored", "DIID": 20}, ), + BMessage.DeviceInstructionMessage( + **{"device": None, "action": "complete", "parameter": {}}, + metadata={"readout_priority": "monitored", "DIID": 21}, + ), BMessage.DeviceInstructionMessage( device=None, action="unstage", parameter={}, - metadata={"readout_priority": "monitored", "DIID": 21}, + metadata={"readout_priority": "monitored", "DIID": 22}, ), BMessage.DeviceInstructionMessage( device=None, action="close_scan", parameter={}, - metadata={"readout_priority": "monitored", "DIID": 22}, + metadata={"readout_priority": "monitored", "DIID": 23}, ), ], ) @@ -2365,17 +2385,21 @@ def test_time_scan(scan_msg, reference_scan_list): parameter={"type": "read", "group": "primary", "wait_group": "readout_primary"}, metadata={"readout_priority": "monitored", "DIID": 10, "RID": "1234"}, ), + BMessage.DeviceInstructionMessage( + **{"device": None, "action": "complete", "parameter": {}}, + metadata={"readout_priority": "monitored", "DIID": 11, "RID": "1234"}, + ), BMessage.DeviceInstructionMessage( device=None, action="unstage", parameter={}, - metadata={"readout_priority": "monitored", "DIID": 11, "RID": "1234"}, + metadata={"readout_priority": "monitored", "DIID": 12, "RID": "1234"}, ), BMessage.DeviceInstructionMessage( device=None, action="close_scan", parameter={}, - metadata={"readout_priority": "monitored", "DIID": 12, "RID": "1234"}, + metadata={"readout_priority": "monitored", "DIID": 13, "RID": "1234"}, ), ], )