diff --git a/.zed/debug.json b/.zed/debug.json new file mode 100644 index 00000000..edef0478 --- /dev/null +++ b/.zed/debug.json @@ -0,0 +1,13 @@ +// Project-local debug tasks +// +// For more documentation on how to configure debug tasks, +// see: https://zed.dev/docs/debugger +[ + { + "adapter": "Debugpy", + "label": "Debug current python file", + "justMyCode": false, + "program": "${ZED_FILE}", + "request": "launch" + }, +] diff --git a/bec_lib/bec_lib/actors.py b/bec_lib/bec_lib/actors.py new file mode 100644 index 00000000..5815eb75 --- /dev/null +++ b/bec_lib/bec_lib/actors.py @@ -0,0 +1,52 @@ +"""Definitions and protocols for the classes in bec_server.actors and in plugins.""" + +from dataclasses import dataclass +from enum import StrEnum +from typing import Any, Protocol, runtime_checkable + +from typing_extensions import TypeAliasType # only for 3.11 + +from bec_lib.client import BECClient + + +class ConditionCombination(StrEnum): + Any = "Any" + All = "All" + + +class ActorCondition(Protocol): + def __call__(self, client: BECClient) -> Any: + """A callable which returns True if the condition is met and False if it is not.""" + + +@dataclass +class ActorConditionSet: + conditions: set[ActorCondition] + combination_mode: ConditionCombination + + def __call__(self, client: BECClient) -> bool: + if self.combination_mode == ConditionCombination.Any: + return any(condition(client) for condition in self.conditions) + return all(condition(client) for condition in self.conditions) + + +class ActorAction(Protocol): + def __call__(self, client: BECClient) -> None: + """A callable which an actor calls if it has met the associated condition""" + + +ActorActionTable = TypeAliasType( + "ActorActionTable", dict[ActorConditionSet | ActorCondition, ActorAction] +) + + +@runtime_checkable +class Actor(Protocol): + client: BECClient + action_table: ActorActionTable + + def __init__(self, client: BECClient, exec_id: str): + """Create the actor instance""" + + def run(self): + """The core logic loop for the actor""" diff --git a/bec_server/bec_server/actors/__init__.py b/bec_server/bec_server/actors/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/bec_server/bec_server/actors/actor.py b/bec_server/bec_server/actors/actor.py new file mode 100644 index 00000000..d130a94f --- /dev/null +++ b/bec_server/bec_server/actors/actor.py @@ -0,0 +1,69 @@ +"""Actors can autonomously respond to changes in beamline states.""" + +import time +from threading import Event +from typing import Iterable + +from bec_lib.actors import Actor as ActorProtocol +from bec_lib.client import BECClient +from bec_lib.endpoints import EndpointInfo, MessageEndpoints +from bec_lib.logger import bec_logger + +logger = bec_logger.logger + + +class ActorBase(ActorProtocol): + def __init__(self, client: BECClient, exec_id: str): + self.client = client + self.stop_event = Event() + self.client.connector.register(MessageEndpoints.actor_stop(exec_id), cb=self.stop) + + def evaluate(self, *_, **__): + for condition, action in self.action_table.items(): + if condition(self.client): + logger.info( + f"{self.__class__.__name__} triggered, executing action for condition: {condition}" + ) + action(self.client) + + def stop(self, *_): + self.stop_event.set() + + +class SubscriptionActor(ActorBase): + """An actor which subscribes to a list of redis endpoints, and evaluates on any message to any + of those endpoints.""" + + def __init__( + self, + client: BECClient, + exec_id: str, + endpoints: Iterable[EndpointInfo], + min_delay_s: float = 0.01, + ): + super().__init__(client, exec_id) + self.min_delay = min_delay_s + self.last_evaluated = time.monotonic() + for endpoint in set(endpoints): + client.connector.register(endpoint, cb=self.evaluate) + + def evaluate(self, *_, **__): + if (now := time.monotonic()) < self.last_evaluated + self.min_delay: + return + self.last_evaluated = now + return super().evaluate(*_, **__) + + def run(self): + self.stop_event.wait() + + +class PollingActor(ActorBase): + """An actor which evaluates its conditions after a certain time interval.""" + + def __init__(self, client: BECClient, exec_id: str, poll_interval_s: float = 0.1): + super().__init__(client, exec_id) + self.poll_interval = poll_interval_s + + def run(self): + while not self.stop_event.wait(timeout=0.1): + self.evaluate() diff --git a/bec_server/bec_server/actors/manager.py b/bec_server/bec_server/actors/manager.py new file mode 100644 index 00000000..3eb9a879 --- /dev/null +++ b/bec_server/bec_server/actors/manager.py @@ -0,0 +1,69 @@ +"""A manager for BEC Actors, based on the same infrastructure as Procedures.""" + +import threading +from concurrent.futures import Future +from typing import Any +from uuid import uuid4 + +from bec_lib.endpoints import MessageEndpoints +from bec_lib.logger import bec_logger +from bec_lib.messages import ActorExecutionMessage, ActorStartRequestMessage +from bec_server.actors.worker import ActorProcedureWorker +from bec_server.procedures.manager import ProcedureManagerBase, _resolve_dict + +logger = bec_logger.logger + + +class ActorHelper: ... + + +class ActorManager(ProcedureManagerBase[ActorStartRequestMessage, ActorExecutionMessage]): + """A specialised procedure manager for running Actors.""" + + def _define_endpoints(self): + self._reply_ep = MessageEndpoints.actor_request_response() + self._request_ep = MessageEndpoints.actor_start_request() + self._abort_ep = MessageEndpoints.actor_stop_request() + + def __init__(self, redis: str, thread_prefix: str = "actor_"): + super().__init__(redis, ActorProcedureWorker, thread_prefix) + + def _register_endpoints(self): + self._conn.register(self._request_ep, None, self._process_queue_request) + + def _unregister_endpoints(self): + self._conn.unregister(self._request_ep, None, self._process_queue_request) + + def _publish_available(self): ... + + def _startup(self): ... + + def _validate_request(self, msg: dict[str, Any] | ActorStartRequestMessage): + return _resolve_dict(msg, ActorStartRequestMessage) + + def _respond_to_valid_request(self, message: ActorStartRequestMessage): + queue = f"{message.actor_module}.{message.actor_class_name}" + exec_id = str(uuid4()) + return ActorExecutionMessage( + execution_id=exec_id, + queue=queue, + env={ + "actor_module": message.actor_module, + "actor_class_name": message.actor_class_name, + "actor_exec_id": exec_id, + }, + ) + + def _cleanup_worker_function(self, queue: str): + def cleanup_worker(fut: Future): ... + + return cleanup_worker + + +if __name__ == "__main__": + e = threading.Event() + manager = ActorManager(redis="localhost:6379") + try: + e.wait() + except KeyboardInterrupt: + ... diff --git a/bec_server/bec_server/actors/worker.py b/bec_server/bec_server/actors/worker.py new file mode 100644 index 00000000..01ba7c99 --- /dev/null +++ b/bec_server/bec_server/actors/worker.py @@ -0,0 +1,68 @@ +import importlib +import os +from contextlib import redirect_stdout +from inspect import isclass + +from bec_lib.client import BECClient +from bec_lib.logger import LogLevel, bec_logger +from bec_lib.redis_connector import RedisConnector +from bec_server.actors.actor import ActorBase +from bec_server.procedures.oop_worker_base import RedisOutputDiverter, get_env, setup +from bec_server.procedures.subprocess_worker import SubProcessWorker + +logger = bec_logger.logger + + +def actor_procedure(actor_module: str, actor_class_name: str, exec_id: str, bec: BECClient): + try: + mod = importlib.import_module(actor_module) + actor_class = getattr(mod, actor_class_name) + except ImportError: + logger.error(f"Module '{actor_module}' not found! Exiting.") + return + except AttributeError: + logger.error( + f"Module '{actor_module}' does not contain {actor_class_name}! Available classes in module: {list(filter(isclass, mod.__dict__.values()))}." + ) + return + if not issubclass(actor_class, ActorBase): + logger.error(f"{actor_class_name} is not a valid Actor! Exiting.") + return + + actor = actor_class(bec, exec_id) + actor.run() + + +class ActorProcedureWorker(SubProcessWorker): + WORKER_FILE = __file__ + + def _worker_environment(self): + return super()._worker_environment() | {"client_class": "BECClient"} + + +def get_actor_env(): + return { + "actor_module": os.environ["actor_module"], + "actor_class_name": os.environ["actor_class_name"], + "exec_id": os.environ["actor_exec_id"], + } + + +if __name__ == "__main__": + """Replaces the main contents of Worker.work() - should be called as the container entrypoint or command""" + + env, helper, client, conn = setup(get_env()) + actor_env = get_actor_env() + logger_connector = RedisConnector(env["redis_server"]) + output_diverter = RedisOutputDiverter(logger_connector, env["queue"]) + with redirect_stdout(output_diverter): + logger.add( + output_diverter, + level=LogLevel.SUCCESS, + format=bec_logger.formatting(is_container=True), + filter=bec_logger.filter(), + ) + logger.success(f"Starting ActorProcedureWorker with env: {actor_env}") + actor_procedure(bec=client, **actor_env) + conn.shutdown() + logger_connector.shutdown() diff --git a/bec_server/bec_server/procedures/manager.py b/bec_server/bec_server/procedures/manager.py index 5639b3c6..37fbb927 100644 --- a/bec_server/bec_server/procedures/manager.py +++ b/bec_server/bec_server/procedures/manager.py @@ -8,6 +8,8 @@ from concurrent.futures import Future, ThreadPoolExecutor from threading import RLock from typing import TYPE_CHECKING, Any, Callable, Generic, Protocol, TypedDict, TypeVar +from pydantic import ValidationError + from bec_lib.endpoints import EndpointInfo, MessageEndpoints from bec_lib.logger import bec_logger from bec_lib.messages import ( @@ -23,8 +25,6 @@ from bec_lib.messages import ( ) from bec_lib.procedures.helper import BackendProcedureHelper from bec_lib.redis_connector import RedisConnector -from pydantic import ValidationError - from bec_server.procedures import procedure_registry from bec_server.procedures.constants import PROCEDURE, WorkerAlreadyExists diff --git a/bec_server/bec_server/test/__init__.py b/bec_server/bec_server/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/bec_server/bec_server/test/actor_test_utils.py b/bec_server/bec_server/test/actor_test_utils.py new file mode 100644 index 00000000..6ef27258 --- /dev/null +++ b/bec_server/bec_server/test/actor_test_utils.py @@ -0,0 +1,21 @@ +from bec_lib.client import BECClient +from bec_lib.endpoints import EndpointInfo, MessageOp +from bec_lib.messages import RawMessage +from bec_server.actors.actor import PollingActor, SubscriptionActor + + +def _test_condition(client: BECClient): + return True + + +def _test_action(client: BECClient): + client.connector.set_and_publish(ep, RawMessage(data={"test": "result"})) + + +ep = EndpointInfo( + endpoint="test_endpoint", message_type=RawMessage, message_op=MessageOp.SET_PUBLISH +) + + +class PollingTestActor(PollingActor): + action_table = {_test_condition: _test_action} diff --git a/bec_server/bec_server/test/helpers.py b/bec_server/bec_server/test/helpers.py new file mode 100644 index 00000000..c50005a3 --- /dev/null +++ b/bec_server/bec_server/test/helpers.py @@ -0,0 +1,13 @@ +import time +from typing import Callable + + +def wait_until(predicate: Callable[[], bool], timeout_s: float = 0.1): + # Yes I know this is actually more like retries than a timeout, + # it's just to make sure the threads have plenty of chances to switch in the test + elapsed, step = 0.0, timeout_s / 10 + while not predicate(): + time.sleep(step) + elapsed += step + if elapsed > timeout_s: + raise TimeoutError() diff --git a/bec_server/tests/tests_actors/test_actors.py b/bec_server/tests/tests_actors/test_actors.py new file mode 100644 index 00000000..24685245 --- /dev/null +++ b/bec_server/tests/tests_actors/test_actors.py @@ -0,0 +1,74 @@ +from threading import Thread +from unittest.mock import patch + +import pytest +from fakeredis import TcpFakeServer + +from bec_lib.endpoints import MessageEndpoints +from bec_lib.messages import ActorStartRequestMessage, RawMessage +from bec_lib.redis_connector import MessageObject, RedisConnector +from bec_server.actors.manager import ActorManager +from bec_server.test.actor_test_utils import ep +from bec_server.test.helpers import wait_until + + +@pytest.fixture +def fakeredis_config(): + redis_config = "localhost", 44556 + server = TcpFakeServer(redis_config, server_type="redis") + t = Thread(target=server.serve_forever, daemon=True) + try: + t.start() + yield redis_config + finally: + server.shutdown() + server.server_close() + t.join() + + +@pytest.fixture +def actor_manager_and_conn(fakeredis_config): + host, port = fakeredis_config + redis = f"{host}:{port}" + manager = ActorManager(redis) + conn = RedisConnector([redis]) + try: + yield manager, conn + finally: + manager.shutdown() + conn.shutdown() + + +def test_validate_and_spawn_called_on_request( + actor_manager_and_conn: tuple[ActorManager, RedisConnector], +): + manager, conn = actor_manager_and_conn + with ( + patch.object(manager, "_validate_request", side_effect=lambda x: x["request"]), + patch.object(manager, "spawn"), + ): + conn.xadd( + MessageEndpoints.actor_start_request(), + {"request": ActorStartRequestMessage(actor_module="test", actor_class_name="Test")}, + ) + wait_until(lambda: manager._validate_request.call_count == 1) + wait_until(lambda: manager.spawn.call_count == 1) + + +def test_polling_actor(actor_manager_and_conn: tuple[ActorManager, RedisConnector]): + manager, conn = actor_manager_and_conn + action_triggered = False + + def action_callback(msg: MessageObject): + nonlocal action_triggered + if msg.value.data == {"test": "result"}: + action_triggered = True + + conn.register(ep, cb=action_callback) + manager._process_queue_request( + msg=ActorStartRequestMessage( + actor_module="bec_server.test.actor_test_utils", actor_class_name="PollingTestActor" + ) + ) + wait_until(lambda: manager._active_workers != {}) + wait_until(lambda: action_triggered, timeout_s=3)