diff --git a/bec_lib/bec_lib/endpoints.py b/bec_lib/bec_lib/endpoints.py index 4affcb4c..393ba991 100644 --- a/bec_lib/bec_lib/endpoints.py +++ b/bec_lib/bec_lib/endpoints.py @@ -21,7 +21,7 @@ class MessageOp(list[str], enum.Enum): SET_PUBLISH = ["register", "set_and_publish", "delete", "get", "keys"] SEND = ["send", "register"] STREAM = ["xadd", "xrange", "xread", "register_stream", "keys", "get_last"] - LIST = ["lpush", "lrange", "rpush", "ltrim", "keys"] + LIST = ["lpush", "lrange", "rpush", "ltrim", "keys", "delete"] SET = ["set", "get", "delete", "keys"] @@ -499,6 +499,23 @@ class MessageEndpoints: message_op=MessageOp.LIST, ) + @staticmethod + def scan_queue_schedule(schedule_name: str) -> EndpointInfo: + """ + Endpoint for scan queue schedule. This endpoint is used to store messages.ScanQueueScheduleMessage messages + in a redis list. + + Args: + schedule_name (str): Name of the schedule. + + Returns: + EndpointInfo: Endpoint for scan queue schedule. + """ + endpoint = f"internal/queue/queue_schedule/{schedule_name}" + return EndpointInfo( + endpoint=endpoint, message_type=messages.ScanQueueMessage, message_op=MessageOp.LIST + ) + # scan info @staticmethod def scan_number() -> EndpointInfo: diff --git a/bec_lib/bec_lib/scan_manager.py b/bec_lib/bec_lib/scan_manager.py index 61f56d2e..1db2474c 100644 --- a/bec_lib/bec_lib/scan_manager.py +++ b/bec_lib/bec_lib/scan_manager.py @@ -6,6 +6,7 @@ as the requests and scans that are currently running or have been completed. from __future__ import annotations import uuid +from typing import TYPE_CHECKING from typeguard import typechecked @@ -18,9 +19,12 @@ from bec_lib.scan_items import ScanStorage logger = bec_logger.logger +if TYPE_CHECKING: + from bec_lib.redis_connector import RedisConnector + class ScanManager: - def __init__(self, connector): + def __init__(self, connector: RedisConnector): """ ScanManager is a class that provides a convenient way to interact with the scan queue as well as the requests and scans that are currently running or have been completed. @@ -225,6 +229,64 @@ class ScanManager: msg = msg.value self.scan_storage.add_scan_baseline(msg) + @typechecked + def add_scan_to_queue_schedule( + self, schedule_name: str, msg: messages.ScanQueueMessage + ) -> None: + """ + Add a scan to the queue schedule + + Args: + schedule_name (str): name of the queue schedule + msg (messages.ScanQueueMessage): scan message + """ + self.connector.rpush(MessageEndpoints.scan_queue_schedule(schedule_name=schedule_name), msg) + + @typechecked + def get_scan_queue_schedule(self, schedule_name: str) -> list: + """ + Get the scan queue schedule + + Args: + schedule_name (str): name of the queue schedule + + Returns: + list: list of scan messages + """ + return self.connector.lrange( + MessageEndpoints.scan_queue_schedule(schedule_name=schedule_name), 0, -1 + ) + + @typechecked + def clear_scan_queue_schedule(self, schedule_name: str) -> None: + """ + Clear the scan queue schedule + + Args: + schedule_name (str): name of the queue schedule + """ + self.connector.delete(MessageEndpoints.scan_queue_schedule(schedule_name=schedule_name)) + + def get_scan_queue_schedule_names(self) -> list: + """ + Get the names of the scan queue schedules + + Returns: + list: list of schedule names + """ + keys = self.connector.keys(MessageEndpoints.scan_queue_schedule(schedule_name="*")) + if not keys: + return [] + return [key.decode().split("/")[-1] for key in keys] + + def clear_all_scan_queue_schedules(self) -> None: + """ + Clear all scan queue schedules + """ + keys = self.get_scan_queue_schedule_names() + for key in keys: + self.clear_scan_queue_schedule(key) + def __str__(self) -> str: try: return "\n".join(self.queue_storage.describe_queue()) diff --git a/bec_lib/tests/test_scan_manager.py b/bec_lib/tests/test_scan_manager.py index 1c65d01f..88755c41 100644 --- a/bec_lib/tests/test_scan_manager.py +++ b/bec_lib/tests/test_scan_manager.py @@ -1,11 +1,18 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING from unittest import mock import pytest +from typeguard import TypeCheckError from bec_lib import messages from bec_lib.endpoints import MessageEndpoints from bec_lib.scan_manager import ScanManager +if TYPE_CHECKING: + from bec_lib.redis_connector import RedisConnector + @pytest.fixture def scan_manager(): @@ -15,6 +22,13 @@ def scan_manager(): manager.shutdown() +@pytest.fixture +def scan_manager_with_fakeredis(connected_connector: RedisConnector): + manager = ScanManager(connector=connected_connector) + yield manager + manager.shutdown() + + def test_scan_manager_next_scan_number(scan_manager): scan_manager.connector.get.return_value = messages.VariableMessage(value=3) assert scan_manager.next_scan_number == 3 @@ -135,3 +149,43 @@ def test_scan_manager_request_scan_continuation_scan_id(scan_manager, scan_id): MessageEndpoints.scan_queue_modification_request(), messages.ScanQueueModificationMessage(scan_id=scan_id, action="continue", parameter={}), ) + + +def test_scan_manager_add_scan_to_queue_schedule(scan_manager_with_fakeredis): + """ + Test the interaction with queue schedules + + Args: + scan_manager_with_fakeredis: The scan manager fixture with a fakeredis connection + """ + manager: ScanManager = scan_manager_with_fakeredis + msg = messages.ScanQueueMessage(scan_type="mv", parameter={"args": {"samx": [5], "samy": [5]}}) + manager.add_scan_to_queue_schedule("new_schedule", msg) + + with pytest.raises(TypeCheckError): + manager.add_scan_to_queue_schedule("new_schedule", {}) + + assert manager.get_scan_queue_schedule("new_schedule") == [msg] + + msg2 = messages.ScanQueueMessage(scan_type="mv", parameter={"args": {"samx": [6], "samy": [6]}}) + manager.add_scan_to_queue_schedule("new_schedule", msg2) + + assert manager.get_scan_queue_schedule("new_schedule") == [msg, msg2] + + manager.add_scan_to_queue_schedule("new_schedule2", msg) + + assert manager.get_scan_queue_schedule("new_schedule2") == [msg] + + assert manager.get_scan_queue_schedule_names() == ["new_schedule", "new_schedule2"] + + manager.clear_scan_queue_schedule("new_schedule2") + + assert manager.get_scan_queue_schedule_names() == ["new_schedule"] + + assert manager.get_scan_queue_schedule("new_schedule2") == [] + + assert manager.get_scan_queue_schedule("new_schedule") == [msg, msg2] + + manager.clear_all_scan_queue_schedules() + + assert manager.get_scan_queue_schedule_names() == []