bec/data_processing/tests/test_worker_manager.py

240 lines
9.5 KiB
Python

from unittest import mock
from bec_lib.core import BECMessage, MessageEndpoints
from bec_lib.core.redis_connector import MessageObject
from data_processing.worker_manager import DAPWorkerManager
def test_worker_manager_retrieves_config_on_startup():
connector = mock.MagicMock()
with mock.patch.object(DAPWorkerManager, "update_config") as mock_update_config:
config = {
"stream": "scan_segment",
"output": "gaussian_fit_worker_3",
"input_xy": ["samx.samx.value", "gauss_bpm.gauss_bpm.value"],
"model": "GaussianModel",
}
worker_config = {"id": "gaussian_fit_worker_3", "config": config}
connector.producer().get.return_value = BECMessage.DAPConfigMessage(
config={"workers": [worker_config]}
).dumps()
worker_manager = DAPWorkerManager(connector)
mock_update_config.assert_called_once()
def test_worker_manager_retrieves_config_on_startup_empty():
connector = mock.MagicMock()
with mock.patch.object(DAPWorkerManager, "update_config") as mock_update_config:
connector.producer().get.return_value = None
worker_manager = DAPWorkerManager(connector)
mock_update_config.assert_not_called()
def test_worker_manager_update_config():
connector = mock.MagicMock()
with mock.patch.object(DAPWorkerManager, "_start_worker") as mock_start_worker:
connector.producer().get.return_value = None
worker_manager = DAPWorkerManager(connector)
config = {
"stream": "scan_segment",
"output": "gaussian_fit_worker_3",
"input_xy": ["samx.samx.value", "gauss_bpm.gauss_bpm.value"],
"model": "GaussianModel",
"worker_cls": "LmfitProcessor",
}
worker_config = {"id": "gaussian_fit_worker_3", "config": config}
worker_manager.update_config(
BECMessage.DAPConfigMessage(config={"workers": [worker_config]})
)
mock_start_worker.assert_called_once()
def test_worker_manager_update_config_no_workers():
connector = mock.MagicMock()
with mock.patch.object(DAPWorkerManager, "_start_worker") as mock_start_worker:
connector.producer().get.return_value = None
worker_manager = DAPWorkerManager(connector)
worker_manager.update_config(BECMessage.DAPConfigMessage(config={"workers": []}))
mock_start_worker.assert_not_called()
def test_worker_manager_update_config_worker_already_running():
connector = mock.MagicMock()
with mock.patch.object(DAPWorkerManager, "_start_worker") as mock_start_worker:
connector.producer().get.return_value = None
worker_manager = DAPWorkerManager(connector)
config = {
"stream": "scan_segment",
"output": "gaussian_fit_worker_3",
"input_xy": ["samx.samx.value", "gauss_bpm.gauss_bpm.value"],
"model": "GaussianModel",
"worker_cls": "LmfitProcessor",
}
worker_config = {"id": "gaussian_fit_worker_3", "config": config}
worker_manager._workers = {"gaussian_fit_worker_3": {"config": config, "worker": None}}
worker_manager.update_config(
BECMessage.DAPConfigMessage(config={"workers": [worker_config]})
)
mock_start_worker.assert_not_called()
def test_worker_manager_update_config_worker_already_running_different_config():
connector = mock.MagicMock()
with mock.patch.object(DAPWorkerManager, "_start_worker") as mock_start_worker:
connector.producer().get.return_value = None
worker_manager = DAPWorkerManager(connector)
config = {
"stream": "scan_segment",
"output": "gaussian_fit_worker_3",
"input_xy": ["samx.samx.value", "gauss_bpm.gauss_bpm.value"],
"model": "GaussianModel",
"worker_cls": "LmfitProcessor",
}
w3_mock = mock.MagicMock()
worker_config = {"id": "gaussian_fit_worker_3", "config": config}
worker_manager._workers = {"gaussian_fit_worker_3": {"config": {}, "worker": w3_mock}}
worker_manager.update_config(
BECMessage.DAPConfigMessage(config={"workers": [worker_config]})
)
mock_start_worker.assert_called_once()
w3_mock.terminate.assert_called_once()
def test_worker_manager_update_config_remove_outdated_workers():
connector = mock.MagicMock()
with mock.patch.object(DAPWorkerManager, "_start_worker") as mock_start_worker:
connector.producer().get.return_value = None
worker_manager = DAPWorkerManager(connector)
config = {
"stream": "scan_segment",
"output": "gaussian_fit_worker_3",
"input_xy": ["samx.samx.value", "gauss_bpm.gauss_bpm.value"],
"model": "GaussianModel",
}
w3_mock = mock.MagicMock()
worker_config = {"id": "gaussian_fit_worker_3", "config": config}
worker_manager._workers = {"gaussian_fit_worker_3": {"config": {}, "worker": w3_mock}}
worker_manager.update_config(BECMessage.DAPConfigMessage(config={"workers": []}))
mock_start_worker.assert_not_called()
w3_mock.terminate.assert_called_once()
assert worker_manager._workers == {}
class DAPMockClass:
def run(self):
pass
class DAPMockWrongClass:
def no_run(self):
pass
def test_worker_manager_update_available_plugins():
connector = mock.MagicMock()
with mock.patch("data_processing.worker_manager.dap_plugins"):
with mock.patch("data_processing.worker_manager.inspect.getmembers") as mock_getmembers:
mock_getmembers.return_value = [
("CustomPlugin", DAPMockClass),
("WrongPlugin", DAPMockWrongClass),
]
worker_manager = DAPWorkerManager(connector)
assert "CustomPlugin" in worker_manager._worker_plugins
assert "WrongPlugin" not in worker_manager._worker_plugins
assert "LmfitProcessor" in worker_manager._worker_plugins
def test_worker_manager_start_worker():
connector = mock.MagicMock()
dap_plugin_cls = mock.MagicMock()
with mock.patch.object(DAPWorkerManager, "_update_config"):
with mock.patch.object(DAPWorkerManager, "run_worker") as mock_run_worker:
worker_manager = DAPWorkerManager(connector)
config = {
"id": "gaussian_fit_worker_3",
"config": {
"stream": "scan_segment",
"output": "gaussian_fit_worker_3",
"input_xy": ["samx.samx.value", "gauss_bpm.gauss_bpm.value"],
"model": "GaussianModel",
"worker_cls": "LmfitProcessor",
},
}
worker_manager._start_worker(
config,
dap_plugin_cls,
)
assert "gaussian_fit_worker_3" in worker_manager._workers
assert worker_manager._workers["gaussian_fit_worker_3"]["config"] == config["config"]
mock_run_worker.assert_called_once_with(
config["config"],
worker_cls=dap_plugin_cls,
connector_host=connector.bootstrap,
)
def test_worker_manager_run_worker():
with mock.patch("data_processing.worker_manager.mp") as mock_mp:
worker_cls = mock.MagicMock()
ret = DAPWorkerManager.run_worker(
config={"stream": "scan_segment", "output": "gaussian_fit_worker_3"},
worker_cls=worker_cls,
connector_host=["localhost:6379"],
)
mock_mp.Process.assert_called_once_with(
target=worker_cls.run,
kwargs={
"config": {"stream": "scan_segment", "output": "gaussian_fit_worker_3"},
"connector_host": ["localhost:6379"],
},
daemon=True,
)
assert ret == mock_mp.Process()
def test_worker_manager_set_config():
connector = mock.MagicMock()
worker_manager = DAPWorkerManager(connector)
msg = BECMessage.DAPConfigMessage(
config={
"workers": [
{
"id": "gaussian_fit_worker_3",
"config": {
"stream": "scan_segment",
"output": "gaussian_fit_worker_3",
"input_xy": ["samx.samx.value", "gauss_bpm.gauss_bpm.value"],
"model": "GaussianModel",
"worker_cls": "LmfitProcessor",
},
}
]
}
)
msg_obj = MessageObject(msg.dumps(), MessageEndpoints.dap_config())
with mock.patch.object(worker_manager, "update_config") as mock_update_config:
msg_obj = MessageObject(msg.dumps(), MessageEndpoints.dap_config())
worker_manager._set_config(msg_obj, worker_manager)
mock_update_config.assert_called_once_with(msg)
mock_update_config.reset_mock()
msg_obj = MessageObject(None, MessageEndpoints.dap_config())
worker_manager._set_config(msg_obj, worker_manager)
mock_update_config.assert_not_called()
def test_worker_manager_shutdown():
connector = mock.MagicMock()
worker_manager = DAPWorkerManager(connector)
worker_mock1 = mock.MagicMock()
worker_mock2 = mock.MagicMock()
worker_manager._workers = {
"gaussian_fit_worker_1": {"worker": worker_mock1},
"gaussian_fit_worker_2": {"worker": worker_mock2},
}
worker_manager.shutdown()
worker_mock1.terminate.assert_called_once()
worker_mock2.terminate.assert_called_once()