refactor: make active workers private

This commit is contained in:
2025-10-16 09:46:30 +02:00
committed by David Perl
parent f7353432d5
commit 4c17a5aa49
4 changed files with 24 additions and 25 deletions
@@ -62,7 +62,7 @@ def test_procedure_runner_spawns_worker(
client_logtool_and_manager: tuple[BECIPythonClient, "LogTestTool", ProcedureManager],
):
client, _, manager = client_logtool_and_manager
assert manager.active_workers == {}
assert manager._active_workers == {}
endpoint = MessageEndpoints.procedure_request()
msg = messages.ProcedureRequestMessage(
identifier="sleep", args_kwargs=((), {"time_s": 2}), queue="test"
@@ -77,8 +77,8 @@ def test_procedure_runner_spawns_worker(
manager.add_callback("test", cb)
client.connector.xadd(topic=endpoint, msg_dict=msg.model_dump())
_wait_while(lambda: manager.active_workers == {}, 5)
_wait_while(lambda: manager.active_workers != {}, 20)
_wait_while(lambda: manager._active_workers == {}, 5)
_wait_while(lambda: manager._active_workers != {}, 20)
assert logs != []
@@ -91,7 +91,7 @@ def test_happy_path_container_procedure_runner(
test_args = (1, 2, 3)
test_kwargs = {"a": "b", "c": "d"}
client, logtool, manager = client_logtool_and_manager
assert manager.active_workers == {}
assert manager._active_workers == {}
conn = client.connector
endpoint = MessageEndpoints.procedure_request()
msg = messages.ProcedureRequestMessage(
@@ -99,8 +99,8 @@ def test_happy_path_container_procedure_runner(
)
conn.xadd(topic=endpoint, msg_dict=msg.model_dump())
_wait_while(lambda: manager.active_workers == {}, 5)
_wait_while(lambda: manager.active_workers != {}, 20)
_wait_while(lambda: manager._active_workers == {}, 5)
_wait_while(lambda: manager._active_workers != {}, 20)
logtool.fetch()
assert logtool.is_present_in_any_message("procedure accepted: True, message:")
@@ -45,7 +45,7 @@ class ProcedureManager:
self._parent = parent
self.lock = RLock()
self.active_workers: dict[str, ProcedureWorkerEntry] = {}
self._active_workers: dict[str, ProcedureWorkerEntry] = {}
self.executor = ThreadPoolExecutor(
max_workers=PROCEDURE.WORKER.MAX_WORKERS, thread_name_prefix="user_procedure_"
)
@@ -86,7 +86,7 @@ class ProcedureManager:
self._callbacks[queue].append(cb)
def _run_callbacks(self, queue: str):
if (worker := self.active_workers[queue]["worker"]) is None:
if (worker := self._active_workers[queue]["worker"]) is None:
return
for cb in self._callbacks.get(queue, []):
cb(worker)
@@ -106,7 +106,7 @@ class ProcedureManager:
self._ack(True, f"Running procedure {message_obj.identifier}")
queue = message_obj.queue or PROCEDURE.WORKER.DEFAULT_QUEUE
endpoint = MessageEndpoints.procedure_execution(queue)
logger.debug(f"active workers: {self.active_workers}, worker requested: {queue}")
logger.debug(f"active workers: {self._active_workers}, worker requested: {queue}")
self._conn.rpush(
endpoint,
endpoint.message_type(
@@ -120,14 +120,14 @@ class ProcedureManager:
with self.lock:
logger.debug(f"cleaning up worker {fut} for queue {queue}...")
self._run_callbacks(queue)
del self.active_workers[queue]
del self._active_workers[queue]
with self.lock:
if queue not in self.active_workers:
if queue not in self._active_workers:
new_worker = self.executor.submit(self.spawn, queue=queue)
new_worker.add_done_callback(_log_on_end)
new_worker.add_done_callback(cleanup_worker)
self.active_workers[queue] = {"worker": None, "future": new_worker}
self._active_workers[queue] = {"worker": None, "future": new_worker}
def spawn(self, queue: str):
"""Spawn a procedure worker future which listens to a given queue, i.e. procedure queue list in Redis.
@@ -135,13 +135,13 @@ class ProcedureManager:
Args:
queue (str): name of the queue to spawn a worker for"""
if queue in self.active_workers and self.active_workers[queue]["worker"] is not None:
if queue in self._active_workers and self._active_workers[queue]["worker"] is not None:
raise WorkerAlreadyExists(
f"Queue {queue} already has an active worker in {self.active_workers}!"
f"Queue {queue} already has an active worker in {self._active_workers}!"
)
with self._worker_cls(self._server, queue, PROCEDURE.WORKER.QUEUE_TIMEOUT_S) as worker:
with self.lock:
self.active_workers[queue]["worker"] = worker
self._active_workers[queue]["worker"] = worker
worker.work()
def shutdown(self):
@@ -152,7 +152,7 @@ class ProcedureManager:
)
self._conn.shutdown()
# cancel futures by hand to give us the opportunity to detatch them from redis if they have started
for entry in self.active_workers.values():
for entry in self._active_workers.values():
cancelled = entry["future"].cancel()
if not cancelled:
# unblock any waiting workers and let them shutdown
@@ -160,7 +160,7 @@ class ProcedureManager:
# redis unblock executor.client_id
worker.abort()
futures.wait(
(entry["future"] for entry in self.active_workers.values()),
(entry["future"] for entry in self._active_workers.values()),
timeout=PROCEDURE.MANAGER_SHUTDOWN_TIMEOUT_S,
)
self.executor.shutdown()
@@ -1,7 +1,6 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from enum import Enum, auto
from typing import cast
from bec_lib.endpoints import MessageEndpoints
@@ -112,7 +112,7 @@ def test_process_request_happy_paths(process_request_manager, message: Procedure
assert queue in endpoint.endpoint
assert execution_msg.identifier == message.identifier
process_request_manager.spawn.assert_called()
assert queue in process_request_manager.active_workers.keys()
assert queue in process_request_manager._active_workers.keys()
def test_process_request_failure(process_request_manager):
@@ -120,7 +120,7 @@ def test_process_request_failure(process_request_manager):
process_request_manager._ack.assert_not_called()
process_request_manager._conn.rpush.assert_not_called()
process_request_manager.spawn.assert_not_called()
assert process_request_manager.active_workers == {}
assert process_request_manager._active_workers == {}
class UnlockableWorker(ProcedureWorker):
@@ -162,13 +162,13 @@ def test_spawn(redis_connector, procedure_manager: ProcedureManager):
procedure_manager._validate_request = MagicMock(side_effect=lambda msg: msg)
# trigger the running of the test message
procedure_manager.process_queue_request(message) # type: ignore
assert queue in procedure_manager.active_workers.keys()
assert queue in procedure_manager._active_workers.keys()
# spawn method should be added as a future
_wait_until(procedure_manager.active_workers[queue]["future"].running)
_wait_until(procedure_manager._active_workers[queue]["future"].running)
# and then create the worker
_wait_until(lambda: procedure_manager.active_workers[queue].get("worker") is not None)
worker = procedure_manager.active_workers[queue]["worker"]
_wait_until(lambda: procedure_manager._active_workers[queue].get("worker") is not None)
worker = procedure_manager._active_workers[queue]["worker"]
assert isinstance(worker, UnlockableWorker)
_wait_until(lambda: worker.status == ProcedureWorkerStatus.RUNNING)
@@ -185,7 +185,7 @@ def test_spawn(redis_connector, procedure_manager: ProcedureManager):
worker.event_2.set()
_wait_until(lambda: worker.status == ProcedureWorkerStatus.FINISHED)
# spawn deletes the worker queue
_wait_until(lambda: len(procedure_manager.active_workers) == 0)
_wait_until(lambda: len(procedure_manager._active_workers) == 0)
@patch("bec_server.scan_server.procedures.worker_base.RedisConnector", MagicMock())