mirror of
https://github.com/ivan-usov-org/bec.git
synced 2025-04-22 02:20:02 +02:00
239 lines
9.5 KiB
Python
239 lines
9.5 KiB
Python
from unittest import mock
|
|
|
|
from bec_lib.core import BECMessage, MessageEndpoints
|
|
from bec_lib.core.redis_connector import MessageObject
|
|
from data_processing.worker_manager import DAPWorkerManager
|
|
|
|
|
|
def test_worker_manager_retrieves_config_on_startup():
|
|
connector = mock.MagicMock()
|
|
with mock.patch.object(DAPWorkerManager, "update_config") as mock_update_config:
|
|
config = {
|
|
"stream": "scan_segment",
|
|
"output": "gaussian_fit_worker_3",
|
|
"input_xy": ["samx.samx.value", "gauss_bpm.gauss_bpm.value"],
|
|
"model": "GaussianModel",
|
|
}
|
|
worker_config = {"id": "gaussian_fit_worker_3", "config": config}
|
|
connector.producer().get.return_value = BECMessage.DAPConfigMessage(
|
|
config={"workers": [worker_config]}
|
|
).dumps()
|
|
worker_manager = DAPWorkerManager(connector)
|
|
mock_update_config.assert_called_once()
|
|
|
|
|
|
def test_worker_manager_retrieves_config_on_startup_empty():
|
|
connector = mock.MagicMock()
|
|
with mock.patch.object(DAPWorkerManager, "update_config") as mock_update_config:
|
|
connector.producer().get.return_value = None
|
|
worker_manager = DAPWorkerManager(connector)
|
|
mock_update_config.assert_not_called()
|
|
|
|
|
|
def test_worker_manager_update_config():
|
|
connector = mock.MagicMock()
|
|
with mock.patch.object(DAPWorkerManager, "_start_worker") as mock_start_worker:
|
|
connector.producer().get.return_value = None
|
|
worker_manager = DAPWorkerManager(connector)
|
|
config = {
|
|
"stream": "scan_segment",
|
|
"output": "gaussian_fit_worker_3",
|
|
"input_xy": ["samx.samx.value", "gauss_bpm.gauss_bpm.value"],
|
|
"model": "GaussianModel",
|
|
"worker_cls": "LmfitProcessor",
|
|
}
|
|
worker_config = {"id": "gaussian_fit_worker_3", "config": config}
|
|
worker_manager.update_config(
|
|
BECMessage.DAPConfigMessage(config={"workers": [worker_config]})
|
|
)
|
|
mock_start_worker.assert_called_once()
|
|
|
|
|
|
def test_worker_manager_update_config_no_workers():
|
|
connector = mock.MagicMock()
|
|
with mock.patch.object(DAPWorkerManager, "_start_worker") as mock_start_worker:
|
|
connector.producer().get.return_value = None
|
|
worker_manager = DAPWorkerManager(connector)
|
|
worker_manager.update_config(BECMessage.DAPConfigMessage(config={"workers": []}))
|
|
mock_start_worker.assert_not_called()
|
|
|
|
|
|
def test_worker_manager_update_config_worker_already_running():
|
|
connector = mock.MagicMock()
|
|
with mock.patch.object(DAPWorkerManager, "_start_worker") as mock_start_worker:
|
|
connector.producer().get.return_value = None
|
|
worker_manager = DAPWorkerManager(connector)
|
|
config = {
|
|
"stream": "scan_segment",
|
|
"output": "gaussian_fit_worker_3",
|
|
"input_xy": ["samx.samx.value", "gauss_bpm.gauss_bpm.value"],
|
|
"model": "GaussianModel",
|
|
"worker_cls": "LmfitProcessor",
|
|
}
|
|
worker_config = {"id": "gaussian_fit_worker_3", "config": config}
|
|
worker_manager._workers = {"gaussian_fit_worker_3": {"config": config, "worker": None}}
|
|
worker_manager.update_config(
|
|
BECMessage.DAPConfigMessage(config={"workers": [worker_config]})
|
|
)
|
|
mock_start_worker.assert_not_called()
|
|
|
|
|
|
def test_worker_manager_update_config_worker_already_running_different_config():
|
|
connector = mock.MagicMock()
|
|
with mock.patch.object(DAPWorkerManager, "_start_worker") as mock_start_worker:
|
|
connector.producer().get.return_value = None
|
|
worker_manager = DAPWorkerManager(connector)
|
|
config = {
|
|
"stream": "scan_segment",
|
|
"output": "gaussian_fit_worker_3",
|
|
"input_xy": ["samx.samx.value", "gauss_bpm.gauss_bpm.value"],
|
|
"model": "GaussianModel",
|
|
"worker_cls": "LmfitProcessor",
|
|
}
|
|
w3_mock = mock.MagicMock()
|
|
worker_config = {"id": "gaussian_fit_worker_3", "config": config}
|
|
worker_manager._workers = {"gaussian_fit_worker_3": {"config": {}, "worker": w3_mock}}
|
|
worker_manager.update_config(
|
|
BECMessage.DAPConfigMessage(config={"workers": [worker_config]})
|
|
)
|
|
mock_start_worker.assert_called_once()
|
|
w3_mock.terminate.assert_called_once()
|
|
|
|
|
|
def test_worker_manager_update_config_remove_outdated_workers():
|
|
connector = mock.MagicMock()
|
|
with mock.patch.object(DAPWorkerManager, "_start_worker") as mock_start_worker:
|
|
connector.producer().get.return_value = None
|
|
worker_manager = DAPWorkerManager(connector)
|
|
config = {
|
|
"stream": "scan_segment",
|
|
"output": "gaussian_fit_worker_3",
|
|
"input_xy": ["samx.samx.value", "gauss_bpm.gauss_bpm.value"],
|
|
"model": "GaussianModel",
|
|
}
|
|
w3_mock = mock.MagicMock()
|
|
worker_config = {"id": "gaussian_fit_worker_3", "config": config}
|
|
worker_manager._workers = {"gaussian_fit_worker_3": {"config": {}, "worker": w3_mock}}
|
|
worker_manager.update_config(BECMessage.DAPConfigMessage(config={"workers": []}))
|
|
mock_start_worker.assert_not_called()
|
|
w3_mock.terminate.assert_called_once()
|
|
assert worker_manager._workers == {}
|
|
|
|
|
|
class DAPMockClass:
|
|
def run(self):
|
|
pass
|
|
|
|
|
|
class DAPMockWrongClass:
|
|
def no_run(self):
|
|
pass
|
|
|
|
|
|
def test_worker_manager_update_available_plugins():
|
|
connector = mock.MagicMock()
|
|
with mock.patch("data_processing.worker_manager.dap_plugins"):
|
|
with mock.patch("data_processing.worker_manager.inspect.getmembers") as mock_getmembers:
|
|
mock_getmembers.return_value = [
|
|
("CustomPlugin", DAPMockClass),
|
|
("WrongPlugin", DAPMockWrongClass),
|
|
]
|
|
worker_manager = DAPWorkerManager(connector)
|
|
assert "CustomPlugin" in worker_manager._worker_plugins
|
|
assert "WrongPlugin" not in worker_manager._worker_plugins
|
|
assert "LmfitProcessor" in worker_manager._worker_plugins
|
|
|
|
|
|
def test_worker_manager_start_worker():
|
|
connector = mock.MagicMock()
|
|
dap_plugin_cls = mock.MagicMock()
|
|
with mock.patch.object(DAPWorkerManager, "_update_config"):
|
|
with mock.patch.object(DAPWorkerManager, "run_worker") as mock_run_worker:
|
|
worker_manager = DAPWorkerManager(connector)
|
|
config = {
|
|
"id": "gaussian_fit_worker_3",
|
|
"config": {
|
|
"stream": "scan_segment",
|
|
"output": "gaussian_fit_worker_3",
|
|
"input_xy": ["samx.samx.value", "gauss_bpm.gauss_bpm.value"],
|
|
"model": "GaussianModel",
|
|
"worker_cls": "LmfitProcessor",
|
|
},
|
|
}
|
|
worker_manager._start_worker(
|
|
config,
|
|
dap_plugin_cls,
|
|
)
|
|
assert "gaussian_fit_worker_3" in worker_manager._workers
|
|
assert worker_manager._workers["gaussian_fit_worker_3"]["config"] == config["config"]
|
|
mock_run_worker.assert_called_once_with(
|
|
config["config"],
|
|
worker_cls=dap_plugin_cls,
|
|
connector_host=connector.bootstrap,
|
|
)
|
|
|
|
|
|
def test_worker_manager_run_worker():
|
|
with mock.patch("data_processing.worker_manager.mp") as mock_mp:
|
|
worker_cls = mock.MagicMock()
|
|
ret = DAPWorkerManager.run_worker(
|
|
config={"stream": "scan_segment", "output": "gaussian_fit_worker_3"},
|
|
worker_cls=worker_cls,
|
|
connector_host=["localhost:6379"],
|
|
)
|
|
mock_mp.Process.assert_called_once_with(
|
|
target=worker_cls.run,
|
|
kwargs={
|
|
"config": {"stream": "scan_segment", "output": "gaussian_fit_worker_3"},
|
|
"connector_host": ["localhost:6379"],
|
|
},
|
|
daemon=True,
|
|
)
|
|
assert ret == mock_mp.Process()
|
|
|
|
|
|
def test_worker_manager_set_config():
|
|
connector = mock.MagicMock()
|
|
|
|
worker_manager = DAPWorkerManager(connector)
|
|
msg = BECMessage.DAPConfigMessage(
|
|
config={
|
|
"workers": [
|
|
{
|
|
"id": "gaussian_fit_worker_3",
|
|
"config": {
|
|
"stream": "scan_segment",
|
|
"output": "gaussian_fit_worker_3",
|
|
"input_xy": ["samx.samx.value", "gauss_bpm.gauss_bpm.value"],
|
|
"model": "GaussianModel",
|
|
"worker_cls": "LmfitProcessor",
|
|
},
|
|
}
|
|
]
|
|
}
|
|
)
|
|
msg_obj = MessageObject(msg.dumps(), MessageEndpoints.dap_config())
|
|
with mock.patch.object(worker_manager, "update_config") as mock_update_config:
|
|
msg_obj = MessageObject(msg.dumps(), MessageEndpoints.dap_config())
|
|
worker_manager._set_config(msg_obj, worker_manager)
|
|
mock_update_config.assert_called_once_with(msg)
|
|
|
|
mock_update_config.reset_mock()
|
|
msg_obj = MessageObject(None, MessageEndpoints.dap_config())
|
|
worker_manager._set_config(msg_obj, worker_manager)
|
|
mock_update_config.assert_not_called()
|
|
|
|
|
|
def test_worker_manager_shutdown():
|
|
connector = mock.MagicMock()
|
|
worker_manager = DAPWorkerManager(connector)
|
|
worker_mock1 = mock.MagicMock()
|
|
worker_mock2 = mock.MagicMock()
|
|
worker_manager._workers = {
|
|
"gaussian_fit_worker_1": {"worker": worker_mock1},
|
|
"gaussian_fit_worker_2": {"worker": worker_mock2},
|
|
}
|
|
worker_manager.shutdown()
|
|
worker_mock1.terminate.assert_called_once()
|
|
worker_mock2.terminate.assert_called_once()
|