mirror of
https://github.com/bec-project/bec.git
synced 2026-06-01 07:48:30 +02:00
feat: add beamline actor manager and worker
This commit is contained in:
@@ -0,0 +1,13 @@
|
||||
// Project-local debug tasks
|
||||
//
|
||||
// For more documentation on how to configure debug tasks,
|
||||
// see: https://zed.dev/docs/debugger
|
||||
[
|
||||
{
|
||||
"adapter": "Debugpy",
|
||||
"label": "Debug current python file",
|
||||
"justMyCode": false,
|
||||
"program": "${ZED_FILE}",
|
||||
"request": "launch"
|
||||
},
|
||||
]
|
||||
@@ -0,0 +1,52 @@
|
||||
"""Definitions and protocols for the classes in bec_server.actors and in plugins."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
from typing_extensions import TypeAliasType # only for 3.11
|
||||
|
||||
from bec_lib.client import BECClient
|
||||
|
||||
|
||||
class ConditionCombination(StrEnum):
|
||||
Any = "Any"
|
||||
All = "All"
|
||||
|
||||
|
||||
class ActorCondition(Protocol):
|
||||
def __call__(self, client: BECClient) -> Any:
|
||||
"""A callable which returns True if the condition is met and False if it is not."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActorConditionSet:
|
||||
conditions: set[ActorCondition]
|
||||
combination_mode: ConditionCombination
|
||||
|
||||
def __call__(self, client: BECClient) -> bool:
|
||||
if self.combination_mode == ConditionCombination.Any:
|
||||
return any(condition(client) for condition in self.conditions)
|
||||
return all(condition(client) for condition in self.conditions)
|
||||
|
||||
|
||||
class ActorAction(Protocol):
|
||||
def __call__(self, client: BECClient) -> None:
|
||||
"""A callable which an actor calls if it has met the associated condition"""
|
||||
|
||||
|
||||
ActorActionTable = TypeAliasType(
|
||||
"ActorActionTable", dict[ActorConditionSet | ActorCondition, ActorAction]
|
||||
)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Actor(Protocol):
|
||||
client: BECClient
|
||||
action_table: ActorActionTable
|
||||
|
||||
def __init__(self, client: BECClient, exec_id: str):
|
||||
"""Create the actor instance"""
|
||||
|
||||
def run(self):
|
||||
"""The core logic loop for the actor"""
|
||||
@@ -0,0 +1,69 @@
|
||||
"""Actors can autonomously respond to changes in beamline states."""
|
||||
|
||||
import time
|
||||
from threading import Event
|
||||
from typing import Iterable
|
||||
|
||||
from bec_lib.actors import Actor as ActorProtocol
|
||||
from bec_lib.client import BECClient
|
||||
from bec_lib.endpoints import EndpointInfo, MessageEndpoints
|
||||
from bec_lib.logger import bec_logger
|
||||
|
||||
logger = bec_logger.logger
|
||||
|
||||
|
||||
class ActorBase(ActorProtocol):
|
||||
def __init__(self, client: BECClient, exec_id: str):
|
||||
self.client = client
|
||||
self.stop_event = Event()
|
||||
self.client.connector.register(MessageEndpoints.actor_stop(exec_id), cb=self.stop)
|
||||
|
||||
def evaluate(self, *_, **__):
|
||||
for condition, action in self.action_table.items():
|
||||
if condition(self.client):
|
||||
logger.info(
|
||||
f"{self.__class__.__name__} triggered, executing action for condition: {condition}"
|
||||
)
|
||||
action(self.client)
|
||||
|
||||
def stop(self, *_):
|
||||
self.stop_event.set()
|
||||
|
||||
|
||||
class SubscriptionActor(ActorBase):
|
||||
"""An actor which subscribes to a list of redis endpoints, and evaluates on any message to any
|
||||
of those endpoints."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: BECClient,
|
||||
exec_id: str,
|
||||
endpoints: Iterable[EndpointInfo],
|
||||
min_delay_s: float = 0.01,
|
||||
):
|
||||
super().__init__(client, exec_id)
|
||||
self.min_delay = min_delay_s
|
||||
self.last_evaluated = time.monotonic()
|
||||
for endpoint in set(endpoints):
|
||||
client.connector.register(endpoint, cb=self.evaluate)
|
||||
|
||||
def evaluate(self, *_, **__):
|
||||
if (now := time.monotonic()) < self.last_evaluated + self.min_delay:
|
||||
return
|
||||
self.last_evaluated = now
|
||||
return super().evaluate(*_, **__)
|
||||
|
||||
def run(self):
|
||||
self.stop_event.wait()
|
||||
|
||||
|
||||
class PollingActor(ActorBase):
|
||||
"""An actor which evaluates its conditions after a certain time interval."""
|
||||
|
||||
def __init__(self, client: BECClient, exec_id: str, poll_interval_s: float = 0.1):
|
||||
super().__init__(client, exec_id)
|
||||
self.poll_interval = poll_interval_s
|
||||
|
||||
def run(self):
|
||||
while not self.stop_event.wait(timeout=0.1):
|
||||
self.evaluate()
|
||||
@@ -0,0 +1,69 @@
|
||||
"""A manager for BEC Actors, based on the same infrastructure as Procedures."""
|
||||
|
||||
import threading
|
||||
from concurrent.futures import Future
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from bec_lib.endpoints import MessageEndpoints
|
||||
from bec_lib.logger import bec_logger
|
||||
from bec_lib.messages import ActorExecutionMessage, ActorStartRequestMessage
|
||||
from bec_server.actors.worker import ActorProcedureWorker
|
||||
from bec_server.procedures.manager import ProcedureManagerBase, _resolve_dict
|
||||
|
||||
logger = bec_logger.logger
|
||||
|
||||
|
||||
class ActorHelper: ...
|
||||
|
||||
|
||||
class ActorManager(ProcedureManagerBase[ActorStartRequestMessage, ActorExecutionMessage]):
|
||||
"""A specialised procedure manager for running Actors."""
|
||||
|
||||
def _define_endpoints(self):
|
||||
self._reply_ep = MessageEndpoints.actor_request_response()
|
||||
self._request_ep = MessageEndpoints.actor_start_request()
|
||||
self._abort_ep = MessageEndpoints.actor_stop_request()
|
||||
|
||||
def __init__(self, redis: str, thread_prefix: str = "actor_"):
|
||||
super().__init__(redis, ActorProcedureWorker, thread_prefix)
|
||||
|
||||
def _register_endpoints(self):
|
||||
self._conn.register(self._request_ep, None, self._process_queue_request)
|
||||
|
||||
def _unregister_endpoints(self):
|
||||
self._conn.unregister(self._request_ep, None, self._process_queue_request)
|
||||
|
||||
def _publish_available(self): ...
|
||||
|
||||
def _startup(self): ...
|
||||
|
||||
def _validate_request(self, msg: dict[str, Any] | ActorStartRequestMessage):
|
||||
return _resolve_dict(msg, ActorStartRequestMessage)
|
||||
|
||||
def _respond_to_valid_request(self, message: ActorStartRequestMessage):
|
||||
queue = f"{message.actor_module}.{message.actor_class_name}"
|
||||
exec_id = str(uuid4())
|
||||
return ActorExecutionMessage(
|
||||
execution_id=exec_id,
|
||||
queue=queue,
|
||||
env={
|
||||
"actor_module": message.actor_module,
|
||||
"actor_class_name": message.actor_class_name,
|
||||
"actor_exec_id": exec_id,
|
||||
},
|
||||
)
|
||||
|
||||
def _cleanup_worker_function(self, queue: str):
|
||||
def cleanup_worker(fut: Future): ...
|
||||
|
||||
return cleanup_worker
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
e = threading.Event()
|
||||
manager = ActorManager(redis="localhost:6379")
|
||||
try:
|
||||
e.wait()
|
||||
except KeyboardInterrupt:
|
||||
...
|
||||
@@ -0,0 +1,68 @@
|
||||
import importlib
|
||||
import os
|
||||
from contextlib import redirect_stdout
|
||||
from inspect import isclass
|
||||
|
||||
from bec_lib.client import BECClient
|
||||
from bec_lib.logger import LogLevel, bec_logger
|
||||
from bec_lib.redis_connector import RedisConnector
|
||||
from bec_server.actors.actor import ActorBase
|
||||
from bec_server.procedures.oop_worker_base import RedisOutputDiverter, get_env, setup
|
||||
from bec_server.procedures.subprocess_worker import SubProcessWorker
|
||||
|
||||
logger = bec_logger.logger
|
||||
|
||||
|
||||
def actor_procedure(actor_module: str, actor_class_name: str, exec_id: str, bec: BECClient):
|
||||
try:
|
||||
mod = importlib.import_module(actor_module)
|
||||
actor_class = getattr(mod, actor_class_name)
|
||||
except ImportError:
|
||||
logger.error(f"Module '{actor_module}' not found! Exiting.")
|
||||
return
|
||||
except AttributeError:
|
||||
logger.error(
|
||||
f"Module '{actor_module}' does not contain {actor_class_name}! Available classes in module: {list(filter(isclass, mod.__dict__.values()))}."
|
||||
)
|
||||
return
|
||||
if not issubclass(actor_class, ActorBase):
|
||||
logger.error(f"{actor_class_name} is not a valid Actor! Exiting.")
|
||||
return
|
||||
|
||||
actor = actor_class(bec, exec_id)
|
||||
actor.run()
|
||||
|
||||
|
||||
class ActorProcedureWorker(SubProcessWorker):
|
||||
WORKER_FILE = __file__
|
||||
|
||||
def _worker_environment(self):
|
||||
return super()._worker_environment() | {"client_class": "BECClient"}
|
||||
|
||||
|
||||
def get_actor_env():
|
||||
return {
|
||||
"actor_module": os.environ["actor_module"],
|
||||
"actor_class_name": os.environ["actor_class_name"],
|
||||
"exec_id": os.environ["actor_exec_id"],
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""Replaces the main contents of Worker.work() - should be called as the container entrypoint or command"""
|
||||
|
||||
env, helper, client, conn = setup(get_env())
|
||||
actor_env = get_actor_env()
|
||||
logger_connector = RedisConnector(env["redis_server"])
|
||||
output_diverter = RedisOutputDiverter(logger_connector, env["queue"])
|
||||
with redirect_stdout(output_diverter):
|
||||
logger.add(
|
||||
output_diverter,
|
||||
level=LogLevel.SUCCESS,
|
||||
format=bec_logger.formatting(is_container=True),
|
||||
filter=bec_logger.filter(),
|
||||
)
|
||||
logger.success(f"Starting ActorProcedureWorker with env: {actor_env}")
|
||||
actor_procedure(bec=client, **actor_env)
|
||||
conn.shutdown()
|
||||
logger_connector.shutdown()
|
||||
@@ -8,6 +8,8 @@ from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from threading import RLock
|
||||
from typing import TYPE_CHECKING, Any, Callable, Generic, Protocol, TypedDict, TypeVar
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from bec_lib.endpoints import EndpointInfo, MessageEndpoints
|
||||
from bec_lib.logger import bec_logger
|
||||
from bec_lib.messages import (
|
||||
@@ -23,8 +25,6 @@ from bec_lib.messages import (
|
||||
)
|
||||
from bec_lib.procedures.helper import BackendProcedureHelper
|
||||
from bec_lib.redis_connector import RedisConnector
|
||||
from pydantic import ValidationError
|
||||
|
||||
from bec_server.procedures import procedure_registry
|
||||
from bec_server.procedures.constants import PROCEDURE, WorkerAlreadyExists
|
||||
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
from bec_lib.client import BECClient
|
||||
from bec_lib.endpoints import EndpointInfo, MessageOp
|
||||
from bec_lib.messages import RawMessage
|
||||
from bec_server.actors.actor import PollingActor, SubscriptionActor
|
||||
|
||||
|
||||
def _test_condition(client: BECClient):
|
||||
return True
|
||||
|
||||
|
||||
def _test_action(client: BECClient):
|
||||
client.connector.set_and_publish(ep, RawMessage(data={"test": "result"}))
|
||||
|
||||
|
||||
ep = EndpointInfo(
|
||||
endpoint="test_endpoint", message_type=RawMessage, message_op=MessageOp.SET_PUBLISH
|
||||
)
|
||||
|
||||
|
||||
class PollingTestActor(PollingActor):
|
||||
action_table = {_test_condition: _test_action}
|
||||
@@ -0,0 +1,13 @@
|
||||
import time
|
||||
from typing import Callable
|
||||
|
||||
|
||||
def wait_until(predicate: Callable[[], bool], timeout_s: float = 0.1):
|
||||
# Yes I know this is actually more like retries than a timeout,
|
||||
# it's just to make sure the threads have plenty of chances to switch in the test
|
||||
elapsed, step = 0.0, timeout_s / 10
|
||||
while not predicate():
|
||||
time.sleep(step)
|
||||
elapsed += step
|
||||
if elapsed > timeout_s:
|
||||
raise TimeoutError()
|
||||
@@ -0,0 +1,74 @@
|
||||
from threading import Thread
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from fakeredis import TcpFakeServer
|
||||
|
||||
from bec_lib.endpoints import MessageEndpoints
|
||||
from bec_lib.messages import ActorStartRequestMessage, RawMessage
|
||||
from bec_lib.redis_connector import MessageObject, RedisConnector
|
||||
from bec_server.actors.manager import ActorManager
|
||||
from bec_server.test.actor_test_utils import ep
|
||||
from bec_server.test.helpers import wait_until
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fakeredis_config():
|
||||
redis_config = "localhost", 44556
|
||||
server = TcpFakeServer(redis_config, server_type="redis")
|
||||
t = Thread(target=server.serve_forever, daemon=True)
|
||||
try:
|
||||
t.start()
|
||||
yield redis_config
|
||||
finally:
|
||||
server.shutdown()
|
||||
server.server_close()
|
||||
t.join()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def actor_manager_and_conn(fakeredis_config):
|
||||
host, port = fakeredis_config
|
||||
redis = f"{host}:{port}"
|
||||
manager = ActorManager(redis)
|
||||
conn = RedisConnector([redis])
|
||||
try:
|
||||
yield manager, conn
|
||||
finally:
|
||||
manager.shutdown()
|
||||
conn.shutdown()
|
||||
|
||||
|
||||
def test_validate_and_spawn_called_on_request(
|
||||
actor_manager_and_conn: tuple[ActorManager, RedisConnector],
|
||||
):
|
||||
manager, conn = actor_manager_and_conn
|
||||
with (
|
||||
patch.object(manager, "_validate_request", side_effect=lambda x: x["request"]),
|
||||
patch.object(manager, "spawn"),
|
||||
):
|
||||
conn.xadd(
|
||||
MessageEndpoints.actor_start_request(),
|
||||
{"request": ActorStartRequestMessage(actor_module="test", actor_class_name="Test")},
|
||||
)
|
||||
wait_until(lambda: manager._validate_request.call_count == 1)
|
||||
wait_until(lambda: manager.spawn.call_count == 1)
|
||||
|
||||
|
||||
def test_polling_actor(actor_manager_and_conn: tuple[ActorManager, RedisConnector]):
|
||||
manager, conn = actor_manager_and_conn
|
||||
action_triggered = False
|
||||
|
||||
def action_callback(msg: MessageObject):
|
||||
nonlocal action_triggered
|
||||
if msg.value.data == {"test": "result"}:
|
||||
action_triggered = True
|
||||
|
||||
conn.register(ep, cb=action_callback)
|
||||
manager._process_queue_request(
|
||||
msg=ActorStartRequestMessage(
|
||||
actor_module="bec_server.test.actor_test_utils", actor_class_name="PollingTestActor"
|
||||
)
|
||||
)
|
||||
wait_until(lambda: manager._active_workers != {})
|
||||
wait_until(lambda: action_triggered, timeout_s=3)
|
||||
Reference in New Issue
Block a user