diff --git a/scan_server/scan_plugins/FermatSpiralScan2.py b/scan_server/scan_plugins/FermatSpiralScan2.py new file mode 100644 index 00000000..1dd37fa5 --- /dev/null +++ b/scan_server/scan_plugins/FermatSpiralScan2.py @@ -0,0 +1,41 @@ +from scan_server.scans import ScanArgType, ScanBase, get_fermat_spiral_pos + + +class FermatSpiralScan2(ScanBase): + scan_name = "fermat_scan2" + scan_report_hint = "table" + required_kwargs = ["exp_time", "step"] + arg_input = [ScanArgType.DEVICE, ScanArgType.FLOAT, ScanArgType.FLOAT] + arg_bundle_size = len(arg_input) + + def __init__(self, *args, parameter=None, **kwargs): + """ + A scan following Fermat's spiral. + + Args: + *args: pairs of device / start position / end position / steps arguments + relative: Start from an absolute or relative position + burst: number of acquisition per point + + Returns: + + Examples: + >>> scans.fermat_scan(dev.motor1, -5, 5, dev.motor2, -5, 5, step=0.5, exp_time=0.1, relative=True) + + """ + super().__init__(parameter=parameter, **kwargs) + self.axis = [] + self.step = parameter.get("kwargs", {}).get("step", 0.1) + self.spiral_type = parameter.get("kwargs", {}).get("spiral_type", 0) + + def _calculate_positions(self): + params = list(self.caller_args.values()) + self.positions = get_fermat_spiral_pos( + params[0][0], + params[0][1], + params[1][0], + params[1][1], + step=self.step, + spiral_type=self.spiral_type, + center=False, + ) diff --git a/scan_server/scan_server/bkqueue.py b/scan_server/scan_server/bkqueue.py index fb00f924..ca0bcdcf 100644 --- a/scan_server/scan_server/bkqueue.py +++ b/scan_server/scan_server/bkqueue.py @@ -77,26 +77,26 @@ class QueueManager: def scan_interception(self, scan_mod_msg: BECMessage.ScanQueueModificationMessage) -> None: action = scan_mod_msg.content["action"] - self.__getattribute__("_set_" + action)(scanID=scan_mod_msg.content["scanID"]) + getattr("set_" + action)(scanID=scan_mod_msg.content["scanID"]) - def _set_pause(self, scanID=None, queue="primary") -> None: + def set_pause(self, scanID=None, queue="primary") -> None: self.queues[queue].status = ScanQueueStatus.PAUSED self.queues[queue].worker_status = InstructionQueueStatus.PAUSED - def _set_deferred_pause(self, scanID=None, queue="primary") -> None: + def set_deferred_pause(self, scanID=None, queue="primary") -> None: self.queues[queue].status = ScanQueueStatus.PAUSED self.queues[queue].worker_status = InstructionQueueStatus.DEFERRED_PAUSE - def _set_continue(self, scanID=None, queue="primary") -> None: + def set_continue(self, scanID=None, queue="primary") -> None: self.queues[queue].status = ScanQueueStatus.RUNNING self.queues[queue].worker_status = InstructionQueueStatus.RUNNING - def _set_abort(self, scanID=None, queue="primary") -> None: + def set_abort(self, scanID=None, queue="primary") -> None: self.queues[queue].status = ScanQueueStatus.PAUSED self.queues[queue].worker_status = InstructionQueueStatus.STOPPED self.queues[queue].remove_queue_item(scanID=scanID) - def _set_clear(self, scanID=None, queue="primary") -> None: + def set_clear(self, scanID=None, queue="primary") -> None: self.queues[queue].status = ScanQueueStatus.PAUSED self.queues[queue].worker_status = InstructionQueueStatus.PAUSED self.queues[queue].clear() diff --git a/scan_server/scan_server/scan_assembler.py b/scan_server/scan_server/scan_assembler.py index 8cd79a14..00389dcd 100644 --- a/scan_server/scan_server/scan_assembler.py +++ b/scan_server/scan_server/scan_assembler.py @@ -14,12 +14,14 @@ class ScanAssembler: self.parent = parent self.device_manager = self.parent.device_manager self.connector = self.parent.connector - self._scans = self.parent.scan_dict # TODO should these be the same dict, or a copy? + self.scan_manager = ( + self.parent.scan_manager + ) # TODO should these be the same dict, or a copy? def assemble_device_instructions(self, msg: BECMessage.ScanQueueMessage): scan = msg.content.get("scan_type") - cls_name = self._scans[scan]["class"] - scan_cls = getattr(ScanServerScans, cls_name) + cls_name = self.scan_manager.available_scans[scan]["class"] + scan_cls = self.scan_manager.scan_dict[cls_name] logger.info(f"Preparing instructions of request of type {scan} / {scan_cls.__name__}") diff --git a/scan_server/scan_server/scan_manager.py b/scan_server/scan_server/scan_manager.py new file mode 100644 index 00000000..a65b19db --- /dev/null +++ b/scan_server/scan_server/scan_manager.py @@ -0,0 +1,79 @@ +import glob +import importlib +import importlib.util +import inspect +import os +from pathlib import Path + +import msgpack +from bec_utils import MessageEndpoints, bec_logger + +from . import scans as ScanServerScans + +logger = bec_logger.logger + + +class ScanManager: + DEFAULT_PLUGIN_PATH = Path( + os.path.dirname(os.path.abspath(__file__)) + "/../scan_plugins" + ).resolve() + + def __init__(self, *, parent): + """ + Scan Manager loads and manages the available scans. + """ + self.parent = parent + self.available_scans = {} + self.scan_dict = {} + self._plugins = {} + self.load_plugins() + self.update_available_scans() + self.publish_available_scans() + + def load_plugins(self): + plugin_path = self.DEFAULT_PLUGIN_PATH + files = glob.glob(os.path.join(plugin_path, "*.py")) + for file in files: + if file.endswith("__init__.py"): + continue + filename = os.path.basename(file).split(".py")[0] + module_spec = importlib.util.spec_from_file_location("scan_plugins", file) + plugin_module = importlib.util.module_from_spec(module_spec) + module_spec.loader.exec_module(plugin_module) + module_members = inspect.getmembers(plugin_module) + for name, cls in module_members: + if name == filename: + self._plugins[name] = cls + logger.info(f"Loading scan plugin {name}") + + def update_available_scans(self): + """load all scans and plugin scans""" + members = inspect.getmembers(ScanServerScans) + for member_name, cls in self._plugins.items(): + members.append((member_name, cls)) + + for name, scan_cls in members: + try: + is_scan = issubclass(scan_cls, ScanServerScans.RequestBase) + except TypeError: + is_scan = False + + if not is_scan or not scan_cls.scan_name: + logger.debug(f"Ignoring {name}") + continue + if scan_cls.scan_name in self.available_scans: + logger.error(f"{scan_cls.scan_name} already exists. Skipping.") + continue + self.scan_dict[scan_cls.__name__] = scan_cls + self.available_scans[scan_cls.scan_name] = { + "class": scan_cls.__name__, + "arg_input": scan_cls.arg_input, + "required_kwargs": scan_cls.required_kwargs, + "scan_report_hint": scan_cls.scan_report_hint, + "doc": scan_cls.__doc__ or scan_cls.__init__.__doc__, + } + + def publish_available_scans(self): + self.parent.producer.set( + MessageEndpoints.available_scans(), msgpack.dumps(self.available_scans) + ) diff --git a/scan_server/scan_server/scan_server.py b/scan_server/scan_server/scan_server.py index cd2ad6c5..2bb9708f 100644 --- a/scan_server/scan_server/scan_server.py +++ b/scan_server/scan_server/scan_server.py @@ -1,17 +1,13 @@ from __future__ import annotations -import inspect - -import msgpack from bec_utils import BECMessage, BECService, MessageEndpoints, bec_logger from bec_utils.connector import ConnectorBase -import scan_server.scans as ScanServerScans - from .bkqueue import QueueManager from .devicemanager import DeviceManagerScanServer from .scan_assembler import ScanAssembler from .scan_guard import ScanGuard +from .scan_manager import ScanManager from .scan_worker import ScanWorker logger = bec_logger.logger @@ -23,20 +19,19 @@ class ScanServer(BECService): scan_guard = None scan_server = None scan_assembler = None + scan_manager = None def __init__(self, bootstrap_server: list, connector_cls: ConnectorBase, scibec_url: str): super().__init__(bootstrap_server, connector_cls) self.scan_number = 0 - self.scan_dict = {} self.scibec_url = scibec_url self.producer = self.connector.producer() - self._update_available_scans() + self._start_scan_manager() self._start_queue_manager() self._start_device_manager() self._start_scan_guard() self._start_scan_assembler() self._start_scan_server() - self._publish_available_scans() self._start_alarm_handler() def _start_device_manager(self): @@ -47,6 +42,9 @@ class ScanServer(BECService): self.scan_worker = ScanWorker(parent=self) self.scan_worker.start() + def _start_scan_manager(self): + self.scan_manager = ScanManager(parent=self) + def _start_queue_manager(self): self.queue_manager = QueueManager(parent=self) @@ -56,28 +54,6 @@ class ScanServer(BECService): def _start_scan_guard(self): self.scan_guard = ScanGuard(parent=self) - def _update_available_scans(self): - for name, val in inspect.getmembers(ScanServerScans): # TODO: use vars() ? - try: - is_scan = issubclass(val, ScanServerScans.RequestBase) - except TypeError: - is_scan = False - - if not is_scan or not val.scan_name: - logger.debug(f"Ignoring {name}") - continue - - self.scan_dict[val.scan_name] = { - "class": val.__name__, - "arg_input": val.arg_input, - "required_kwargs": val.required_kwargs, - "scan_report_hint": val.scan_report_hint, - "doc": val.__doc__ or val.__init__.__doc__, - } - - def _publish_available_scans(self): - self.producer.set(MessageEndpoints.available_scans(), msgpack.dumps(self.scan_dict)) - def _start_alarm_handler(self): self._alarm_consumer = self.connector.consumer( MessageEndpoints.alarm(), @@ -88,11 +64,11 @@ class ScanServer(BECService): @staticmethod def _alarm_callback(msg, parent: ScanServer, **_kwargs): - md = BECMessage.AlarmMessage.loads(msg.value).metadata - scanID = md.get("scanID") - queue = md.get("stream") + metadata = BECMessage.AlarmMessage.loads(msg.value).metadata + scanID = metadata.get("scanID") + queue = metadata.get("stream") if scanID and queue: - parent.queue_manager._set_abort( + parent.queue_manager.set_abort( scanID=msg.metadata["scanID"], queue=msg.metadata["stream"] ) diff --git a/scan_server/tests/test_queue.py b/scan_server/tests/test_queue.py index 6af3246a..270e3c05 100644 --- a/scan_server/tests/test_queue.py +++ b/scan_server/tests/test_queue.py @@ -64,35 +64,35 @@ def test_queuemanager_add_to_queue(queue): def test_set_pause(): queue_manager = get_queuemanager() - queue_manager._set_pause(queue="primary") + queue_manager.set_pause(queue="primary") assert queue_manager.queues["primary"].status == ScanQueueStatus.PAUSED assert queue_manager.producer.message_sent.get("queue") == MessageEndpoints.scan_queue_status() def test_set_deferred_pause(): queue_manager = get_queuemanager() - queue_manager._set_deferred_pause(queue="primary") + queue_manager.set_deferred_pause(queue="primary") assert queue_manager.queues["primary"].status == ScanQueueStatus.PAUSED assert queue_manager.producer.message_sent.get("queue") == MessageEndpoints.scan_queue_status() def test_set_continue(): queue_manager = get_queuemanager() - queue_manager._set_continue(queue="primary") + queue_manager.set_continue(queue="primary") assert queue_manager.queues["primary"].status == ScanQueueStatus.RUNNING assert queue_manager.producer.message_sent.get("queue") == MessageEndpoints.scan_queue_status() def test_set_abort(): queue_manager = get_queuemanager() - queue_manager._set_abort(queue="primary") + queue_manager.set_abort(queue="primary") assert queue_manager.queues["primary"].status == ScanQueueStatus.PAUSED assert queue_manager.producer.message_sent.get("queue") == MessageEndpoints.scan_queue_status() def test_set_clear_sends_message(): queue_manager = get_queuemanager() - queue_manager._set_clear(queue="primary") + queue_manager.set_clear(queue="primary") assert queue_manager.queues["primary"].status == ScanQueueStatus.PAUSED assert queue_manager.producer.message_sent.get("queue") == MessageEndpoints.scan_queue_status() @@ -124,5 +124,5 @@ def test_set_clear(): metadata={"RID": "something"}, ) queue_manager.add_to_queue(scan_queue="primary", msg=msg) - queue_manager._set_clear(queue="primary") + queue_manager.set_clear(queue="primary") assert len(queue_manager.queues["primary"].queue) == 0