feat: add beamline actor manager and worker

This commit is contained in:
2026-04-24 12:52:48 +02:00
committed by David Perl
parent fb0d40a31e
commit cbd35bd9d3
11 changed files with 381 additions and 2 deletions
+13
View File
@@ -0,0 +1,13 @@
// Project-local debug tasks
//
// For more documentation on how to configure debug tasks,
// see: https://zed.dev/docs/debugger
[
{
"adapter": "Debugpy",
"label": "Debug current python file",
"justMyCode": false,
"program": "${ZED_FILE}",
"request": "launch"
},
]
+52
View File
@@ -0,0 +1,52 @@
"""Definitions and protocols for the classes in bec_server.actors and in plugins."""
from dataclasses import dataclass
from enum import StrEnum
from typing import Any, Protocol, runtime_checkable
from typing_extensions import TypeAliasType # only for 3.11
from bec_lib.client import BECClient
class ConditionCombination(StrEnum):
Any = "Any"
All = "All"
class ActorCondition(Protocol):
def __call__(self, client: BECClient) -> Any:
"""A callable which returns True if the condition is met and False if it is not."""
@dataclass
class ActorConditionSet:
conditions: set[ActorCondition]
combination_mode: ConditionCombination
def __call__(self, client: BECClient) -> bool:
if self.combination_mode == ConditionCombination.Any:
return any(condition(client) for condition in self.conditions)
return all(condition(client) for condition in self.conditions)
class ActorAction(Protocol):
def __call__(self, client: BECClient) -> None:
"""A callable which an actor calls if it has met the associated condition"""
ActorActionTable = TypeAliasType(
"ActorActionTable", dict[ActorConditionSet | ActorCondition, ActorAction]
)
@runtime_checkable
class Actor(Protocol):
client: BECClient
action_table: ActorActionTable
def __init__(self, client: BECClient, exec_id: str):
"""Create the actor instance"""
def run(self):
"""The core logic loop for the actor"""
+69
View File
@@ -0,0 +1,69 @@
"""Actors can autonomously respond to changes in beamline states."""
import time
from threading import Event
from typing import Iterable
from bec_lib.actors import Actor as ActorProtocol
from bec_lib.client import BECClient
from bec_lib.endpoints import EndpointInfo, MessageEndpoints
from bec_lib.logger import bec_logger
logger = bec_logger.logger
class ActorBase(ActorProtocol):
def __init__(self, client: BECClient, exec_id: str):
self.client = client
self.stop_event = Event()
self.client.connector.register(MessageEndpoints.actor_stop(exec_id), cb=self.stop)
def evaluate(self, *_, **__):
for condition, action in self.action_table.items():
if condition(self.client):
logger.info(
f"{self.__class__.__name__} triggered, executing action for condition: {condition}"
)
action(self.client)
def stop(self, *_):
self.stop_event.set()
class SubscriptionActor(ActorBase):
"""An actor which subscribes to a list of redis endpoints, and evaluates on any message to any
of those endpoints."""
def __init__(
self,
client: BECClient,
exec_id: str,
endpoints: Iterable[EndpointInfo],
min_delay_s: float = 0.01,
):
super().__init__(client, exec_id)
self.min_delay = min_delay_s
self.last_evaluated = time.monotonic()
for endpoint in set(endpoints):
client.connector.register(endpoint, cb=self.evaluate)
def evaluate(self, *_, **__):
if (now := time.monotonic()) < self.last_evaluated + self.min_delay:
return
self.last_evaluated = now
return super().evaluate(*_, **__)
def run(self):
self.stop_event.wait()
class PollingActor(ActorBase):
"""An actor which evaluates its conditions after a certain time interval."""
def __init__(self, client: BECClient, exec_id: str, poll_interval_s: float = 0.1):
super().__init__(client, exec_id)
self.poll_interval = poll_interval_s
def run(self):
while not self.stop_event.wait(timeout=0.1):
self.evaluate()
+69
View File
@@ -0,0 +1,69 @@
"""A manager for BEC Actors, based on the same infrastructure as Procedures."""
import threading
from concurrent.futures import Future
from typing import Any
from uuid import uuid4
from bec_lib.endpoints import MessageEndpoints
from bec_lib.logger import bec_logger
from bec_lib.messages import ActorExecutionMessage, ActorStartRequestMessage
from bec_server.actors.worker import ActorProcedureWorker
from bec_server.procedures.manager import ProcedureManagerBase, _resolve_dict
logger = bec_logger.logger
class ActorHelper: ...
class ActorManager(ProcedureManagerBase[ActorStartRequestMessage, ActorExecutionMessage]):
"""A specialised procedure manager for running Actors."""
def _define_endpoints(self):
self._reply_ep = MessageEndpoints.actor_request_response()
self._request_ep = MessageEndpoints.actor_start_request()
self._abort_ep = MessageEndpoints.actor_stop_request()
def __init__(self, redis: str, thread_prefix: str = "actor_"):
super().__init__(redis, ActorProcedureWorker, thread_prefix)
def _register_endpoints(self):
self._conn.register(self._request_ep, None, self._process_queue_request)
def _unregister_endpoints(self):
self._conn.unregister(self._request_ep, None, self._process_queue_request)
def _publish_available(self): ...
def _startup(self): ...
def _validate_request(self, msg: dict[str, Any] | ActorStartRequestMessage):
return _resolve_dict(msg, ActorStartRequestMessage)
def _respond_to_valid_request(self, message: ActorStartRequestMessage):
queue = f"{message.actor_module}.{message.actor_class_name}"
exec_id = str(uuid4())
return ActorExecutionMessage(
execution_id=exec_id,
queue=queue,
env={
"actor_module": message.actor_module,
"actor_class_name": message.actor_class_name,
"actor_exec_id": exec_id,
},
)
def _cleanup_worker_function(self, queue: str):
def cleanup_worker(fut: Future): ...
return cleanup_worker
if __name__ == "__main__":
e = threading.Event()
manager = ActorManager(redis="localhost:6379")
try:
e.wait()
except KeyboardInterrupt:
...
+68
View File
@@ -0,0 +1,68 @@
import importlib
import os
from contextlib import redirect_stdout
from inspect import isclass
from bec_lib.client import BECClient
from bec_lib.logger import LogLevel, bec_logger
from bec_lib.redis_connector import RedisConnector
from bec_server.actors.actor import ActorBase
from bec_server.procedures.oop_worker_base import RedisOutputDiverter, get_env, setup
from bec_server.procedures.subprocess_worker import SubProcessWorker
logger = bec_logger.logger
def actor_procedure(actor_module: str, actor_class_name: str, exec_id: str, bec: BECClient):
try:
mod = importlib.import_module(actor_module)
actor_class = getattr(mod, actor_class_name)
except ImportError:
logger.error(f"Module '{actor_module}' not found! Exiting.")
return
except AttributeError:
logger.error(
f"Module '{actor_module}' does not contain {actor_class_name}! Available classes in module: {list(filter(isclass, mod.__dict__.values()))}."
)
return
if not issubclass(actor_class, ActorBase):
logger.error(f"{actor_class_name} is not a valid Actor! Exiting.")
return
actor = actor_class(bec, exec_id)
actor.run()
class ActorProcedureWorker(SubProcessWorker):
WORKER_FILE = __file__
def _worker_environment(self):
return super()._worker_environment() | {"client_class": "BECClient"}
def get_actor_env():
return {
"actor_module": os.environ["actor_module"],
"actor_class_name": os.environ["actor_class_name"],
"exec_id": os.environ["actor_exec_id"],
}
if __name__ == "__main__":
"""Replaces the main contents of Worker.work() - should be called as the container entrypoint or command"""
env, helper, client, conn = setup(get_env())
actor_env = get_actor_env()
logger_connector = RedisConnector(env["redis_server"])
output_diverter = RedisOutputDiverter(logger_connector, env["queue"])
with redirect_stdout(output_diverter):
logger.add(
output_diverter,
level=LogLevel.SUCCESS,
format=bec_logger.formatting(is_container=True),
filter=bec_logger.filter(),
)
logger.success(f"Starting ActorProcedureWorker with env: {actor_env}")
actor_procedure(bec=client, **actor_env)
conn.shutdown()
logger_connector.shutdown()
+2 -2
View File
@@ -8,6 +8,8 @@ from concurrent.futures import Future, ThreadPoolExecutor
from threading import RLock
from typing import TYPE_CHECKING, Any, Callable, Generic, Protocol, TypedDict, TypeVar
from pydantic import ValidationError
from bec_lib.endpoints import EndpointInfo, MessageEndpoints
from bec_lib.logger import bec_logger
from bec_lib.messages import (
@@ -23,8 +25,6 @@ from bec_lib.messages import (
)
from bec_lib.procedures.helper import BackendProcedureHelper
from bec_lib.redis_connector import RedisConnector
from pydantic import ValidationError
from bec_server.procedures import procedure_registry
from bec_server.procedures.constants import PROCEDURE, WorkerAlreadyExists
@@ -0,0 +1,21 @@
from bec_lib.client import BECClient
from bec_lib.endpoints import EndpointInfo, MessageOp
from bec_lib.messages import RawMessage
from bec_server.actors.actor import PollingActor, SubscriptionActor
def _test_condition(client: BECClient):
return True
def _test_action(client: BECClient):
client.connector.set_and_publish(ep, RawMessage(data={"test": "result"}))
ep = EndpointInfo(
endpoint="test_endpoint", message_type=RawMessage, message_op=MessageOp.SET_PUBLISH
)
class PollingTestActor(PollingActor):
action_table = {_test_condition: _test_action}
+13
View File
@@ -0,0 +1,13 @@
import time
from typing import Callable
def wait_until(predicate: Callable[[], bool], timeout_s: float = 0.1):
# Yes I know this is actually more like retries than a timeout,
# it's just to make sure the threads have plenty of chances to switch in the test
elapsed, step = 0.0, timeout_s / 10
while not predicate():
time.sleep(step)
elapsed += step
if elapsed > timeout_s:
raise TimeoutError()
@@ -0,0 +1,74 @@
from threading import Thread
from unittest.mock import patch
import pytest
from fakeredis import TcpFakeServer
from bec_lib.endpoints import MessageEndpoints
from bec_lib.messages import ActorStartRequestMessage, RawMessage
from bec_lib.redis_connector import MessageObject, RedisConnector
from bec_server.actors.manager import ActorManager
from bec_server.test.actor_test_utils import ep
from bec_server.test.helpers import wait_until
@pytest.fixture
def fakeredis_config():
redis_config = "localhost", 44556
server = TcpFakeServer(redis_config, server_type="redis")
t = Thread(target=server.serve_forever, daemon=True)
try:
t.start()
yield redis_config
finally:
server.shutdown()
server.server_close()
t.join()
@pytest.fixture
def actor_manager_and_conn(fakeredis_config):
host, port = fakeredis_config
redis = f"{host}:{port}"
manager = ActorManager(redis)
conn = RedisConnector([redis])
try:
yield manager, conn
finally:
manager.shutdown()
conn.shutdown()
def test_validate_and_spawn_called_on_request(
actor_manager_and_conn: tuple[ActorManager, RedisConnector],
):
manager, conn = actor_manager_and_conn
with (
patch.object(manager, "_validate_request", side_effect=lambda x: x["request"]),
patch.object(manager, "spawn"),
):
conn.xadd(
MessageEndpoints.actor_start_request(),
{"request": ActorStartRequestMessage(actor_module="test", actor_class_name="Test")},
)
wait_until(lambda: manager._validate_request.call_count == 1)
wait_until(lambda: manager.spawn.call_count == 1)
def test_polling_actor(actor_manager_and_conn: tuple[ActorManager, RedisConnector]):
manager, conn = actor_manager_and_conn
action_triggered = False
def action_callback(msg: MessageObject):
nonlocal action_triggered
if msg.value.data == {"test": "result"}:
action_triggered = True
conn.register(ep, cb=action_callback)
manager._process_queue_request(
msg=ActorStartRequestMessage(
actor_module="bec_server.test.actor_test_utils", actor_class_name="PollingTestActor"
)
)
wait_until(lambda: manager._active_workers != {})
wait_until(lambda: action_triggered, timeout_s=3)