diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 17e3685b..58ba6d76 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -30,6 +30,7 @@ stages: - pip install -e ./bec_client - pip install -e ./file_writer - pip install -e ./scihub + - pip install -e ./data_processing formatter: stage: Formatter @@ -44,7 +45,7 @@ pylint: - pip install pylint pylint-exit anybadge script: - mkdir ./pylint - - pylint ./bec_client_lib/bec_client_lib ./scan_server/scan_server ./device_server/device_server ./scan_bundler/scan_bundler ./bec_client/bec_client ./file_writer/file_writer --output-format=text . | tee ./pylint/pylint.log || pylint-exit $? + - pylint ./data_processing/data_processing ./bec_client_lib/bec_client_lib ./scan_server/scan_server ./device_server/device_server ./scan_bundler/scan_bundler ./bec_client/bec_client ./file_writer/file_writer --output-format=text . | tee ./pylint/pylint.log || pylint-exit $? - PYLINT_SCORE=$(sed -n 's/^Your code has been rated at \([-0-9.]*\)\/.*/\1/p' ./pylint/pylint.log) - anybadge --label=Pylint --file=pylint/pylint.svg --value=$PYLINT_SCORE 2=red 4=orange 8=yellow 10=green - echo "Pylint score is $PYLINT_SCORE" @@ -93,7 +94,7 @@ tests: - pip install pytest pytest-random-order pytest-cov pytest-asyncio - apt-get install -y gcc - *install-bec-services - - coverage run --source=./bec_client_lib/bec_client_lib,./device_server/device_server,./scan_server/scan_server,./scan_bundler/scan_bundler,./bec_client/bec_client,./file_writer/file_writer,./scihub/scihub --omit=*/bec_client/bec_client/plugins/*,*/bec_client/scripts/* -m pytest -v --junitxml=report.xml --random-order ./scan_server/tests ./device_server/tests ./scan_bundler/tests ./bec_client/tests/client_tests ./file_writer/tests ./scihub/tests ./bec_client_lib/tests + - coverage run --source=./data_processing/data_processing,./bec_client_lib/bec_client_lib,./device_server/device_server,./scan_server/scan_server,./scan_bundler/scan_bundler,./bec_client/bec_client,./file_writer/file_writer,./scihub/scihub --omit=*/bec_client/bec_client/plugins/*,*/bec_client/scripts/* -m pytest -v --junitxml=report.xml --random-order ./data_processing/tests ./scan_server/tests ./device_server/tests ./scan_bundler/tests ./bec_client/tests/client_tests ./file_writer/tests ./scihub/tests ./bec_client_lib/tests - coverage report - coverage xml coverage: '/(?i)total.*? (100(?:\.0+)?\%|[1-9]?\d(?:\.\d+)?\%)$/' @@ -113,7 +114,7 @@ tests-3.9: - pip install pytest pytest-random-order pytest-cov pytest-asyncio pytest-timeout - apt-get install -y gcc - *install-bec-services - - pytest -v --junitxml=report.xml --random-order ./bec_client_lib/tests ./scan_server/tests ./device_server/tests ./scan_bundler/tests ./bec_client/tests/client_tests ./file_writer/tests + - pytest -v --junitxml=report.xml --random-order ./data_processing/tests ./bec_client_lib/tests ./scan_server/tests ./device_server/tests ./scan_bundler/tests ./bec_client/tests/client_tests ./file_writer/tests tests-3.10: extends: "tests-3.9" diff --git a/data_processing/data_processing/__init__.py b/data_processing/data_processing/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/data_processing/data_processing/dap_server.py b/data_processing/data_processing/dap_server.py new file mode 100644 index 00000000..0bd2c703 --- /dev/null +++ b/data_processing/data_processing/dap_server.py @@ -0,0 +1,23 @@ +from bec_client_lib.core import BECService +from bec_client_lib.core.connector import ConnectorBase +from bec_client_lib.core.service_config import ServiceConfig + +from .worker_manager import DAPWorkerManager + + +class DAPServer(BECService): + """Data processing server class.""" + + def __init__( + self, config: ServiceConfig, connector_cls: ConnectorBase, unique_service=False + ) -> None: + super().__init__(config, connector_cls, unique_service) + self._work_manager = None + self._start_manager() + + def _start_manager(self): + self._work_manager = DAPWorkerManager(self.connector) + + def shutdown(self): + self._work_manager.shutdown() + super().shutdown() diff --git a/data_processing/data_processing/stream_processor.py b/data_processing/data_processing/stream_processor.py new file mode 100644 index 00000000..44b01f58 --- /dev/null +++ b/data_processing/data_processing/stream_processor.py @@ -0,0 +1,247 @@ +from __future__ import annotations + +import time +from abc import ABC, abstractmethod +from collections import deque +from typing import Any, List, Optional, Tuple + +import lmfit +from bec_client_lib.core import BECMessage, MessageEndpoints +from bec_client_lib.core.redis_connector import MessageObject, RedisConnector + + +def nested_get(data: str, keys, default=None): + """ + Get a value from a nested dictionary. + + Args: + data (dict): Dictionary to get the value from. + keys (str): Keys to get the value from. Keys are separated by a dot. + default (Any, optional): Default value to return if the key is not found. Defaults to None. + + Returns: + Any: Value of the key. + + Examples: + >>> data = {"a": {"b": 1}} + >>> nested_get(data, "a.b") + 1 + """ + if "." in keys: + key, rest = keys.split(".", 1) + return nested_get(data[key], rest, default=default) + return data.get(keys, default) + + +class StreamProcessor(ABC): + """ + Abstract class for stream processors. This class is responsible for + processing stream data. Each processor is started in a separate process. + Override the process method to implement the processing logic. + + Please note that the processor stores the data in the self.data attribute. + This is done to allow the processor to access multiple data points at once, + e.g. for fitting.Make sure to reset the data attribute after processing + the data to avoid memory leaks. + """ + + def __init__(self, connector: RedisConnector, config: dict) -> None: + """ + Initialize the StreamProcessor class. + + Args: + connector (RedisConnector): Redis connector. + config (dict): Configuration for the processor. + """ + super().__init__() + self._connector = connector + self.producer = connector.producer() + self._process = None + self.queue = deque() + self.consumer = None + self.config = config + self.data = None + + def reset_data(self): + """Reset the data.""" + self.data = None + + @abstractmethod + def process(self, data: dict, metadata: dict) -> Tuple[dict, dict]: + """ + Process data and return the result. + + Args: + data (dict): Data to be processed. + metadata (dict): Metadata associated with the data. + + Returns: + Tuple[dict, dict]: Tuple containing the processed data and metadata. + """ + + @property + def status(self): + """Return the worker status.""" + return { + "process": self._process, + "config": self.config, + "started": self._process.is_alive(), + } + + def shutdown(self): + """Shutdown the worker. Terminate the process and wait for it to join.""" + self._process.terminate() + self._process.join() + + def _run_forever(self): + """Core method for the worker. This method is called in a while True loop.""" + if not self.queue: + time.sleep(0.1) + return + data = self.queue.popleft() + + # Process data + result = self._process_data(data) + + # publish the result + if not all(result): + return + + # for multiple results, publish them as a bundle + if isinstance(result, list) and len(result) > 1: + msg_bundle = BECMessage.BundleMessage() + for data, metadata in result: + msg = BECMessage.ProcessedDataMessage(data=data, metadata=metadata).dumps() + msg_bundle.append(msg) + self._publish_result(msg_bundle.dumps()) + else: + msg = BECMessage.ProcessedDataMessage(data=result[0][0], metadata=result[0][1]).dumps() + self._publish_result(msg) + + def start(self): + """Run the worker. This method is called in a separate process.""" + while True: + self._run_forever() + + def _process_data(self, data: BECMessage.BECMessage) -> List[Tuple[dict, dict]]: + """Process data.""" + if not isinstance(data, list): + data = [data] + + return [self.process(sub_data.content, sub_data.metadata) for sub_data in data] + + def start_data_consumer(self): + """Get data from redis.""" + if self.consumer and self.consumer.is_alive(): + self.consumer.shutdown() + self.consumer = self._connector.consumer( + self.config["stream"], cb=self._set_data, parent=self + ) + self.consumer.start() + + @staticmethod + def _set_data(msg: MessageObject, parent: StreamProcessor): + """Set data to the parent.""" + parent.queue.append(BECMessage.MessageReader.loads(msg.value)) + + def _publish_result(self, msg: BECMessage.BECMessage): + """Publish the result.""" + self.producer.set_and_publish( + MessageEndpoints.processed_data(self.config["output"]), + msg, + ) + + @classmethod + def run(cls, config: dict, connector_host: list[str]) -> None: + """Run the worker.""" + connector = RedisConnector(connector_host) + + worker = cls(connector, config) + worker.start_data_consumer() + worker.start() + + +class LmfitProcessor(StreamProcessor): + """Lmfit processor class.""" + + def __init__(self, connector: RedisConnector, config: dict) -> None: + """ + Initialize the LmfitProcessor class. + + Args: + connector (RedisConnector): Redis connector. + config (dict): Configuration for the processor. + """ + super().__init__(connector, config) + self.model = self._get_model() + self.scan_id = None + + def _get_model(self) -> lmfit.Model: + """Get the model from the config and convert it to an lmfit model.""" + + if not self.config: + raise ValueError("No config provided") + if not self.config.get("model"): + raise ValueError("No model provided") + + model = self.config["model"] + if not isinstance(model, str): + raise ValueError("Model must be a string") + + # check if the model is a valid lmfit model + if not hasattr(lmfit.models, model): + raise ValueError(f"Invalid model: {model}") + + model = getattr(lmfit.models, model) + + return model() + + def process(self, data: dict, metadata: dict) -> Optional[Tuple[dict, dict]]: + """ + Process data and return the result. + + Args: + data (dict): Data to be processed. + metadata (dict): Metadata associated with the data. + + Returns: + Tuple[dict, dict]: Tuple containing the processed data and metadata. Returns None if no data is provided or if the fit is skipped. + """ + + if not data: + return None + + # get the event data + x = nested_get(data.get("data", {}), self.config["input_xy"][0]) + y = nested_get(data.get("data", {}), self.config["input_xy"][1]) + + # check if the data is indeed a number + if not isinstance(x, (int, float)) or not isinstance(y, (int, float)): + return None + + # reset the data if the scan id changed + if self.scan_id != data.get("scanID"): + self.reset_data() + self.scan_id = data.get("scanID") + + # append the data to the data attribute + if self.data is None: + self.data = {"x": [], "y": []} + self.data["x"].append(x) + self.data["y"].append(y) + + # check if the data is long enough to fit + if len(self.data["x"]) < 3: + return None + + # fit the data + result = self.model.fit(self.data["y"], x=self.data["x"]) + + # add the fit result to the output + stream_output = {self.config["output"]: result.best_fit, "input": self.config["input_xy"]} + + # add the fit parameters to the metadata + metadata["fit_parameters"] = result.best_values + metadata["fit_summary"] = result.summary() + + return (stream_output, metadata) diff --git a/data_processing/data_processing/worker_manager.py b/data_processing/data_processing/worker_manager.py new file mode 100644 index 00000000..af4c773b --- /dev/null +++ b/data_processing/data_processing/worker_manager.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +import multiprocessing as mp + +from bec_client_lib.core import BECMessage, MessageEndpoints, bec_logger +from bec_client_lib.core.redis_connector import RedisConnector + +from .stream_processor import LmfitProcessor + +logger = bec_logger.logger + + +class DAPWorkerManager: + """Data processing worker manager class.""" + + def __init__(self, connector: RedisConnector): + self.connector = connector + self.producer = connector.producer() + self._workers = {} + self._config = {} + self._update_config() + self._start_config_consumer() + + def _update_config(self): + """Get config from redis.""" + logger.debug("Getting config from redis") + msg = self.producer.get(MessageEndpoints.dap_config()) + if not msg: + return + self.update_config(BECMessage.DAPConfigMessage.loads(msg)) + + def _start_config_consumer(self): + """Get config from redis.""" + logger.debug("Starting config consumer") + self.consumer = self.connector.consumer( + MessageEndpoints.dap_config(), cb=self._set_config, parent=self + ) + self.consumer.start() + + @staticmethod + def _set_config(msg: BECMessage.BECMessage, parent: DAPWorkerManager) -> None: + """Set config to the parent.""" + msg = BECMessage.DAPConfigMessage.loads(msg.value) + if not msg: + return + parent.update_config(msg) + + def update_config(self, msg: BECMessage.DAPConfigMessage): + """Update the config.""" + logger.debug(f"Updating config: {msg.content}") + if not msg.content["config"]: + return + self._config = msg.content["config"] + for worker_config in self._config["workers"]: + # Check if the worker is already running and start it if not + if worker_config["id"] not in self._workers: + self._start_worker(worker_config) + continue + + # Check if the config has changed + if self._workers[worker_config["id"]]["config"] == worker_config["config"]: + logger.debug(f"Worker config has not changed: {worker_config['id']}") + continue + + # If the config has changed, terminate the worker and start a new one + logger.debug(f"Restarting worker: {worker_config['id']}") + self._workers[worker_config["id"]]["worker"].terminate() + self._start_worker(worker_config) + + # Check if any workers need to be removed + for worker_id in list(self._workers): + if worker_id not in [worker["id"] for worker in self._config["workers"]]: + logger.debug(f"Removing worker: {worker_id}") + self._workers[worker_id]["worker"].terminate() + del self._workers[worker_id] + + def _start_worker(self, config: dict): + """Start a worker.""" + logger.debug(f"Starting worker: {config}") + + self._workers[config["id"]] = { + "worker": self.run_worker(config["config"]), + "config": config["config"], + } + + def shutdown(self): + for worker in self._workers: + worker.shutdown() + + @staticmethod + def run_worker(config: dict) -> mp.Process: + """Run the worker.""" + worker = mp.Process( + target=LmfitProcessor.run, + kwargs={"config": config, "connector_host": ["localhost:6379"]}, + daemon=True, + ) + worker.start() + return worker diff --git a/data_processing/launch.py b/data_processing/launch.py new file mode 100644 index 00000000..e01bf2af --- /dev/null +++ b/data_processing/launch.py @@ -0,0 +1,35 @@ +import argparse +import threading + +from bec_client_lib.core import RedisConnector, ServiceConfig, bec_logger + +from data_processing.dap_server import DAPServer + +if __name__ == "__main__": + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument( + "--config", + default="", + help="path to the config file", + ) + clargs = parser.parse_args() + config_path = clargs.config + + config = ServiceConfig(config_path) + bec_logger.level = bec_logger.LOGLEVEL.DEBUG + logger = bec_logger.logger + + bec_server = DAPServer( + config=config, + connector_cls=RedisConnector, + ) + try: + event = threading.Event() + # pylint: disable=E1102 + logger.success("Started DAP server") + event.wait() + except KeyboardInterrupt as e: + # bec_server.connector.raise_error("KeyboardInterrupt") + bec_server.shutdown() + event.set() + raise e diff --git a/data_processing/setup.cfg b/data_processing/setup.cfg new file mode 100644 index 00000000..8fb53e99 --- /dev/null +++ b/data_processing/setup.cfg @@ -0,0 +1,24 @@ +[metadata] +name = bec_dap +version = 0.0.1 +description = BEC Online Data Processing Service +long_description = file: README.md +long_description_content_type = text/markdown +url = https://github.com/pypa/sampleproject +project_urls = + Bug Tracker = https://github.com/pypa/sampleproject/issues +classifiers = + Programming Language :: Python :: 3 + License :: OSI Approved :: MIT License + Operating System :: OS Independent + +[options] +package_dir = + = . +packages = find: +python_requires = >=3.7 + +[options.packages.find] +where = . + + diff --git a/data_processing/setup.py b/data_processing/setup.py new file mode 100644 index 00000000..39bbfa85 --- /dev/null +++ b/data_processing/setup.py @@ -0,0 +1,14 @@ +import pathlib +import subprocess + +from setuptools import setup + +current_path = pathlib.Path(__file__).parent.resolve() +utils = f"{current_path}/../bec_client_lib/" + + +if __name__ == "__main__": + setup(install_requires=["lmfit"]) + local_deps = [utils] + for dep in local_deps: + subprocess.run(f"pip install -e {dep}", shell=True, check=True) diff --git a/data_processing/tests/test_lmfit_processor.py b/data_processing/tests/test_lmfit_processor.py new file mode 100644 index 00000000..d664822a --- /dev/null +++ b/data_processing/tests/test_lmfit_processor.py @@ -0,0 +1,77 @@ +from unittest import mock + +import numpy as np +import pytest + +from data_processing.stream_processor import LmfitProcessor + + +def test_LmfitProcessor_get_model_needs_config(): + """ + Test the LmfitProcessor class get_model method. + """ + connector = mock.MagicMock() + with pytest.raises(ValueError): + LmfitProcessor(connector, {}) + + +def test_LmfitProcessor_get_model_returns_correct_model(): + """ + Test the LmfitProcessor class get_model method. + """ + connector = mock.MagicMock() + config = {"model": "GaussianModel"} + processor = LmfitProcessor(connector, config) + assert processor.model.name == "Model(gaussian)" + + +def test_LmfitProcessor_process_gaussian(): + """ + Test the LmfitProcessor class process method with a gaussian model. + """ + connector = mock.MagicMock() + config = {"model": "GaussianModel", "input_xy": ["x", "y"], "output": "gaussian_fit"} + processor = LmfitProcessor(connector, config) + processor.data = {"x": [1, 2, 3], "y": [1, 2, 3]} + data = {"data": {"x": 4, "y": 4}} + metadata = {} + result_data, result_metadata = processor.process(data, metadata) + assert np.allclose(result_data["gaussian_fit"], processor.data["y"], atol=0.1) + assert {"amplitude", "sigma", "center"} & set(result_metadata["fit_parameters"]) == { + "amplitude", + "sigma", + "center", + } + assert len(processor.data["x"]) == 4 + assert len(processor.data["y"]) == 4 + + +def test_LmfitProcessor_resets_scan_data(): + """ + Test the LmfitProcessor class make sure it resets the data when the scan id changes. + """ + connector = mock.MagicMock() + config = {"model": "GaussianModel", "input_xy": ["x", "y"], "output": "gaussian_fit"} + processor = LmfitProcessor(connector, config) + processor.data = {"x": [1, 2, 3], "y": [1, 2, 3]} + data = {"data": {"x": 4, "y": 4}, "scanID": 1} + metadata = {} + result = processor.process(data, metadata) + assert result is None + assert len(processor.data["x"]) == 1 + + +def test_LmfitProcessor_resets_scan_data_with_existing_id(): + """ + Test the LmfitProcessor class to make sure it resets the data when the scan id changes. + """ + connector = mock.MagicMock() + config = {"model": "GaussianModel", "input_xy": ["x", "y"], "output": "gaussian_fit"} + processor = LmfitProcessor(connector, config) + processor.scan_id = 1 + processor.data = {"x": [1, 2, 3], "y": [1, 2, 3]} + data = {"data": {"x": 4, "y": 4}, "scanID": 2} + metadata = {} + result = processor.process(data, metadata) + assert result is None + assert len(processor.data["x"]) == 1 diff --git a/data_processing/tests/test_nested_get.py b/data_processing/tests/test_nested_get.py new file mode 100644 index 00000000..a15ac617 --- /dev/null +++ b/data_processing/tests/test_nested_get.py @@ -0,0 +1,25 @@ +from data_processing.stream_processor import nested_get + + +def test_nested_get_default(): + """ + Test the nested_get function. + """ + data = {"a": {"b": {"c": 1}}} + assert nested_get(data, "a.b.c") == 1 + + +def test_nested_get_returns_default(): + """ + Test the nested_get function. + """ + data = {"a": {"b": {"c": 1}}} + assert nested_get(data, "a.b.d", 2) == 2 + + +def test_nested_get_with_plain_key(): + """ + Test the nested_get function. + """ + data = {"a": {"b": 1}} + assert nested_get(data, "a") == {"b": 1} diff --git a/data_processing/tests/test_stream_processor.py b/data_processing/tests/test_stream_processor.py new file mode 100644 index 00000000..bb860532 --- /dev/null +++ b/data_processing/tests/test_stream_processor.py @@ -0,0 +1,92 @@ +from unittest import mock + +import pytest +from bec_client_lib.core import BECMessage + +from data_processing.stream_processor import StreamProcessor + + +class DummyStreamProcessor(StreamProcessor): + def process(self, data: dict, metadata: dict) -> tuple: + return data, metadata + + +@pytest.fixture(scope="function") +def stream_processor(): + connector = mock.MagicMock() + config = { + "stream": "scan_segment", + "output": "gaussian_fit_worker_3", + "input_xy": ["samx.samx.value", "gauss_bpm.gauss_bpm.value"], + "model": "GaussianModel", + } + return DummyStreamProcessor(connector, config) + + +def test_stream_processor_run_forever(stream_processor): + """ + Test the StreamProcessor class run_forever method. + """ + + stream_processor.queue.append( + BECMessage.ScanMessage(point_id=1, scanID="scanID", data={"x": 1, "y": 1}) + ) + with mock.patch.object(StreamProcessor, "_process_data") as mock_process_data: + mock_process_data.return_value = [ + ({"x": 1, "y": 1}, {"scanID": "scanID"}), + ] + stream_processor._run_forever() + mock_process_data.assert_called_once() + + +def test_stream_processor_publishes_bundled_data(stream_processor): + """ + Test the StreamProcessor class run_forever method and make sure it publishes bundled data. + """ + stream_processor.queue.append( + BECMessage.ScanMessage(point_id=1, scanID="scanID", data={"x": 1, "y": 1}) + ) + with mock.patch.object(StreamProcessor, "_process_data") as mock_process_data: + mock_process_data.return_value = [ + ({"x": 1, "y": 1}, {"scanID": "scanID"}), + ({"x": 1, "y": 1}, {"scanID": "scanID"}), + ] + stream_processor._run_forever() + mock_process_data.assert_called_once() + assert stream_processor._connector.producer().set_and_publish.call_count == 1 + + +def test_stream_processor_does_not_publish_empty_data(stream_processor): + """ + Test the StreamProcessor class run_forever method and make sure does not publish empty data. + """ + stream_processor.queue.append( + BECMessage.ScanMessage(point_id=1, scanID="scanID", data={"x": 1, "y": 1}) + ) + with mock.patch.object(StreamProcessor, "_process_data") as mock_process_data: + mock_process_data.return_value = [ + None, + ] + stream_processor._run_forever() + mock_process_data.assert_called_once() + assert stream_processor._connector.producer().set_and_publish.call_count == 0 + + +def test_stream_processor_start_data_consumer(stream_processor): + """ + Test the StreamProcessor class start_data_consumer method. + """ + stream_processor.start_data_consumer() + stream_processor._connector.consumer.assert_called_once() + assert stream_processor._connector.consumer().start.call_count == 1 + + +def test_stream_processor_start_data_consumer_stops_existing_consumer(stream_processor): + """ + Test the StreamProcessor class start_data_consumer method and make sure it stops the existing consumer. + """ + orig_consumer = mock.MagicMock() + stream_processor.consumer = orig_consumer + stream_processor.consumer.is_alive.return_value = True + stream_processor.start_data_consumer() + assert orig_consumer.shutdown.call_count == 1 diff --git a/data_processing/tests/test_worker_manager.py b/data_processing/tests/test_worker_manager.py new file mode 100644 index 00000000..f65a8273 --- /dev/null +++ b/data_processing/tests/test_worker_manager.py @@ -0,0 +1,117 @@ +from unittest import mock + +from bec_client_lib.core import BECMessage, MessageEndpoints + +from data_processing.worker_manager import DAPWorkerManager + + +def test_worker_manager_retrieves_config_on_startup(): + connector = mock.MagicMock() + with mock.patch.object(DAPWorkerManager, "update_config") as mock_update_config: + config = { + "stream": "scan_segment", + "output": "gaussian_fit_worker_3", + "input_xy": ["samx.samx.value", "gauss_bpm.gauss_bpm.value"], + "model": "GaussianModel", + } + worker_config = {"id": "gaussian_fit_worker_3", "config": config} + connector.producer().get.return_value = BECMessage.DAPConfigMessage( + config={"workers": [worker_config]} + ).dumps() + worker_manager = DAPWorkerManager(connector) + mock_update_config.assert_called_once() + + +def test_worker_manager_retrieves_config_on_startup_empty(): + connector = mock.MagicMock() + with mock.patch.object(DAPWorkerManager, "update_config") as mock_update_config: + connector.producer().get.return_value = None + worker_manager = DAPWorkerManager(connector) + mock_update_config.assert_not_called() + + +def test_worker_manager_update_config(): + connector = mock.MagicMock() + with mock.patch.object(DAPWorkerManager, "_start_worker") as mock_start_worker: + connector.producer().get.return_value = None + worker_manager = DAPWorkerManager(connector) + config = { + "stream": "scan_segment", + "output": "gaussian_fit_worker_3", + "input_xy": ["samx.samx.value", "gauss_bpm.gauss_bpm.value"], + "model": "GaussianModel", + } + worker_config = {"id": "gaussian_fit_worker_3", "config": config} + worker_manager.update_config( + BECMessage.DAPConfigMessage(config={"workers": [worker_config]}) + ) + mock_start_worker.assert_called_once() + + +def test_worker_manager_update_config_no_workers(): + connector = mock.MagicMock() + with mock.patch.object(DAPWorkerManager, "_start_worker") as mock_start_worker: + connector.producer().get.return_value = None + worker_manager = DAPWorkerManager(connector) + worker_manager.update_config(BECMessage.DAPConfigMessage(config={"workers": []})) + mock_start_worker.assert_not_called() + + +def test_worker_manager_update_config_worker_already_running(): + connector = mock.MagicMock() + with mock.patch.object(DAPWorkerManager, "_start_worker") as mock_start_worker: + connector.producer().get.return_value = None + worker_manager = DAPWorkerManager(connector) + config = { + "stream": "scan_segment", + "output": "gaussian_fit_worker_3", + "input_xy": ["samx.samx.value", "gauss_bpm.gauss_bpm.value"], + "model": "GaussianModel", + } + worker_config = {"id": "gaussian_fit_worker_3", "config": config} + worker_manager._workers = {"gaussian_fit_worker_3": {"config": config, "worker": None}} + worker_manager.update_config( + BECMessage.DAPConfigMessage(config={"workers": [worker_config]}) + ) + mock_start_worker.assert_not_called() + + +def test_worker_manager_update_config_worker_already_running_different_config(): + connector = mock.MagicMock() + with mock.patch.object(DAPWorkerManager, "_start_worker") as mock_start_worker: + connector.producer().get.return_value = None + worker_manager = DAPWorkerManager(connector) + config = { + "stream": "scan_segment", + "output": "gaussian_fit_worker_3", + "input_xy": ["samx.samx.value", "gauss_bpm.gauss_bpm.value"], + "model": "GaussianModel", + } + w3_mock = mock.MagicMock() + worker_config = {"id": "gaussian_fit_worker_3", "config": config} + worker_manager._workers = {"gaussian_fit_worker_3": {"config": {}, "worker": w3_mock}} + worker_manager.update_config( + BECMessage.DAPConfigMessage(config={"workers": [worker_config]}) + ) + mock_start_worker.assert_called_once() + w3_mock.terminate.assert_called_once() + + +def test_worker_manager_update_config_remove_outdated_workers(): + connector = mock.MagicMock() + with mock.patch.object(DAPWorkerManager, "_start_worker") as mock_start_worker: + connector.producer().get.return_value = None + worker_manager = DAPWorkerManager(connector) + config = { + "stream": "scan_segment", + "output": "gaussian_fit_worker_3", + "input_xy": ["samx.samx.value", "gauss_bpm.gauss_bpm.value"], + "model": "GaussianModel", + } + w3_mock = mock.MagicMock() + worker_config = {"id": "gaussian_fit_worker_3", "config": config} + worker_manager._workers = {"gaussian_fit_worker_3": {"config": {}, "worker": w3_mock}} + worker_manager.update_config(BECMessage.DAPConfigMessage(config={"workers": []})) + mock_start_worker.assert_not_called() + w3_mock.terminate.assert_called_once() + assert worker_manager._workers == {}