feat: added bec data processing service

This commit is contained in:
wakonig_k 2023-06-20 19:50:47 +02:00 committed by wakonig_k
parent e1aa5e199b
commit 17213da46b
12 changed files with 757 additions and 3 deletions

View File

@ -30,6 +30,7 @@ stages:
- pip install -e ./bec_client - pip install -e ./bec_client
- pip install -e ./file_writer - pip install -e ./file_writer
- pip install -e ./scihub - pip install -e ./scihub
- pip install -e ./data_processing
formatter: formatter:
stage: Formatter stage: Formatter
@ -44,7 +45,7 @@ pylint:
- pip install pylint pylint-exit anybadge - pip install pylint pylint-exit anybadge
script: script:
- mkdir ./pylint - mkdir ./pylint
- pylint ./bec_client_lib/bec_client_lib ./scan_server/scan_server ./device_server/device_server ./scan_bundler/scan_bundler ./bec_client/bec_client ./file_writer/file_writer --output-format=text . | tee ./pylint/pylint.log || pylint-exit $? - pylint ./data_processing/data_processing ./bec_client_lib/bec_client_lib ./scan_server/scan_server ./device_server/device_server ./scan_bundler/scan_bundler ./bec_client/bec_client ./file_writer/file_writer --output-format=text . | tee ./pylint/pylint.log || pylint-exit $?
- PYLINT_SCORE=$(sed -n 's/^Your code has been rated at \([-0-9.]*\)\/.*/\1/p' ./pylint/pylint.log) - PYLINT_SCORE=$(sed -n 's/^Your code has been rated at \([-0-9.]*\)\/.*/\1/p' ./pylint/pylint.log)
- anybadge --label=Pylint --file=pylint/pylint.svg --value=$PYLINT_SCORE 2=red 4=orange 8=yellow 10=green - anybadge --label=Pylint --file=pylint/pylint.svg --value=$PYLINT_SCORE 2=red 4=orange 8=yellow 10=green
- echo "Pylint score is $PYLINT_SCORE" - echo "Pylint score is $PYLINT_SCORE"
@ -93,7 +94,7 @@ tests:
- pip install pytest pytest-random-order pytest-cov pytest-asyncio - pip install pytest pytest-random-order pytest-cov pytest-asyncio
- apt-get install -y gcc - apt-get install -y gcc
- *install-bec-services - *install-bec-services
- coverage run --source=./bec_client_lib/bec_client_lib,./device_server/device_server,./scan_server/scan_server,./scan_bundler/scan_bundler,./bec_client/bec_client,./file_writer/file_writer,./scihub/scihub --omit=*/bec_client/bec_client/plugins/*,*/bec_client/scripts/* -m pytest -v --junitxml=report.xml --random-order ./scan_server/tests ./device_server/tests ./scan_bundler/tests ./bec_client/tests/client_tests ./file_writer/tests ./scihub/tests ./bec_client_lib/tests - coverage run --source=./data_processing/data_processing,./bec_client_lib/bec_client_lib,./device_server/device_server,./scan_server/scan_server,./scan_bundler/scan_bundler,./bec_client/bec_client,./file_writer/file_writer,./scihub/scihub --omit=*/bec_client/bec_client/plugins/*,*/bec_client/scripts/* -m pytest -v --junitxml=report.xml --random-order ./data_processing/tests ./scan_server/tests ./device_server/tests ./scan_bundler/tests ./bec_client/tests/client_tests ./file_writer/tests ./scihub/tests ./bec_client_lib/tests
- coverage report - coverage report
- coverage xml - coverage xml
coverage: '/(?i)total.*? (100(?:\.0+)?\%|[1-9]?\d(?:\.\d+)?\%)$/' coverage: '/(?i)total.*? (100(?:\.0+)?\%|[1-9]?\d(?:\.\d+)?\%)$/'
@ -113,7 +114,7 @@ tests-3.9:
- pip install pytest pytest-random-order pytest-cov pytest-asyncio pytest-timeout - pip install pytest pytest-random-order pytest-cov pytest-asyncio pytest-timeout
- apt-get install -y gcc - apt-get install -y gcc
- *install-bec-services - *install-bec-services
- pytest -v --junitxml=report.xml --random-order ./bec_client_lib/tests ./scan_server/tests ./device_server/tests ./scan_bundler/tests ./bec_client/tests/client_tests ./file_writer/tests - pytest -v --junitxml=report.xml --random-order ./data_processing/tests ./bec_client_lib/tests ./scan_server/tests ./device_server/tests ./scan_bundler/tests ./bec_client/tests/client_tests ./file_writer/tests
tests-3.10: tests-3.10:
extends: "tests-3.9" extends: "tests-3.9"

View File

@ -0,0 +1,23 @@
from bec_client_lib.core import BECService
from bec_client_lib.core.connector import ConnectorBase
from bec_client_lib.core.service_config import ServiceConfig
from .worker_manager import DAPWorkerManager
class DAPServer(BECService):
"""Data processing server class."""
def __init__(
self, config: ServiceConfig, connector_cls: ConnectorBase, unique_service=False
) -> None:
super().__init__(config, connector_cls, unique_service)
self._work_manager = None
self._start_manager()
def _start_manager(self):
self._work_manager = DAPWorkerManager(self.connector)
def shutdown(self):
self._work_manager.shutdown()
super().shutdown()

View File

@ -0,0 +1,247 @@
from __future__ import annotations
import time
from abc import ABC, abstractmethod
from collections import deque
from typing import Any, List, Optional, Tuple
import lmfit
from bec_client_lib.core import BECMessage, MessageEndpoints
from bec_client_lib.core.redis_connector import MessageObject, RedisConnector
def nested_get(data: str, keys, default=None):
"""
Get a value from a nested dictionary.
Args:
data (dict): Dictionary to get the value from.
keys (str): Keys to get the value from. Keys are separated by a dot.
default (Any, optional): Default value to return if the key is not found. Defaults to None.
Returns:
Any: Value of the key.
Examples:
>>> data = {"a": {"b": 1}}
>>> nested_get(data, "a.b")
1
"""
if "." in keys:
key, rest = keys.split(".", 1)
return nested_get(data[key], rest, default=default)
return data.get(keys, default)
class StreamProcessor(ABC):
"""
Abstract class for stream processors. This class is responsible for
processing stream data. Each processor is started in a separate process.
Override the process method to implement the processing logic.
Please note that the processor stores the data in the self.data attribute.
This is done to allow the processor to access multiple data points at once,
e.g. for fitting.Make sure to reset the data attribute after processing
the data to avoid memory leaks.
"""
def __init__(self, connector: RedisConnector, config: dict) -> None:
"""
Initialize the StreamProcessor class.
Args:
connector (RedisConnector): Redis connector.
config (dict): Configuration for the processor.
"""
super().__init__()
self._connector = connector
self.producer = connector.producer()
self._process = None
self.queue = deque()
self.consumer = None
self.config = config
self.data = None
def reset_data(self):
"""Reset the data."""
self.data = None
@abstractmethod
def process(self, data: dict, metadata: dict) -> Tuple[dict, dict]:
"""
Process data and return the result.
Args:
data (dict): Data to be processed.
metadata (dict): Metadata associated with the data.
Returns:
Tuple[dict, dict]: Tuple containing the processed data and metadata.
"""
@property
def status(self):
"""Return the worker status."""
return {
"process": self._process,
"config": self.config,
"started": self._process.is_alive(),
}
def shutdown(self):
"""Shutdown the worker. Terminate the process and wait for it to join."""
self._process.terminate()
self._process.join()
def _run_forever(self):
"""Core method for the worker. This method is called in a while True loop."""
if not self.queue:
time.sleep(0.1)
return
data = self.queue.popleft()
# Process data
result = self._process_data(data)
# publish the result
if not all(result):
return
# for multiple results, publish them as a bundle
if isinstance(result, list) and len(result) > 1:
msg_bundle = BECMessage.BundleMessage()
for data, metadata in result:
msg = BECMessage.ProcessedDataMessage(data=data, metadata=metadata).dumps()
msg_bundle.append(msg)
self._publish_result(msg_bundle.dumps())
else:
msg = BECMessage.ProcessedDataMessage(data=result[0][0], metadata=result[0][1]).dumps()
self._publish_result(msg)
def start(self):
"""Run the worker. This method is called in a separate process."""
while True:
self._run_forever()
def _process_data(self, data: BECMessage.BECMessage) -> List[Tuple[dict, dict]]:
"""Process data."""
if not isinstance(data, list):
data = [data]
return [self.process(sub_data.content, sub_data.metadata) for sub_data in data]
def start_data_consumer(self):
"""Get data from redis."""
if self.consumer and self.consumer.is_alive():
self.consumer.shutdown()
self.consumer = self._connector.consumer(
self.config["stream"], cb=self._set_data, parent=self
)
self.consumer.start()
@staticmethod
def _set_data(msg: MessageObject, parent: StreamProcessor):
"""Set data to the parent."""
parent.queue.append(BECMessage.MessageReader.loads(msg.value))
def _publish_result(self, msg: BECMessage.BECMessage):
"""Publish the result."""
self.producer.set_and_publish(
MessageEndpoints.processed_data(self.config["output"]),
msg,
)
@classmethod
def run(cls, config: dict, connector_host: list[str]) -> None:
"""Run the worker."""
connector = RedisConnector(connector_host)
worker = cls(connector, config)
worker.start_data_consumer()
worker.start()
class LmfitProcessor(StreamProcessor):
"""Lmfit processor class."""
def __init__(self, connector: RedisConnector, config: dict) -> None:
"""
Initialize the LmfitProcessor class.
Args:
connector (RedisConnector): Redis connector.
config (dict): Configuration for the processor.
"""
super().__init__(connector, config)
self.model = self._get_model()
self.scan_id = None
def _get_model(self) -> lmfit.Model:
"""Get the model from the config and convert it to an lmfit model."""
if not self.config:
raise ValueError("No config provided")
if not self.config.get("model"):
raise ValueError("No model provided")
model = self.config["model"]
if not isinstance(model, str):
raise ValueError("Model must be a string")
# check if the model is a valid lmfit model
if not hasattr(lmfit.models, model):
raise ValueError(f"Invalid model: {model}")
model = getattr(lmfit.models, model)
return model()
def process(self, data: dict, metadata: dict) -> Optional[Tuple[dict, dict]]:
"""
Process data and return the result.
Args:
data (dict): Data to be processed.
metadata (dict): Metadata associated with the data.
Returns:
Tuple[dict, dict]: Tuple containing the processed data and metadata. Returns None if no data is provided or if the fit is skipped.
"""
if not data:
return None
# get the event data
x = nested_get(data.get("data", {}), self.config["input_xy"][0])
y = nested_get(data.get("data", {}), self.config["input_xy"][1])
# check if the data is indeed a number
if not isinstance(x, (int, float)) or not isinstance(y, (int, float)):
return None
# reset the data if the scan id changed
if self.scan_id != data.get("scanID"):
self.reset_data()
self.scan_id = data.get("scanID")
# append the data to the data attribute
if self.data is None:
self.data = {"x": [], "y": []}
self.data["x"].append(x)
self.data["y"].append(y)
# check if the data is long enough to fit
if len(self.data["x"]) < 3:
return None
# fit the data
result = self.model.fit(self.data["y"], x=self.data["x"])
# add the fit result to the output
stream_output = {self.config["output"]: result.best_fit, "input": self.config["input_xy"]}
# add the fit parameters to the metadata
metadata["fit_parameters"] = result.best_values
metadata["fit_summary"] = result.summary()
return (stream_output, metadata)

View File

@ -0,0 +1,99 @@
from __future__ import annotations
import multiprocessing as mp
from bec_client_lib.core import BECMessage, MessageEndpoints, bec_logger
from bec_client_lib.core.redis_connector import RedisConnector
from .stream_processor import LmfitProcessor
logger = bec_logger.logger
class DAPWorkerManager:
"""Data processing worker manager class."""
def __init__(self, connector: RedisConnector):
self.connector = connector
self.producer = connector.producer()
self._workers = {}
self._config = {}
self._update_config()
self._start_config_consumer()
def _update_config(self):
"""Get config from redis."""
logger.debug("Getting config from redis")
msg = self.producer.get(MessageEndpoints.dap_config())
if not msg:
return
self.update_config(BECMessage.DAPConfigMessage.loads(msg))
def _start_config_consumer(self):
"""Get config from redis."""
logger.debug("Starting config consumer")
self.consumer = self.connector.consumer(
MessageEndpoints.dap_config(), cb=self._set_config, parent=self
)
self.consumer.start()
@staticmethod
def _set_config(msg: BECMessage.BECMessage, parent: DAPWorkerManager) -> None:
"""Set config to the parent."""
msg = BECMessage.DAPConfigMessage.loads(msg.value)
if not msg:
return
parent.update_config(msg)
def update_config(self, msg: BECMessage.DAPConfigMessage):
"""Update the config."""
logger.debug(f"Updating config: {msg.content}")
if not msg.content["config"]:
return
self._config = msg.content["config"]
for worker_config in self._config["workers"]:
# Check if the worker is already running and start it if not
if worker_config["id"] not in self._workers:
self._start_worker(worker_config)
continue
# Check if the config has changed
if self._workers[worker_config["id"]]["config"] == worker_config["config"]:
logger.debug(f"Worker config has not changed: {worker_config['id']}")
continue
# If the config has changed, terminate the worker and start a new one
logger.debug(f"Restarting worker: {worker_config['id']}")
self._workers[worker_config["id"]]["worker"].terminate()
self._start_worker(worker_config)
# Check if any workers need to be removed
for worker_id in list(self._workers):
if worker_id not in [worker["id"] for worker in self._config["workers"]]:
logger.debug(f"Removing worker: {worker_id}")
self._workers[worker_id]["worker"].terminate()
del self._workers[worker_id]
def _start_worker(self, config: dict):
"""Start a worker."""
logger.debug(f"Starting worker: {config}")
self._workers[config["id"]] = {
"worker": self.run_worker(config["config"]),
"config": config["config"],
}
def shutdown(self):
for worker in self._workers:
worker.shutdown()
@staticmethod
def run_worker(config: dict) -> mp.Process:
"""Run the worker."""
worker = mp.Process(
target=LmfitProcessor.run,
kwargs={"config": config, "connector_host": ["localhost:6379"]},
daemon=True,
)
worker.start()
return worker

35
data_processing/launch.py Normal file
View File

@ -0,0 +1,35 @@
import argparse
import threading
from bec_client_lib.core import RedisConnector, ServiceConfig, bec_logger
from data_processing.dap_server import DAPServer
if __name__ == "__main__":
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
"--config",
default="",
help="path to the config file",
)
clargs = parser.parse_args()
config_path = clargs.config
config = ServiceConfig(config_path)
bec_logger.level = bec_logger.LOGLEVEL.DEBUG
logger = bec_logger.logger
bec_server = DAPServer(
config=config,
connector_cls=RedisConnector,
)
try:
event = threading.Event()
# pylint: disable=E1102
logger.success("Started DAP server")
event.wait()
except KeyboardInterrupt as e:
# bec_server.connector.raise_error("KeyboardInterrupt")
bec_server.shutdown()
event.set()
raise e

24
data_processing/setup.cfg Normal file
View File

@ -0,0 +1,24 @@
[metadata]
name = bec_dap
version = 0.0.1
description = BEC Online Data Processing Service
long_description = file: README.md
long_description_content_type = text/markdown
url = https://github.com/pypa/sampleproject
project_urls =
Bug Tracker = https://github.com/pypa/sampleproject/issues
classifiers =
Programming Language :: Python :: 3
License :: OSI Approved :: MIT License
Operating System :: OS Independent
[options]
package_dir =
= .
packages = find:
python_requires = >=3.7
[options.packages.find]
where = .

14
data_processing/setup.py Normal file
View File

@ -0,0 +1,14 @@
import pathlib
import subprocess
from setuptools import setup
current_path = pathlib.Path(__file__).parent.resolve()
utils = f"{current_path}/../bec_client_lib/"
if __name__ == "__main__":
setup(install_requires=["lmfit"])
local_deps = [utils]
for dep in local_deps:
subprocess.run(f"pip install -e {dep}", shell=True, check=True)

View File

@ -0,0 +1,77 @@
from unittest import mock
import numpy as np
import pytest
from data_processing.stream_processor import LmfitProcessor
def test_LmfitProcessor_get_model_needs_config():
"""
Test the LmfitProcessor class get_model method.
"""
connector = mock.MagicMock()
with pytest.raises(ValueError):
LmfitProcessor(connector, {})
def test_LmfitProcessor_get_model_returns_correct_model():
"""
Test the LmfitProcessor class get_model method.
"""
connector = mock.MagicMock()
config = {"model": "GaussianModel"}
processor = LmfitProcessor(connector, config)
assert processor.model.name == "Model(gaussian)"
def test_LmfitProcessor_process_gaussian():
"""
Test the LmfitProcessor class process method with a gaussian model.
"""
connector = mock.MagicMock()
config = {"model": "GaussianModel", "input_xy": ["x", "y"], "output": "gaussian_fit"}
processor = LmfitProcessor(connector, config)
processor.data = {"x": [1, 2, 3], "y": [1, 2, 3]}
data = {"data": {"x": 4, "y": 4}}
metadata = {}
result_data, result_metadata = processor.process(data, metadata)
assert np.allclose(result_data["gaussian_fit"], processor.data["y"], atol=0.1)
assert {"amplitude", "sigma", "center"} & set(result_metadata["fit_parameters"]) == {
"amplitude",
"sigma",
"center",
}
assert len(processor.data["x"]) == 4
assert len(processor.data["y"]) == 4
def test_LmfitProcessor_resets_scan_data():
"""
Test the LmfitProcessor class make sure it resets the data when the scan id changes.
"""
connector = mock.MagicMock()
config = {"model": "GaussianModel", "input_xy": ["x", "y"], "output": "gaussian_fit"}
processor = LmfitProcessor(connector, config)
processor.data = {"x": [1, 2, 3], "y": [1, 2, 3]}
data = {"data": {"x": 4, "y": 4}, "scanID": 1}
metadata = {}
result = processor.process(data, metadata)
assert result is None
assert len(processor.data["x"]) == 1
def test_LmfitProcessor_resets_scan_data_with_existing_id():
"""
Test the LmfitProcessor class to make sure it resets the data when the scan id changes.
"""
connector = mock.MagicMock()
config = {"model": "GaussianModel", "input_xy": ["x", "y"], "output": "gaussian_fit"}
processor = LmfitProcessor(connector, config)
processor.scan_id = 1
processor.data = {"x": [1, 2, 3], "y": [1, 2, 3]}
data = {"data": {"x": 4, "y": 4}, "scanID": 2}
metadata = {}
result = processor.process(data, metadata)
assert result is None
assert len(processor.data["x"]) == 1

View File

@ -0,0 +1,25 @@
from data_processing.stream_processor import nested_get
def test_nested_get_default():
"""
Test the nested_get function.
"""
data = {"a": {"b": {"c": 1}}}
assert nested_get(data, "a.b.c") == 1
def test_nested_get_returns_default():
"""
Test the nested_get function.
"""
data = {"a": {"b": {"c": 1}}}
assert nested_get(data, "a.b.d", 2) == 2
def test_nested_get_with_plain_key():
"""
Test the nested_get function.
"""
data = {"a": {"b": 1}}
assert nested_get(data, "a") == {"b": 1}

View File

@ -0,0 +1,92 @@
from unittest import mock
import pytest
from bec_client_lib.core import BECMessage
from data_processing.stream_processor import StreamProcessor
class DummyStreamProcessor(StreamProcessor):
def process(self, data: dict, metadata: dict) -> tuple:
return data, metadata
@pytest.fixture(scope="function")
def stream_processor():
connector = mock.MagicMock()
config = {
"stream": "scan_segment",
"output": "gaussian_fit_worker_3",
"input_xy": ["samx.samx.value", "gauss_bpm.gauss_bpm.value"],
"model": "GaussianModel",
}
return DummyStreamProcessor(connector, config)
def test_stream_processor_run_forever(stream_processor):
"""
Test the StreamProcessor class run_forever method.
"""
stream_processor.queue.append(
BECMessage.ScanMessage(point_id=1, scanID="scanID", data={"x": 1, "y": 1})
)
with mock.patch.object(StreamProcessor, "_process_data") as mock_process_data:
mock_process_data.return_value = [
({"x": 1, "y": 1}, {"scanID": "scanID"}),
]
stream_processor._run_forever()
mock_process_data.assert_called_once()
def test_stream_processor_publishes_bundled_data(stream_processor):
"""
Test the StreamProcessor class run_forever method and make sure it publishes bundled data.
"""
stream_processor.queue.append(
BECMessage.ScanMessage(point_id=1, scanID="scanID", data={"x": 1, "y": 1})
)
with mock.patch.object(StreamProcessor, "_process_data") as mock_process_data:
mock_process_data.return_value = [
({"x": 1, "y": 1}, {"scanID": "scanID"}),
({"x": 1, "y": 1}, {"scanID": "scanID"}),
]
stream_processor._run_forever()
mock_process_data.assert_called_once()
assert stream_processor._connector.producer().set_and_publish.call_count == 1
def test_stream_processor_does_not_publish_empty_data(stream_processor):
"""
Test the StreamProcessor class run_forever method and make sure does not publish empty data.
"""
stream_processor.queue.append(
BECMessage.ScanMessage(point_id=1, scanID="scanID", data={"x": 1, "y": 1})
)
with mock.patch.object(StreamProcessor, "_process_data") as mock_process_data:
mock_process_data.return_value = [
None,
]
stream_processor._run_forever()
mock_process_data.assert_called_once()
assert stream_processor._connector.producer().set_and_publish.call_count == 0
def test_stream_processor_start_data_consumer(stream_processor):
"""
Test the StreamProcessor class start_data_consumer method.
"""
stream_processor.start_data_consumer()
stream_processor._connector.consumer.assert_called_once()
assert stream_processor._connector.consumer().start.call_count == 1
def test_stream_processor_start_data_consumer_stops_existing_consumer(stream_processor):
"""
Test the StreamProcessor class start_data_consumer method and make sure it stops the existing consumer.
"""
orig_consumer = mock.MagicMock()
stream_processor.consumer = orig_consumer
stream_processor.consumer.is_alive.return_value = True
stream_processor.start_data_consumer()
assert orig_consumer.shutdown.call_count == 1

View File

@ -0,0 +1,117 @@
from unittest import mock
from bec_client_lib.core import BECMessage, MessageEndpoints
from data_processing.worker_manager import DAPWorkerManager
def test_worker_manager_retrieves_config_on_startup():
connector = mock.MagicMock()
with mock.patch.object(DAPWorkerManager, "update_config") as mock_update_config:
config = {
"stream": "scan_segment",
"output": "gaussian_fit_worker_3",
"input_xy": ["samx.samx.value", "gauss_bpm.gauss_bpm.value"],
"model": "GaussianModel",
}
worker_config = {"id": "gaussian_fit_worker_3", "config": config}
connector.producer().get.return_value = BECMessage.DAPConfigMessage(
config={"workers": [worker_config]}
).dumps()
worker_manager = DAPWorkerManager(connector)
mock_update_config.assert_called_once()
def test_worker_manager_retrieves_config_on_startup_empty():
connector = mock.MagicMock()
with mock.patch.object(DAPWorkerManager, "update_config") as mock_update_config:
connector.producer().get.return_value = None
worker_manager = DAPWorkerManager(connector)
mock_update_config.assert_not_called()
def test_worker_manager_update_config():
connector = mock.MagicMock()
with mock.patch.object(DAPWorkerManager, "_start_worker") as mock_start_worker:
connector.producer().get.return_value = None
worker_manager = DAPWorkerManager(connector)
config = {
"stream": "scan_segment",
"output": "gaussian_fit_worker_3",
"input_xy": ["samx.samx.value", "gauss_bpm.gauss_bpm.value"],
"model": "GaussianModel",
}
worker_config = {"id": "gaussian_fit_worker_3", "config": config}
worker_manager.update_config(
BECMessage.DAPConfigMessage(config={"workers": [worker_config]})
)
mock_start_worker.assert_called_once()
def test_worker_manager_update_config_no_workers():
connector = mock.MagicMock()
with mock.patch.object(DAPWorkerManager, "_start_worker") as mock_start_worker:
connector.producer().get.return_value = None
worker_manager = DAPWorkerManager(connector)
worker_manager.update_config(BECMessage.DAPConfigMessage(config={"workers": []}))
mock_start_worker.assert_not_called()
def test_worker_manager_update_config_worker_already_running():
connector = mock.MagicMock()
with mock.patch.object(DAPWorkerManager, "_start_worker") as mock_start_worker:
connector.producer().get.return_value = None
worker_manager = DAPWorkerManager(connector)
config = {
"stream": "scan_segment",
"output": "gaussian_fit_worker_3",
"input_xy": ["samx.samx.value", "gauss_bpm.gauss_bpm.value"],
"model": "GaussianModel",
}
worker_config = {"id": "gaussian_fit_worker_3", "config": config}
worker_manager._workers = {"gaussian_fit_worker_3": {"config": config, "worker": None}}
worker_manager.update_config(
BECMessage.DAPConfigMessage(config={"workers": [worker_config]})
)
mock_start_worker.assert_not_called()
def test_worker_manager_update_config_worker_already_running_different_config():
connector = mock.MagicMock()
with mock.patch.object(DAPWorkerManager, "_start_worker") as mock_start_worker:
connector.producer().get.return_value = None
worker_manager = DAPWorkerManager(connector)
config = {
"stream": "scan_segment",
"output": "gaussian_fit_worker_3",
"input_xy": ["samx.samx.value", "gauss_bpm.gauss_bpm.value"],
"model": "GaussianModel",
}
w3_mock = mock.MagicMock()
worker_config = {"id": "gaussian_fit_worker_3", "config": config}
worker_manager._workers = {"gaussian_fit_worker_3": {"config": {}, "worker": w3_mock}}
worker_manager.update_config(
BECMessage.DAPConfigMessage(config={"workers": [worker_config]})
)
mock_start_worker.assert_called_once()
w3_mock.terminate.assert_called_once()
def test_worker_manager_update_config_remove_outdated_workers():
connector = mock.MagicMock()
with mock.patch.object(DAPWorkerManager, "_start_worker") as mock_start_worker:
connector.producer().get.return_value = None
worker_manager = DAPWorkerManager(connector)
config = {
"stream": "scan_segment",
"output": "gaussian_fit_worker_3",
"input_xy": ["samx.samx.value", "gauss_bpm.gauss_bpm.value"],
"model": "GaussianModel",
}
w3_mock = mock.MagicMock()
worker_config = {"id": "gaussian_fit_worker_3", "config": config}
worker_manager._workers = {"gaussian_fit_worker_3": {"config": {}, "worker": w3_mock}}
worker_manager.update_config(BECMessage.DAPConfigMessage(config={"workers": []}))
mock_start_worker.assert_not_called()
w3_mock.terminate.assert_called_once()
assert worker_manager._workers == {}