From b92a79b0c063d07bd811a35b4a72104a22f2b60e Mon Sep 17 00:00:00 2001 From: Mathias Guijarro Date: Wed, 31 Jan 2024 13:43:01 +0100 Subject: [PATCH] refactor!(connector): unify connector/redis_connector in one class --- .../bec_client/callbacks/move_device.py | 8 +- .../bec_client/callbacks/scan_progress.py | 2 +- .../tests/client_tests/test_move_callback.py | 18 +- .../tests/client_tests/test_scan_progress.py | 10 +- bec_client/tests/end-2-end/test_scans.py | 4 +- bec_lib/bec_lib/__init__.py | 1 - bec_lib/bec_lib/alarm_handler.py | 10 +- bec_lib/bec_lib/async_data.py | 12 +- bec_lib/bec_lib/bec_plotter.py | 16 +- bec_lib/bec_lib/bec_service.py | 18 +- bec_lib/bec_lib/channel_monitor.py | 8 +- bec_lib/bec_lib/client.py | 8 +- bec_lib/bec_lib/config_helper.py | 9 +- bec_lib/bec_lib/connector.py | 215 ++--- bec_lib/bec_lib/dap_plugin_objects.py | 8 +- bec_lib/bec_lib/dap_plugins.py | 2 +- bec_lib/bec_lib/device.py | 32 +- bec_lib/bec_lib/devicemanager.py | 63 +- bec_lib/bec_lib/logbook_connector.py | 7 +- bec_lib/bec_lib/logger.py | 4 +- bec_lib/bec_lib/observer.py | 2 +- bec_lib/bec_lib/queue_items.py | 8 +- bec_lib/bec_lib/redis_connector.py | 587 ++++--------- bec_lib/bec_lib/scan_items.py | 2 +- bec_lib/bec_lib/scan_manager.py | 87 +- bec_lib/bec_lib/scan_report.py | 2 +- bec_lib/bec_lib/scans.py | 12 +- bec_lib/bec_lib/service_config.py | 6 +- bec_lib/bec_lib/tests/utils.py | 81 +- bec_lib/setup.py | 13 +- bec_lib/tests/test_bec_plotter.py | 8 +- bec_lib/tests/test_bec_service.py | 8 +- bec_lib/tests/test_channel_monitor.py | 6 +- bec_lib/tests/test_config_helper.py | 12 +- bec_lib/tests/test_dap_plugins.py | 16 +- bec_lib/tests/test_device_manager.py | 11 +- bec_lib/tests/test_devices.py | 32 +- bec_lib/tests/test_observer.py | 12 +- bec_lib/tests/test_redis_connector.py | 774 ++++++++---------- bec_lib/tests/test_scan_items.py | 2 +- bec_lib/tests/test_scan_report.py | 2 +- bec_lib/util_scripts/init_config.py | 3 +- .../data_processing/dap_service_manager.py | 18 +- .../data_processing/lmfit1d_service.py | 2 +- .../tests/test_dap_service_manager.py | 2 +- data_processing/tests/test_lmfit1d_service.py | 2 +- device_server/device_server/device_server.py | 80 +- .../devices/config_update_handler.py | 12 +- .../device_server/devices/devicemanager.py | 44 +- device_server/device_server/rpc_mixin.py | 4 +- device_server/tests/test_device_manager_ds.py | 16 +- device_server/tests/test_device_server.py | 18 +- device_server/tests/test_rpc_mixin.py | 6 +- file_writer/file_writer/file_writer.py | 2 +- .../file_writer/file_writer_manager.py | 38 +- file_writer/tests/test_file_writer_manager.py | 24 +- scan_bundler/scan_bundler/bec_emitter.py | 17 +- scan_bundler/scan_bundler/bluesky_emitter.py | 8 +- scan_bundler/scan_bundler/emitter.py | 18 +- scan_bundler/scan_bundler/scan_bundler.py | 74 +- scan_bundler/tests/test_bec_emitter.py | 8 +- scan_bundler/tests/test_bluesky_emitter.py | 4 +- scan_bundler/tests/test_emitter.py | 20 +- scan_bundler/tests/test_scan_bundler.py | 18 +- scan_server/scan_plugins/LamNIFermatScan.py | 2 +- scan_server/scan_plugins/owis_grid.py | 2 +- scan_server/scan_plugins/sgalil_grid.py | 2 +- scan_server/scan_server/device_validation.py | 10 +- scan_server/scan_server/scan_guard.py | 18 +- scan_server/scan_server/scan_manager.py | 2 +- scan_server/scan_server/scan_queue.py | 17 +- scan_server/scan_server/scan_server.py | 19 +- scan_server/scan_server/scan_stubs.py | 16 +- scan_server/scan_server/scan_worker.py | 54 +- scan_server/scan_server/scans.py | 26 +- scan_server/tests/test_scan_guard.py | 24 +- scan_server/tests/test_scan_server_queue.py | 34 +- scan_server/tests/test_scan_stubs.py | 28 +- scan_server/tests/test_scan_worker.py | 72 +- scan_server/tests/test_scans.py | 6 +- scihub/scihub/scibec/config_handler.py | 14 +- scihub/scihub/scibec/scibec_connector.py | 12 +- .../scihub/scibec/scibec_metadata_handler.py | 12 +- scihub/scihub/scilog/scilog.py | 3 +- scihub/tests/test_scibec_config_handler.py | 4 +- scihub/tests/test_scibec_connector.py | 4 +- 86 files changed, 1212 insertions(+), 1745 deletions(-) diff --git a/bec_client/bec_client/callbacks/move_device.py b/bec_client/bec_client/callbacks/move_device.py index bf17ea80..501c1cdc 100644 --- a/bec_client/bec_client/callbacks/move_device.py +++ b/bec_client/bec_client/callbacks/move_device.py @@ -38,16 +38,16 @@ class ReadbackDataMixin: def get_request_done_msgs(self): """get all request-done messages""" - pipe = self.device_manager.producer.pipeline() + pipe = self.device_manager.connector.pipeline() for dev in self.devices: - self.device_manager.producer.get(MessageEndpoints.device_req_status(dev), pipe) - return self.device_manager.producer.execute_pipeline(pipe) + self.device_manager.connector.get(MessageEndpoints.device_req_status(dev), pipe) + return self.device_manager.connector.execute_pipeline(pipe) def wait_for_RID(self, request): """wait for the readback's metadata to match the request ID""" while True: msgs = [ - self.device_manager.producer.get(MessageEndpoints.device_readback(dev)) + self.device_manager.connector.get(MessageEndpoints.device_readback(dev)) for dev in self.devices ] if all(msg.metadata.get("RID") == request.metadata["RID"] for msg in msgs if msg): diff --git a/bec_client/bec_client/callbacks/scan_progress.py b/bec_client/bec_client/callbacks/scan_progress.py index cb7ae59a..aff3241b 100644 --- a/bec_client/bec_client/callbacks/scan_progress.py +++ b/bec_client/bec_client/callbacks/scan_progress.py @@ -27,7 +27,7 @@ class LiveUpdatesScanProgress(LiveUpdatesTable): Update the progressbar based on the device status message. Returns True if the scan is finished. """ self.check_alarms() - status = self.bec.producer.get(MessageEndpoints.device_progress(device_names[0])) + status = self.bec.connector.get(MessageEndpoints.device_progress(device_names[0])) if not status: logger.debug("waiting for new data point") await asyncio.sleep(0.1) diff --git a/bec_client/tests/client_tests/test_move_callback.py b/bec_client/tests/client_tests/test_move_callback.py index b9473a0c..cf39790d 100644 --- a/bec_client/tests/client_tests/test_move_callback.py +++ b/bec_client/tests/client_tests/test_move_callback.py @@ -13,7 +13,7 @@ from bec_client.callbacks.move_device import ( @pytest.fixture def readback_data_mixin(bec_client): - with mock.patch.object(bec_client.device_manager, "producer"): + with mock.patch.object(bec_client.device_manager, "connector"): yield ReadbackDataMixin(bec_client.device_manager, ["samx", "samy"]) @@ -102,7 +102,7 @@ async def test_move_callback_with_report_instruction(bec_client): def test_readback_data_mixin(readback_data_mixin): - readback_data_mixin.device_manager.producer.get.side_effect = [ + readback_data_mixin.device_manager.connector.get.side_effect = [ messages.DeviceMessage( signals={"samx": {"value": 10}, "samx_setpoint": {"value": 20}}, metadata={"device": "samx"}, @@ -121,7 +121,7 @@ def test_readback_data_mixin_multiple_hints(readback_data_mixin): "samx_setpoint", "samx", ] - readback_data_mixin.device_manager.producer.get.side_effect = [ + readback_data_mixin.device_manager.connector.get.side_effect = [ messages.DeviceMessage( signals={"samx": {"value": 10}, "samx_setpoint": {"value": 20}}, metadata={"device": "samx"}, @@ -137,7 +137,7 @@ def test_readback_data_mixin_multiple_hints(readback_data_mixin): def test_readback_data_mixin_multiple_no_hints(readback_data_mixin): readback_data_mixin.device_manager.devices.samx._info["hints"]["fields"] = [] - readback_data_mixin.device_manager.producer.get.side_effect = [ + readback_data_mixin.device_manager.connector.get.side_effect = [ messages.DeviceMessage( signals={"samx": {"value": 10}, "samx_setpoint": {"value": 20}}, metadata={"device": "samx"}, @@ -153,18 +153,18 @@ def test_readback_data_mixin_multiple_no_hints(readback_data_mixin): def test_get_request_done_msgs(readback_data_mixin): res = readback_data_mixin.get_request_done_msgs() - readback_data_mixin.device_manager.producer.pipeline.assert_called_once() + readback_data_mixin.device_manager.connector.pipeline.assert_called_once() assert ( mock.call( MessageEndpoints.device_req_status("samx"), - readback_data_mixin.device_manager.producer.pipeline.return_value, + readback_data_mixin.device_manager.connector.pipeline.return_value, ) - in readback_data_mixin.device_manager.producer.get.call_args_list + in readback_data_mixin.device_manager.connector.get.call_args_list ) assert ( mock.call( MessageEndpoints.device_req_status("samy"), - readback_data_mixin.device_manager.producer.pipeline.return_value, + readback_data_mixin.device_manager.connector.pipeline.return_value, ) - in readback_data_mixin.device_manager.producer.get.call_args_list + in readback_data_mixin.device_manager.connector.get.call_args_list ) diff --git a/bec_client/tests/client_tests/test_scan_progress.py b/bec_client/tests/client_tests/test_scan_progress.py index 0f760dfe..84608e65 100644 --- a/bec_client/tests/client_tests/test_scan_progress.py +++ b/bec_client/tests/client_tests/test_scan_progress.py @@ -15,7 +15,7 @@ async def test_update_progressbar_continues_without_device_data(): live_update = LiveUpdatesScanProgress(bec=bec, report_instruction={}, request=request) progressbar = mock.MagicMock() - bec.producer.get.return_value = None + bec.connector.get.return_value = None res = await live_update._update_progressbar(progressbar, "async_dev1") assert res is False @@ -29,7 +29,7 @@ async def test_update_progressbar_continues_when_scanID_doesnt_match(): live_update.scan_item = mock.MagicMock() live_update.scan_item.scanID = "scanID2" - bec.producer.get.return_value = messages.ProgressMessage( + bec.connector.get.return_value = messages.ProgressMessage( value=1, max_value=10, done=False, metadata={"scanID": "scanID"} ) res = await live_update._update_progressbar(progressbar, "async_dev1") @@ -45,7 +45,7 @@ async def test_update_progressbar_continues_when_msg_specifies_no_value(): live_update.scan_item = mock.MagicMock() live_update.scan_item.scanID = "scanID" - bec.producer.get.return_value = messages.ProgressMessage( + bec.connector.get.return_value = messages.ProgressMessage( value=None, max_value=None, done=None, metadata={"scanID": "scanID"} ) res = await live_update._update_progressbar(progressbar, "async_dev1") @@ -61,7 +61,7 @@ async def test_update_progressbar_updates_max_value(): live_update.scan_item = mock.MagicMock() live_update.scan_item.scanID = "scanID" - bec.producer.get.return_value = messages.ProgressMessage( + bec.connector.get.return_value = messages.ProgressMessage( value=10, max_value=20, done=False, metadata={"scanID": "scanID"} ) res = await live_update._update_progressbar(progressbar, "async_dev1") @@ -79,7 +79,7 @@ async def test_update_progressbar_returns_true_when_max_value_is_reached(): live_update.scan_item = mock.MagicMock() live_update.scan_item.scanID = "scanID" - bec.producer.get.return_value = messages.ProgressMessage( + bec.connector.get.return_value = messages.ProgressMessage( value=10, max_value=10, done=True, metadata={"scanID": "scanID"} ) res = await live_update._update_progressbar(progressbar, "async_dev1") diff --git a/bec_client/tests/end-2-end/test_scans.py b/bec_client/tests/end-2-end/test_scans.py index 214f2ab2..ec65fe94 100644 --- a/bec_client/tests/end-2-end/test_scans.py +++ b/bec_client/tests/end-2-end/test_scans.py @@ -463,11 +463,11 @@ def test_file_writer(client): md={"datasetID": 325}, ) assert len(scan.scan.data) == 100 - msg = bec.device_manager.producer.get(MessageEndpoints.public_file(scan.scan.scanID, "master")) + msg = bec.device_manager.connector.get(MessageEndpoints.public_file(scan.scan.scanID, "master")) while True: if msg: break - msg = bec.device_manager.producer.get( + msg = bec.device_manager.connector.get( MessageEndpoints.public_file(scan.scan.scanID, "master") ) diff --git a/bec_lib/bec_lib/__init__.py b/bec_lib/bec_lib/__init__.py index a7deb4ec..54e3a0ab 100644 --- a/bec_lib/bec_lib/__init__.py +++ b/bec_lib/bec_lib/__init__.py @@ -3,7 +3,6 @@ from bec_lib.bec_service import BECService from bec_lib.channel_monitor import channel_monitor_launch from bec_lib.client import BECClient from bec_lib.config_helper import ConfigHelper -from bec_lib.connector import ProducerConnector from bec_lib.device import DeviceBase, DeviceStatus, Status from bec_lib.devicemanager import DeviceConfigError, DeviceContainer, DeviceManagerBase from bec_lib.endpoints import MessageEndpoints diff --git a/bec_lib/bec_lib/alarm_handler.py b/bec_lib/bec_lib/alarm_handler.py index ae08fabf..185a7792 100644 --- a/bec_lib/bec_lib/alarm_handler.py +++ b/bec_lib/bec_lib/alarm_handler.py @@ -48,23 +48,21 @@ class AlarmBase(Exception): class AlarmHandler: def __init__(self, connector: RedisConnector) -> None: self.connector = connector - self.alarm_consumer = None self.alarms_stack = deque(maxlen=100) self._raised_alarms = deque(maxlen=100) self._lock = threading.RLock() def start(self): """start the alarm handler and its subscriptions""" - self.alarm_consumer = self.connector.consumer( + self.connector.register( topics=MessageEndpoints.alarm(), name="AlarmHandler", - cb=self._alarm_consumer_callback, + cb=self._alarm_register_callback, parent=self, ) - self.alarm_consumer.start() @staticmethod - def _alarm_consumer_callback(msg, *, parent, **_kwargs): + def _alarm_register_callback(msg, *, parent, **_kwargs): parent.add_alarm(msg.value) @threadlocked @@ -136,4 +134,4 @@ class AlarmHandler: def shutdown(self): """shutdown the alarm handler""" - self.alarm_consumer.shutdown() + self.connector.shutdown() diff --git a/bec_lib/bec_lib/async_data.py b/bec_lib/bec_lib/async_data.py index 03c3bf94..b78fffa4 100644 --- a/bec_lib/bec_lib/async_data.py +++ b/bec_lib/bec_lib/async_data.py @@ -8,12 +8,12 @@ from bec_lib.endpoints import MessageEndpoints if TYPE_CHECKING: from bec_lib import messages - from bec_lib.redis_connector import RedisProducer + from bec_lib.connector import ConnectorBase class AsyncDataHandler: - def __init__(self, producer: RedisProducer): - self.producer = producer + def __init__(self, connector: ConnectorBase): + self.connector = connector def get_async_data_for_scan(self, scan_id: str) -> dict[list]: """ @@ -25,7 +25,9 @@ class AsyncDataHandler: Returns: dict[list]: the async data for the scan sorted by device name """ - async_device_keys = self.producer.keys(MessageEndpoints.device_async_readback(scan_id, "*")) + async_device_keys = self.connector.keys( + MessageEndpoints.device_async_readback(scan_id, "*") + ) async_data = {} for device_key in async_device_keys: key = device_key.decode() @@ -50,7 +52,7 @@ class AsyncDataHandler: list: the async data for the device """ key = MessageEndpoints.device_async_readback(scan_id, device_name) - msgs = self.producer.xrange(key, min="-", max="+") + msgs = self.connector.xrange(key, min="-", max="+") if not msgs: return [] return self.process_async_data(msgs) diff --git a/bec_lib/bec_lib/bec_plotter.py b/bec_lib/bec_lib/bec_plotter.py index f899bc2c..3f25e465 100644 --- a/bec_lib/bec_lib/bec_plotter.py +++ b/bec_lib/bec_lib/bec_plotter.py @@ -54,14 +54,14 @@ class BECWidgetsConnector: def __init__(self, gui_id: str, bec_client: BECClient = None) -> None: self._client = bec_client self.gui_id = gui_id - # TODO replace with a global producer + # TODO replace with a global connector if self._client is None: if "bec" in builtins.__dict__: self._client = builtins.bec else: self._client = BECClient() self._client.start() - self._producer = self._client.connector.producer() + self._connector = self._client.connector def set_plot_config(self, plot_id: str, config: dict) -> None: """ @@ -72,7 +72,7 @@ class BECWidgetsConnector: config (dict): The config to set. """ msg = messages.GUIConfigMessage(config=config) - self._producer.set_and_publish(MessageEndpoints.gui_config(plot_id), msg) + self._connector.set_and_publish(MessageEndpoints.gui_config(plot_id), msg) def close(self, plot_id: str) -> None: """ @@ -82,7 +82,7 @@ class BECWidgetsConnector: plot_id (str): The id of the plot. """ msg = messages.GUIInstructionMessage(action="close", parameter={}) - self._producer.set_and_publish(MessageEndpoints.gui_instructions(plot_id), msg) + self._connector.set_and_publish(MessageEndpoints.gui_instructions(plot_id), msg) def config_dialog(self, plot_id: str) -> None: """ @@ -92,7 +92,7 @@ class BECWidgetsConnector: plot_id (str): The id of the plot. """ msg = messages.GUIInstructionMessage(action="config_dialog", parameter={}) - self._producer.set_and_publish(MessageEndpoints.gui_instructions(plot_id), msg) + self._connector.set_and_publish(MessageEndpoints.gui_instructions(plot_id), msg) def send_data(self, plot_id: str, data: dict) -> None: """ @@ -103,9 +103,9 @@ class BECWidgetsConnector: data (dict): The data to send. """ msg = messages.GUIDataMessage(data=data) - self._producer.set_and_publish(topic=MessageEndpoints.gui_data(plot_id), msg=msg) + self._connector.set_and_publish(topic=MessageEndpoints.gui_data(plot_id), msg=msg) # TODO bec_dispatcher can only handle set_and_publish ATM - # self._producer.xadd(topic=MessageEndpoints.gui_data(plot_id),msg= {"data": msg}) + # self._connector.xadd(topic=MessageEndpoints.gui_data(plot_id),msg= {"data": msg}) def clear(self, plot_id: str) -> None: """ @@ -115,7 +115,7 @@ class BECWidgetsConnector: plot_id (str): The id of the plot. """ msg = messages.GUIInstructionMessage(action="clear", parameter={}) - self._producer.set_and_publish(MessageEndpoints.gui_instructions(plot_id), msg) + self._connector.set_and_publish(MessageEndpoints.gui_instructions(plot_id), msg) class BECPlotter: diff --git a/bec_lib/bec_lib/bec_service.py b/bec_lib/bec_lib/bec_service.py index 129a2bb6..77768940 100644 --- a/bec_lib/bec_lib/bec_service.py +++ b/bec_lib/bec_lib/bec_service.py @@ -41,7 +41,6 @@ class BECService: self.connector = connector_cls(self.bootstrap_server) self._unique_service = unique_service self.wait_for_server = wait_for_server - self.producer = self.connector.producer() self.__service_id = str(uuid.uuid4()) self._user = getpass.getuser() self._hostname = socket.gethostname() @@ -110,11 +109,11 @@ class BECService: ) def _update_existing_services(self): - service_keys = self.producer.keys(MessageEndpoints.service_status("*")) + service_keys = self.connector.keys(MessageEndpoints.service_status("*")) if not service_keys: return services = [service.decode().split(":", maxsplit=1)[0] for service in service_keys] - msgs = [self.producer.get(service) for service in services] + msgs = [self.connector.get(service) for service in services] self._services_info = {msg.content["name"]: msg for msg in msgs if msg is not None} def _update_service_info(self): @@ -124,7 +123,7 @@ class BECService: self._service_info_event.wait(timeout=3) def _send_service_status(self): - self.producer.set_and_publish( + self.connector.set_and_publish( topic=MessageEndpoints.service_status(self._service_id), msg=messages.StatusMessage( name=self._service_name, @@ -189,7 +188,7 @@ class BECService: ) ) msg = messages.ServiceMetricMessage(name=self.__class__.__name__, metrics=data) - self.producer.send(MessageEndpoints.metrics(self._service_id), msg) + self.connector.send(MessageEndpoints.metrics(self._service_id), msg) self._metrics_emitter_event.wait(timeout=1) def set_global_var(self, name: str, val: Any) -> None: @@ -200,7 +199,7 @@ class BECService: val (Any): Value of the variable """ - self.producer.set(MessageEndpoints.global_vars(name), messages.VariableMessage(value=val)) + self.connector.set(MessageEndpoints.global_vars(name), messages.VariableMessage(value=val)) def get_global_var(self, name: str) -> Any: """Get a global variable from Redis @@ -211,7 +210,7 @@ class BECService: Returns: Any: Value of the variable """ - msg = self.producer.get(MessageEndpoints.global_vars(name)) + msg = self.connector.get(MessageEndpoints.global_vars(name)) if msg: return msg.content.get("value") return None @@ -223,12 +222,12 @@ class BECService: name (str): Name of the variable """ - self.producer.delete(MessageEndpoints.global_vars(name)) + self.connector.delete(MessageEndpoints.global_vars(name)) def global_vars(self) -> str: """Get all available global variables""" # sadly, this cannot be a property as it causes side effects with IPython's tab completion - available_keys = self.producer.keys(MessageEndpoints.global_vars("*")) + available_keys = self.connector.keys(MessageEndpoints.global_vars("*")) def get_endpoint_from_topic(topic: str) -> str: return topic.decode().split(MessageEndpoints.global_vars(""))[-1] @@ -252,6 +251,7 @@ class BECService: def shutdown(self): """shutdown the BECService""" + self.connector.shutdown() self._service_info_event.set() if self._service_info_thread: self._service_info_thread.join() diff --git a/bec_lib/bec_lib/channel_monitor.py b/bec_lib/bec_lib/channel_monitor.py index 957908bd..bb09d9cb 100644 --- a/bec_lib/bec_lib/channel_monitor.py +++ b/bec_lib/bec_lib/channel_monitor.py @@ -16,11 +16,11 @@ def channel_callback(msg, **_kwargs): print(json.dumps(out, indent=4, default=lambda o: "")) -def _start_consumer(config_path, topic): +def _start_register(config_path, topic): config = ServiceConfig(config_path) connector = RedisConnector(config.redis) - consumer = connector.consumer(topics=topic, cb=channel_callback) - consumer.start() + register = connector.register(topics=topic, cb=channel_callback) + register.start() event = threading.Event() event.wait() @@ -38,4 +38,4 @@ def channel_monitor_launch(): config_path = clargs.config topic = clargs.channel - _start_consumer(config_path, topic) + _start_register(config_path, topic) diff --git a/bec_lib/bec_lib/client.py b/bec_lib/bec_lib/client.py index daffb9c8..93027df8 100644 --- a/bec_lib/bec_lib/client.py +++ b/bec_lib/bec_lib/client.py @@ -91,7 +91,7 @@ class BECClient(BECService, UserScriptsMixin): @property def active_account(self) -> str: """get the currently active target (e)account""" - return self.producer.get(MessageEndpoints.account()) + return self.connector.get(MessageEndpoints.account()) def start(self): """start the client""" @@ -133,13 +133,13 @@ class BECClient(BECService, UserScriptsMixin): @property def pre_scan_hooks(self): """currently stored pre-scan hooks""" - return self.producer.lrange(MessageEndpoints.pre_scan_macros(), 0, -1) + return self.connector.lrange(MessageEndpoints.pre_scan_macros(), 0, -1) @pre_scan_hooks.setter def pre_scan_hooks(self, hooks: list): - self.producer.delete(MessageEndpoints.pre_scan_macros()) + self.connector.delete(MessageEndpoints.pre_scan_macros()) for hook in hooks: - self.producer.lpush(MessageEndpoints.pre_scan_macros(), hook) + self.connector.lpush(MessageEndpoints.pre_scan_macros(), hook) def _load_scans(self): self.scans = Scans(self) diff --git a/bec_lib/bec_lib/config_helper.py b/bec_lib/bec_lib/config_helper.py index 56926e52..5d85e993 100644 --- a/bec_lib/bec_lib/config_helper.py +++ b/bec_lib/bec_lib/config_helper.py @@ -26,7 +26,6 @@ logger = bec_logger.logger class ConfigHelper: def __init__(self, connector: RedisConnector, service_name: str = None) -> None: self.connector = connector - self.producer = connector.producer() self._service_name = service_name def update_session_with_file(self, file_path: str, save_recovery: bool = True) -> None: @@ -71,7 +70,7 @@ class ConfigHelper: print(f"Config was written to {file_path}.") def _save_config_to_file(self, file_path: str, raise_on_error: bool = True) -> bool: - config = self.producer.get(MessageEndpoints.device_config()) + config = self.connector.get(MessageEndpoints.device_config()) if not config: if raise_on_error: raise DeviceConfigError("No config found in the session.") @@ -99,7 +98,7 @@ class ConfigHelper: if action in ["update", "add", "set"] and not config: raise DeviceConfigError(f"Config cannot be empty for an {action} request.") RID = str(uuid.uuid4()) - self.producer.send( + self.connector.send( MessageEndpoints.device_config_request(), DeviceConfigMessage(action=action, config=config, metadata={"RID": RID}), ) @@ -145,7 +144,7 @@ class ConfigHelper: elapsed_time = 0 max_time = timeout while True: - service_messages = self.producer.lrange(MessageEndpoints.service_response(RID), 0, -1) + service_messages = self.connector.lrange(MessageEndpoints.service_response(RID), 0, -1) if not service_messages: time.sleep(0.005) elapsed_time += 0.005 @@ -185,7 +184,7 @@ class ConfigHelper: """ start = 0 while True: - msg = self.producer.get(MessageEndpoints.device_config_request_response(RID)) + msg = self.connector.get(MessageEndpoints.device_config_request_response(RID)) if msg is None: time.sleep(0.01) start += 0.01 diff --git a/bec_lib/bec_lib/connector.py b/bec_lib/bec_lib/connector.py index ccfed001..418dbc1a 100644 --- a/bec_lib/bec_lib/connector.py +++ b/bec_lib/bec_lib/connector.py @@ -6,7 +6,8 @@ import threading import traceback from bec_lib.logger import bec_logger -from bec_lib.messages import BECMessage +from bec_lib.messages import BECMessage, LogMessage +from bec_lib.endpoints import MessageEndpoints logger = bec_logger.logger @@ -33,154 +34,98 @@ class MessageObject: return f"MessageObject(topic={self.topic}, value={self._value})" -class ConnectorBase(abc.ABC): - """ - ConnectorBase implements producer and consumer clients for communicating with a broker. - One ought to inherit from this base class and provide at least customized producer and consumer methods. +class StoreInterface(abc.ABC): + """StoreBase defines the interface for storing data""" - """ - - def __init__(self, bootstrap_server: list): - self.bootstrap = bootstrap_server - self._threads = [] - - def producer(self, **kwargs) -> ProducerConnector: - raise NotImplementedError - - def consumer(self, **kwargs) -> ConsumerConnectorThreaded: - raise NotImplementedError - - def shutdown(self): - for t in self._threads: - t.signal_event.set() - t.join() - - def raise_warning(self, msg): - raise NotImplementedError - - def send_log(self, msg): - raise NotImplementedError - - def poll_messages(self): - """Poll for new messages, receive them and execute callbacks""" + def __init__(self, store): pass + def pipeline(self): + pass -class ProducerConnector(abc.ABC): + def execute_pipeline(self): + pass + + def lpush( + self, topic: str, msg: str, pipe=None, max_size: int = None, expire: int = None + ) -> None: + raise NotImplementedError + + def lset(self, topic: str, index: int, msg: str, pipe=None) -> None: + raise NotImplementedError + + def rpush(self, topic: str, msg: str, pipe=None) -> int: + raise NotImplementedError + + def lrange(self, topic: str, start: int, end: int, pipe=None): + raise NotImplementedError + + def set(self, topic: str, msg, pipe=None, expire: int = None) -> None: + raise NotImplementedError + + def keys(self, pattern: str) -> list: + raise NotImplementedError + + def delete(self, topic, pipe=None): + raise NotImplementedError + + def get(self, topic: str, pipe=None): + raise NotImplementedError + + def xadd(self, topic: str, msg: dict, max_size=None, pipe=None, expire: int = None): + raise NotImplementedError + + def xread( + self, + topic: str, + id: str = None, + count: int = None, + block: int = None, + pipe=None, + from_start=False, + ) -> list: + raise NotImplementedError + + def xrange(self, topic: str, min: str, max: str, count: int = None, pipe=None): + raise NotImplementedError + + +class PubSubInterface(abc.ABC): def raw_send(self, topic: str, msg: bytes) -> None: raise NotImplementedError def send(self, topic: str, msg: BECMessage) -> None: raise NotImplementedError - -class ConsumerConnector(abc.ABC): - def __init__( - self, bootstrap_server, cb, topics=None, pattern=None, group_id=None, event=None, **kwargs - ): - """ - ConsumerConnector class defines the communication with the broker for consuming messages. - An implementation ought to inherit from this class and implement the initialize_connector and poll_messages methods. - - Args: - bootstrap_server: list of bootstrap servers, e.g. ["localhost:9092", "localhost:9093"] - topics: the topic(s) to which the connector should attach - event: external event to trigger start and stop of the connector - cb: callback function; will be triggered from within poll_messages - kwargs: additional keyword arguments - - """ - self.bootstrap = bootstrap_server - self.topics = topics - self.pattern = pattern - self.group_id = group_id - self.connector = None - self.cb = cb - self.kwargs = kwargs - - if not self.topics and not self.pattern: - raise ConsumerConnectorError("Either a topic or a patter must be specified.") - - def initialize_connector(self) -> None: - """ - initialize the connector instance self.connector - The connector will be initialized once the thread is started - """ + def register(self, topics=None, pattern=None, cb=None, start_thread=True, **kwargs): raise NotImplementedError - def poll_messages(self) -> None: - """ - Poll messages from self.connector and call the callback function self.cb + def poll_messages(self, timeout=None): + """Poll for new messages, receive them and execute callbacks""" + raise NotImplementedError - """ - raise NotImplementedError() - - -class ConsumerConnectorThreaded(ConsumerConnector, threading.Thread): - def __init__( - self, - bootstrap_server, - cb, - topics=None, - pattern=None, - group_id=None, - event=None, - name=None, - **kwargs, - ): - """ - ConsumerConnectorThreaded class defines the threaded communication with the broker for consuming messages. - An implementation ought to inherit from this class and implement the initialize_connector and poll_messages methods. - Once started, the connector is expected to poll new messages until the signal_event is set. - - Args: - bootstrap_server: list of bootstrap servers, e.g. ["localhost:9092", "localhost:9093"] - topics: the topic(s) to which the connector should attach - event: external event to trigger start and stop of the connector - cb: callback function; will be triggered from within poll_messages - kwargs: additional keyword arguments - - """ - super().__init__( - bootstrap_server=bootstrap_server, - topics=topics, - pattern=pattern, - group_id=group_id, - event=event, - cb=cb, - **kwargs, - ) - if name is not None: - thread_kwargs = {"name": name, "daemon": True} - else: - thread_kwargs = {"daemon": True} - super(ConsumerConnector, self).__init__(**thread_kwargs) - self.signal_event = event if event is not None else threading.Event() - - def run(self): - self.initialize_connector() - - while True: - try: - self.poll_messages() - except Exception as e: - logger.error(traceback.format_exc()) - _thread.interrupt_main() - raise e - finally: - if self.signal_event.is_set(): - self.shutdown() - break + def run_messages_loop(self): + raise NotImplementedError def shutdown(self): - self.signal_event.set() + raise NotImplementedError - # def stop(self) -> None: - # """ - # Stop consumer - # Returns: - # """ - # self.signal_event.set() - # self.connector.close() - # self.join() +class ConnectorBase(PubSubInterface, StoreInterface): + def raise_warning(self, msg): + raise NotImplementedError + + def log_warning(self, msg): + """send a warning""" + self.send(MessageEndpoints.log(), LogMessage(log_type="warning", log_msg=msg)) + + def log_message(self, msg): + """send a log message""" + self.send(MessageEndpoints.log(), LogMessage(log_type="log", log_msg=msg)) + + def log_error(self, msg): + """send an error as log""" + self.send(MessageEndpoints.log(), LogMessage(log_type="error", log_msg=msg)) + + def set_and_publish(self, topic: str, msg, pipe=None, expire: int = None) -> None: + raise NotImplementedError diff --git a/bec_lib/bec_lib/dap_plugin_objects.py b/bec_lib/bec_lib/dap_plugin_objects.py index 838a5fce..cfdc10bc 100644 --- a/bec_lib/bec_lib/dap_plugin_objects.py +++ b/bec_lib/bec_lib/dap_plugin_objects.py @@ -78,7 +78,7 @@ class DAPPluginObjectBase: converted_kwargs[key] = val kwargs = converted_kwargs request_id = str(uuid.uuid4()) - self._client.producer.set_and_publish( + self._client.connector.set_and_publish( MessageEndpoints.dap_request(), messages.DAPRequestMessage( dap_cls=self._plugin_info["class"], @@ -110,7 +110,7 @@ class DAPPluginObjectBase: while True: if time.time() - start_time > timeout: raise TimeoutError("Timeout waiting for DAP response.") - response = self._client.producer.get(MessageEndpoints.dap_response(request_id)) + response = self._client.connector.get(MessageEndpoints.dap_response(request_id)) if not response: time.sleep(0.005) continue @@ -128,7 +128,7 @@ class DAPPluginObjectBase: return self._plugin_config["class_args"] = self._plugin_info.get("class_args") self._plugin_config["class_kwargs"] = self._plugin_info.get("class_kwargs") - self._client.producer.set_and_publish( + self._client.connector.set_and_publish( MessageEndpoints.dap_request(), messages.DAPRequestMessage( dap_cls=self._plugin_info["class"], @@ -149,7 +149,7 @@ class DAPPluginObject(DAPPluginObjectBase): """ Get the data from last run. """ - msg = self._client.producer.get_last(MessageEndpoints.processed_data(self._service_name)) + msg = self._client.connector.get_last(MessageEndpoints.processed_data(self._service_name)) if not msg: return None return self._convert_result(msg) diff --git a/bec_lib/bec_lib/dap_plugins.py b/bec_lib/bec_lib/dap_plugins.py index fe7b9a80..62e95f4d 100644 --- a/bec_lib/bec_lib/dap_plugins.py +++ b/bec_lib/bec_lib/dap_plugins.py @@ -38,7 +38,7 @@ class DAPPlugins: service for service in available_services if service.startswith("DAPServer/") ] for service in dap_services: - available_plugins = self._parent.producer.get( + available_plugins = self._parent.connector.get( MessageEndpoints.dap_available_plugins(service) ) if available_plugins is None: diff --git a/bec_lib/bec_lib/device.py b/bec_lib/bec_lib/device.py index f5a6e074..fcc7a8fd 100644 --- a/bec_lib/bec_lib/device.py +++ b/bec_lib/bec_lib/device.py @@ -10,7 +10,7 @@ from typeguard import typechecked from bec_lib import messages from bec_lib.endpoints import MessageEndpoints from bec_lib.logger import bec_logger -from bec_lib.redis_connector import RedisProducer +from bec_lib.redis_connector import RedisConnector logger = bec_logger.logger @@ -61,15 +61,15 @@ class ReadoutPriority(str, enum.Enum): class Status: - def __init__(self, producer: RedisProducer, RID: str) -> None: + def __init__(self, connector: RedisConnector, RID: str) -> None: """ Status object for RPC calls Args: - producer (RedisProducer): Redis producer + connector (RedisConnector): Redis connector RID (str): Request ID """ - self._producer = producer + self._connector = connector self._RID = RID def __eq__(self, __value: object) -> bool: @@ -91,7 +91,7 @@ class Status: raise TimeoutError() while True: - request_status = self._producer.lrange( + request_status = self._connector.lrange( MessageEndpoints.device_req_status(self._RID), 0, -1 ) if request_status: @@ -251,7 +251,7 @@ class DeviceBase: if not isinstance(return_val, dict): return return_val if return_val.get("type") == "status" and return_val.get("RID"): - return Status(self.root.parent.producer, return_val.get("RID")) + return Status(self.root.parent.connector, return_val.get("RID")) return return_val def _get_rpc_response(self, request_id, rpc_id) -> Any: @@ -267,7 +267,7 @@ class DeviceBase: f" {scan_queue_request.response.content['message']}" ) while True: - msg = self.root.parent.producer.get(MessageEndpoints.device_rpc(rpc_id)) + msg = self.root.parent.connector.get(MessageEndpoints.device_rpc(rpc_id)) if msg: break time.sleep(0.01) @@ -296,7 +296,7 @@ class DeviceBase: msg = self._prepare_rpc_msg(rpc_id, request_id, device, func_call, *args, **kwargs) # send RPC message - self.root.parent.producer.send(MessageEndpoints.scan_queue_request(), msg) + self.root.parent.connector.send(MessageEndpoints.scan_queue_request(), msg) # wait for RPC response if not wait_for_rpc_response: @@ -496,7 +496,7 @@ class DeviceBase: # def read(self, cached, filter_readback=True): # """get the last reading from a device""" - # val = self.parent.producer.get(MessageEndpoints.device_read(self.name)) + # val = self.parent.connector.get(MessageEndpoints.device_read(self.name)) # if not val: # return None # if filter_readback: @@ -505,7 +505,7 @@ class DeviceBase: # # def readback(self, filter_readback=True): # """get the last readback value from a device""" - # val = self.parent.producer.get(MessageEndpoints.device_readback(self.name)) + # val = self.parent.connector.get(MessageEndpoints.device_readback(self.name)) # if not val: # return None # if filter_readback: @@ -515,7 +515,7 @@ class DeviceBase: # @property # def device_status(self): # """get the current status of the device""" - # val = self.parent.producer.get(MessageEndpoints.device_status(self.name)) + # val = self.parent.connector.get(MessageEndpoints.device_status(self.name)) # if val is None: # return val # val = DeviceStatusMessage.loads(val) @@ -524,7 +524,7 @@ class DeviceBase: # @property # def signals(self): # """get the last signals from a device""" - # val = self.parent.producer.get(MessageEndpoints.device_read(self.name)) + # val = self.parent.connector.get(MessageEndpoints.device_read(self.name)) # if val is None: # return None # self._signals = DeviceMessage.loads(val).content["signals"] @@ -593,11 +593,11 @@ class OphydInterfaceBase(DeviceBase): if is_config_signal: return self.read_configuration(cached=cached) if use_readback: - val = self.root.parent.producer.get( + val = self.root.parent.connector.get( MessageEndpoints.device_readback(self.root.name) ) else: - val = self.root.parent.producer.get(MessageEndpoints.device_read(self.root.name)) + val = self.root.parent.connector.get(MessageEndpoints.device_read(self.root.name)) if not val: return None @@ -623,7 +623,7 @@ class OphydInterfaceBase(DeviceBase): if is_signal and not is_config_signal: return self.read(cached=True) - val = self.root.parent.producer.get( + val = self.root.parent.connector.get( MessageEndpoints.device_read_configuration(self.root.name) ) if not val: @@ -766,7 +766,7 @@ class AdjustableMixin: """ Returns the device limits. """ - limit_msg = self.root.parent.producer.get(MessageEndpoints.device_limits(self.root.name)) + limit_msg = self.root.parent.connector.get(MessageEndpoints.device_limits(self.root.name)) if not limit_msg: return [0, 0] limits = [ diff --git a/bec_lib/bec_lib/devicemanager.py b/bec_lib/bec_lib/devicemanager.py index 7edb7c2c..327aebe8 100644 --- a/bec_lib/bec_lib/devicemanager.py +++ b/bec_lib/bec_lib/devicemanager.py @@ -370,8 +370,7 @@ class DeviceManagerBase: _request_config_parsed = None # parsed config request _response = None # response message - _connector_base_consumer = {} - producer = None + _connector_base_register = {} config_helper = None _device_cls = DeviceBase _status_cb = [] @@ -464,7 +463,7 @@ class DeviceManagerBase: """ if not msg.metadata.get("RID"): return - self.producer.lpush( + self.connector.lpush( MessageEndpoints.service_response(msg.metadata["RID"]), messages.ServiceResponseMessage( # pylint: disable=no-member @@ -487,25 +486,20 @@ class DeviceManagerBase: self._remove_device(dev) def _start_connectors(self, bootstrap_server) -> None: - self._start_base_consumer() - self.producer = self.connector.producer() - self._start_custom_connectors(bootstrap_server) + self._start_base_register() - def _start_base_consumer(self) -> None: + def _start_base_register(self) -> None: """ Start consuming messages for all base topics. This method will be called upon startup. Returns: """ - self._connector_base_consumer["device_config_update"] = self.connector.consumer( + self.connector.register( MessageEndpoints.device_config_update(), cb=self._device_config_update_callback, parent=self, ) - # self._connector_base_consumer["log"].start() - self._connector_base_consumer["device_config_update"].start() - @staticmethod def _log_callback(msg, *, parent, **kwargs) -> None: """ @@ -541,48 +535,11 @@ class DeviceManagerBase: self._load_session() def _get_redis_device_config(self) -> list: - devices = self.producer.get(MessageEndpoints.device_config()) + devices = self.connector.get(MessageEndpoints.device_config()) if not devices: return [] return devices.content["resource"] - def _stop_base_consumer(self): - """ - Stop all base consumers by setting the corresponding event - Returns: - - """ - if self.connector is not None: - for _, con in self._connector_base_consumer.items(): - con.signal_event.set() - con.join() - - def _stop_consumer(self): - """ - Stop all consumers - Returns: - - """ - self._stop_base_consumer() - self._stop_custom_consumer() - - def _start_custom_connectors(self, bootstrap_server) -> None: - """ - Override this method in a derived class to start custom connectors upon initialization. - Args: - bootstrap_server: Kafka bootstrap server - - Returns: - - """ - - def _stop_custom_consumer(self) -> None: - """ - Stop all custom consumers. Override this method in a derived class. - Returns: - - """ - def _add_device(self, dev: dict, msg: messages.DeviceInfoMessage): name = msg.content["device"] info = msg.content["info"] @@ -621,8 +578,7 @@ class DeviceManagerBase: logger.error(f"Failed to load device {dev}: {content}") def _get_device_info(self, device_name) -> DeviceInfoMessage: - msg = self.producer.get(MessageEndpoints.device_info(device_name)) - return msg + return self.connector.get(MessageEndpoints.device_info(device_name)) def check_request_validity(self, msg: DeviceConfigMessage) -> None: """ @@ -663,10 +619,7 @@ class DeviceManagerBase: """ Shutdown all connectors. """ - try: - self.connector.shutdown() - except RuntimeError as runtime_error: - logger.error(f"Failed to shutdown connector. {runtime_error}") + self.connector.shutdown() def __del__(self): self.shutdown() diff --git a/bec_lib/bec_lib/logbook_connector.py b/bec_lib/bec_lib/logbook_connector.py index 4b75d57e..c857db4f 100644 --- a/bec_lib/bec_lib/logbook_connector.py +++ b/bec_lib/bec_lib/logbook_connector.py @@ -24,7 +24,6 @@ if TYPE_CHECKING: class LogbookConnector: def __init__(self, connector: RedisConnector) -> None: self.connector = connector - self.producer = connector.producer() self.connected = False self._scilog_module = None self._connect() @@ -34,12 +33,12 @@ class LogbookConnector: if "scilog" not in sys.modules: return - msg = self.producer.get(MessageEndpoints.logbook()) + msg = self.connector.get(MessageEndpoints.logbook()) if not msg: return msg = msgpack.loads(msg) - account = self.producer.get(MessageEndpoints.account()) + account = self.connector.get(MessageEndpoints.account()) if not account: return account = account.decode() @@ -54,7 +53,7 @@ class LogbookConnector: try: logbooks = self.log.get_logbooks(readACL={"inq": [account]}) except HTTPError: - self.producer.set(MessageEndpoints.logbook(), b"") + self.connector.set(MessageEndpoints.logbook(), b"") return if len(logbooks) > 1: logger.warning("Found two logbooks. Taking the first one.") diff --git a/bec_lib/bec_lib/logger.py b/bec_lib/bec_lib/logger.py index f8b0f965..5e006cbd 100644 --- a/bec_lib/bec_lib/logger.py +++ b/bec_lib/bec_lib/logger.py @@ -45,7 +45,6 @@ class BECLogger: self.bootstrap_server = None self.connector = None self.service_name = None - self.producer = None self.logger = loguru_logger self._log_level = LogLevel.INFO self.level = self._log_level @@ -73,7 +72,6 @@ class BECLogger: self.bootstrap_server = bootstrap_server self.connector = connector_cls(bootstrap_server) self.service_name = service_name - self.producer = self.connector.producer() self._configured = True self._update_sinks() @@ -82,7 +80,7 @@ class BECLogger: return msg = json.loads(msg) msg["service_name"] = self.service_name - self.producer.send( + self.connector.send( topic=MessageEndpoints.log(), msg=bec_lib.messages.LogMessage(log_type=msg["record"]["level"]["name"], log_msg=msg), ) diff --git a/bec_lib/bec_lib/observer.py b/bec_lib/bec_lib/observer.py index f71ef8a6..9e7740db 100644 --- a/bec_lib/bec_lib/observer.py +++ b/bec_lib/bec_lib/observer.py @@ -152,7 +152,7 @@ class ObserverManager: def _get_installed_observer(self): # get current observer list from Redis - observer_msg = self.device_manager.producer.get(MessageEndpoints.observer()) + observer_msg = self.device_manager.connector.get(MessageEndpoints.observer()) if observer_msg is None: return [] return [Observer.from_dict(obs) for obs in observer_msg.content["observer"]] diff --git a/bec_lib/bec_lib/queue_items.py b/bec_lib/bec_lib/queue_items.py index 4001f86a..bf039dbd 100644 --- a/bec_lib/bec_lib/queue_items.py +++ b/bec_lib/bec_lib/queue_items.py @@ -122,12 +122,16 @@ class QueueStorage: if history < 0: history *= -1 - return self.scan_manager.producer.lrange(MessageEndpoints.scan_queue_history(), 0, history) + return self.scan_manager.connector.lrange( + MessageEndpoints.scan_queue_history(), + 0, + history, + ) @property def current_scan_queue(self) -> dict: """get the current scan queue from redis""" - msg = self.scan_manager.producer.get(MessageEndpoints.scan_queue_status()) + msg = self.scan_manager.connector.get(MessageEndpoints.scan_queue_status()) if msg: self._current_scan_queue = msg.content["queue"] return self._current_scan_queue diff --git a/bec_lib/bec_lib/redis_connector.py b/bec_lib/bec_lib/redis_connector.py index 40150ead..96d8e1bf 100644 --- a/bec_lib/bec_lib/redis_connector.py +++ b/bec_lib/bec_lib/redis_connector.py @@ -1,19 +1,18 @@ from __future__ import annotations -import time +import collections +import queue +import sys +import threading import warnings from functools import wraps from typing import TYPE_CHECKING +import louie import redis +import redis.client -from bec_lib.connector import ( - ConnectorBase, - ConsumerConnector, - ConsumerConnectorThreaded, - MessageObject, - ProducerConnector, -) +from bec_lib.connector import ConnectorBase, MessageObject from bec_lib.endpoints import MessageEndpoints from bec_lib.messages import AlarmMessage, BECMessage, LogMessage from bec_lib.serialization import MsgpackSerialization @@ -31,7 +30,6 @@ def catch_connection_error(func): return func(*args, **kwargs) except redis.exceptions.ConnectionError: warnings.warn("Failed to connect to redis. Is the server running?") - time.sleep(0.1) return None return wrapper @@ -40,149 +38,80 @@ def catch_connection_error(func): class RedisConnector(ConnectorBase): def __init__(self, bootstrap: list, redis_cls=None): super().__init__(bootstrap) - self.redis_cls = redis_cls self.host, self.port = ( bootstrap[0].split(":") if isinstance(bootstrap, list) else bootstrap.split(":") ) - self._notifications_producer = RedisProducer( - host=self.host, port=self.port, redis_cls=self.redis_cls - ) - def producer(self, **kwargs): - return RedisProducer(host=self.host, port=self.port, redis_cls=self.redis_cls) + if redis_cls: + self._redis_conn = redis_cls(host=self.host, port=self.port) + else: + self._redis_conn = redis.Redis(host=self.host, port=self.port) - # pylint: disable=too-many-arguments - def consumer( - self, - topics=None, - pattern=None, - group_id=None, - event=None, - cb=None, - threaded=True, - name=None, - **kwargs, - ): - if cb is None: - raise ValueError("The callback function must be specified.") + # main pubsub connection + self._pubsub_conn = self._redis_conn.pubsub() + self._pubsub_conn.ignore_subscribe_messages = True + # keep track of topics and callbacks + self._topics_cb = collections.defaultdict(list) - if threaded: - if topics is None and pattern is None: - raise ValueError("Topics must be set for threaded consumer") - listener = RedisConsumerThreaded( - self.host, - self.port, - topics, - pattern, - group_id, - event, - cb, - redis_cls=self.redis_cls, - name=name, - **kwargs, - ) - self._threads.append(listener) - return listener - return RedisConsumer( - self.host, - self.port, - topics, - pattern, - group_id, - event, - cb, - redis_cls=self.redis_cls, - **kwargs, - ) + self._events_listener_thread = None + self._events_dispatcher_thread = None + self._messages_queue = queue.Queue() + self._stop_events_listener_thread = threading.Event() - def stream_consumer( - self, - topics=None, - pattern=None, - group_id=None, - event=None, - cb=None, - from_start=False, - newest_only=False, - **kwargs, - ): - """ - Threaded stream consumer for redis streams. + self.stream_keys = {} - Args: - topics (str, list): topics to subscribe to - pattern (str, list): pattern to subscribe to - group_id (str): group id - event (threading.Event): event to stop the consumer - cb (function): callback function - from_start (bool): read from start. Defaults to False. - newest_only (bool): read only the newest message. Defaults to False. - """ - if cb is None: - raise ValueError("The callback function must be specified.") - - if pattern: - raise ValueError("Pattern is currently not supported for stream consumer.") - - if topics is None and pattern is None: - raise ValueError("Topics must be set for stream consumer.") - listener = RedisStreamConsumerThreaded( - self.host, - self.port, - topics, - pattern, - group_id, - event, - cb, - redis_cls=self.redis_cls, - from_start=from_start, - newest_only=newest_only, - **kwargs, - ) - self._threads.append(listener) - return listener + def shutdown(self): + if self._events_listener_thread: + self._stop_events_listener_thread.set() + self._events_listener_thread.join() + self._events_listener_thread = None + if self._events_dispatcher_thread: + self._messages_queue.put(StopIteration) + self._events_dispatcher_thread.join() + self._events_dispatcher_thread = None + # release all connections + self._pubsub_conn.close() + self._redis_conn.close() @catch_connection_error def log_warning(self, msg): """send a warning""" - self._notifications_producer.send( - MessageEndpoints.log(), LogMessage(log_type="warning", log_msg=msg) - ) + self.send(MessageEndpoints.log(), LogMessage(log_type="warning", log_msg=msg)) @catch_connection_error def log_message(self, msg): """send a log message""" - self._notifications_producer.send( - MessageEndpoints.log(), LogMessage(log_type="log", log_msg=msg) - ) + self.send(MessageEndpoints.log(), LogMessage(log_type="log", log_msg=msg)) @catch_connection_error def log_error(self, msg): """send an error as log""" - self._notifications_producer.send( - MessageEndpoints.log(), LogMessage(log_type="error", log_msg=msg) - ) + self.send(MessageEndpoints.log(), LogMessage(log_type="error", log_msg=msg)) @catch_connection_error - def raise_alarm(self, severity: Alarms, alarm_type: str, source: str, msg: str, metadata: dict): + def raise_alarm( + self, + severity: Alarms, + alarm_type: str, + source: str, + msg: str, + metadata: dict, + ): """raise an alarm""" - self._notifications_producer.set_and_publish( - MessageEndpoints.alarm(), - AlarmMessage( - severity=severity, alarm_type=alarm_type, source=source, msg=msg, metadata=metadata - ), + alarm_msg = AlarmMessage( + severity=severity, + alarm_type=alarm_type, + source=source, + msg=msg, + metadata=metadata, ) + self.set_and_publish(MessageEndpoints.alarm(), alarm_msg) + def pipeline(self): + """Create a new pipeline""" + return self._redis_conn.pipeline() -class RedisProducer(ProducerConnector): - def __init__(self, host: str, port: int, redis_cls=None) -> None: - # pylint: disable=invalid-name - if redis_cls: - self.r = redis_cls(host=host, port=port) - return - self.r = redis.Redis(host=host, port=port) - self.stream_keys = {} - + @catch_connection_error def execute_pipeline(self, pipeline): """Execute the pipeline and returns the results with decoded BECMessages""" ret = [] @@ -197,7 +126,7 @@ class RedisProducer(ProducerConnector): @catch_connection_error def raw_send(self, topic: str, msg: bytes, pipe=None): """send to redis without any check on message type""" - client = pipe if pipe is not None else self.r + client = pipe if pipe is not None else self._redis_conn client.publish(topic, msg) def send(self, topic: str, msg: BECMessage, pipe=None) -> None: @@ -206,6 +135,95 @@ class RedisProducer(ProducerConnector): raise TypeError(f"Message {msg} is not a BECMessage") self.raw_send(topic, MsgpackSerialization.dumps(msg), pipe) + def register(self, topics=None, patterns=None, cb=None, start_thread=True, **kwargs): + if self._events_listener_thread is None: + # create the thread that will get all messages for this connector; + # under the hood, it uses asyncio - this lets the possibility to stop + # the loop on demand + self._events_listener_thread = threading.Thread( + target=self._get_messages_loop, + args=(self._pubsub_conn,), + ) + self._events_listener_thread.start() + # make a weakref from the callable, using louie; + # it can create safe refs for simple functions as well as methods + cb_ref = louie.saferef.safe_ref(cb) + + if patterns is not None: + if isinstance(patterns, str): + patterns = [patterns] + + self._pubsub_conn.psubscribe(patterns) + for pattern in patterns: + self._topics_cb[pattern].append((cb_ref, kwargs)) + else: + if isinstance(topics, str): + topics = [topics] + + self._pubsub_conn.subscribe(topics) + for topic in topics: + self._topics_cb[topic].append((cb_ref, kwargs)) + + if start_thread and self._events_dispatcher_thread is None: + # start dispatcher thread + self._events_dispatcher_thread = threading.Thread(target=self.dispatch_events) + self._events_dispatcher_thread.start() + + def _get_messages_loop(self, pubsub) -> None: + """ + Start a listening coroutine to deal with redis events and wait for completion + """ + while not self._stop_events_listener_thread.is_set(): + try: + msg = pubsub.get_message(timeout=1) + except Exception: + sys.excepthook(*sys.exc_info()) + else: + if msg is not None: + self._messages_queue.put(msg) + + def _handle_message(self, msg): + if msg["type"].endswith("subscribe"): + # ignore subscribe messages + return False + channel = msg["channel"].decode() + if msg["pattern"] is not None: + callbacks = self._topics_cb[msg["pattern"].decode()] + else: + callbacks = self._topics_cb[channel] + msg = MessageObject( + topic=channel, + value=MsgpackSerialization.loads(msg["data"]), + ) + for cb_ref, kwargs in callbacks: + cb = cb_ref() + if cb: + try: + cb(msg, **kwargs) + except Exception: + sys.excepthook(*sys.exc_info()) + return True + + def poll_messages(self, timeout=None) -> None: + while True: + try: + msg = self._messages_queue.get(timeout=timeout) + except queue.Empty: + raise TimeoutError( + f"{self}: poll_messages: did not receive a message within {timeout} seconds" + ) + else: + if msg is StopIteration: + return False + if self._handle_message(msg): + return True + else: + continue + + def dispatch_events(self): + while self.poll_messages(): + ... + @catch_connection_error def lpush( self, topic: str, msg: str, pipe=None, max_size: int = None, expire: int = None @@ -229,7 +247,7 @@ class RedisProducer(ProducerConnector): @catch_connection_error def lset(self, topic: str, index: int, msg: str, pipe=None) -> None: - client = pipe if pipe is not None else self.r + client = pipe if pipe is not None else self._redis_conn if isinstance(msg, BECMessage): msg = MsgpackSerialization.dumps(msg) return client.lset(topic, index, msg) @@ -241,7 +259,7 @@ class RedisProducer(ProducerConnector): values at the tail of the list stored at key. If key does not exist, it is created as empty list before performing the push operation. When key holds a value that is not a list, an error is returned.""" - client = pipe if pipe is not None else self.r + client = pipe if pipe is not None else self._redis_conn if isinstance(msg, BECMessage): msg = MsgpackSerialization.dumps(msg) return client.rpush(topic, msg) @@ -254,7 +272,7 @@ class RedisProducer(ProducerConnector): of the list stored at key. The offsets start and stop are zero-based indexes, with 0 being the first element of the list (the head of the list), 1 being the next element and so on.""" - client = pipe if pipe is not None else self.r + client = pipe if pipe is not None else self._redis_conn cmd_result = client.lrange(topic, start, end) if pipe: return cmd_result @@ -268,23 +286,21 @@ class RedisProducer(ProducerConnector): ret.append(msg) return ret - @catch_connection_error def set_and_publish(self, topic: str, msg, pipe=None, expire: int = None) -> None: """piped combination of self.publish and self.set""" client = pipe if pipe is not None else self.pipeline() - if isinstance(msg, BECMessage): - msg = MsgpackSerialization.dumps(msg) - client.publish(topic, msg) - client.set(topic, msg) - if expire: - client.expire(topic, expire) + if not isinstance(msg, BECMessage): + raise TypeError(f"Message {msg} is not a BECMessage") + msg = MsgpackSerialization.dumps(msg) + self.set(topic, msg, pipe=client, expire=expire) + self.raw_send(topic, msg, pipe=client) if not pipe: client.execute() @catch_connection_error def set(self, topic: str, msg, pipe=None, expire: int = None) -> None: """set redis value""" - client = pipe if pipe is not None else self.r + client = pipe if pipe is not None else self._redis_conn if isinstance(msg, BECMessage): msg = MsgpackSerialization.dumps(msg) client.set(topic, msg, ex=expire) @@ -292,23 +308,18 @@ class RedisProducer(ProducerConnector): @catch_connection_error def keys(self, pattern: str) -> list: """returns all keys matching a pattern""" - return self.r.keys(pattern) - - @catch_connection_error - def pipeline(self): - """create a new pipeline""" - return self.r.pipeline() + return self._redis_conn.keys(pattern) @catch_connection_error def delete(self, topic, pipe=None): """delete topic""" - client = pipe if pipe is not None else self.r + client = pipe if pipe is not None else self._redis_conn client.delete(topic) @catch_connection_error def get(self, topic: str, pipe=None): """retrieve entry, either via hgetall or get""" - client = pipe if pipe is not None else self.r + client = pipe if pipe is not None else self._redis_conn data = client.get(topic) if pipe: return data @@ -339,7 +350,7 @@ class RedisProducer(ProducerConnector): elif expire: client = self.pipeline() else: - client = self.r + client = self._redis_conn for key, msg in msg_dict.items(): msg_dict[key] = MsgpackSerialization.dumps(msg) @@ -356,7 +367,7 @@ class RedisProducer(ProducerConnector): @catch_connection_error def get_last(self, topic: str, key="data"): """retrieve last entry from stream""" - client = self.r + client = self._redis_conn try: _, msg_dict = client.xrevrange(topic, "+", "-", count=1)[0] except TypeError: @@ -370,7 +381,12 @@ class RedisProducer(ProducerConnector): @catch_connection_error def xread( - self, topic: str, id: str = None, count: int = None, block: int = None, from_start=False + self, + topic: str, + id: str = None, + count: int = None, + block: int = None, + from_start=False, ) -> list: """ read from stream @@ -395,13 +411,13 @@ class RedisProducer(ProducerConnector): >>> key = msg[0][1][0][0] >>> next_msg = redis.xread("test", key, count=1) """ - client = self.r + client = self._redis_conn if from_start: self.stream_keys[topic] = "0-0" if topic not in self.stream_keys: if id is None: try: - msg = self.r.xrevrange(topic, "+", "-", count=1) + msg = client.xrevrange(topic, "+", "-", count=1) if msg: self.stream_keys[topic] = msg[0][0] out = {} @@ -438,7 +454,7 @@ class RedisProducer(ProducerConnector): max (str): max id. Use "+" to read to end count (int, optional): number of messages to read. Defaults to None. """ - client = self.r + client = self._redis_conn msgs = [] for reading in client.xrange(topic, min, max, count=count): index, msg_dict = reading @@ -446,270 +462,3 @@ class RedisProducer(ProducerConnector): {k.decode(): MsgpackSerialization.loads(msg) for k, msg in msg_dict.items()} ) return msgs - - -class RedisConsumerMixin: - def _init_topics_and_pattern(self, topics, pattern): - if topics: - if not isinstance(topics, list): - topics = [topics] - if pattern: - if not isinstance(pattern, list): - pattern = [pattern] - return topics, pattern - - def _init_redis_cls(self, redis_cls): - # pylint: disable=invalid-name - if redis_cls: - self.r = redis_cls(host=self.host, port=self.port) - else: - self.r = redis.Redis(host=self.host, port=self.port) - - @catch_connection_error - def initialize_connector(self) -> None: - if self.pattern is not None: - self.pubsub.psubscribe(self.pattern) - else: - self.pubsub.subscribe(self.topics) - - -class RedisConsumer(RedisConsumerMixin, ConsumerConnector): - # pylint: disable=too-many-arguments - def __init__( - self, - host, - port, - topics=None, - pattern=None, - group_id=None, - event=None, - cb=None, - redis_cls=None, - **kwargs, - ): - self.host = host - self.port = port - - bootstrap_server = "".join([host, ":", port]) - topics, pattern = self._init_topics_and_pattern(topics, pattern) - super().__init__( - bootstrap_server=bootstrap_server, - topics=topics, - pattern=pattern, - group_id=group_id, - event=event, - cb=cb, - **kwargs, - ) - self.error_message_sent = False - self._init_redis_cls(redis_cls) - self.pubsub = self.r.pubsub() - self.initialize_connector() - - @catch_connection_error - def poll_messages(self) -> None: - """ - Poll messages from self.connector and call the callback function self.cb - """ - message = self.pubsub.get_message(ignore_subscribe_messages=True) - if message is not None: - msg = MessageObject( - topic=message["channel"], value=MsgpackSerialization.loads(message["data"]) - ) - return self.cb(msg, **self.kwargs) - - time.sleep(0.01) - return None - - def shutdown(self): - """shutdown the consumer""" - self.pubsub.close() - - -class RedisStreamConsumerThreaded(RedisConsumerMixin, ConsumerConnectorThreaded): - # pylint: disable=too-many-arguments - def __init__( - self, - host, - port, - topics=None, - pattern=None, - group_id=None, - event=None, - cb=None, - redis_cls=None, - from_start=False, - newest_only=False, - **kwargs, - ): - self.host = host - self.port = port - self.from_start = from_start - self.newest_only = newest_only - - bootstrap_server = "".join([host, ":", port]) - topics, pattern = self._init_topics_and_pattern(topics, pattern) - super().__init__( - bootstrap_server=bootstrap_server, - topics=topics, - pattern=pattern, - group_id=group_id, - event=event, - cb=cb, - **kwargs, - ) - - self._init_redis_cls(redis_cls) - - self.sleep_times = [0.005, 0.1] - self.last_received_msg = 0 - self.idle_time = 30 - self.error_message_sent = False - self.stream_keys = {} - - def initialize_connector(self) -> None: - pass - - def _init_topics_and_pattern(self, topics, pattern): - if topics: - if not isinstance(topics, list): - topics = [topics] - if pattern: - if not isinstance(pattern, list): - pattern = [pattern] - return topics, pattern - - def get_id(self, topic: str) -> str: - """ - Get the stream key for the given topic. - - Args: - topic (str): topic to get the stream key for - """ - if topic not in self.stream_keys: - return "0-0" - return self.stream_keys.get(topic) - - def get_newest_message(self, container: list, append=True) -> None: - """ - Get the newest message from the stream and update the stream key. If - append is True, append the message to the container. - - Args: - container (list): container to append the message to - append (bool, optional): append to container. Defaults to True. - """ - for topic in self.topics: - msg = self.r.xrevrange(topic, "+", "-", count=1) - if msg: - if append: - container.append((topic, msg[0][1])) - self.stream_keys[topic] = msg[0][0] - else: - self.stream_keys[topic] = "0-0" - - @catch_connection_error - def poll_messages(self) -> None: - """ - Poll messages from self.connector and call the callback function self.cb - - """ - if self.pattern is not None: - topics = [key.decode() for key in self.r.scan_iter(match=self.pattern, _type="stream")] - else: - topics = self.topics - messages = [] - if self.newest_only: - self.get_newest_message(messages) - elif not self.from_start and not self.stream_keys: - self.get_newest_message(messages, append=False) - else: - streams = {topic: self.get_id(topic) for topic in topics} - read_msgs = self.r.xread(streams, count=1) - if read_msgs: - for msg in read_msgs: - topic = msg[0].decode() - messages.append((topic, msg[1][0][1])) - self.stream_keys[topic] = msg[1][-1][0] - - if messages: - if MessageEndpoints.log() not in topics: - # no need to update the update frequency just for logs - self.last_received_msg = time.time() - for topic, msg in messages: - try: - msg = MsgpackSerialization.loads(msg[b"data"]) - except RuntimeError: - msg = msg[b"data"] - msg_obj = MessageObject(topic=topic, value=msg) - self.cb(msg_obj, **self.kwargs) - else: - sleep_time = int(bool(time.time() - self.last_received_msg > self.idle_time)) - if self.sleep_times[sleep_time]: - time.sleep(self.sleep_times[sleep_time]) - - -class RedisConsumerThreaded(RedisConsumerMixin, ConsumerConnectorThreaded): - # pylint: disable=too-many-arguments - def __init__( - self, - host, - port, - topics=None, - pattern=None, - group_id=None, - event=None, - cb=None, - redis_cls=None, - name=None, - **kwargs, - ): - self.host = host - self.port = port - - bootstrap_server = "".join([host, ":", port]) - topics, pattern = self._init_topics_and_pattern(topics, pattern) - super().__init__( - bootstrap_server=bootstrap_server, - topics=topics, - pattern=pattern, - group_id=group_id, - event=event, - cb=cb, - name=name, - **kwargs, - ) - - self._init_redis_cls(redis_cls) - self.pubsub = self.r.pubsub() - - self.sleep_times = [0.005, 0.1] - self.last_received_msg = 0 - self.idle_time = 30 - self.error_message_sent = False - - @catch_connection_error - def poll_messages(self) -> None: - """ - Poll messages from self.connector and call the callback function self.cb - - Note: pubsub messages are supposed to be BECMessage objects only - """ - messages = self.pubsub.get_message(ignore_subscribe_messages=True) - if messages is not None: - if f"{MessageEndpoints.log()}".encode() not in messages["channel"]: - # no need to update the update frequency just for logs - self.last_received_msg = time.time() - msg = MessageObject( - topic=messages["channel"].decode(), - value=MsgpackSerialization.loads(messages["data"]), - ) - self.cb(msg, **self.kwargs) - else: - sleep_time = int(bool(time.time() - self.last_received_msg > self.idle_time)) - if self.sleep_times[sleep_time]: - time.sleep(self.sleep_times[sleep_time]) - - def shutdown(self): - super().shutdown() - self.pubsub.close() diff --git a/bec_lib/bec_lib/scan_items.py b/bec_lib/bec_lib/scan_items.py index 69bc82bc..cdf6620d 100644 --- a/bec_lib/bec_lib/scan_items.py +++ b/bec_lib/bec_lib/scan_items.py @@ -46,7 +46,7 @@ class ScanItem: self.data = ScanData() self.async_data = {} self.baseline = ScanData() - self._async_data_handler = AsyncDataHandler(scan_manager.producer) + self._async_data_handler = AsyncDataHandler(scan_manager.connector) self.open_scan_defs = set() self.open_queue_group = None self.num_points = None diff --git a/bec_lib/bec_lib/scan_manager.py b/bec_lib/bec_lib/scan_manager.py index 2c604163..8d86ed1f 100644 --- a/bec_lib/bec_lib/scan_manager.py +++ b/bec_lib/bec_lib/scan_manager.py @@ -25,44 +25,31 @@ class ScanManager: connector (BECConnector): BECConnector instance """ self.connector = connector - self.producer = self.connector.producer() self.queue_storage = QueueStorage(scan_manager=self) self.request_storage = RequestStorage(scan_manager=self) self.scan_storage = ScanStorage(scan_manager=self) - self._scan_queue_consumer = self.connector.consumer( + self.connector.register( topics=MessageEndpoints.scan_queue_status(), cb=self._scan_queue_status_callback, - parent=self, ) - self._scan_queue_request_consumer = self.connector.consumer( + self.connector.register( topics=MessageEndpoints.scan_queue_request(), cb=self._scan_queue_request_callback, - parent=self, ) - self._scan_queue_request_response_consumer = self.connector.consumer( + self.connector.register( topics=MessageEndpoints.scan_queue_request_response(), cb=self._scan_queue_request_response_callback, - parent=self, ) - self._scan_status_consumer = self.connector.consumer( - topics=MessageEndpoints.scan_status(), cb=self._scan_status_callback, parent=self + self.connector.register( + topics=MessageEndpoints.scan_status(), cb=self._scan_status_callback ) - self._scan_segment_consumer = self.connector.consumer( - topics=MessageEndpoints.scan_segment(), cb=self._scan_segment_callback, parent=self + self.connector.register( + topics=MessageEndpoints.scan_segment(), cb=self._scan_segment_callback ) - self._baseline_consumer = self.connector.consumer( - topics=MessageEndpoints.scan_baseline(), cb=self._baseline_callback, parent=self - ) - - self._scan_queue_consumer.start() - self._scan_queue_request_consumer.start() - self._scan_queue_request_response_consumer.start() - self._scan_status_consumer.start() - self._scan_segment_consumer.start() - self._baseline_consumer.start() + self.connector.register(topics=MessageEndpoints.scan_baseline(), cb=self._baseline_callback) def update_with_queue_status(self, queue: messages.ScanQueueStatusMessage) -> None: """update storage with a new queue status message""" @@ -84,7 +71,7 @@ class ScanManager: action = "deferred_pause" if deferred_pause else "pause" logger.info(f"Requesting {action}") - return self.producer.send( + return self.connector.send( MessageEndpoints.scan_queue_modification_request(), messages.ScanQueueModificationMessage(scanID=scanID, action=action, parameter={}), ) @@ -99,7 +86,7 @@ class ScanManager: if scanID is None: scanID = self.scan_storage.current_scanID logger.info("Requesting scan abortion") - self.producer.send( + self.connector.send( MessageEndpoints.scan_queue_modification_request(), messages.ScanQueueModificationMessage(scanID=scanID, action="abort", parameter={}), ) @@ -114,7 +101,7 @@ class ScanManager: if scanID is None: scanID = self.scan_storage.current_scanID logger.info("Requesting scan halt") - self.producer.send( + self.connector.send( MessageEndpoints.scan_queue_modification_request(), messages.ScanQueueModificationMessage(scanID=scanID, action="halt", parameter={}), ) @@ -129,7 +116,7 @@ class ScanManager: if scanID is None: scanID = self.scan_storage.current_scanID logger.info("Requesting scan continuation") - self.producer.send( + self.connector.send( MessageEndpoints.scan_queue_modification_request(), messages.ScanQueueModificationMessage(scanID=scanID, action="continue", parameter={}), ) @@ -137,7 +124,7 @@ class ScanManager: def request_queue_reset(self): """request a scan queue reset""" logger.info("Requesting a queue reset") - self.producer.send( + self.connector.send( MessageEndpoints.scan_queue_modification_request(), messages.ScanQueueModificationMessage(scanID=None, action="clear", parameter={}), ) @@ -151,7 +138,7 @@ class ScanManager: logger.info("Requesting to abort and repeat a scan") position = "replace" if replace else "append" - self.producer.send( + self.connector.send( MessageEndpoints.scan_queue_modification_request(), messages.ScanQueueModificationMessage( scanID=scanID, action="restart", parameter={"position": position, "RID": requestID} @@ -162,7 +149,7 @@ class ScanManager: @property def next_scan_number(self): """get the next scan number from redis""" - num = self.producer.get(MessageEndpoints.scan_number()) + num = self.connector.get(MessageEndpoints.scan_number()) if num is None: logger.warning("Failed to retrieve scan number from redis.") return -1 @@ -172,63 +159,51 @@ class ScanManager: @typechecked def next_scan_number(self, val: int): """set the next scan number in redis""" - return self.producer.set(MessageEndpoints.scan_number(), val) + return self.connector.set(MessageEndpoints.scan_number(), val) @property def next_dataset_number(self): """get the next dataset number from redis""" - return int(self.producer.get(MessageEndpoints.dataset_number())) + return int(self.connector.get(MessageEndpoints.dataset_number())) @next_dataset_number.setter @typechecked def next_dataset_number(self, val: int): """set the next dataset number in redis""" - return self.producer.set(MessageEndpoints.dataset_number(), val) + return self.connector.set(MessageEndpoints.dataset_number(), val) - @staticmethod - def _scan_queue_status_callback(msg, *, parent: ScanManager, **_kwargs) -> None: + def _scan_queue_status_callback(self, msg, **_kwargs) -> None: queue_status = msg.value if not queue_status: return - parent.update_with_queue_status(queue_status) + self.update_with_queue_status(queue_status) - @staticmethod - def _scan_queue_request_callback(msg, *, parent: ScanManager, **_kwargs) -> None: + def _scan_queue_request_callback(self, msg, **_kwargs) -> None: request = msg.value - parent.request_storage.update_with_request(request) + self.request_storage.update_with_request(request) - @staticmethod - def _scan_queue_request_response_callback(msg, *, parent: ScanManager, **_kwargs) -> None: + def _scan_queue_request_response_callback(self, msg, **_kwargs) -> None: response = msg.value logger.debug(response) - parent.request_storage.update_with_response(response) + self.request_storage.update_with_response(response) - @staticmethod - def _scan_status_callback(msg, *, parent: ScanManager, **_kwargs) -> None: + def _scan_status_callback(self, msg, **_kwargs) -> None: scan = msg.value - parent.scan_storage.update_with_scan_status(scan) + self.scan_storage.update_with_scan_status(scan) - @staticmethod - def _scan_segment_callback(msg, *, parent: ScanManager, **_kwargs) -> None: + def _scan_segment_callback(self, msg, **_kwargs) -> None: scan_msgs = msg.value if not isinstance(scan_msgs, list): scan_msgs = [scan_msgs] for scan_msg in scan_msgs: - parent.scan_storage.add_scan_segment(scan_msg) + self.scan_storage.add_scan_segment(scan_msg) - @staticmethod - def _baseline_callback(msg, *, parent: ScanManager, **_kwargs) -> None: + def _baseline_callback(self, msg, **_kwargs) -> None: msg = msg.value - parent.scan_storage.add_scan_baseline(msg) + self.scan_storage.add_scan_baseline(msg) def __str__(self) -> str: return "\n".join(self.queue_storage.describe_queue()) def shutdown(self): - """stop the scan manager's threads""" - self._scan_queue_consumer.shutdown() - self._scan_queue_request_consumer.shutdown() - self._scan_queue_request_response_consumer.shutdown() - self._scan_status_consumer.shutdown() - self._scan_segment_consumer.shutdown() - self._baseline_consumer.shutdown() + pass diff --git a/bec_lib/bec_lib/scan_report.py b/bec_lib/bec_lib/scan_report.py index 045aedbb..0f1a2b25 100644 --- a/bec_lib/bec_lib/scan_report.py +++ b/bec_lib/bec_lib/scan_report.py @@ -89,7 +89,7 @@ class ScanReport: def _get_mv_status(self) -> bool: """get the status of a move request""" motors = list(self.request.request.content["parameter"]["args"].keys()) - request_status = self._client.device_manager.producer.lrange( + request_status = self._client.device_manager.connector.lrange( MessageEndpoints.device_req_status(self.request.requestID), 0, -1 ) if len(request_status) == len(motors): diff --git a/bec_lib/bec_lib/scans.py b/bec_lib/bec_lib/scans.py index 661980d5..5c75f190 100644 --- a/bec_lib/bec_lib/scans.py +++ b/bec_lib/bec_lib/scans.py @@ -100,9 +100,9 @@ class ScanObject: return None return self.scan_info.get("scan_report_hint") - def _start_consumer(self, request: messages.ScanQueueMessage) -> ConsumerConnector: - """Start a consumer for the given request""" - consumer = self.client.device_manager.connector.consumer( + def _start_register(self, request: messages.ScanQueueMessage) -> ConsumerConnector: + """Start a register for the given request""" + register = self.client.device_manager.connector.register( [ MessageEndpoints.device_readback(dev) for dev in request.content["parameter"]["args"].keys() @@ -110,11 +110,11 @@ class ScanObject: threaded=False, cb=(lambda msg: msg), ) - return consumer + return register def _send_scan_request(self, request: messages.ScanQueueMessage) -> None: """Send a scan request to the scan server""" - self.client.device_manager.producer.send(MessageEndpoints.scan_queue_request(), request) + self.client.device_manager.connector.send(MessageEndpoints.scan_queue_request(), request) class Scans: @@ -136,7 +136,7 @@ class Scans: def _import_scans(self): """Import scans from the scan server""" - available_scans = self.parent.producer.get(MessageEndpoints.available_scans()) + available_scans = self.parent.connector.get(MessageEndpoints.available_scans()) if available_scans is None: logger.warning("No scans available. Are redis and the BEC server running?") return diff --git a/bec_lib/bec_lib/service_config.py b/bec_lib/bec_lib/service_config.py index 3fd5e25a..53113a6f 100644 --- a/bec_lib/bec_lib/service_config.py +++ b/bec_lib/bec_lib/service_config.py @@ -14,7 +14,9 @@ logger = bec_logger.logger DEFAULT_SERVICE_CONFIG = { "redis": {"host": "localhost", "port": 6379}, - "service_config": {"file_writer": {"plugin": "default_NeXus_format", "base_path": "./"}}, + "service_config": { + "file_writer": {"plugin": "default_NeXus_format", "base_path": os.path.dirname(__file__)} + }, } @@ -32,7 +34,7 @@ class ServiceConfig: self._update_config(service_config=config, redis=redis) self.service_config = self.config.get( - "service_config", {"file_writer": {"plugin": "default_NeXus_format", "base_path": "./"}} + "service_config", DEFAULT_SERVICE_CONFIG["service_config"] ) def _update_config(self, **kwargs): diff --git a/bec_lib/bec_lib/tests/utils.py b/bec_lib/bec_lib/tests/utils.py index 4434ed30..1525e652 100644 --- a/bec_lib/bec_lib/tests/utils.py +++ b/bec_lib/bec_lib/tests/utils.py @@ -56,7 +56,7 @@ def queue_is_empty(queue) -> bool: # pragma: no cover def get_queue(bec): # pragma: no cover - return bec.queue.producer.get(MessageEndpoints.scan_queue_status()) + return bec.queue.connector.get(MessageEndpoints.scan_queue_status()) def wait_for_empty_queue(bec): # pragma: no cover @@ -484,7 +484,6 @@ def bec_client(): with open(f"{dir_path}/tests/test_config.yaml", "r", encoding="utf-8") as f: builtins.__dict__["test_session"] = create_session_from_config(yaml.safe_load(f)) device_manager._session = builtins.__dict__["test_session"] - device_manager.producer = device_manager.connector.producer() client.wait_for_service = lambda service_name: None device_manager._load_session() for name, dev in device_manager.devices.items(): @@ -497,37 +496,23 @@ def bec_client(): class PipelineMock: # pragma: no cover _pipe_buffer = [] - _producer = None + _connector = None - def __init__(self, producer) -> None: - self._producer = producer + def __init__(self, connector) -> None: + self._connector = connector def execute(self): - if not self._producer.store_data: + if not self._connector.store_data: self._pipe_buffer = [] return [] res = [ - getattr(self._producer, method)(*args, **kwargs) + getattr(self._connector, method)(*args, **kwargs) for method, args, kwargs in self._pipe_buffer ] self._pipe_buffer = [] return res -class ConsumerMock: # pragma: no cover - def __init__(self) -> None: - self.signal_event = SignalMock() - - def start(self): - pass - - def join(self): - pass - - def shutdown(self): - pass - - class SignalMock: # pragma: no cover def __init__(self) -> None: self.is_set = False @@ -536,12 +521,36 @@ class SignalMock: # pragma: no cover self.is_set = True -class ProducerMock: # pragma: no cover - def __init__(self, store_data=True) -> None: +class ConnectorMock(ConnectorBase): # pragma: no cover + def __init__(self, bootstrap_server="localhost:0000", store_data=True): + super().__init__(bootstrap_server) self.message_sent = [] self._get_buffer = {} self.store_data = store_data + def raise_alarm( + self, severity: Alarms, alarm_type: str, source: str, msg: dict, metadata: dict + ): + pass + + def log_error(self, *args, **kwargs): + pass + + def shutdown(self): + pass + + def register(self, *args, **kwargs): + pass + + def set(self, *args, **kwargs): + pass + + def set_and_publish(self, *args, **kwargs): + pass + + def keys(self, *args, **kwargs): + return [] + def set(self, topic, msg, pipe=None, expire: int = None): if pipe: pipe._pipe_buffer.append(("set", (topic, msg), {"expire": expire})) @@ -592,9 +601,6 @@ class ProducerMock: # pragma: no cover self._get_buffer.pop(topic, None) return val - def keys(self, pattern: str) -> list: - return [] - def pipeline(self): return PipelineMock(self) @@ -609,29 +615,6 @@ class ProducerMock: # pragma: no cover return -class ConnectorMock(ConnectorBase): # pragma: no cover - def __init__(self, bootstrap_server: list, store_data=True): - super().__init__(bootstrap_server) - self.store_data = store_data - - def consumer(self, *args, **kwargs) -> ConsumerMock: - return ConsumerMock() - - def producer(self, *args, **kwargs): - return ProducerMock(self.store_data) - - def raise_alarm( - self, severity: Alarms, alarm_type: str, source: str, msg: dict, metadata: dict - ): - pass - - def log_error(self, *args, **kwargs): - pass - - def shutdown(self): - pass - - def create_session_from_config(config: dict) -> dict: device_configs = [] session_id = str(uuid.uuid4()) diff --git a/bec_lib/setup.py b/bec_lib/setup.py index f2af0505..944594b3 100644 --- a/bec_lib/setup.py +++ b/bec_lib/setup.py @@ -5,6 +5,8 @@ __version__ = "1.12.1" if __name__ == "__main__": setup( install_requires=[ + "hiredis", + "louie", "numpy", "scipy", "msgpack", @@ -22,7 +24,16 @@ if __name__ == "__main__": "lmfit", ], extras_require={ - "dev": ["pytest", "pytest-random-order", "coverage", "pandas", "black", "pylint"] + "dev": [ + "pytest", + "pytest-random-order", + "pytest-redis", + "pytest-timeout", + "coverage", + "pandas", + "black", + "pylint", + ] }, entry_points={"console_scripts": ["bec-channel-monitor = bec_lib:channel_monitor_launch"]}, package_data={"bec_lib.tests": ["*.yaml"], "bec_lib.configs": ["*.yaml", "*.json"]}, diff --git a/bec_lib/tests/test_bec_plotter.py b/bec_lib/tests/test_bec_plotter.py index a6e2f062..855739ca 100644 --- a/bec_lib/tests/test_bec_plotter.py +++ b/bec_lib/tests/test_bec_plotter.py @@ -17,7 +17,7 @@ def test_bec_widgets_connector_set_plot_config(bec_client): config = {"x": "test", "y": "test", "color": "test", "size": "test", "shape": "test"} connector.set_plot_config(plot_id="plot_id", config=config) msg = messages.GUIConfigMessage(config=config) - bec_client.connector.producer().set_and_publish.assert_called_once_with( + bec_client.connector.set_and_publish.assert_called_once_with( MessageEndpoints.gui_config("plot_id"), msg ) is None @@ -26,7 +26,7 @@ def test_bec_widgets_connector_close(bec_client): connector = BECWidgetsConnector(gui_id="gui_id", bec_client=bec_client) connector.close("plot_id") msg = messages.GUIInstructionMessage(action="close", parameter={}) - bec_client.connector.producer().set_and_publish.assert_called_once_with( + bec_client.connector.set_and_publish.assert_called_once_with( MessageEndpoints.gui_instructions("plot_id"), msg ) @@ -36,7 +36,7 @@ def test_bec_widgets_connector_send_data(bec_client): data = {"x": [1, 2, 3], "y": [1, 2, 3]} connector.send_data("plot_id", data) msg = messages.GUIDataMessage(data=data) - bec_client.connector.producer().set_and_publish.assert_called_once_with( + bec_client.connector.set_and_publish.assert_called_once_with( topic=MessageEndpoints.gui_data("plot_id"), msg=msg ) @@ -45,7 +45,7 @@ def test_bec_widgets_connector_clear(bec_client): connector = BECWidgetsConnector(gui_id="gui_id", bec_client=bec_client) connector.clear("plot_id") msg = messages.GUIInstructionMessage(action="clear", parameter={}) - bec_client.connector.producer().set_and_publish.assert_called_once_with( + bec_client.connector.set_and_publish.assert_called_once_with( MessageEndpoints.gui_instructions("plot_id"), msg ) diff --git a/bec_lib/tests/test_bec_service.py b/bec_lib/tests/test_bec_service.py index 602b3706..2cd0e9a8 100644 --- a/bec_lib/tests/test_bec_service.py +++ b/bec_lib/tests/test_bec_service.py @@ -124,8 +124,8 @@ def test_bec_service_update_existing_services(): messages.StatusMessage(name="service2", status=BECStatus.IDLE, info={}, metadata={}), ] connector_cls = mock.MagicMock() - connector_cls().producer().keys.return_value = service_keys - connector_cls().producer().get.side_effect = [msg for msg in service_msgs] + connector_cls().keys.return_value = service_keys + connector_cls().get.side_effect = [msg for msg in service_msgs] service = BECService( config=f"{os.path.dirname(bec_lib.__file__)}/tests/test_service_config.yaml", connector_cls=connector_cls, @@ -144,8 +144,8 @@ def test_bec_service_update_existing_services_ignores_wrong_msgs(): None, ] connector_cls = mock.MagicMock() - connector_cls().producer().keys.return_value = service_keys - connector_cls().producer().get.side_effect = [service_msgs[0], None] + connector_cls().keys.return_value = service_keys + connector_cls().get.side_effect = [service_msgs[0], None] service = BECService( config=f"{os.path.dirname(bec_lib.__file__)}/tests/test_service_config.yaml", connector_cls=connector_cls, diff --git a/bec_lib/tests/test_channel_monitor.py b/bec_lib/tests/test_channel_monitor.py index 94dcbc4e..1b106a72 100644 --- a/bec_lib/tests/test_channel_monitor.py +++ b/bec_lib/tests/test_channel_monitor.py @@ -13,7 +13,7 @@ def test_channel_monitor_callback(): mock_print.assert_called_once() -def test_channel_monitor_start_consumer(): +def test_channel_monitor_start_register(): with mock.patch("bec_lib.channel_monitor.argparse") as mock_argparse: with mock.patch("bec_lib.channel_monitor.ServiceConfig") as mock_config: with mock.patch("bec_lib.channel_monitor.RedisConnector") as mock_connector: @@ -26,6 +26,6 @@ def test_channel_monitor_start_consumer(): mock_config.return_value = mock.MagicMock() mock_connector.return_value = mock.MagicMock() channel_monitor_launch() - mock_connector().consumer.assert_called_once() - mock_connector().consumer.return_value.start.assert_called_once() + mock_connector().register.assert_called_once() + mock_connector().register.return_value.start.assert_called_once() mock_threading.Event().wait.assert_called_once() diff --git a/bec_lib/tests/test_config_helper.py b/bec_lib/tests/test_config_helper.py index 2ccede70..70e0f529 100644 --- a/bec_lib/tests/test_config_helper.py +++ b/bec_lib/tests/test_config_helper.py @@ -49,7 +49,7 @@ def test_config_helper_save_current_session(): connector = mock.MagicMock() config_helper = ConfigHelper(connector) - connector.producer().get.return_value = messages.AvailableResourceMessage( + connector.get.return_value = messages.AvailableResourceMessage( resource=[ { "id": "648c817f67d3c7cd6a354e8e", @@ -158,9 +158,7 @@ def test_send_config_request_raises_for_rejected_update(config_helper): def test_wait_for_config_reply(): connector = mock.MagicMock() config_helper = ConfigHelper(connector) - connector.producer().get.return_value = messages.RequestResponseMessage( - accepted=True, message="test" - ) + connector.get.return_value = messages.RequestResponseMessage(accepted=True, message="test") res = config_helper.wait_for_config_reply("test") assert res == messages.RequestResponseMessage(accepted=True, message="test") @@ -169,7 +167,7 @@ def test_wait_for_config_reply(): def test_wait_for_config_raises_timeout(): connector = mock.MagicMock() config_helper = ConfigHelper(connector) - connector.producer().get.return_value = None + connector.get.return_value = None with pytest.raises(DeviceConfigError): config_helper.wait_for_config_reply("test", timeout=0.3) @@ -178,7 +176,7 @@ def test_wait_for_config_raises_timeout(): def test_wait_for_service_response(): connector = mock.MagicMock() config_helper = ConfigHelper(connector) - connector.producer().lrange.side_effect = [ + connector.lrange.side_effect = [ [], [ messages.ServiceResponseMessage( @@ -196,7 +194,7 @@ def test_wait_for_service_response(): def test_wait_for_service_response_raises_timeout(): connector = mock.MagicMock() config_helper = ConfigHelper(connector) - connector.producer().lrange.return_value = [] + connector.lrange.return_value = [] with pytest.raises(DeviceConfigError): config_helper.wait_for_service_response("test", timeout=0.3) diff --git a/bec_lib/tests/test_dap_plugins.py b/bec_lib/tests/test_dap_plugins.py index 9532e160..cbd0d0e6 100644 --- a/bec_lib/tests/test_dap_plugins.py +++ b/bec_lib/tests/test_dap_plugins.py @@ -349,7 +349,7 @@ def dap(dap_plugin_message): } client = mock.MagicMock() client.service_status = dap_services - client.producer.get.return_value = dap_plugin_message + client.connector.get.return_value = dap_plugin_message dap_plugins = DAPPlugins(client) yield dap_plugins @@ -367,7 +367,7 @@ def test_dap_plugins_construction(dap): def test_dap_plugin_fit(dap): with mock.patch.object(dap.GaussianModel, "_wait_for_dap_response") as mock_wait: dap.GaussianModel.fit() - dap._parent.producer.set_and_publish.assert_called_once() + dap._parent.connector.set_and_publish.assert_called_once() mock_wait.assert_called_once() @@ -380,7 +380,7 @@ def test_dap_auto_run(dap): def test_dap_wait_for_dap_response_waits_for_RID(dap): - dap._parent.producer.get.return_value = messages.DAPResponseMessage( + dap._parent.connector.get.return_value = messages.DAPResponseMessage( success=True, data={}, metadata={"RID": "wrong_ID"} ) with pytest.raises(TimeoutError): @@ -388,7 +388,7 @@ def test_dap_wait_for_dap_response_waits_for_RID(dap): def test_dap_wait_for_dap_respnse_returns(dap): - dap._parent.producer.get.return_value = messages.DAPResponseMessage( + dap._parent.connector.get.return_value = messages.DAPResponseMessage( success=True, data={}, metadata={"RID": "1234"} ) val = dap.GaussianModel._wait_for_dap_response(request_id="1234", timeout=0.1) @@ -429,11 +429,11 @@ def test_dap_select_raises_on_wrong_device(dap): def test_dap_get_data(dap): - dap._parent.producer.get_last.return_value = messages.ProcessedDataMessage( + dap._parent.connector.get_last.return_value = messages.ProcessedDataMessage( data=[{"x": [1, 2, 3], "y": [4, 5, 6]}, {"fit_parameters": {"amplitude": 1}}] ) data = dap.GaussianModel.get_data() - dap._parent.producer.get_last.assert_called_once_with( + dap._parent.connector.get_last.assert_called_once_with( MessageEndpoints.processed_data("GaussianModel") ) @@ -443,13 +443,13 @@ def test_dap_get_data(dap): def test_dap_update_dap_config_not_called_without_device(dap): dap.GaussianModel._update_dap_config(request_id="1234") - dap._parent.producer.set_and_publish.assert_not_called() + dap._parent.connector.set_and_publish.assert_not_called() def test_dap_update_dap_config(dap): dap.GaussianModel._plugin_config["selected_device"] = ["samx", "samx"] dap.GaussianModel._update_dap_config(request_id="1234") - dap._parent.producer.set_and_publish.assert_called_with( + dap._parent.connector.set_and_publish.assert_called_with( MessageEndpoints.dap_request(), messages.DAPRequestMessage( dap_cls="LmfitService1D", diff --git a/bec_lib/tests/test_device_manager.py b/bec_lib/tests/test_device_manager.py index 63ff5113..16408b27 100644 --- a/bec_lib/tests/test_device_manager.py +++ b/bec_lib/tests/test_device_manager.py @@ -91,15 +91,14 @@ def test_get_config_calls_load(dm): dm, "_get_redis_device_config", return_value={"devices": [{}]} ) as get_redis_config: with mock.patch.object(dm, "_load_session") as load_session: - with mock.patch.object(dm, "producer") as producer: - dm._get_config() - get_redis_config.assert_called_once() - load_session.assert_called_once() + dm._get_config() + get_redis_config.assert_called_once() + load_session.assert_called_once() def test_get_redis_device_config(dm): - with mock.patch.object(dm, "producer") as producer: - producer.get.return_value = messages.AvailableResourceMessage(resource={"devices": [{}]}) + with mock.patch.object(dm, "connector") as connector: + connector.get.return_value = messages.AvailableResourceMessage(resource={"devices": [{}]}) assert dm._get_redis_device_config() == {"devices": [{}]} diff --git a/bec_lib/tests/test_devices.py b/bec_lib/tests/test_devices.py index 9d48e746..a35d1516 100644 --- a/bec_lib/tests/test_devices.py +++ b/bec_lib/tests/test_devices.py @@ -23,7 +23,7 @@ def test_nested_device_root(dev): def test_read(dev): - with mock.patch.object(dev.samx.root.parent.producer, "get") as mock_get: + with mock.patch.object(dev.samx.root.parent.connector, "get") as mock_get: mock_get.return_value = messages.DeviceMessage( signals={ "samx": {"value": 0, "timestamp": 1701105880.1711318}, @@ -42,7 +42,7 @@ def test_read(dev): def test_read_filtered_hints(dev): - with mock.patch.object(dev.samx.root.parent.producer, "get") as mock_get: + with mock.patch.object(dev.samx.root.parent.connector, "get") as mock_get: mock_get.return_value = messages.DeviceMessage( signals={ "samx": {"value": 0, "timestamp": 1701105880.1711318}, @@ -57,7 +57,7 @@ def test_read_filtered_hints(dev): def test_read_use_read(dev): - with mock.patch.object(dev.samx.root.parent.producer, "get") as mock_get: + with mock.patch.object(dev.samx.root.parent.connector, "get") as mock_get: data = { "samx": {"value": 0, "timestamp": 1701105880.1711318}, "samx_setpoint": {"value": 0, "timestamp": 1701105880.1693492}, @@ -72,7 +72,7 @@ def test_read_use_read(dev): def test_read_nested_device(dev): - with mock.patch.object(dev.dyn_signals.root.parent.producer, "get") as mock_get: + with mock.patch.object(dev.dyn_signals.root.parent.connector, "get") as mock_get: data = { "dyn_signals_messages_message1": {"value": 0, "timestamp": 1701105880.0716832}, "dyn_signals_messages_message2": {"value": 0, "timestamp": 1701105880.071722}, @@ -93,7 +93,7 @@ def test_read_nested_device(dev): ) def test_read_kind_hinted(dev, kind, cached): with mock.patch.object(dev.samx.readback, "_run") as mock_run: - with mock.patch.object(dev.samx.root.parent.producer, "get") as mock_get: + with mock.patch.object(dev.samx.root.parent.connector, "get") as mock_get: data = { "samx": {"value": 0, "timestamp": 1701105880.1711318}, "samx_setpoint": {"value": 0, "timestamp": 1701105880.1693492}, @@ -138,7 +138,7 @@ def test_read_configuration_cached(dev, is_signal, is_config_signal, method): with mock.patch.object( dev.samx.readback, "_get_rpc_signal_info", return_value=(is_signal, is_config_signal, True) ): - with mock.patch.object(dev.samx.root.parent.producer, "get") as mock_get: + with mock.patch.object(dev.samx.root.parent.connector, "get") as mock_get: mock_get.return_value = messages.DeviceMessage( signals={ "samx": {"value": 0, "timestamp": 1701105880.1711318}, @@ -182,7 +182,7 @@ def test_get_rpc_func_name_read(dev): ) def test_get_rpc_func_name_readback_get(dev, kind, cached): with mock.patch.object(dev.samx.readback, "_run") as mock_rpc: - with mock.patch.object(dev.samx.root.parent.producer, "get") as mock_get: + with mock.patch.object(dev.samx.root.parent.connector, "get") as mock_get: mock_get.return_value = messages.DeviceMessage( signals={ "samx": {"value": 0, "timestamp": 1701105880.1711318}, @@ -219,9 +219,7 @@ def test_handle_rpc_response_returns_status(dev, bec_client): msg = messages.DeviceRPCMessage( device="samx", return_val={"type": "status", "RID": "request_id"}, out="done", success=True ) - assert dev.samx._handle_rpc_response(msg) == Status( - bec_client.device_manager.producer, "request_id" - ) + assert dev.samx._handle_rpc_response(msg) == Status(bec_client.device_manager, "request_id") def test_handle_rpc_response_raises(dev): @@ -348,7 +346,7 @@ def test_device_update_user_parameter(device_obj, user_param, val, out, raised_e def test_status_wait(): - producer = mock.MagicMock() + connector = mock.MagicMock() def lrange_mock(*args, **kwargs): yield False @@ -358,8 +356,8 @@ def test_status_wait(): return next(lmock) lmock = lrange_mock() - producer.lrange = get_lrange - status = Status(producer, "test") + connector.lrange = get_lrange + status = Status(connector, "test") status.wait() @@ -561,7 +559,7 @@ def test_show_all(): def test_adjustable_mixin_limits(): adj = AdjustableMixin() adj.root = mock.MagicMock() - adj.root.parent.producer.get.return_value = messages.DeviceMessage( + adj.root.parent.connector.get.return_value = messages.DeviceMessage( signals={"low": -12, "high": 12}, metadata={} ) assert adj.limits == [-12, 12] @@ -570,7 +568,7 @@ def test_adjustable_mixin_limits(): def test_adjustable_mixin_limits_missing(): adj = AdjustableMixin() adj.root = mock.MagicMock() - adj.root.parent.producer.get.return_value = None + adj.root.parent.connector.get.return_value = None assert adj.limits == [0, 0] @@ -585,7 +583,7 @@ def test_adjustable_mixin_set_low_limit(): adj = AdjustableMixin() adj.update_config = mock.MagicMock() adj.root = mock.MagicMock() - adj.root.parent.producer.get.return_value = messages.DeviceMessage( + adj.root.parent.connector.get.return_value = messages.DeviceMessage( signals={"low": -12, "high": 12}, metadata={} ) adj.low_limit = -20 @@ -596,7 +594,7 @@ def test_adjustable_mixin_set_high_limit(): adj = AdjustableMixin() adj.update_config = mock.MagicMock() adj.root = mock.MagicMock() - adj.root.parent.producer.get.return_value = messages.DeviceMessage( + adj.root.parent.connector.get.return_value = messages.DeviceMessage( signals={"low": -12, "high": 12}, metadata={} ) adj.high_limit = 20 diff --git a/bec_lib/tests/test_observer.py b/bec_lib/tests/test_observer.py index 5f34df91..6564fcff 100644 --- a/bec_lib/tests/test_observer.py +++ b/bec_lib/tests/test_observer.py @@ -97,9 +97,9 @@ def device_manager(dm_with_devices): def test_observer_manager_None(device_manager): - with mock.patch.object(device_manager.producer, "get", return_value=None) as producer_get: + with mock.patch.object(device_manager.connector, "get", return_value=None) as connector_get: observer_manager = ObserverManager(device_manager=device_manager) - producer_get.assert_called_once_with(MessageEndpoints.observer()) + connector_get.assert_called_once_with(MessageEndpoints.observer()) assert len(observer_manager._observer) == 0 @@ -115,9 +115,9 @@ def test_observer_manager_msg(device_manager): } ] ) - with mock.patch.object(device_manager.producer, "get", return_value=msg) as producer_get: + with mock.patch.object(device_manager.connector, "get", return_value=msg) as connector_get: observer_manager = ObserverManager(device_manager=device_manager) - producer_get.assert_called_once_with(MessageEndpoints.observer()) + connector_get.assert_called_once_with(MessageEndpoints.observer()) assert len(observer_manager._observer) == 1 @@ -139,7 +139,7 @@ def test_observer_manager_msg(device_manager): ], ) def test_add_observer(device_manager, observer, raises_error): - with mock.patch.object(device_manager.producer, "get", return_value=None) as producer_get: + with mock.patch.object(device_manager.connector, "get", return_value=None) as connector_get: observer_manager = ObserverManager(device_manager=device_manager) observer_manager.add_observer(observer) with pytest.raises(AttributeError): @@ -185,7 +185,7 @@ def test_add_observer_existing_device(device_manager, observer, raises_error): "limits": [380, None], } ) - with mock.patch.object(device_manager.producer, "get", return_value=None) as producer_get: + with mock.patch.object(device_manager.connector, "get", return_value=None) as connector_get: observer_manager = ObserverManager(device_manager=device_manager) observer_manager.add_observer(default_observer) if raises_error: diff --git a/bec_lib/tests/test_redis_connector.py b/bec_lib/tests/test_redis_connector.py index 9c83e37e..1b43b31b 100644 --- a/bec_lib/tests/test_redis_connector.py +++ b/bec_lib/tests/test_redis_connector.py @@ -12,113 +12,71 @@ from bec_lib.messages import AlarmMessage, BECMessage, LogMessage from bec_lib.redis_connector import ( MessageObject, RedisConnector, - RedisConsumer, - RedisConsumerMixin, - RedisConsumerThreaded, - RedisProducer, - RedisStreamConsumerThreaded, ) from bec_lib.serialization import MsgpackSerialization -@pytest.fixture -def producer(): - with mock.patch("bec_lib.redis_connector.redis.Redis"): - prod = RedisProducer("localhost", 1) - yield prod - - @pytest.fixture def connector(): with mock.patch("bec_lib.redis_connector.redis.Redis"): connector = RedisConnector("localhost:1") + try: + yield connector + finally: + connector.shutdown() + + +@pytest.fixture +def connected_connector(redis_proc): + connector = RedisConnector(f"localhost:{redis_proc.port}") + try: yield connector - - -@pytest.fixture -def consumer(): - with mock.patch("bec_lib.redis_connector.redis.Redis"): - consumer = RedisConsumer("localhost", "1", topics="topics") - yield consumer - - -@pytest.fixture -def consumer_threaded(): - with mock.patch("bec_lib.redis_connector.redis.Redis"): - consumer_threaded = RedisConsumerThreaded("localhost", "1", topics="topics") - yield consumer_threaded - - -@pytest.fixture -def mixin(): - with mock.patch("bec_lib.redis_connector.redis.Redis"): - mixin = RedisConsumerMixin - yield mixin - - -def test_redis_connector_producer(connector): - ret = connector.producer() - assert isinstance(ret, RedisProducer) + finally: + connector.shutdown() @pytest.mark.parametrize( - "topics, threaded", [["topics", True], ["topics", False], [None, True], [None, False]] + "topics, threaded", + [["topics", True], ["topics", False], [None, True], [None, False]], ) -def test_redis_connector_consumer(connector, threaded, topics): - pattern = None - len_of_threads = len(connector._threads) - - if threaded: - if topics is None and pattern is None: - with pytest.raises(ValueError) as exc_info: - ret = connector.consumer( - topics=topics, threaded=threaded, cb=lambda *args, **kwargs: ... - ) - - assert exc_info.value.args[0] == "Topics must be set for threaded consumer" - else: - ret = connector.consumer( - topics=topics, threaded=threaded, cb=lambda *args, **kwargs: ... +def test_redis_connector_register(connected_connector, threaded, topics): + breakpoint() + connector = connected_connector + if topics is None: + with pytest.raises(TypeError): + ret = connector.register( + topics=topics, cb=lambda *args, **kwargs: ..., start_thread=threaded ) - assert len(connector._threads) == len_of_threads + 1 - assert isinstance(ret, RedisConsumerThreaded) - else: - if not topics: - with pytest.raises(ConsumerConnectorError): - ret = connector.consumer( - topics=topics, threaded=threaded, cb=lambda *args, **kwargs: ... - ) - return - ret = connector.consumer(topics=topics, threaded=threaded, cb=lambda *args, **kwargs: ...) - assert isinstance(ret, RedisConsumer) + ret = connector.register( + topics=topics, cb=lambda *args, **kwargs: ..., start_thread=threaded + ) + if threaded: + assert connector._events_listener_thread is not None def test_redis_connector_log_warning(connector): - connector._notifications_producer.send = mock.MagicMock() - - connector.log_warning("msg") - connector._notifications_producer.send.assert_called_once_with( - MessageEndpoints.log(), LogMessage(log_type="warning", log_msg="msg") - ) + with mock.patch.object(connector, "send", return_value=None): + connector.log_warning("msg") + connector.send.assert_called_once_with( + MessageEndpoints.log(), LogMessage(log_type="warning", log_msg="msg") + ) def test_redis_connector_log_message(connector): - connector._notifications_producer.send = mock.MagicMock() - - connector.log_message("msg") - connector._notifications_producer.send.assert_called_once_with( - MessageEndpoints.log(), LogMessage(log_type="log", log_msg="msg") - ) + with mock.patch.object(connector, "send", return_value=None): + connector.log_message("msg") + connector.send.assert_called_once_with( + MessageEndpoints.log(), LogMessage(log_type="log", log_msg="msg") + ) def test_redis_connector_log_error(connector): - connector._notifications_producer.send = mock.MagicMock() - - connector.log_error("msg") - connector._notifications_producer.send.assert_called_once_with( - MessageEndpoints.log(), LogMessage(log_type="error", log_msg="msg") - ) + with mock.patch.object(connector, "send", return_value=None): + connector.log_error("msg") + connector.send.assert_called_once_with( + MessageEndpoints.log(), LogMessage(log_type="error", log_msg="msg") + ) @pytest.mark.parametrize( @@ -130,20 +88,24 @@ def test_redis_connector_log_error(connector): ], ) def test_redis_connector_raise_alarm(connector, severity, alarm_type, source, msg, metadata): - connector._notifications_producer.set_and_publish = mock.MagicMock() + with mock.patch.object(connector, "set_and_publish", return_value=None): + connector.raise_alarm(severity, alarm_type, source, msg, metadata) - connector.raise_alarm(severity, alarm_type, source, msg, metadata) - - connector._notifications_producer.set_and_publish.assert_called_once_with( - MessageEndpoints.alarm(), - AlarmMessage( - severity=severity, alarm_type=alarm_type, source=source, msg=msg, metadata=metadata - ), - ) + connector.set_and_publish.assert_called_once_with( + MessageEndpoints.alarm(), + AlarmMessage( + severity=severity, + alarm_type=alarm_type, + source=source, + msg=msg, + metadata=metadata, + ), + ) @dataclass(eq=False) class TestMessage(BECMessage): + __test__ = False # just for pytest to ignore this class msg_type = "test_message" msg: str # have to add this field here, @@ -160,30 +122,36 @@ bec_messages.TestMessage = TestMessage @pytest.mark.parametrize( "topic , msg", [["topic1", TestMessage("msg1")], ["topic2", TestMessage("msg2")]] ) -def test_redis_producer_send(producer, topic, msg): - producer.send(topic, msg) - producer.r.publish.assert_called_once_with(topic, MsgpackSerialization.dumps(msg)) +def test_redis_connector_send(connector, topic, msg): + connector.send(topic, msg) + connector._redis_conn.publish.assert_called_once_with(topic, MsgpackSerialization.dumps(msg)) - producer.send(topic, msg, pipe=producer.pipeline()) - producer.r.pipeline().publish.assert_called_once_with(topic, MsgpackSerialization.dumps(msg)) + connector.send(topic, msg, pipe=connector.pipeline()) + connector._redis_conn.pipeline().publish.assert_called_once_with( + topic, MsgpackSerialization.dumps(msg) + ) @pytest.mark.parametrize( "topic, msgs, max_size, expire", - [["topic1", "msgs", None, None], ["topic1", "msgs", 10, None], ["topic1", "msgs", None, 100]], + [ + ["topic1", "msgs", None, None], + ["topic1", "msgs", 10, None], + ["topic1", "msgs", None, 100], + ], ) -def test_redis_producer_lpush(producer, topic, msgs, max_size, expire): +def test_redis_connector_lpush(connector, topic, msgs, max_size, expire): pipe = None - producer.lpush(topic, msgs, pipe, max_size, expire) + connector.lpush(topic, msgs, pipe, max_size, expire) - producer.r.pipeline().lpush.assert_called_once_with(topic, msgs) + connector._redis_conn.pipeline().lpush.assert_called_once_with(topic, msgs) if max_size: - producer.r.pipeline().ltrim.assert_called_once_with(topic, 0, max_size) + connector._redis_conn.pipeline().ltrim.assert_called_once_with(topic, 0, max_size) if expire: - producer.r.pipeline().expire.assert_called_once_with(topic, expire) + connector._redis_conn.pipeline().expire.assert_called_once_with(topic, expire) if not pipe: - producer.r.pipeline().execute.assert_called_once() + connector._redis_conn.pipeline().execute.assert_called_once() @pytest.mark.parametrize( @@ -194,68 +162,76 @@ def test_redis_producer_lpush(producer, topic, msgs, max_size, expire): ["topic1", TestMessage("msgs"), None, 100], ], ) -def test_redis_producer_lpush_BECMessage(producer, topic, msgs, max_size, expire): +def test_redis_connector_lpush_BECMessage(connector, topic, msgs, max_size, expire): pipe = None - producer.lpush(topic, msgs, pipe, max_size, expire) + connector.lpush(topic, msgs, pipe, max_size, expire) - producer.r.pipeline().lpush.assert_called_once_with(topic, MsgpackSerialization.dumps(msgs)) + connector._redis_conn.pipeline().lpush.assert_called_once_with( + topic, MsgpackSerialization.dumps(msgs) + ) if max_size: - producer.r.pipeline().ltrim.assert_called_once_with(topic, 0, max_size) + connector._redis_conn.pipeline().ltrim.assert_called_once_with(topic, 0, max_size) if expire: - producer.r.pipeline().expire.assert_called_once_with(topic, expire) + connector._redis_conn.pipeline().expire.assert_called_once_with(topic, expire) if not pipe: - producer.r.pipeline().execute.assert_called_once() + connector._redis_conn.pipeline().execute.assert_called_once() @pytest.mark.parametrize( - "topic , index , msgs, use_pipe", [["topic1", 1, "msg1", True], ["topic2", 4, "msg2", False]] + "topic , index , msgs, use_pipe", + [["topic1", 1, "msg1", True], ["topic2", 4, "msg2", False]], ) -def test_redis_producer_lset(producer, topic, index, msgs, use_pipe): - pipe = use_pipe_fcn(producer, use_pipe) +def test_redis_connector_lset(connector, topic, index, msgs, use_pipe): + pipe = use_pipe_fcn(connector, use_pipe) - ret = producer.lset(topic, index, msgs, pipe) + ret = connector.lset(topic, index, msgs, pipe) if pipe: - producer.r.pipeline().lset.assert_called_once_with(topic, index, msgs) + connector._redis_conn.pipeline().lset.assert_called_once_with(topic, index, msgs) assert ret == redis.Redis().pipeline().lset() else: - producer.r.lset.assert_called_once_with(topic, index, msgs) + connector._redis_conn.lset.assert_called_once_with(topic, index, msgs) assert ret == redis.Redis().lset() @pytest.mark.parametrize( "topic , index , msgs, use_pipe", - [["topic1", 1, TestMessage("msg1"), True], ["topic2", 4, TestMessage("msg2"), False]], + [ + ["topic1", 1, TestMessage("msg1"), True], + ["topic2", 4, TestMessage("msg2"), False], + ], ) -def test_redis_producer_lset_BECMessage(producer, topic, index, msgs, use_pipe): - pipe = use_pipe_fcn(producer, use_pipe) +def test_redis_connector_lset_BECMessage(connector, topic, index, msgs, use_pipe): + pipe = use_pipe_fcn(connector, use_pipe) - ret = producer.lset(topic, index, msgs, pipe) + ret = connector.lset(topic, index, msgs, pipe) if pipe: - producer.r.pipeline().lset.assert_called_once_with( + connector._redis_conn.pipeline().lset.assert_called_once_with( topic, index, MsgpackSerialization.dumps(msgs) ) assert ret == redis.Redis().pipeline().lset() else: - producer.r.lset.assert_called_once_with(topic, index, MsgpackSerialization.dumps(msgs)) + connector._redis_conn.lset.assert_called_once_with( + topic, index, MsgpackSerialization.dumps(msgs) + ) assert ret == redis.Redis().lset() @pytest.mark.parametrize( "topic, msgs, use_pipe", [["topic1", "msg1", True], ["topic2", "msg2", False]] ) -def test_redis_producer_rpush(producer, topic, msgs, use_pipe): - pipe = use_pipe_fcn(producer, use_pipe) +def test_redis_connector_rpush(connector, topic, msgs, use_pipe): + pipe = use_pipe_fcn(connector, use_pipe) - ret = producer.rpush(topic, msgs, pipe) + ret = connector.rpush(topic, msgs, pipe) if pipe: - producer.r.pipeline().rpush.assert_called_once_with(topic, msgs) + connector._redis_conn.pipeline().rpush.assert_called_once_with(topic, msgs) assert ret == redis.Redis().pipeline().rpush() else: - producer.r.rpush.assert_called_once_with(topic, msgs) + connector._redis_conn.rpush.assert_called_once_with(topic, msgs) assert ret == redis.Redis().rpush() @@ -263,421 +239,349 @@ def test_redis_producer_rpush(producer, topic, msgs, use_pipe): "topic, msgs, use_pipe", [["topic1", TestMessage("msg1"), True], ["topic2", TestMessage("msg2"), False]], ) -def test_redis_producer_rpush_BECMessage(producer, topic, msgs, use_pipe): - pipe = use_pipe_fcn(producer, use_pipe) +def test_redis_connector_rpush_BECMessage(connector, topic, msgs, use_pipe): + pipe = use_pipe_fcn(connector, use_pipe) - ret = producer.rpush(topic, msgs, pipe) + ret = connector.rpush(topic, msgs, pipe) if pipe: - producer.r.pipeline().rpush.assert_called_once_with(topic, MsgpackSerialization.dumps(msgs)) + connector._redis_conn.pipeline().rpush.assert_called_once_with( + topic, MsgpackSerialization.dumps(msgs) + ) assert ret == redis.Redis().pipeline().rpush() else: - producer.r.rpush.assert_called_once_with(topic, MsgpackSerialization.dumps(msgs)) + connector._redis_conn.rpush.assert_called_once_with(topic, MsgpackSerialization.dumps(msgs)) assert ret == redis.Redis().rpush() @pytest.mark.parametrize( "topic, start, end, use_pipe", [["topic1", 0, 4, True], ["topic2", 3, 7, False]] ) -def test_redis_producer_lrange(producer, topic, start, end, use_pipe): - pipe = use_pipe_fcn(producer, use_pipe) +def test_redis_connector_lrange(connector, topic, start, end, use_pipe): + pipe = use_pipe_fcn(connector, use_pipe) - ret = producer.lrange(topic, start, end, pipe) + ret = connector.lrange(topic, start, end, pipe) if pipe: - producer.r.pipeline().lrange.assert_called_once_with(topic, start, end) + connector._redis_conn.pipeline().lrange.assert_called_once_with(topic, start, end) assert ret == redis.Redis().pipeline().lrange() else: - producer.r.lrange.assert_called_once_with(topic, start, end) + connector._redis_conn.lrange.assert_called_once_with(topic, start, end) assert ret == [] @pytest.mark.parametrize( - "topic, msg, pipe, expire", [["topic1", "msg1", None, 400], ["topic2", "msg2", None, None]] + "topic, msg, pipe, expire", + [ + ["topic1", TestMessage("msg1"), None, 400], + ["topic2", TestMessage("msg2"), None, None], + ["topic3", "msg3", None, None], + ], ) -def test_redis_producer_set_and_publish(producer, topic, msg, pipe, expire): - producer.set_and_publish(topic, msg, pipe, expire) +def test_redis_connector_set_and_publish(connector, topic, msg, pipe, expire): + if not isinstance(msg, BECMessage): + with pytest.raises(TypeError): + connector.set_and_publish(topic, msg, pipe, expire) + else: + connector.set_and_publish(topic, msg, pipe, expire) - producer.r.pipeline().publish.assert_called_once_with(topic, msg) - producer.r.pipeline().set.assert_called_once_with(topic, msg) - if expire: - producer.r.pipeline().expire.assert_called_once_with(topic, expire) - if not pipe: - producer.r.pipeline().execute.assert_called_once() + connector._redis_conn.pipeline().publish.assert_called_once_with( + topic, MsgpackSerialization.dumps(msg) + ) + connector._redis_conn.pipeline().set.assert_called_once_with( + topic, MsgpackSerialization.dumps(msg), ex=expire + ) + if not pipe: + connector._redis_conn.pipeline().execute.assert_called_once() @pytest.mark.parametrize("topic, msg, expire", [["topic1", "msg1", None], ["topic2", "msg2", 400]]) -def test_redis_producer_set(producer, topic, msg, expire): +def test_redis_connector_set(connector, topic, msg, expire): pipe = None - producer.set(topic, msg, pipe, expire) + connector.set(topic, msg, pipe, expire) if pipe: - producer.r.pipeline().set.assert_called_once_with(topic, msg, ex=expire) + connector._redis_conn.pipeline().set.assert_called_once_with(topic, msg, ex=expire) else: - producer.r.set.assert_called_once_with(topic, msg, ex=expire) + connector._redis_conn.set.assert_called_once_with(topic, msg, ex=expire) @pytest.mark.parametrize("pattern", ["samx", "samy"]) -def test_redis_producer_keys(producer, pattern): - ret = producer.keys(pattern) - producer.r.keys.assert_called_once_with(pattern) +def test_redis_connector_keys(connector, pattern): + ret = connector.keys(pattern) + connector._redis_conn.keys.assert_called_once_with(pattern) assert ret == redis.Redis().keys() -def test_redis_producer_pipeline(producer): - ret = producer.pipeline() - producer.r.pipeline.assert_called_once() +def test_redis_connector_pipeline(connector): + ret = connector.pipeline() + connector._redis_conn.pipeline.assert_called_once() assert ret == redis.Redis().pipeline() -@pytest.mark.parametrize("topic,use_pipe", [["topic1", True], ["topic2", False]]) -def test_redis_producer_delete(producer, topic, use_pipe): - pipe = use_pipe_fcn(producer, use_pipe) - - producer.delete(topic, pipe) - - if pipe: - producer.pipeline().delete.assert_called_once_with(topic) - else: - producer.r.delete.assert_called_once_with(topic) - - -@pytest.mark.parametrize("topic, use_pipe", [["topic1", True], ["topic2", False]]) -def test_redis_producer_get(producer, topic, use_pipe): - pipe = use_pipe_fcn(producer, use_pipe) - - ret = producer.get(topic, pipe) - if pipe: - producer.pipeline().get.assert_called_once_with(topic) - assert ret == redis.Redis().pipeline().get() - else: - producer.r.get.assert_called_once_with(topic) - assert ret == redis.Redis().get() - - -def use_pipe_fcn(producer, use_pipe): +def use_pipe_fcn(connector, use_pipe): if use_pipe: - return producer.pipeline() + return connector.pipeline() return None +@pytest.mark.parametrize("topic,use_pipe", [["topic1", True], ["topic2", False]]) +def test_redis_connector_delete(connector, topic, use_pipe): + pipe = use_pipe_fcn(connector, use_pipe) + + connector.delete(topic, pipe) + + if pipe: + connector.pipeline().delete.assert_called_once_with(topic) + else: + connector._redis_conn.delete.assert_called_once_with(topic) + + +@pytest.mark.parametrize("topic, use_pipe", [["topic1", True], ["topic2", False]]) +def test_redis_connector_get(connector, topic, use_pipe): + pipe = use_pipe_fcn(connector, use_pipe) + + ret = connector.get(topic, pipe) + if pipe: + connector.pipeline().get.assert_called_once_with(topic) + assert ret == redis.Redis().pipeline().get() + else: + connector._redis_conn.get.assert_called_once_with(topic) + assert ret == redis.Redis().get() + + @pytest.mark.parametrize( - "topics, pattern", + "subscribed_topics, subscribed_patterns, msgs", [ - ["topics1", None], - [["topics1", "topics2"], None], - [None, "pattern1"], - [None, ["pattern1", "pattern2"]], + ["topics1", None, ["topics1"]], + [["topics1", "topics2"], None, ["topics1", "topics2"]], + [None, "pattern1", ["pattern1"]], + [None, ["patt*", "top*"], ["pattern1", "topics1"]], ], ) -def test_redis_consumer_init(consumer, topics, pattern): - with mock.patch("bec_lib.redis_connector.redis.Redis"): - consumer = RedisConsumer( - "localhost", "1", topics, pattern, redis_cls=redis.Redis, cb=lambda *args, **kwargs: ... +def test_redis_connector_register( + redisdb, connected_connector, subscribed_topics, subscribed_patterns, msgs +): + connector = connected_connector + test_msg = TestMessage("test") + cb_mock = mock.Mock(spec=[]) # spec is here to remove all attributes + if subscribed_topics: + connector.register( + subscribed_topics, subscribed_patterns, cb=cb_mock, start_thread=False, a=1 ) - - if topics: - if isinstance(topics, list): - assert consumer.topics == topics - else: - assert consumer.topics == [topics] - if pattern: - if isinstance(pattern, list): - assert consumer.pattern == pattern - else: - assert consumer.pattern == [pattern] - - assert consumer.r == redis.Redis() - assert consumer.pubsub == consumer.r.pubsub() - assert consumer.host == "localhost" - assert consumer.port == "1" + for msg in msgs: + connector.send(msg, TestMessage(msg)) + connector.poll_messages() + msg_object = MessageObject(msg, TestMessage(msg)) + cb_mock.assert_called_with(msg_object, a=1) + if subscribed_patterns: + connector.register( + subscribed_topics, subscribed_patterns, cb=cb_mock, start_thread=False, a=1 + ) + for msg in msgs: + connector.send(msg, TestMessage(msg)) + connector.poll_messages() + msg_object = MessageObject(msg, TestMessage(msg)) + cb_mock.assert_called_with(msg_object, a=1) -@pytest.mark.parametrize("pattern, topics", [["pattern", "topics1"], [None, "topics2"]]) -def test_redis_consumer_initialize_connector(consumer, pattern, topics): - consumer.pattern = pattern - - consumer.topics = topics - consumer.initialize_connector() - - if consumer.pattern is not None: - consumer.pubsub.psubscribe.assert_called_once_with(consumer.pattern) - else: - consumer.pubsub.subscribe.assert_called_with(consumer.topics) - - -def test_redis_consumer_poll_messages(consumer): +def test_redis_register_poll_messages(redisdb, connected_connector): + connector = connected_connector cb_fcn_has_been_called = False def cb_fcn(msg, **kwargs): nonlocal cb_fcn_has_been_called cb_fcn_has_been_called = True - print(msg) - - consumer.cb = cb_fcn + assert kwargs["a"] == 1 test_msg = TestMessage("test") - consumer.pubsub.get_message.return_value = { - "channel": "", - "data": MsgpackSerialization.dumps(test_msg), - } - ret = consumer.poll_messages() - consumer.pubsub.get_message.assert_called_once_with(ignore_subscribe_messages=True) + connector.register("test", cb=cb_fcn, a=1, start_thread=False) + redisdb.publish("test", MsgpackSerialization.dumps(test_msg)) + + connector.poll_messages(timeout=1) assert cb_fcn_has_been_called - -def test_redis_consumer_shutdown(consumer): - consumer.shutdown() - consumer.pubsub.close.assert_called_once() + with pytest.raises(TimeoutError): + connector.poll_messages(timeout=0.1) -def test_redis_consumer_additional_kwargs(connector): - cons = connector.consumer(topics="topic1", parent="here", cb=lambda *args, **kwargs: ...) - assert "parent" in cons.kwargs +def test_redis_connector_xadd(connector): + connector.xadd("topic1", {"key": "value"}) + connector._redis_conn.xadd.assert_called_once_with("topic1", {"key": "value"}) -@pytest.mark.parametrize( - "topics, pattern", - [ - ["topics1", None], - [["topics1", "topics2"], None], - [None, "pattern1"], - [None, ["pattern1", "pattern2"]], - ], -) -def test_mixin_init_topics_and_pattern(mixin, topics, pattern): - ret_topics, ret_pattern = mixin._init_topics_and_pattern(mixin, topics, pattern) - - if topics: - if isinstance(topics, list): - assert ret_topics == topics - else: - assert ret_topics == [topics] - if pattern: - if isinstance(pattern, list): - assert ret_pattern == pattern - else: - assert ret_pattern == [pattern] +def test_redis_connector_xadd_with_maxlen(connector): + connector.xadd("topic1", {"key": "value"}, max_size=100) + connector._redis_conn.xadd.assert_called_once_with("topic1", {"key": "value"}, maxlen=100) -def test_mixin_init_redis_cls(mixin, consumer): - mixin._init_redis_cls(consumer, None) - assert consumer.r == redis.Redis(host="localhost", port=1) +def test_redis_connector_xadd_with_expire(connector): + connector.xadd("topic1", {"key": "value"}, expire=100) + connector._redis_conn.pipeline().xadd.assert_called_once_with("topic1", {"key": "value"}) + connector._redis_conn.pipeline().expire.assert_called_once_with("topic1", 100) + connector._redis_conn.pipeline().execute.assert_called_once() -@pytest.mark.parametrize( - "topics, pattern", - [ - ["topics1", None], - [["topics1", "topics2"], None], - [None, "pattern1"], - [None, ["pattern1", "pattern2"]], - ], -) -def test_redis_consumer_threaded_init(consumer_threaded, topics, pattern): - with mock.patch("bec_lib.redis_connector.redis.Redis"): - consumer_threaded = RedisConsumerThreaded( - "localhost", "1", topics, pattern, redis_cls=redis.Redis, cb=lambda *args, **kwargs: ... - ) - - if topics: - if isinstance(topics, list): - assert consumer_threaded.topics == topics - else: - assert consumer_threaded.topics == [topics] - if pattern: - if isinstance(pattern, list): - assert consumer_threaded.pattern == pattern - else: - assert consumer_threaded.pattern == [pattern] - - assert consumer_threaded.r == redis.Redis() - assert consumer_threaded.pubsub == consumer_threaded.r.pubsub() - assert consumer_threaded.host == "localhost" - assert consumer_threaded.port == "1" - assert consumer_threaded.sleep_times == [0.005, 0.1] - assert consumer_threaded.last_received_msg == 0 - assert consumer_threaded.idle_time == 30 +def test_redis_connector_xread(connector): + connector.xread("topic1", "id") + connector._redis_conn.xread.assert_called_once_with({"topic1": "id"}, count=None, block=None) -def test_redis_connector_xadd(producer): - producer.xadd("topic1", {"key": "value"}) - producer.r.xadd.assert_called_once_with("topic1", {"key": MsgpackSerialization.dumps("value")}) +def test_redis_connector_xadd(connector): + connector.xadd("topic1", {"key": "value"}) + connector._redis_conn.xadd.assert_called_once_with( + "topic1", {"key": MsgpackSerialization.dumps("value")} + ) test_msg = TestMessage("test") - producer.xadd("topic1", {"data": test_msg}) - producer.r.xadd.assert_called_with("topic1", {"data": MsgpackSerialization.dumps(test_msg)}) - producer.r.xrevrange.return_value = [ + connector.xadd("topic1", {"data": test_msg}) + connector._redis_conn.xadd.assert_called_with( + "topic1", {"data": MsgpackSerialization.dumps(test_msg)} + ) + connector._redis_conn.xrevrange.return_value = [ (b"1707391599960-0", {b"data": MsgpackSerialization.dumps(test_msg)}) ] - msg = producer.get_last("topic1") + msg = connector.get_last("topic1") assert msg == test_msg -def test_redis_connector_xadd_with_maxlen(producer): - producer.xadd("topic1", {"key": "value"}, max_size=100) - producer.r.xadd.assert_called_once_with( +def test_redis_connector_xadd_with_maxlen(connector): + connector.xadd("topic1", {"key": "value"}, max_size=100) + connector._redis_conn.xadd.assert_called_once_with( "topic1", {"key": MsgpackSerialization.dumps("value")}, maxlen=100 ) -def test_redis_connector_xadd_with_expire(producer): - producer.xadd("topic1", {"key": "value"}, expire=100) - producer.r.pipeline().xadd.assert_called_once_with( +def test_redis_connector_xadd_with_expire(connector): + connector.xadd("topic1", {"key": "value"}, expire=100) + connector._redis_conn.pipeline().xadd.assert_called_once_with( "topic1", {"key": MsgpackSerialization.dumps("value")} ) - producer.r.pipeline().expire.assert_called_once_with("topic1", 100) - producer.r.pipeline().execute.assert_called_once() + connector._redis_conn.pipeline().expire.assert_called_once_with("topic1", 100) + connector._redis_conn.pipeline().execute.assert_called_once() -def test_redis_connector_xread(producer): - producer.xread("topic1", "id") - producer.r.xread.assert_called_once_with({"topic1": "id"}, count=None, block=None) +def test_redis_connector_xread(connector): + connector.xread("topic1", "id") + connector._redis_conn.xread.assert_called_once_with({"topic1": "id"}, count=None, block=None) -def test_redis_connector_xread_without_id(producer): - producer.xread("topic1", from_start=True) - producer.r.xread.assert_called_once_with({"topic1": "0-0"}, count=None, block=None) - producer.r.xread.reset_mock() +def test_redis_connector_xread_without_id(connector): + connector.xread("topic1", from_start=True) + connector._redis_conn.xread.assert_called_once_with({"topic1": "0-0"}, count=None, block=None) + connector._redis_conn.xread.reset_mock() - producer.stream_keys["topic1"] = "id" - producer.xread("topic1") - producer.r.xread.assert_called_once_with({"topic1": "id"}, count=None, block=None) + connector.stream_keys["topic1"] = "id" + connector.xread("topic1") + connector._redis_conn.xread.assert_called_once_with({"topic1": "id"}, count=None, block=None) -def test_redis_connector_xread_from_end(producer): - producer.xread("topic1", from_start=False) - producer.r.xrevrange.assert_called_once_with("topic1", "+", "-", count=1) +def test_redis_connector_xread_from_end(connector): + connector.xread("topic1", from_start=False) + connector._redis_conn.xrevrange.assert_called_once_with("topic1", "+", "-", count=1) -def test_redis_connector_get_last(producer): - producer.r.xrevrange.return_value = [ +def test_redis_connector_get_last(connector): + connector._redis_conn.xrevrange.return_value = [ (b"1707391599960-0", {b"key": MsgpackSerialization.dumps("value")}) ] - msg = producer.get_last("topic1") - producer.r.xrevrange.assert_called_once_with("topic1", "+", "-", count=1) + msg = connector.get_last("topic1") + connector._redis_conn.xrevrange.assert_called_once_with("topic1", "+", "-", count=1) assert msg is None # no key given, default is b'data' - assert producer.get_last("topic1", "key") == "value" - assert producer.get_last("topic1", None) == {"key": "value"} + assert connector.get_last("topic1", "key") == "value" + assert connector.get_last("topic1", None) == {"key": "value"} -def test_redis_xrange(producer): - producer.xrange("topic1", "start", "end") - producer.r.xrange.assert_called_once_with("topic1", "start", "end", count=None) +def test_redis_connector_xread_without_id(connector): + connector.xread("topic1", from_start=True) + connector._redis_conn.xread.assert_called_once_with({"topic1": "0-0"}, count=None, block=None) + connector._redis_conn.xread.reset_mock() + + connector.stream_keys["topic1"] = "id" + connector.xread("topic1") + connector._redis_conn.xread.assert_called_once_with({"topic1": "id"}, count=None, block=None) -def test_redis_xrange_topic_with_suffix(producer): - producer.xrange("topic1", "start", "end") - producer.r.xrange.assert_called_once_with("topic1", "start", "end", count=None) +def test_redis_xrange(connector): + connector.xrange("topic1", "start", "end") + connector._redis_conn.xrange.assert_called_once_with("topic1", "start", "end", count=None) -def test_redis_consumer_threaded_no_cb_without_messages(consumer_threaded): - with mock.patch.object(consumer_threaded.pubsub, "get_message", return_value=None): - consumer_threaded.cb = mock.MagicMock() - consumer_threaded.poll_messages() - consumer_threaded.cb.assert_not_called() +def test_redis_xrange_topic_with_suffix(connector): + connector.xrange("topic1", "start", "end") + connector._redis_conn.xrange.assert_called_once_with("topic1", "start", "end", count=None) -def test_redis_consumer_threaded_cb_called_with_messages(consumer_threaded): - message = {"channel": b"topic1", "data": MsgpackSerialization.dumps(TestMessage("test"))} - - with mock.patch.object(consumer_threaded.pubsub, "get_message", return_value=message): - consumer_threaded.cb = mock.MagicMock() - consumer_threaded.poll_messages() - msg_object = MessageObject("topic1", TestMessage("test")) - consumer_threaded.cb.assert_called_once_with(msg_object) +# def test_redis_stream_register_threaded_get_id(): +# register = RedisStreamConsumerThreaded( +# "localhost", "1", topics="topic1", cb=mock.MagicMock(), redis_cls=mock.MagicMock() +# ) +# register.stream_keys["topic1"] = b"1691610882756-0" +# assert register.get_id("topic1") == b"1691610882756-0" +# assert register.get_id("doesnt_exist") == "0-0" -def test_redis_consumer_threaded_shutdown(consumer_threaded): - consumer_threaded.shutdown() - consumer_threaded.pubsub.close.assert_called_once() +# def test_redis_stream_register_threaded_poll_messages(): +# register = RedisStreamConsumerThreaded( +# "localhost", "1", topics="topic1", cb=mock.MagicMock(), redis_cls=mock.MagicMock() +# ) +# with mock.patch.object( +# register, "get_newest_message", return_value=None +# ) as mock_get_newest_message: +# register.poll_messages() +# mock_get_newest_message.assert_called_once() +# register._redis_conn.xread.assert_not_called() -def test_redis_stream_consumer_threaded_get_newest_message(): - consumer = RedisStreamConsumerThreaded( - "localhost", "1", topics="topic1", cb=mock.MagicMock(), redis_cls=mock.MagicMock() - ) - consumer.r.xrevrange.return_value = [(b"1691610882756-0", {b"data": b"msg"})] - msgs = [] - consumer.get_newest_message(msgs) - assert "topic1" in consumer.stream_keys - assert consumer.stream_keys["topic1"] == b"1691610882756-0" +# def test_redis_stream_register_threaded_poll_messages_newest_only(): +# register = RedisStreamConsumerThreaded( +# "localhost", +# "1", +# topics="topic1", +# cb=mock.MagicMock(), +# redis_cls=mock.MagicMock(), +# newest_only=True, +# ) +# +# register._redis_conn.xrevrange.return_value = [(b"1691610882756-0", {b"data": b"msg"})] +# register.poll_messages() +# register._redis_conn.xread.assert_not_called() +# register.cb.assert_called_once_with(MessageObject(topic="topic1", value=b"msg")) -def test_redis_stream_consumer_threaded_get_newest_message_no_msg(): - consumer = RedisStreamConsumerThreaded( - "localhost", "1", topics="topic1", cb=mock.MagicMock(), redis_cls=mock.MagicMock() - ) - consumer.r.xrevrange.return_value = [] - msgs = [] - consumer.get_newest_message(msgs) - assert "topic1" in consumer.stream_keys - assert consumer.stream_keys["topic1"] == "0-0" +# def test_redis_stream_register_threaded_poll_messages_read(): +# register = RedisStreamConsumerThreaded( +# "localhost", +# "1", +# topics="topic1", +# cb=mock.MagicMock(), +# redis_cls=mock.MagicMock(), +# ) +# register.stream_keys["topic1"] = "0-0" +# +# msg = [[b"topic1", [(b"1691610714612-0", {b"data": b"msg"})]]] +# +# register._redis_conn.xread.return_value = msg +# register.poll_messages() +# register._redis_conn.xread.assert_called_once_with({"topic1": "0-0"}, count=1) +# register.cb.assert_called_once_with(MessageObject(topic="topic1", value=b"msg")) - -def test_redis_stream_consumer_threaded_get_id(): - consumer = RedisStreamConsumerThreaded( - "localhost", "1", topics="topic1", cb=mock.MagicMock(), redis_cls=mock.MagicMock() - ) - consumer.stream_keys["topic1"] = b"1691610882756-0" - assert consumer.get_id("topic1") == b"1691610882756-0" - assert consumer.get_id("doesnt_exist") == "0-0" - - -def test_redis_stream_consumer_threaded_poll_messages(): - consumer = RedisStreamConsumerThreaded( - "localhost", "1", topics="topic1", cb=mock.MagicMock(), redis_cls=mock.MagicMock() - ) - with mock.patch.object( - consumer, "get_newest_message", return_value=None - ) as mock_get_newest_message: - consumer.poll_messages() - mock_get_newest_message.assert_called_once() - consumer.r.xread.assert_not_called() - - -def test_redis_stream_consumer_threaded_poll_messages_newest_only(): - consumer = RedisStreamConsumerThreaded( - "localhost", - "1", - topics="topic1", - cb=mock.MagicMock(), - redis_cls=mock.MagicMock(), - newest_only=True, - ) - - consumer.r.xrevrange.return_value = [(b"1691610882756-0", {b"data": b"msg"})] - consumer.poll_messages() - consumer.r.xread.assert_not_called() - consumer.cb.assert_called_once_with(MessageObject(topic="topic1", value=b"msg")) - - -def test_redis_stream_consumer_threaded_poll_messages_read(): - consumer = RedisStreamConsumerThreaded( - "localhost", "1", topics="topic1", cb=mock.MagicMock(), redis_cls=mock.MagicMock() - ) - consumer.stream_keys["topic1"] = "0-0" - - msg = [[b"topic1", [(b"1691610714612-0", {b"data": b"msg"})]]] - - consumer.r.xread.return_value = msg - consumer.poll_messages() - consumer.r.xread.assert_called_once_with({"topic1": "0-0"}, count=1) - consumer.cb.assert_called_once_with(MessageObject(topic="topic1", value=b"msg")) - - -@pytest.mark.parametrize( - "topics,expected", - [ - ("topic1", ["topic1"]), - (["topic1"], ["topic1"]), - (["topic1", "topic2"], ["topic1", "topic2"]), - ], -) -def test_redis_stream_consumer_threaded_init_topics(topics, expected): - consumer = RedisStreamConsumerThreaded( - "localhost", "1", topics=topics, cb=mock.MagicMock(), redis_cls=mock.MagicMock() - ) - assert consumer.topics == expected +# @pytest.mark.parametrize( +# "topics,expected", +# [ +# ("topic1", ["topic1"]), +# (["topic1"], ["topic1"]), +# (["topic1", "topic2"], ["topic1", "topic2"]), +# ], +# ) +# def test_redis_stream_register_threaded_init_topics(topics, expected): +# register = RedisStreamConsumerThreaded( +# "localhost", +# "1", +# topics=topics, +# cb=mock.MagicMock(), +# redis_cls=mock.MagicMock(), +# ) +# assert register.topics == expected diff --git a/bec_lib/tests/test_scan_items.py b/bec_lib/tests/test_scan_items.py index b88c42f2..0ec441d5 100644 --- a/bec_lib/tests/test_scan_items.py +++ b/bec_lib/tests/test_scan_items.py @@ -63,7 +63,7 @@ from bec_lib.tests.utils import ConnectorMock ) def test_update_with_queue_status(queue_msg): scan_manager = ScanManager(ConnectorMock("")) - scan_manager.producer._get_buffer[MessageEndpoints.scan_queue_status()] = queue_msg + scan_manager.connector._get_buffer[MessageEndpoints.scan_queue_status()] = queue_msg scan_manager.update_with_queue_status(queue_msg) assert ( scan_manager.scan_storage.find_scan_by_ID("bfa582aa-f9cd-4258-ab5d-3e5d54d3dde5") diff --git a/bec_lib/tests/test_scan_report.py b/bec_lib/tests/test_scan_report.py index 6af654e4..a8393d86 100644 --- a/bec_lib/tests/test_scan_report.py +++ b/bec_lib/tests/test_scan_report.py @@ -105,6 +105,6 @@ def test_scan_report_get_mv_status(scan_report, lrange_return, expected): scan_report.request.request = messages.ScanQueueMessage( scan_type="mv", parameter={"args": {"samx": [5], "samy": [5]}} ) - with mock.patch.object(scan_report._client.device_manager.producer, "lrange") as mock_lrange: + with mock.patch.object(scan_report._client.device_manager.connector, "lrange") as mock_lrange: mock_lrange.return_value = lrange_return assert scan_report._get_mv_status() == expected diff --git a/bec_lib/util_scripts/init_config.py b/bec_lib/util_scripts/init_config.py index 394f5be3..bfaf3eab 100644 --- a/bec_lib/util_scripts/init_config.py +++ b/bec_lib/util_scripts/init_config.py @@ -14,7 +14,6 @@ parser.add_argument("--redis", default="localhost:6379", help="redis host and po clargs = parser.parse_args() connector = RedisConnector(clargs.redis) -producer = connector.producer() with open(clargs.config, "r", encoding="utf-8") as stream: data = yaml.safe_load(stream) @@ -22,4 +21,4 @@ for name, device in data.items(): device["name"] = name config_data = list(data.values()) msg = messages.AvailableResourceMessage(resource=config_data) -producer.set(MessageEndpoints.device_config(), msg) +connector.set(MessageEndpoints.device_config(), msg) diff --git a/data_processing/data_processing/dap_service_manager.py b/data_processing/data_processing/dap_service_manager.py index ddcd9858..c940a255 100644 --- a/data_processing/data_processing/dap_service_manager.py +++ b/data_processing/data_processing/dap_service_manager.py @@ -11,7 +11,6 @@ class DAPServiceManager: def __init__(self, services: list) -> None: self.connector = None - self.producer = None self._started = False self.client = None self._dap_request_thread = None @@ -24,13 +23,11 @@ class DAPServiceManager: """ Start the dap request consumer. """ - self._dap_request_thread = self.connector.consumer( - topics=MessageEndpoints.dap_request(), cb=self._dap_request_callback, parent=self + self.connector.register( + topics=MessageEndpoints.dap_request(), cb=self._dap_request_callback ) - self._dap_request_thread.start() - @staticmethod - def _dap_request_callback(msg: MessageObject, *, parent: DAPServiceManager) -> None: + def _dap_request_callback(self, msg: MessageObject) -> None: """ Callback function for dap request consumer. @@ -41,7 +38,7 @@ class DAPServiceManager: dap_request_msg = messages.DAPRequestMessage.loads(msg.value) if not dap_request_msg: return - parent.process_dap_request(dap_request_msg) + self.process_dap_request(dap_request_msg) def process_dap_request(self, dap_request_msg: messages.DAPRequestMessage) -> None: """ @@ -153,7 +150,7 @@ class DAPServiceManager: dap_response_msg = messages.DAPResponseMessage( success=success, data=data, error=error, dap_request=dap_request_msg, metadata=metadata ) - self.producer.set_and_publish( + self.connector.set_and_publish( MessageEndpoints.dap_response(metadata.get("RID")), dap_response_msg, expire=60 ) @@ -168,7 +165,6 @@ class DAPServiceManager: return self.client = client self.connector = client.connector - self.producer = self.connector.producer() self._start_dap_request_consumer() self.update_available_dap_services() self.publish_available_services() @@ -264,12 +260,12 @@ class DAPServiceManager: """send all available dap services to the broker""" msg = messages.AvailableResourceMessage(resource=self.available_dap_services) # pylint: disable=protected-access - self.producer.set( + self.connector.set( MessageEndpoints.dap_available_plugins(f"DAPServer/{self.client._service_id}"), msg ) def shutdown(self) -> None: if not self._started: return - self._dap_request_thread.stop() + self.connector.shutdown() self._started = False diff --git a/data_processing/data_processing/lmfit1d_service.py b/data_processing/data_processing/lmfit1d_service.py index 7eafed82..b14965b2 100644 --- a/data_processing/data_processing/lmfit1d_service.py +++ b/data_processing/data_processing/lmfit1d_service.py @@ -148,7 +148,7 @@ class LmfitService1D(DAPServiceBase): out = self.process() if out: stream_output, metadata = out - self.client.producer.xadd( + self.client.connector.xadd( MessageEndpoints.processed_data(self.model.__class__.__name__), msg={ "data": MsgpackSerialization.dumps( diff --git a/data_processing/tests/test_dap_service_manager.py b/data_processing/tests/test_dap_service_manager.py index 5594ec72..43fe8a8e 100644 --- a/data_processing/tests/test_dap_service_manager.py +++ b/data_processing/tests/test_dap_service_manager.py @@ -72,7 +72,7 @@ def test_DAPServiceManager_init(service_manager): def test_DAPServiceManager_request_callback(service_manager, msg, process_called): msg_obj = MessageObject(value=msg, topic="topic") with mock.patch.object(service_manager, "process_dap_request") as mock_process_dap_request: - service_manager._dap_request_callback(msg_obj, parent=service_manager) + service_manager._dap_request_callback(msg_obj) if process_called: mock_process_dap_request.assert_called_once_with(msg) diff --git a/data_processing/tests/test_lmfit1d_service.py b/data_processing/tests/test_lmfit1d_service.py index 60139648..8ab388f0 100644 --- a/data_processing/tests/test_lmfit1d_service.py +++ b/data_processing/tests/test_lmfit1d_service.py @@ -134,7 +134,7 @@ def test_LmfitService1D_process_until_finished(lmfit_service): lmfit_service.process_until_finished(event) assert get_data.call_count == 2 assert process.call_count == 2 - assert lmfit_service.client.producer.xadd.call_count == 2 + assert lmfit_service.client.connector.xadd.call_count == 2 def test_LmfitService1D_configure(lmfit_service): diff --git a/device_server/device_server/device_server.py b/device_server/device_server/device_server.py index f158ce6c..e3ceda81 100644 --- a/device_server/device_server/device_server.py +++ b/device_server/device_server/device_server.py @@ -18,7 +18,7 @@ from device_server.rpc_mixin import RPCMixin logger = bec_logger.logger -consumer_stop = threading.Event() +register_stop = threading.Event() class DisabledDeviceError(Exception): @@ -38,14 +38,10 @@ class DeviceServer(RPCMixin, BECService): super().__init__(config, connector_cls, unique_service=True) self._tasks = [] self.device_manager = None - self.threads = [] - self.sig_thread = None - self.sig_thread = self.connector.consumer( + self.connector.register( MessageEndpoints.scan_queue_modification(), - cb=self.consumer_interception_callback, - parent=self, + cb=self.register_interception_callback, ) - self.sig_thread.start() self.executor = ThreadPoolExecutor(max_workers=4) self._start_device_manager() @@ -55,19 +51,16 @@ class DeviceServer(RPCMixin, BECService): def start(self) -> None: """start the device server""" - if consumer_stop.is_set(): - consumer_stop.clear() + if register_stop.is_set(): + register_stop.clear() + + self.connector.register( + MessageEndpoints.device_instructions(), + event=register_stop, + cb=self.instructions_callback, + parent=self, + ) - self.threads = [ - self.connector.consumer( - MessageEndpoints.device_instructions(), - event=consumer_stop, - cb=self.instructions_callback, - parent=self, - ) - ] - for thread in self.threads: - thread.start() self.status = BECStatus.RUNNING def update_status(self, status: BECStatus): @@ -76,17 +69,13 @@ class DeviceServer(RPCMixin, BECService): def stop(self) -> None: """stop the device server""" - consumer_stop.set() - for thread in self.threads: - thread.join() + register_stop.set() self.status = BECStatus.IDLE def shutdown(self) -> None: """shutdown the device server""" super().shutdown() self.stop() - self.sig_thread.signal_event.set() - self.sig_thread.join() self.device_manager.shutdown() def _update_device_metadata(self, instr) -> None: @@ -97,8 +86,7 @@ class DeviceServer(RPCMixin, BECService): device_root = dev.split(".")[0] self.device_manager.devices.get(device_root).metadata = instr.metadata - @staticmethod - def consumer_interception_callback(msg, *, parent, **_kwargs) -> None: + def register_interception_callback(self, msg, **_kwargs) -> None: """callback for receiving scan modifications / interceptions""" mvalue = msg.value if mvalue is None: @@ -106,7 +94,7 @@ class DeviceServer(RPCMixin, BECService): return logger.info(f"Receiving: {mvalue.content}") if mvalue.content.get("action") in ["pause", "abort", "halt"]: - parent.stop_devices() + self.stop_devices() def stop_devices(self) -> None: """stop all enabled devices""" @@ -279,7 +267,7 @@ class DeviceServer(RPCMixin, BECService): devices = instr.content["device"] if not isinstance(devices, list): devices = [devices] - pipe = self.producer.pipeline() + pipe = self.connector.pipeline() for dev in devices: obj = self.device_manager.devices.get(dev) obj.metadata = instr.metadata @@ -288,11 +276,11 @@ class DeviceServer(RPCMixin, BECService): dev_msg = messages.DeviceReqStatusMessage( device=dev, success=True, metadata=instr.metadata ) - self.producer.set_and_publish(MessageEndpoints.device_req_status(dev), dev_msg, pipe) + self.connector.set_and_publish(MessageEndpoints.device_req_status(dev), dev_msg, pipe) pipe.execute() def _status_callback(self, status): - pipe = self.producer.pipeline() + pipe = self.connector.pipeline() if hasattr(status, "device"): obj = status.device else: @@ -302,12 +290,12 @@ class DeviceServer(RPCMixin, BECService): device=device_name, success=status.success, metadata=status.instruction.metadata ) logger.debug(f"req status for device {device_name}: {status.success}") - self.producer.set_and_publish( + self.connector.set_and_publish( MessageEndpoints.device_req_status(device_name), dev_msg, pipe ) response = status.instruction.metadata.get("response") if response: - self.producer.lpush( + self.connector.lpush( MessageEndpoints.device_req_status(status.instruction.metadata["RID"]), dev_msg, pipe, @@ -328,7 +316,7 @@ class DeviceServer(RPCMixin, BECService): dev_config_msg = messages.DeviceMessage( signals=obj.root.read_configuration(), metadata=metadata ) - self.producer.set_and_publish( + self.connector.set_and_publish( MessageEndpoints.device_read_configuration(obj.root.name), dev_config_msg, pipe ) @@ -342,7 +330,7 @@ class DeviceServer(RPCMixin, BECService): def _read_and_update_devices(self, devices: list[str], metadata: dict) -> list: start = time.time() - pipe = self.producer.pipeline() + pipe = self.connector.pipeline() signal_container = [] for dev in devices: device_root = dev.split(".")[0] @@ -354,17 +342,17 @@ class DeviceServer(RPCMixin, BECService): except Exception as exc: signals = self._retry_obj_method(dev, obj, "read", exc) - self.producer.set_and_publish( + self.connector.set_and_publish( MessageEndpoints.device_read(device_root), messages.DeviceMessage(signals=signals, metadata=metadata), pipe, ) - self.producer.set_and_publish( + self.connector.set_and_publish( MessageEndpoints.device_readback(device_root), messages.DeviceMessage(signals=signals, metadata=metadata), pipe, ) - self.producer.set( + self.connector.set( MessageEndpoints.device_status(device_root), messages.DeviceStatusMessage(device=device_root, status=0, metadata=metadata), pipe, @@ -377,7 +365,7 @@ class DeviceServer(RPCMixin, BECService): def _read_config_and_update_devices(self, devices: list[str], metadata: dict) -> list: start = time.time() - pipe = self.producer.pipeline() + pipe = self.connector.pipeline() signal_container = [] for dev in devices: self.device_manager.devices.get(dev).metadata = metadata @@ -387,7 +375,7 @@ class DeviceServer(RPCMixin, BECService): signal_container.append(signals) except Exception as exc: signals = self._retry_obj_method(dev, obj, "read_configuration", exc) - self.producer.set_and_publish( + self.connector.set_and_publish( MessageEndpoints.device_read_configuration(dev), messages.DeviceMessage(signals=signals, metadata=metadata), pipe, @@ -420,9 +408,11 @@ class DeviceServer(RPCMixin, BECService): f"Failed to run {method} on device {device_root}. Trying to load an old value." ) if method == "read": - old_msg = self.producer.get(MessageEndpoints.device_read(device_root)) + old_msg = self.connector.get(MessageEndpoints.device_read(device_root)) elif method == "read_configuration": - old_msg = self.producer.get(MessageEndpoints.device_read_configuration(device_root)) + old_msg = self.connector.get( + MessageEndpoints.device_read_configuration(device_root) + ) else: raise ValueError(f"Unknown method {method}.") if not old_msg: @@ -435,7 +425,7 @@ class DeviceServer(RPCMixin, BECService): if not isinstance(devices, list): devices = [devices] - pipe = self.producer.pipeline() + pipe = self.connector.pipeline() for dev in devices: obj = self.device_manager.devices[dev].obj if hasattr(obj, "_staged"): @@ -444,7 +434,7 @@ class DeviceServer(RPCMixin, BECService): logger.info(f"Device {obj.name} was already staged and will be first unstaged.") self.device_manager.devices[dev].obj.unstage() self.device_manager.devices[dev].obj.stage() - self.producer.set( + self.connector.set( MessageEndpoints.device_staged(dev), messages.DeviceStatusMessage(device=dev, status=1, metadata=instr.metadata), pipe, @@ -456,7 +446,7 @@ class DeviceServer(RPCMixin, BECService): if not isinstance(devices, list): devices = [devices] - pipe = self.producer.pipeline() + pipe = self.connector.pipeline() for dev in devices: obj = self.device_manager.devices[dev].obj if hasattr(obj, "_staged"): @@ -465,7 +455,7 @@ class DeviceServer(RPCMixin, BECService): self.device_manager.devices[dev].obj.unstage() else: logger.debug(f"Device {obj.name} was already unstaged.") - self.producer.set( + self.connector.set( MessageEndpoints.device_staged(dev), messages.DeviceStatusMessage(device=dev, status=0, metadata=instr.metadata), pipe, diff --git a/device_server/device_server/devices/config_update_handler.py b/device_server/device_server/devices/config_update_handler.py index 513003f3..089af643 100644 --- a/device_server/device_server/devices/config_update_handler.py +++ b/device_server/device_server/devices/config_update_handler.py @@ -16,17 +16,11 @@ class ConfigUpdateHandler: def __init__(self, device_manager: DeviceManagerDS) -> None: self.device_manager = device_manager self.connector = self.device_manager.connector - self._config_request_handler = None - - self._start_config_handler() - - def _start_config_handler(self) -> None: - self._config_request_handler = self.connector.consumer( + self.connector.register( MessageEndpoints.device_server_config_request(), cb=self._device_config_callback, parent=self, ) - self._config_request_handler.start() @staticmethod def _device_config_callback(msg, *, parent, **_kwargs) -> None: @@ -74,7 +68,7 @@ class ConfigUpdateHandler: accepted=accepted, message=error_msg, metadata=metadata ) RID = metadata.get("RID") - self.device_manager.producer.set( + self.device_manager.connector.set( MessageEndpoints.device_config_request_response(RID), msg, expire=60 ) @@ -97,7 +91,7 @@ class ConfigUpdateHandler: "low": device.obj.low_limit_travel.get(), "high": device.obj.high_limit_travel.get(), } - self.device_manager.producer.set_and_publish( + self.device_manager.connector.set_and_publish( MessageEndpoints.device_limits(device.name), messages.DeviceMessage(signals=limits), ) diff --git a/device_server/device_server/devices/devicemanager.py b/device_server/device_server/devices/devicemanager.py index 22c6a14e..8bf3843d 100644 --- a/device_server/device_server/devices/devicemanager.py +++ b/device_server/device_server/devices/devicemanager.py @@ -51,7 +51,7 @@ class DSDevice(DeviceBase): self.metadata = {} self.initialized = False - def initialize_device_buffer(self, producer): + def initialize_device_buffer(self, connector): """initialize the device read and readback buffer on redis with a new reading""" dev_msg = messages.DeviceMessage(signals=self.obj.read(), metadata={}) dev_config_msg = messages.DeviceMessage(signals=self.obj.read_configuration(), metadata={}) @@ -62,14 +62,14 @@ class DSDevice(DeviceBase): } else: limits = None - pipe = producer.pipeline() - producer.set_and_publish(MessageEndpoints.device_readback(self.name), dev_msg, pipe=pipe) - producer.set(topic=MessageEndpoints.device_read(self.name), msg=dev_msg, pipe=pipe) - producer.set_and_publish( + pipe = connector.pipeline() + connector.set_and_publish(MessageEndpoints.device_readback(self.name), dev_msg, pipe=pipe) + connector.set(topic=MessageEndpoints.device_read(self.name), msg=dev_msg, pipe=pipe) + connector.set_and_publish( MessageEndpoints.device_read_configuration(self.name), dev_config_msg, pipe=pipe ) if limits is not None: - producer.set_and_publish( + connector.set_and_publish( MessageEndpoints.device_limits(self.name), messages.DeviceMessage(signals=limits), pipe=pipe, @@ -318,7 +318,7 @@ class DeviceManagerDS(DeviceManagerBase): self.update_config(obj, config) # refresh the device info - pipe = self.producer.pipeline() + pipe = self.connector.pipeline() self.reset_device_data(obj, pipe) self.publish_device_info(obj, pipe) pipe.execute() @@ -369,7 +369,7 @@ class DeviceManagerDS(DeviceManagerBase): def initialize_enabled_device(self, opaas_obj): """connect to an enabled device and initialize the device buffer""" self.connect_device(opaas_obj.obj) - opaas_obj.initialize_device_buffer(self.producer) + opaas_obj.initialize_device_buffer(self.connector) @staticmethod def disconnect_device(obj): @@ -420,7 +420,7 @@ class DeviceManagerDS(DeviceManagerBase): """ interface = get_device_info(obj, {}) - self.producer.set( + self.connector.set( MessageEndpoints.device_info(obj.name), messages.DeviceInfoMessage(device=obj.name, info=interface), pipe, @@ -428,9 +428,9 @@ class DeviceManagerDS(DeviceManagerBase): def reset_device_data(self, obj: OphydObject, pipe=None) -> None: """delete all device data and device info""" - self.producer.delete(MessageEndpoints.device_status(obj.name), pipe) - self.producer.delete(MessageEndpoints.device_read(obj.name), pipe) - self.producer.delete(MessageEndpoints.device_info(obj.name), pipe) + self.connector.delete(MessageEndpoints.device_status(obj.name), pipe) + self.connector.delete(MessageEndpoints.device_read(obj.name), pipe) + self.connector.delete(MessageEndpoints.device_info(obj.name), pipe) def _obj_callback_readback(self, *_args, obj: OphydObject, **kwargs): if obj.connected: @@ -438,8 +438,8 @@ class DeviceManagerDS(DeviceManagerBase): signals = obj.read() metadata = self.devices.get(obj.root.name).metadata dev_msg = messages.DeviceMessage(signals=signals, metadata=metadata) - pipe = self.producer.pipeline() - self.producer.set_and_publish(MessageEndpoints.device_readback(name), dev_msg, pipe) + pipe = self.connector.pipeline() + self.connector.set_and_publish(MessageEndpoints.device_readback(name), dev_msg, pipe) pipe.execute() @typechecked @@ -466,7 +466,7 @@ class DeviceManagerDS(DeviceManagerBase): metadata = self.devices[name].metadata msg = messages.DeviceMonitorMessage(device=name, data=value, metadata=metadata) stream_msg = {"data": msg} - self.producer.xadd( + self.connector.xadd( MessageEndpoints.device_monitor(name), stream_msg, max_size=min(100, int(max_size // dsize)), @@ -476,7 +476,7 @@ class DeviceManagerDS(DeviceManagerBase): device = kwargs["obj"].root.name status = 0 metadata = self.devices[device].metadata - self.producer.send( + self.connector.send( MessageEndpoints.device_status(device), messages.DeviceStatusMessage(device=device, status=status, metadata=metadata), ) @@ -489,8 +489,8 @@ class DeviceManagerDS(DeviceManagerBase): device = kwargs["obj"].root.name status = int(kwargs.get("value")) metadata = self.devices[device].metadata - self.producer.set( - MessageEndpoints.device_status(kwargs["obj"].root.name), + self.connector.set( + MessageEndpoints.device_status(device), messages.DeviceStatusMessage(device=device, status=status, metadata=metadata), ) @@ -521,12 +521,12 @@ class DeviceManagerDS(DeviceManagerBase): ) ) ds_obj.emitted_points[metadata["scanID"]] = max_points - pipe = self.producer.pipeline() - self.producer.send(MessageEndpoints.device_read(obj.root.name), bundle, pipe=pipe) + pipe = self.connector.pipeline() + self.connector.send(MessageEndpoints.device_read(obj.root.name), bundle, pipe=pipe) msg = messages.DeviceStatusMessage( device=obj.root.name, status=max_points, metadata=metadata ) - self.producer.set_and_publish( + self.connector.set_and_publish( MessageEndpoints.device_progress(obj.root.name), msg, pipe=pipe ) pipe.execute() @@ -536,4 +536,4 @@ class DeviceManagerDS(DeviceManagerBase): msg = messages.ProgressMessage( value=value, max_value=max_value, done=done, metadata=metadata ) - self.producer.set_and_publish(MessageEndpoints.device_progress(obj.root.name), msg) + self.connector.set_and_publish(MessageEndpoints.device_progress(obj.root.name), msg) diff --git a/device_server/device_server/rpc_mixin.py b/device_server/device_server/rpc_mixin.py index d1e1d255..ce27d095 100644 --- a/device_server/device_server/rpc_mixin.py +++ b/device_server/device_server/rpc_mixin.py @@ -70,7 +70,7 @@ class RPCMixin: def _send_rpc_result_to_client( self, device: str, instr_params: dict, res: Any, result: StringIO ): - self.producer.set( + self.connector.set( MessageEndpoints.device_rpc(instr_params.get("rpc_id")), messages.DeviceRPCMessage( device=device, return_val=res, out=result.getvalue(), success=True @@ -175,7 +175,7 @@ class RPCMixin: } logger.info(f"Received exception: {exc_formatted}, {exc}") instr_params = instr.content.get("parameter") - self.producer.set( + self.connector.set( MessageEndpoints.device_rpc(instr_params.get("rpc_id")), messages.DeviceRPCMessage( device=instr.content["device"], return_val=None, out=exc_formatted, success=False diff --git a/device_server/tests/test_device_manager_ds.py b/device_server/tests/test_device_manager_ds.py index e042eaf9..8310e0b3 100644 --- a/device_server/tests/test_device_manager_ds.py +++ b/device_server/tests/test_device_manager_ds.py @@ -6,7 +6,7 @@ import numpy as np import pytest import yaml from bec_lib import MessageEndpoints, messages -from bec_lib.tests.utils import ConnectorMock, ProducerMock, create_session_from_config +from bec_lib.tests.utils import ConnectorMock, create_session_from_config from device_server.devices.devicemanager import DeviceManagerDS @@ -52,7 +52,7 @@ def load_device_manager(): service_mock = mock.MagicMock() service_mock.connector = ConnectorMock("", store_data=False) device_manager = DeviceManagerDS(service_mock, "") - device_manager.producer = service_mock.connector.producer() + device_manager.connector = service_mock.connector device_manager.config_update_handler = mock.MagicMock() with open(f"{dir_path}/tests/test_config.yaml", "r") as session_file: device_manager._session = create_session_from_config(yaml.safe_load(session_file)) @@ -133,10 +133,10 @@ def test_flyer_event_callback(): device_manager._obj_flyer_callback( obj=samx.obj, value={"data": {"idata": np.random.rand(20), "edata": np.random.rand(20)}} ) - pipe = device_manager.producer.pipeline() + pipe = device_manager.connector.pipeline() bundle, progress = pipe._pipe_buffer[-2:] - # check producer method + # check connector method assert bundle[0] == "send" assert progress[0] == "set_and_publish" @@ -157,9 +157,9 @@ def test_obj_progress_callback(): samx = device_manager.devices.samx samx.metadata = {"scanID": "12345"} - with mock.patch.object(device_manager, "producer") as mock_producer: + with mock.patch.object(device_manager, "connector") as mock_connector: device_manager._obj_progress_callback(obj=samx.obj, value=1, max_value=2, done=False) - mock_producer.set_and_publish.assert_called_once_with( + mock_connector.set_and_publish.assert_called_once_with( MessageEndpoints.device_progress("samx"), messages.ProgressMessage( value=1, max_value=2, done=False, metadata={"scanID": "12345"} @@ -176,9 +176,9 @@ def test_obj_monitor_callback(value): eiger.metadata = {"scanID": "12345"} value_size = len(value.tobytes()) / 1e6 # MB max_size = 100 - with mock.patch.object(device_manager, "producer") as mock_producer: + with mock.patch.object(device_manager, "connector") as mock_connector: device_manager._obj_callback_monitor(obj=eiger.obj, value=value) - mock_producer.xadd.assert_called_once_with( + mock_connector.xadd.assert_called_once_with( MessageEndpoints.device_monitor(eiger.name), { "data": messages.DeviceMonitorMessage( diff --git a/device_server/tests/test_device_server.py b/device_server/tests/test_device_server.py index 3cf6967c..0f69f8c8 100644 --- a/device_server/tests/test_device_server.py +++ b/device_server/tests/test_device_server.py @@ -7,7 +7,7 @@ from bec_lib import Alarms, MessageEndpoints, ServiceConfig, messages from bec_lib.device import OnFailure from bec_lib.messages import BECStatus from bec_lib.redis_connector import MessageObject -from bec_lib.tests.utils import ConnectorMock, ConsumerMock +from bec_lib.tests.utils import ConnectorMock from ophyd import Staged from ophyd.utils import errors as ophyd_errors from test_device_manager_ds import device_manager, load_device_manager @@ -54,8 +54,6 @@ def test_start(device_server_mock): device_server.start() - assert device_server.threads - assert isinstance(device_server.threads[0], ConsumerMock) assert device_server.status == BECStatus.RUNNING @@ -187,10 +185,10 @@ def test_stop_devices(device_server_mock): ), ], ) -def test_consumer_interception_callback(device_server_mock, msg, stop_called): +def test_register_interception_callback(device_server_mock, msg, stop_called): device_server = device_server_mock with mock.patch.object(device_server, "stop_devices") as stop: - device_server.consumer_interception_callback(msg, parent=device_server) + device_server.register_interception_callback(msg, parent=device_server) if stop_called: stop.assert_called_once() else: @@ -640,7 +638,7 @@ def test_set_device(device_server_mock, instr): while True: res = [ msg - for msg in device_server.producer.message_sent + for msg in device_server.connector.message_sent if msg["queue"] == MessageEndpoints.device_req_status("samx") ] if res: @@ -676,7 +674,7 @@ def test_read_device(device_server_mock, instr): for device in devices: res = [ msg - for msg in device_server.producer.message_sent + for msg in device_server.connector.message_sent if msg["queue"] == MessageEndpoints.device_read(device) ] assert res[-1]["msg"].metadata["RID"] == instr.metadata["RID"] @@ -690,7 +688,7 @@ def test_read_config_and_update_devices(device_server_mock, devices): for device in devices: res = [ msg - for msg in device_server.producer.message_sent + for msg in device_server.connector.message_sent if msg["queue"] == MessageEndpoints.device_read_configuration(device) ] config = device_server.device_manager.devices[device].obj.read_configuration() @@ -755,8 +753,8 @@ def test_retry_obj_method_buffer(device_server_mock, instr): return signals_before = getattr(samx.obj, instr)() - device_server.producer = mock.MagicMock() - device_server.producer.get.return_value = messages.DeviceMessage( + device_server.connector = mock.MagicMock() + device_server.connector.get.return_value = messages.DeviceMessage( signals=signals_before, metadata={"RID": "test", "stream": "primary"} ) diff --git a/device_server/tests/test_rpc_mixin.py b/device_server/tests/test_rpc_mixin.py index 11f04d74..ba5bdf56 100644 --- a/device_server/tests/test_rpc_mixin.py +++ b/device_server/tests/test_rpc_mixin.py @@ -13,7 +13,7 @@ from device_server.rpc_mixin import RPCMixin def rpc_cls(): rpc_mixin = RPCMixin() rpc_mixin.connector = mock.MagicMock() - rpc_mixin.producer = mock.MagicMock() + rpc_mixin.connector = mock.MagicMock() rpc_mixin.device_manager = mock.MagicMock() yield rpc_mixin @@ -93,7 +93,7 @@ def test_get_result_from_rpc_list_from_stage(rpc_cls): def test_send_rpc_exception(rpc_cls, instr): rpc_cls._send_rpc_exception(Exception(), instr) - rpc_cls.producer.set.assert_called_once_with( + rpc_cls.connector.set.assert_called_once_with( MessageEndpoints.device_rpc("rpc_id"), messages.DeviceRPCMessage( device="device", @@ -108,7 +108,7 @@ def test_send_rpc_result_to_client(rpc_cls): result = mock.MagicMock() result.getvalue.return_value = "result" rpc_cls._send_rpc_result_to_client("device", {"rpc_id": "rpc_id"}, 1, result) - rpc_cls.producer.set.assert_called_once_with( + rpc_cls.connector.set.assert_called_once_with( MessageEndpoints.device_rpc("rpc_id"), messages.DeviceRPCMessage(device="device", return_val=1, out="result", success=True), expire=1800, diff --git a/file_writer/file_writer/file_writer.py b/file_writer/file_writer/file_writer.py index c40e35d5..6d51df90 100644 --- a/file_writer/file_writer/file_writer.py +++ b/file_writer/file_writer/file_writer.py @@ -325,7 +325,7 @@ class NexusFileWriter(FileWriter): file_data[key] = val if not isinstance(val, list) else merge_dicts(val) msg_data = {"file_path": file_path, "data": file_data} msg = messages.FileContentMessage(**msg_data) - self.file_writer_manager.producer.set_and_publish(MessageEndpoints.file_content(), msg) + self.file_writer_manager.connector.set_and_publish(MessageEndpoints.file_content(), msg) with h5py.File(file_path, "w") as file: HDF5StorageWriter.write(writer_storage._storage, device_storage, file) diff --git a/file_writer/file_writer/file_writer_manager.py b/file_writer/file_writer/file_writer_manager.py index edc00872..bc7ec4a0 100644 --- a/file_writer/file_writer/file_writer_manager.py +++ b/file_writer/file_writer/file_writer_manager.py @@ -80,10 +80,13 @@ class FileWriterManager(BECService): self._lock = threading.RLock() self.file_writer_config = self._service_config.service_config.get("file_writer") self.writer_mixin = FileWriterMixin(self.file_writer_config) - self.producer = self.connector.producer() self._start_device_manager() - self._start_scan_segment_consumer() - self._start_scan_status_consumer() + self.connector.register( + patterns=MessageEndpoints.scan_segment(), cb=self._scan_segment_callback, parent=self + ) + self.connector.register( + MessageEndpoints.scan_status(), cb=self._scan_status_callback, parent=self + ) self.scan_storage = {} self.file_writer = NexusFileWriter(self) @@ -92,20 +95,7 @@ class FileWriterManager(BECService): self.device_manager = DeviceManagerBase(self) self.device_manager.initialize([self.bootstrap_server]) - def _start_scan_segment_consumer(self): - self._scan_segment_consumer = self.connector.consumer( - pattern=MessageEndpoints.scan_segment(), cb=self._scan_segment_callback, parent=self - ) - self._scan_segment_consumer.start() - - def _start_scan_status_consumer(self): - self._scan_status_consumer = self.connector.consumer( - MessageEndpoints.scan_status(), cb=self._scan_status_callback, parent=self - ) - self._scan_status_consumer.start() - - @staticmethod - def _scan_segment_callback(msg: MessageObject, *, parent: FileWriterManager): + def _scan_segment_callback(self, msg: MessageObject, *, parent: FileWriterManager): msgs = msg.value for scan_msg in msgs: parent.insert_to_scan_storage(scan_msg) @@ -188,7 +178,7 @@ class FileWriterManager(BECService): return if self.scan_storage[scanID].baseline: return - baseline = self.producer.get(MessageEndpoints.public_scan_baseline(scanID)) + baseline = self.connector.get(MessageEndpoints.public_scan_baseline(scanID)) if not baseline: return self.scan_storage[scanID].baseline = baseline.content["data"] @@ -205,13 +195,13 @@ class FileWriterManager(BECService): """ if not self.scan_storage.get(scanID): return - msgs = self.producer.keys(MessageEndpoints.public_file(scanID, "*")) + msgs = self.connector.keys(MessageEndpoints.public_file(scanID, "*")) if not msgs: return # extract name from 'public//file/' names = [msg.decode().split("/")[-1] for msg in msgs] - file_msgs = [self.producer.get(msg.decode()) for msg in msgs] + file_msgs = [self.connector.get(msg.decode()) for msg in msgs] if not file_msgs: return for name, file_msg in zip(names, file_msgs): @@ -236,7 +226,7 @@ class FileWriterManager(BECService): if not self.scan_storage.get(scanID): return # get all async devices - async_device_keys = self.producer.keys(MessageEndpoints.device_async_readback(scanID, "*")) + async_device_keys = self.connector.keys(MessageEndpoints.device_async_readback(scanID, "*")) if not async_device_keys: return for device_key in async_device_keys: @@ -244,7 +234,7 @@ class FileWriterManager(BECService): device_name = key.split(MessageEndpoints.device_async_readback(scanID, ""))[-1].split( ":" )[0] - msgs = self.producer.xrange(key, min="-", max="+") + msgs = self.connector.xrange(key, min="-", max="+") if not msgs: continue self._process_async_data(msgs, scanID, device_name) @@ -298,7 +288,7 @@ class FileWriterManager(BECService): try: file_path = self.writer_mixin.compile_full_filename(scan, file_suffix) - self.producer.set_and_publish( + self.connector.set_and_publish( MessageEndpoints.public_file(scanID, "master"), messages.FileMessage(file_path=file_path, done=False), ) @@ -319,7 +309,7 @@ class FileWriterManager(BECService): ) successful = False self.scan_storage.pop(scanID) - self.producer.set_and_publish( + self.connector.set_and_publish( MessageEndpoints.public_file(scanID, "master"), messages.FileMessage(file_path=file_path, successful=successful), ) diff --git a/file_writer/tests/test_file_writer_manager.py b/file_writer/tests/test_file_writer_manager.py index 24ef4562..9da6b8b6 100644 --- a/file_writer/tests/test_file_writer_manager.py +++ b/file_writer/tests/test_file_writer_manager.py @@ -25,7 +25,7 @@ def load_FileWriter(): service_mock = mock.MagicMock() service_mock.connector = ConnectorMock("") device_manager = DeviceManagerBase(service_mock, "") - device_manager.producer = service_mock.connector.producer() + device_manager.connector = service_mock.connector with open(f"{dir_path}/tests/test_config.yaml", "r") as session_file: device_manager._session = create_session_from_config(yaml.safe_load(session_file)) device_manager._load_session() @@ -152,13 +152,13 @@ def test_write_file_raises_alarm_on_error(): def test_update_baseline_reading(): file_manager = load_FileWriter() file_manager.scan_storage["scanID"] = ScanStorage(10, "scanID") - with mock.patch.object(file_manager, "producer") as mock_producer: - mock_producer.get.return_value = messages.ScanBaselineMessage( + with mock.patch.object(file_manager, "connector") as mock_connector: + mock_connector.get.return_value = messages.ScanBaselineMessage( scanID="scanID", data={"data": "data"} ) file_manager.update_baseline_reading("scanID") assert file_manager.scan_storage["scanID"].baseline == {"data": "data"} - mock_producer.get.assert_called_once_with(MessageEndpoints.public_scan_baseline("scanID")) + mock_connector.get.assert_called_once_with(MessageEndpoints.public_scan_baseline("scanID")) def test_scan_storage_append(): @@ -178,30 +178,30 @@ def test_scan_storage_ready_to_write(): def test_update_file_references(): file_manager = load_FileWriter() - with mock.patch.object(file_manager, "producer") as mock_producer: + with mock.patch.object(file_manager, "connector") as mock_connector: file_manager.update_file_references("scanID") - mock_producer.keys.assert_not_called() + mock_connector.keys.assert_not_called() def test_update_file_references_gets_keys(): file_manager = load_FileWriter() file_manager.scan_storage["scanID"] = ScanStorage(10, "scanID") - with mock.patch.object(file_manager, "producer") as mock_producer: + with mock.patch.object(file_manager, "connector") as mock_connector: file_manager.update_file_references("scanID") - mock_producer.keys.assert_called_once_with(MessageEndpoints.public_file("scanID", "*")) + mock_connector.keys.assert_called_once_with(MessageEndpoints.public_file("scanID", "*")) def test_update_async_data(): file_manager = load_FileWriter() file_manager.scan_storage["scanID"] = ScanStorage(10, "scanID") - with mock.patch.object(file_manager, "producer") as mock_producer: + with mock.patch.object(file_manager, "connector") as mock_connector: with mock.patch.object(file_manager, "_process_async_data") as mock_process: key = MessageEndpoints.device_async_readback("scanID", "dev1") - mock_producer.keys.return_value = [key.encode()] + mock_connector.keys.return_value = [key.encode()] data = [(b"0-0", b'{"data": "data"}')] - mock_producer.xrange.return_value = data + mock_connector.xrange.return_value = data file_manager.update_async_data("scanID") - mock_producer.xrange.assert_called_once_with(key, min="-", max="+") + mock_connector.xrange.assert_called_once_with(key, min="-", max="+") mock_process.assert_called_once_with(data, "scanID", "dev1") diff --git a/scan_bundler/scan_bundler/bec_emitter.py b/scan_bundler/scan_bundler/bec_emitter.py index f50e19dc..2739dccd 100644 --- a/scan_bundler/scan_bundler/bec_emitter.py +++ b/scan_bundler/scan_bundler/bec_emitter.py @@ -14,7 +14,7 @@ if TYPE_CHECKING: class BECEmitter(EmitterBase): def __init__(self, scan_bundler: ScanBundler) -> None: - super().__init__(scan_bundler.producer) + super().__init__(scan_bundler.connector) self.scan_bundler = scan_bundler def on_scan_point_emit(self, scanID: str, pointID: int): @@ -46,9 +46,16 @@ class BECEmitter(EmitterBase): data=sb.sync_storage[scanID]["baseline"], metadata=sb.sync_storage[scanID]["info"], ) - pipe = sb.producer.pipeline() - sb.producer.set( - MessageEndpoints.public_scan_baseline(scanID=scanID), msg, expire=1800, pipe=pipe + pipe = sb.connector.pipeline() + sb.connector.set( + MessageEndpoints.public_scan_baseline(scanID=scanID), + msg, + expire=1800, + pipe=pipe, + ) + sb.connector.set_and_publish( + MessageEndpoints.scan_baseline(), + msg, + pipe=pipe, ) - sb.producer.set_and_publish(MessageEndpoints.scan_baseline(), msg, pipe=pipe) pipe.execute() diff --git a/scan_bundler/scan_bundler/bluesky_emitter.py b/scan_bundler/scan_bundler/bluesky_emitter.py index 5e73d175..f1a63e4c 100644 --- a/scan_bundler/scan_bundler/bluesky_emitter.py +++ b/scan_bundler/scan_bundler/bluesky_emitter.py @@ -17,7 +17,7 @@ if TYPE_CHECKING: class BlueskyEmitter(EmitterBase): def __init__(self, scan_bundler: ScanBundler) -> None: - super().__init__(scan_bundler.producer) + super().__init__(scan_bundler.connector) self.scan_bundler = scan_bundler self.bluesky_metadata = {} @@ -27,7 +27,7 @@ class BlueskyEmitter(EmitterBase): self.bluesky_metadata[scanID] = {} doc = self._get_run_start_document(scanID) self.bluesky_metadata[scanID]["start"] = doc - self.producer.raw_send(MessageEndpoints.bluesky_events(), msgpack.dumps(("start", doc))) + self.connector.raw_send(MessageEndpoints.bluesky_events(), msgpack.dumps(("start", doc))) self.send_descriptor_document(scanID) def _get_run_start_document(self, scanID) -> dict: @@ -71,7 +71,7 @@ class BlueskyEmitter(EmitterBase): """Bluesky only: send descriptor document""" doc = self._get_descriptor_document(scanID) self.bluesky_metadata[scanID]["descriptor"] = doc - self.producer.raw_send( + self.connector.raw_send( MessageEndpoints.bluesky_events(), msgpack.dumps(("descriptor", doc)) ) @@ -85,7 +85,7 @@ class BlueskyEmitter(EmitterBase): logger.warning(f"Failed to remove {scanID} from {storage}.") def send_bluesky_scan_point(self, scanID, pointID) -> None: - self.producer.raw_send( + self.connector.raw_send( MessageEndpoints.bluesky_events(), msgpack.dumps(("event", self._prepare_bluesky_event_data(scanID, pointID))), ) diff --git a/scan_bundler/scan_bundler/emitter.py b/scan_bundler/scan_bundler/emitter.py index 709f9d45..b6ffdfca 100644 --- a/scan_bundler/scan_bundler/emitter.py +++ b/scan_bundler/scan_bundler/emitter.py @@ -6,16 +6,16 @@ from bec_lib import messages class EmitterBase: - def __init__(self, producer) -> None: + def __init__(self, connector) -> None: self._send_buffer = Queue() - self.producer = producer - self._start_buffered_producer() + self.connector = connector + self._start_buffered_connector() - def _start_buffered_producer(self): - self._buffered_producer_thread = threading.Thread( + def _start_buffered_connector(self): + self._buffered_connector_thread = threading.Thread( target=self._buffered_publish, daemon=True, name="buffered_publisher" ) - self._buffered_producer_thread.start() + self._buffered_connector_thread.start() def add_message(self, msg: messages.BECMessage, endpoint: str, public: str = None): self._send_buffer.put((msg, endpoint, public)) @@ -37,20 +37,20 @@ class EmitterBase: time.sleep(0.1) return - pipe = self.producer.pipeline() + pipe = self.connector.pipeline() msgs = messages.BundleMessage() _, endpoint, _ = msgs_to_send[0] for msg, endpoint, public in msgs_to_send: msg_dump = msg msgs.append(msg_dump) if public: - self.producer.set( + self.connector.set( public, msg_dump, pipe=pipe, expire=1800, ) - self.producer.send(endpoint, msgs, pipe=pipe) + self.connector.send(endpoint, msgs, pipe=pipe) pipe.execute() def on_init(self, scanID: str): diff --git a/scan_bundler/scan_bundler/scan_bundler.py b/scan_bundler/scan_bundler/scan_bundler.py index 8e3699a8..a15402c7 100644 --- a/scan_bundler/scan_bundler/scan_bundler.py +++ b/scan_bundler/scan_bundler/scan_bundler.py @@ -20,9 +20,23 @@ class ScanBundler(BECService): self.device_manager = None self._start_device_manager() - self._start_device_read_consumer() - self._start_scan_queue_consumer() - self._start_scan_status_consumer() + self.connector.register( + patterns=MessageEndpoints.device_read("*"), + cb=self._device_read_callback, + name="device_read_register", + ) + self.connector.register( + MessageEndpoints.scan_queue_status(), + cb=self._scan_queue_callback, + group_id="scan_bundler", + name="scan_queue_register", + ) + self.connector.register( + MessageEndpoints.scan_status(), + cb=self._scan_status_callback, + group_id="scan_bundler", + name="scan_status_register", + ) self.sync_storage = {} self.monitored_devices = {} @@ -56,56 +70,24 @@ class ScanBundler(BECService): self.device_manager = DeviceManagerBase(self) self.device_manager.initialize(self.bootstrap_server) - def _start_device_read_consumer(self): - self._device_read_consumer = self.connector.consumer( - pattern=MessageEndpoints.device_read("*"), - cb=self._device_read_callback, - parent=self, - name="device_read_consumer", - ) - self._device_read_consumer.start() - - def _start_scan_queue_consumer(self): - self._scan_queue_consumer = self.connector.consumer( - MessageEndpoints.scan_queue_status(), - cb=self._scan_queue_callback, - group_id="scan_bundler", - parent=self, - name="scan_queue_consumer", - ) - self._scan_queue_consumer.start() - - def _start_scan_status_consumer(self): - self._scan_status_consumer = self.connector.consumer( - MessageEndpoints.scan_status(), - cb=self._scan_status_callback, - group_id="scan_bundler", - parent=self, - name="scan_status_consumer", - ) - self._scan_status_consumer.start() - - @staticmethod - def _device_read_callback(msg, parent, **_kwargs): + def _device_read_callback(self, msg, **_kwargs): # pylint: disable=protected-access dev = msg.topic.split(MessageEndpoints._device_read + "/")[-1] msgs = msg.value logger.debug(f"Received reading from device {dev}") if not isinstance(msgs, list): msgs = [msgs] - task = parent.executor.submit(parent._add_device_to_storage, msgs, dev) - parent.executor_tasks.append(task) + task = self.executor.submit(self._add_device_to_storage, msgs, dev) + self.executor_tasks.append(task) - @staticmethod - def _scan_queue_callback(msg, parent, **_kwargs): + def _scan_queue_callback(self, msg, **_kwargs): msg = msg.value logger.trace(msg) - parent.current_queue = msg.content["queue"]["primary"].get("info") + self.current_queue = msg.content["queue"]["primary"].get("info") - @staticmethod - def _scan_status_callback(msg, parent, **_kwargs): + def _scan_status_callback(self, msg, **_kwargs): msg = msg.value - parent.handle_scan_status_message(msg) + self.handle_scan_status_message(msg) def handle_scan_status_message(self, msg: messages.ScanStatusMessage) -> None: """handle scan status messages""" @@ -270,7 +252,7 @@ class ScanBundler(BECService): } def _get_scan_status_history(self, length): - return self.producer.lrange(MessageEndpoints.scan_status() + "_list", length * -1, -1) + return self.connector.lrange(MessageEndpoints.scan_status() + "_list", length * -1, -1) def _wait_for_scanID(self, scanID, timeout_time=10): elapsed_time = 0 @@ -344,10 +326,10 @@ class ScanBundler(BECService): self.sync_storage[scanID][pointID][dev.name] = read def _get_last_device_readback(self, devices: list) -> list: - pipe = self.producer.pipeline() + pipe = self.connector.pipeline() for dev in devices: - self.producer.get(MessageEndpoints.device_readback(dev.name), pipe) - return [msg.content["signals"] for msg in self.producer.execute_pipeline(pipe)] + self.connector.get(MessageEndpoints.device_readback(dev.name), pipe) + return [msg.content["signals"] for msg in self.connector.execute_pipeline(pipe)] def cleanup_storage(self): """remove old scanIDs to free memory""" diff --git a/scan_bundler/tests/test_bec_emitter.py b/scan_bundler/tests/test_bec_emitter.py index 2826452c..78c0f42e 100644 --- a/scan_bundler/tests/test_bec_emitter.py +++ b/scan_bundler/tests/test_bec_emitter.py @@ -57,16 +57,16 @@ def test_send_baseline_BEC(): sb.sync_storage[scanID] = {"info": {}, "status": "open", "sent": set()} sb.sync_storage[scanID]["baseline"] = {} msg = messages.ScanBaselineMessage(scanID=scanID, data=sb.sync_storage[scanID]["baseline"]) - with mock.patch.object(sb, "producer") as producer: + with mock.patch.object(sb, "connector") as connector: bec_emitter._send_baseline(scanID) - pipe = producer.pipeline() - producer.set.assert_called_once_with( + pipe = connector.pipeline() + connector.set.assert_called_once_with( MessageEndpoints.public_scan_baseline(scanID), msg, expire=1800, pipe=pipe, ) - producer.set_and_publish.assert_called_once_with( + connector.set_and_publish.assert_called_once_with( MessageEndpoints.scan_baseline(), msg, pipe=pipe, diff --git a/scan_bundler/tests/test_bluesky_emitter.py b/scan_bundler/tests/test_bluesky_emitter.py index ed84762c..64219728 100644 --- a/scan_bundler/tests/test_bluesky_emitter.py +++ b/scan_bundler/tests/test_bluesky_emitter.py @@ -13,7 +13,7 @@ from scan_bundler.bluesky_emitter import BlueskyEmitter def test_run_start_document(scanID): sb = load_ScanBundlerMock() bls_emitter = BlueskyEmitter(sb) - with mock.patch.object(bls_emitter.producer, "raw_send") as send: + with mock.patch.object(bls_emitter.connector, "raw_send") as send: with mock.patch.object(bls_emitter, "send_descriptor_document") as send_descr: with mock.patch.object( bls_emitter, "_get_run_start_document", return_value={} @@ -45,7 +45,7 @@ def test_send_descriptor_document(): bls_emitter = BlueskyEmitter(sb) scanID = "lkajsdl" bls_emitter.bluesky_metadata[scanID] = {} - with mock.patch.object(bls_emitter.producer, "raw_send") as send: + with mock.patch.object(bls_emitter.connector, "raw_send") as send: with mock.patch.object( bls_emitter, "_get_descriptor_document", return_value={} ) as get_descr: diff --git a/scan_bundler/tests/test_emitter.py b/scan_bundler/tests/test_emitter.py index 6354141e..8bb3ace7 100644 --- a/scan_bundler/tests/test_emitter.py +++ b/scan_bundler/tests/test_emitter.py @@ -51,30 +51,30 @@ from scan_bundler.emitter import EmitterBase ], ) def test_publish_data(msgs): - producer = mock.MagicMock() - with mock.patch.object(EmitterBase, "_start_buffered_producer") as start: - emitter = EmitterBase(producer) + connector = mock.MagicMock() + with mock.patch.object(EmitterBase, "_start_buffered_connector") as start: + emitter = EmitterBase(connector) start.assert_called_once() with mock.patch.object(emitter, "_get_messages_from_buffer", return_value=msgs) as get_msgs: emitter._publish_data() get_msgs.assert_called_once() if not msgs: - producer.send.assert_not_called() + connector.send.assert_not_called() return - pipe = producer.pipeline() + pipe = connector.pipeline() msgs_bundle = messages.BundleMessage() _, endpoint, _ = msgs[0] for msg, endpoint, public in msgs: msg_dump = msg msgs_bundle.append(msg_dump) if public: - producer.set.assert_has_calls( - producer.set(public, msg_dump, pipe=pipe, expire=1800) + connector.set.assert_has_calls( + connector.set(public, msg_dump, pipe=pipe, expire=1800) ) - producer.send.assert_called_with(endpoint, msgs_bundle, pipe=pipe) + connector.send.assert_called_with(endpoint, msgs_bundle, pipe=pipe) @pytest.mark.parametrize( @@ -93,8 +93,8 @@ def test_publish_data(msgs): ], ) def test_add_message(msg, endpoint, public): - producer = mock.MagicMock() - emitter = EmitterBase(producer) + connector = mock.MagicMock() + emitter = EmitterBase(connector) emitter.add_message(msg, endpoint, public) msgs = emitter._get_messages_from_buffer() out_msg, out_endpoint, out_public = msgs[0] diff --git a/scan_bundler/tests/test_scan_bundler.py b/scan_bundler/tests/test_scan_bundler.py index 991fbf84..4e2575a2 100644 --- a/scan_bundler/tests/test_scan_bundler.py +++ b/scan_bundler/tests/test_scan_bundler.py @@ -36,7 +36,7 @@ def load_ScanBundlerMock(): service_mock = mock.MagicMock() service_mock.connector = ConnectorMock("") device_manager = ScanBundlerDeviceManagerMock(service_mock, "") - device_manager.producer = service_mock.connector.producer() + device_manager.connector = service_mock.connector with open(f"{dir_path}/tests/test_config.yaml", "r") as session_file: device_manager._session = create_session_from_config(yaml.safe_load(session_file)) device_manager._load_session() @@ -74,7 +74,7 @@ def test_device_read_callback(): msg.topic = MessageEndpoints.device_read("samx") with mock.patch.object(scan_bundler, "_add_device_to_storage") as add_dev: - scan_bundler._device_read_callback(msg, scan_bundler) + scan_bundler._device_read_callback(msg) add_dev.assert_called_once_with([dev_msg], "samx") @@ -157,7 +157,7 @@ def test_wait_for_scanID(scanID, storageID, scan_msg): ) def test_get_scan_status_history(msgs): sb = load_ScanBundlerMock() - with mock.patch.object(sb.producer, "lrange", return_value=[msg for msg in msgs]) as lrange: + with mock.patch.object(sb.connector, "lrange", return_value=[msg for msg in msgs]) as lrange: res = sb._get_scan_status_history(5) lrange.assert_called_once_with(MessageEndpoints.scan_status() + "_list", -5, -1) assert res == msgs @@ -371,7 +371,7 @@ def test_scan_queue_callback(queue_msg): sb = load_ScanBundlerMock() msg = MessageMock() msg.value = queue_msg - sb._scan_queue_callback(msg, sb) + sb._scan_queue_callback(msg) assert sb.current_queue == queue_msg.content["queue"]["primary"].get("info") @@ -399,7 +399,7 @@ def test_scan_status_callback(scan_msg): msg.value = scan_msg with mock.patch.object(sb, "handle_scan_status_message") as handle_scan_status_message_mock: - sb._scan_status_callback(msg, sb) + sb._scan_status_callback(msg) handle_scan_status_message_mock.assert_called_once_with(scan_msg) @@ -744,10 +744,10 @@ def test_get_last_device_readback(): signals={"samx": {"samx": 0.51, "setpoint": 0.5, "motor_is_moving": 0}}, metadata={"scanID": "laksjd", "readout_priority": "monitored"}, ) - with mock.patch.object(sb, "producer") as producer_mock: - producer_mock.execute_pipeline.return_value = [dev_msg] + with mock.patch.object(sb, "connector") as connector_mock: + connector_mock.execute_pipeline.return_value = [dev_msg] ret = sb._get_last_device_readback([sb.device_manager.devices.samx]) - assert producer_mock.get.mock_calls == [ - mock.call(MessageEndpoints.device_readback("samx"), producer_mock.pipeline()) + assert connector_mock.get.mock_calls == [ + mock.call(MessageEndpoints.device_readback("samx"), connector_mock.pipeline()) ] assert ret == [dev_msg.content["signals"]] diff --git a/scan_server/scan_plugins/LamNIFermatScan.py b/scan_server/scan_plugins/LamNIFermatScan.py index b643a62e..86661435 100644 --- a/scan_server/scan_plugins/LamNIFermatScan.py +++ b/scan_server/scan_plugins/LamNIFermatScan.py @@ -458,7 +458,7 @@ class LamNIFermatScan(ScanBase, LamNIMixin): yield from self.stubs.kickoff(device="rtx") while True: yield from self.stubs.read_and_wait(group="primary", wait_group="readout_primary") - msg = self.device_manager.producer.get(MessageEndpoints.device_status("rt_scan")) + msg = self.device_manager.connector.get(MessageEndpoints.device_status("rt_scan")) if msg: status = msg status_id = status.content.get("status", 1) diff --git a/scan_server/scan_plugins/owis_grid.py b/scan_server/scan_plugins/owis_grid.py index 5104e1ea..19a93cdb 100644 --- a/scan_server/scan_plugins/owis_grid.py +++ b/scan_server/scan_plugins/owis_grid.py @@ -163,7 +163,7 @@ class OwisGrid(AsyncFlyScanBase): def scan_progress(self) -> int: """Timeout of the progress bar. This gets updated in the frequency of scan segments""" - msg = self.device_manager.producer.get(MessageEndpoints.device_progress("mcs")) + msg = self.device_manager.connector.get(MessageEndpoints.device_progress("mcs")) if not msg: self.timeout_progress += 1 return self.timeout_progress diff --git a/scan_server/scan_plugins/sgalil_grid.py b/scan_server/scan_plugins/sgalil_grid.py index d890533c..0f0595fc 100644 --- a/scan_server/scan_plugins/sgalil_grid.py +++ b/scan_server/scan_plugins/sgalil_grid.py @@ -106,7 +106,7 @@ class SgalilGrid(AsyncFlyScanBase): def scan_progress(self) -> int: """Timeout of the progress bar. This gets updated in the frequency of scan segments""" - msg = self.device_manager.producer.get(MessageEndpoints.device_progress("mcs")) + msg = self.device_manager.connector.get(MessageEndpoints.device_progress("mcs")) if not msg: self.timeout_progress += 1 return self.timeout_progress diff --git a/scan_server/scan_server/device_validation.py b/scan_server/scan_server/device_validation.py index 59e4a84a..3b4d8535 100644 --- a/scan_server/scan_server/device_validation.py +++ b/scan_server/scan_server/device_validation.py @@ -10,8 +10,8 @@ class DeviceValidation: Mixin class for validation methods """ - def __init__(self, producer, worker): - self.producer = producer + def __init__(self, connector, worker): + self.connector = connector self.worker = worker def get_device_status(self, endpoint: MessageEndpoints, devices: list) -> list: @@ -25,10 +25,10 @@ class DeviceValidation: Returns: list: List of BECMessage objects """ - pipe = self.producer.pipeline() + pipe = self.connector.pipeline() for dev in devices: - self.producer.get(endpoint(dev), pipe) - return self.producer.execute_pipeline(pipe) + self.connector.get(endpoint(dev), pipe) + return self.connector.execute_pipeline(pipe) def devices_are_ready( self, diff --git a/scan_server/scan_server/scan_guard.py b/scan_server/scan_server/scan_guard.py index d9028b53..fd7f5831 100644 --- a/scan_server/scan_server/scan_guard.py +++ b/scan_server/scan_server/scan_guard.py @@ -27,25 +27,19 @@ class ScanGuard: self.parent = parent self.device_manager = self.parent.device_manager self.connector = self.parent.connector - self.producer = self.connector.producer() - self._start_scan_queue_request_consumer() - def _start_scan_queue_request_consumer(self): - self._scan_queue_request_consumer = self.connector.consumer( + self.connector.register( MessageEndpoints.scan_queue_request(), cb=self._scan_queue_request_callback, parent=self, ) - self._scan_queue_modification_request_consumer = self.connector.consumer( + self.connector.register( MessageEndpoints.scan_queue_modification_request(), cb=self._scan_queue_modification_request_callback, parent=self, ) - self._scan_queue_request_consumer.start() - self._scan_queue_modification_request_consumer.start() - def _is_valid_scan_request(self, request) -> ScanStatus: try: self._check_valid_request(request) @@ -63,7 +57,7 @@ class ScanGuard: raise ScanRejection("Invalid request.") def _check_valid_scan(self, request) -> None: - avail_scans = self.producer.get(MessageEndpoints.available_scans()) + avail_scans = self.connector.get(MessageEndpoints.available_scans()) scan_type = request.content.get("scan_type") if scan_type not in avail_scans.resource: raise ScanRejection(f"Unknown scan type {scan_type}.") @@ -140,7 +134,7 @@ class ScanGuard: message=scan_status.message, metadata=metadata, ) - self.device_manager.producer.send(sqrr, rrm) + self.device_manager.connector.send(sqrr, rrm) def _handle_scan_request(self, msg): """ @@ -181,10 +175,10 @@ class ScanGuard: self._send_scan_request_response(ScanStatus(), mod_msg.metadata) sqm = MessageEndpoints.scan_queue_modification() - self.device_manager.producer.send(sqm, mod_msg) + self.device_manager.connector.send(sqm, mod_msg) def _append_to_scan_queue(self, msg): logger.info("Appending new scan to queue") msg = msg sqi = MessageEndpoints.scan_queue_insert() - self.device_manager.producer.send(sqi, msg) + self.device_manager.connector.send(sqi, msg) diff --git a/scan_server/scan_server/scan_manager.py b/scan_server/scan_server/scan_manager.py index aad81460..fa4210be 100644 --- a/scan_server/scan_server/scan_manager.py +++ b/scan_server/scan_server/scan_manager.py @@ -101,7 +101,7 @@ class ScanManager: def publish_available_scans(self): """send all available scans to the broker""" - self.parent.producer.set( + self.parent.connector.set( MessageEndpoints.available_scans(), AvailableResourceMessage(resource=self.available_scans), ) diff --git a/scan_server/scan_server/scan_queue.py b/scan_server/scan_server/scan_queue.py index 559a461d..9da7df09 100644 --- a/scan_server/scan_server/scan_queue.py +++ b/scan_server/scan_server/scan_queue.py @@ -51,11 +51,10 @@ class QueueManager: def __init__(self, parent) -> None: self.parent = parent self.connector = parent.connector - self.producer = parent.producer self.num_queues = 1 self.key = "" self.queues = {} - self._start_scan_queue_consumer() + self._start_scan_queue_register() self._lock = threading.RLock() def add_to_queue(self, scan_queue: str, msg: messages.ScanQueueMessage, position=-1) -> None: @@ -91,17 +90,15 @@ class QueueManager: self.queues[queue_name] = ScanQueue(self, queue_name=queue_name) self.queues[queue_name].start_worker() - def _start_scan_queue_consumer(self) -> None: - self._scan_queue_consumer = self.connector.consumer( + def _start_scan_queue_register(self) -> None: + self.connector.register( MessageEndpoints.scan_queue_insert(), cb=self._scan_queue_callback, parent=self ) - self._scan_queue_modification_consumer = self.connector.consumer( + self.connector.register( MessageEndpoints.scan_queue_modification(), cb=self._scan_queue_modification_callback, parent=self, ) - self._scan_queue_consumer.start() - self._scan_queue_modification_consumer.start() @staticmethod def _scan_queue_callback(msg, parent, **_kwargs) -> None: @@ -233,7 +230,7 @@ class QueueManager: logger.info("New scan queue:") for queue in self.describe_queue(): logger.info(f"\n {queue}") - self.producer.set_and_publish( + self.connector.set_and_publish( MessageEndpoints.scan_queue_status(), messages.ScanQueueStatusMessage(queue=queue_export), ) @@ -685,7 +682,7 @@ class InstructionQueueItem: self.instructions = [] self.parent = parent self.queue = RequestBlockQueue(instruction_queue=self, assembler=assembler) - self.producer = self.parent.queue_manager.producer + self.connector = self.parent.queue_manager.connector self._is_scan = False self.is_active = False # set to true while a worker is processing the instructions self.completed = False @@ -790,7 +787,7 @@ class InstructionQueueItem: msg = messages.ScanQueueHistoryMessage( status=self.status.name, queueID=self.queue_id, info=self.describe() ) - self.parent.queue_manager.producer.lpush( + self.parent.queue_manager.connector.lpush( MessageEndpoints.scan_queue_history(), msg, max_size=100 ) diff --git a/scan_server/scan_server/scan_server.py b/scan_server/scan_server/scan_server.py index 0d12400e..e63a3d61 100644 --- a/scan_server/scan_server/scan_server.py +++ b/scan_server/scan_server/scan_server.py @@ -23,7 +23,6 @@ class ScanServer(BECService): def __init__(self, config: ServiceConfig, connector_cls: ConnectorBase): super().__init__(config, connector_cls, unique_service=True) - self.producer = self.connector.producer() self._start_scan_manager() self._start_queue_manager() self._start_device_manager() @@ -52,15 +51,12 @@ class ScanServer(BECService): self.scan_guard = ScanGuard(parent=self) def _start_alarm_handler(self): - self._alarm_consumer = self.connector.consumer( - MessageEndpoints.alarm(), cb=self._alarm_callback, parent=self - ) - self._alarm_consumer.start() + self.connector.register(MessageEndpoints.alarm(), cb=self._alarm_callback, parent=self) def _reset_scan_number(self): - if self.producer.get(MessageEndpoints.scan_number()) is None: + if self.connector.get(MessageEndpoints.scan_number()) is None: self.scan_number = 1 - if self.producer.get(MessageEndpoints.dataset_number()) is None: + if self.connector.get(MessageEndpoints.dataset_number()) is None: self.dataset_number = 1 @staticmethod @@ -74,25 +70,24 @@ class ScanServer(BECService): @property def scan_number(self) -> int: """get the current scan number""" - return int(self.producer.get(MessageEndpoints.scan_number())) + return int(self.connector.get(MessageEndpoints.scan_number())) @scan_number.setter def scan_number(self, val: int): """set the current scan number""" - self.producer.set(MessageEndpoints.scan_number(), val) + self.connector.set(MessageEndpoints.scan_number(), val) @property def dataset_number(self) -> int: """get the current dataset number""" - return int(self.producer.get(MessageEndpoints.dataset_number())) + return int(self.connector.get(MessageEndpoints.dataset_number())) @dataset_number.setter def dataset_number(self, val: int): """set the current dataset number""" - self.producer.set(MessageEndpoints.dataset_number(), val) + self.connector.set(MessageEndpoints.dataset_number(), val) def shutdown(self) -> None: """shutdown the scan server""" - self.device_manager.shutdown() self.queue_manager.shutdown() diff --git a/scan_server/scan_server/scan_stubs.py b/scan_server/scan_server/scan_stubs.py index fcd05a75..d06d04f5 100644 --- a/scan_server/scan_server/scan_stubs.py +++ b/scan_server/scan_server/scan_stubs.py @@ -5,7 +5,9 @@ import uuid from collections.abc import Callable import numpy as np -from bec_lib import MessageEndpoints, ProducerConnector, Status, bec_logger, messages + +from bec_lib import MessageEndpoints, Status, bec_logger, messages +from bec_lib.connector import ConnectorBase from .errors import DeviceMessageError, ScanAbortion @@ -13,8 +15,8 @@ logger = bec_logger.logger class ScanStubs: - def __init__(self, producer: ProducerConnector, device_msg_callback: Callable = None) -> None: - self.producer = producer + def __init__(self, connector: ConnectorBase, device_msg_callback: Callable = None) -> None: + self.connector = connector self.device_msg_metadata = ( device_msg_callback if device_msg_callback is not None else lambda: {} ) @@ -62,7 +64,7 @@ class ScanStubs: def _get_from_rpc(self, rpc_id): while True: - msg = self.producer.get(MessageEndpoints.device_rpc(rpc_id)) + msg = self.connector.get(MessageEndpoints.device_rpc(rpc_id)) if msg: break time.sleep(0.001) @@ -81,7 +83,7 @@ class ScanStubs: if not isinstance(return_val, dict): return return_val if return_val.get("type") == "status" and return_val.get("RID"): - return Status(self.producer, return_val.get("RID")) + return Status(self.connector, return_val.get("RID")) return return_val def set_and_wait(self, *, device: list[str], positions: list | np.ndarray): @@ -182,7 +184,7 @@ class ScanStubs: DIID (int): device instruction ID """ - msg = self.producer.get(MessageEndpoints.device_req_status(device)) + msg = self.connector.get(MessageEndpoints.device_req_status(device)) if not msg: return 0 matching_RID = msg.metadata.get("RID") == RID @@ -199,7 +201,7 @@ class ScanStubs: RID (str): request ID """ - msg = self.producer.get(MessageEndpoints.device_progress(device)) + msg = self.connector.get(MessageEndpoints.device_progress(device)) if not msg: return None matching_RID = msg.metadata.get("RID") == RID diff --git a/scan_server/scan_server/scan_worker.py b/scan_server/scan_server/scan_worker.py index bab2bf96..00a2e53d 100644 --- a/scan_server/scan_server/scan_worker.py +++ b/scan_server/scan_server/scan_worker.py @@ -41,7 +41,7 @@ class ScanWorker(threading.Thread): self._groups = {} self.interception_msg = None self.reset() - self.validate = DeviceValidation(self.device_manager.producer, self) + self.validate = DeviceValidation(self.device_manager.connector, self) def open_scan(self, instr: messages.DeviceInstructionMessage) -> None: """ @@ -138,7 +138,7 @@ class ScanWorker(threading.Thread): """ devices = [dev.name for dev in self.device_manager.devices.get_software_triggered_devices()] self._last_trigger = instr - self.device_manager.producer.send( + self.device_manager.connector.send( MessageEndpoints.device_instructions(), messages.DeviceInstructionMessage( device=devices, @@ -157,7 +157,7 @@ class ScanWorker(threading.Thread): """ # send instruction - self.device_manager.producer.send(MessageEndpoints.device_instructions(), instr) + self.device_manager.connector.send(MessageEndpoints.device_instructions(), instr) def read_devices(self, instr: messages.DeviceInstructionMessage) -> None: """ @@ -171,7 +171,7 @@ class ScanWorker(threading.Thread): self._publish_readback(instr) return - producer = self.device_manager.producer + connector = self.device_manager.connector devices = instr.content.get("device") if devices is None: @@ -181,7 +181,7 @@ class ScanWorker(threading.Thread): readout_priority=self.readout_priority ) ] - producer.send( + connector.send( MessageEndpoints.device_instructions(), messages.DeviceInstructionMessage( device=devices, @@ -201,7 +201,7 @@ class ScanWorker(threading.Thread): """ # logger.info("kickoff") - self.device_manager.producer.send( + self.device_manager.connector.send( MessageEndpoints.device_instructions(), messages.DeviceInstructionMessage( device=instr.content.get("device"), @@ -225,7 +225,7 @@ class ScanWorker(threading.Thread): devices = instr.content.get("device") if not isinstance(devices, list): devices = [devices] - self.device_manager.producer.send( + self.device_manager.connector.send( MessageEndpoints.device_instructions(), messages.DeviceInstructionMessage( device=devices, @@ -251,7 +251,7 @@ class ScanWorker(threading.Thread): ) ] params = instr.content["parameter"] - self.device_manager.producer.send( + self.device_manager.connector.send( MessageEndpoints.device_instructions(), messages.DeviceInstructionMessage( device=baseline_devices, action="read", parameter=params, metadata=instr.metadata @@ -266,7 +266,7 @@ class ScanWorker(threading.Thread): instr (DeviceInstructionMessage): Device instruction received from the scan assembler """ devices = [dev.name for dev in self.device_manager.devices.enabled_devices] - self.device_manager.producer.send( + self.device_manager.connector.send( MessageEndpoints.device_instructions(), messages.DeviceInstructionMessage( device=devices, @@ -285,7 +285,7 @@ class ScanWorker(threading.Thread): Args: instr (DeviceInstructionMessage): Device instruction received from the scan assembler """ - producer = self.device_manager.producer + connector = self.device_manager.connector data = instr.content["parameter"]["data"] devices = instr.content["device"] if not isinstance(devices, list): @@ -294,7 +294,7 @@ class ScanWorker(threading.Thread): data = [data] for device, dev_data in zip(devices, data): msg = messages.DeviceMessage(signals=dev_data, metadata=instr.metadata) - producer.set_and_publish(MessageEndpoints.device_read(device), msg) + connector.set_and_publish(MessageEndpoints.device_read(device), msg) def send_rpc(self, instr: messages.DeviceInstructionMessage) -> None: """ @@ -304,7 +304,7 @@ class ScanWorker(threading.Thread): instr (DeviceInstructionMessage): Device instruction received from the scan assembler """ - self.device_manager.producer.send(MessageEndpoints.device_instructions(), instr) + self.device_manager.connector.send(MessageEndpoints.device_instructions(), instr) def process_scan_report_instruction(self, instr): """ @@ -333,7 +333,7 @@ class ScanWorker(threading.Thread): if dev.name not in async_devices ] for det in async_devices: - self.device_manager.producer.send( + self.device_manager.connector.send( MessageEndpoints.device_instructions(), messages.DeviceInstructionMessage( device=det, @@ -344,7 +344,7 @@ class ScanWorker(threading.Thread): ) self._staged_devices.update(async_devices) - self.device_manager.producer.send( + self.device_manager.connector.send( MessageEndpoints.device_instructions(), messages.DeviceInstructionMessage( device=devices, @@ -375,7 +375,7 @@ class ScanWorker(threading.Thread): parameter = {} if not instr else instr.content["parameter"] metadata = {} if not instr else instr.metadata self._staged_devices.difference_update(devices) - self.device_manager.producer.send( + self.device_manager.connector.send( MessageEndpoints.device_instructions(), messages.DeviceInstructionMessage( device=devices, action="unstage", parameter=parameter, metadata=metadata @@ -459,7 +459,7 @@ class ScanWorker(threading.Thread): matching_DIID = device_status[ind].metadata.get("DIID") >= devices[ind][1] matching_RID = device_status[ind].metadata.get("RID") == instr.metadata["RID"] if matching_DIID and matching_RID: - last_pos_msg = self.device_manager.producer.get( + last_pos_msg = self.device_manager.connector.get( MessageEndpoints.device_readback(failed_device[0]) ) last_pos = last_pos_msg.content["signals"][failed_device[0]]["value"] @@ -603,25 +603,25 @@ class ScanWorker(threading.Thread): def _publish_readback( self, instr: messages.DeviceInstructionMessage, devices: list = None ) -> None: - producer = self.device_manager.producer + connector = self.device_manager.connector if not devices: devices = instr.content.get("device") # cached readout readouts = self._get_readback(devices) - pipe = producer.pipeline() + pipe = connector.pipeline() for readout, device in zip(readouts, devices): msg = messages.DeviceMessage(signals=readout, metadata=instr.metadata) - producer.set_and_publish(MessageEndpoints.device_read(device), msg, pipe) + connector.set_and_publish(MessageEndpoints.device_read(device), msg, pipe) return pipe.execute() def _get_readback(self, devices: list) -> list: - producer = self.device_manager.producer + connector = self.device_manager.connector # cached readout - pipe = producer.pipeline() + pipe = connector.pipeline() for dev in devices: - producer.get(MessageEndpoints.device_readback(dev), pipe=pipe) - return producer.execute_pipeline(pipe) + connector.get(MessageEndpoints.device_readback(dev), pipe=pipe) + return connector.execute_pipeline(pipe) def _check_for_interruption(self) -> None: if self.status == InstructionQueueStatus.PAUSED: @@ -700,11 +700,13 @@ class ScanWorker(threading.Thread): scanID=self.current_scanID, status=status, info=self.current_scan_info ) expire = None if status in ["open", "paused"] else 1800 - pipe = self.device_manager.producer.pipeline() - self.device_manager.producer.set( + pipe = self.device_manager.connector.pipeline() + self.device_manager.connector.set( MessageEndpoints.public_scan_info(self.current_scanID), msg, pipe=pipe, expire=expire ) - self.device_manager.producer.set_and_publish(MessageEndpoints.scan_status(), msg, pipe=pipe) + self.device_manager.connector.set_and_publish( + MessageEndpoints.scan_status(), msg, pipe=pipe + ) pipe.execute() def _process_instructions(self, queue: InstructionQueueItem) -> None: diff --git a/scan_server/scan_server/scans.py b/scan_server/scan_server/scans.py index e946187b..77c0fde4 100644 --- a/scan_server/scan_server/scans.py +++ b/scan_server/scan_server/scans.py @@ -213,7 +213,7 @@ class RequestBase(ABC): if metadata is None: self.metadata = {} self.stubs = ScanStubs( - producer=self.device_manager.producer, device_msg_callback=self.device_msg_metadata + connector=self.device_manager.connector, device_msg_callback=self.device_msg_metadata ) @property @@ -239,7 +239,7 @@ class RequestBase(ABC): def run_pre_scan_macros(self): """run pre scan macros if any""" - macros = self.device_manager.producer.lrange(MessageEndpoints.pre_scan_macros(), 0, -1) + macros = self.device_manager.connector.lrange(MessageEndpoints.pre_scan_macros(), 0, -1) for macro in macros: macro = macro.decode().strip() func_name = self._get_func_name_from_macro(macro) @@ -558,12 +558,12 @@ class SyncFlyScanBase(ScanBase, ABC): def _get_flyer_status(self) -> list: flyer = self.scan_motors[0] - producer = self.device_manager.producer + connector = self.device_manager.connector - pipe = producer.pipeline() - producer.lrange(MessageEndpoints.device_req_status(self.metadata["RID"]), 0, -1, pipe) - producer.get(MessageEndpoints.device_readback(flyer), pipe) - return producer.execute_pipeline(pipe) + pipe = connector.pipeline() + connector.lrange(MessageEndpoints.device_req_status(self.metadata["RID"]), 0, -1, pipe) + connector.get(MessageEndpoints.device_readback(flyer), pipe) + return connector.execute_pipeline(pipe) @abstractmethod def scan_core(self): @@ -1098,7 +1098,7 @@ class RoundScanFlySim(SyncFlyScanBase): while True: yield from self.stubs.read_and_wait(group="primary", wait_group="readout_primary") - status = self.device_manager.producer.get(MessageEndpoints.device_status(self.flyer)) + status = self.device_manager.connector.get(MessageEndpoints.device_status(self.flyer)) if status: device_is_idle = status.content.get("status", 1) == 0 matching_RID = self.metadata.get("RID") == status.metadata.get("RID") @@ -1318,12 +1318,12 @@ class MonitorScan(ScanBase): self._check_limits() def _get_flyer_status(self) -> list: - producer = self.device_manager.producer + connector = self.device_manager.connector - pipe = producer.pipeline() - producer.lrange(MessageEndpoints.device_req_status(self.metadata["RID"]), 0, -1, pipe) - producer.get(MessageEndpoints.device_readback(self.flyer), pipe) - return producer.execute_pipeline(pipe) + pipe = connector.pipeline() + connector.lrange(MessageEndpoints.device_req_status(self.metadata["RID"]), 0, -1, pipe) + connector.get(MessageEndpoints.device_readback(self.flyer), pipe) + return connector.execute_pipeline(pipe) def scan_core(self): yield from self.stubs.set( diff --git a/scan_server/tests/test_scan_guard.py b/scan_server/tests/test_scan_guard.py index d0fb886d..d62030c0 100644 --- a/scan_server/tests/test_scan_guard.py +++ b/scan_server/tests/test_scan_guard.py @@ -12,7 +12,7 @@ from scan_server.scan_guard import ScanGuard, ScanRejection, ScanStatus @pytest.fixture def scan_guard_mock(scan_server_mock): sg = ScanGuard(parent=scan_server_mock) - sg.device_manager.producer = mock.MagicMock() + sg.device_manager.connector = mock.MagicMock() yield sg @@ -113,8 +113,8 @@ def test_valid_request(scan_server_mock, scan_queue_msg, valid): def test_check_valid_scan_raises_for_unknown_scan(scan_guard_mock): sg = scan_guard_mock - sg.producer = mock.MagicMock() - sg.producer.get.return_value = messages.AvailableResourceMessage( + sg.connector = mock.MagicMock() + sg.connector.get.return_value = messages.AvailableResourceMessage( resource={"fermat_scan": "fermat_scan"} ) @@ -130,8 +130,8 @@ def test_check_valid_scan_raises_for_unknown_scan(scan_guard_mock): def test_check_valid_scan_accepts_known_scan(scan_guard_mock): sg = scan_guard_mock - sg.producer = mock.MagicMock() - sg.producer.get.return_value = messages.AvailableResourceMessage( + sg.connector = mock.MagicMock() + sg.connector.get.return_value = messages.AvailableResourceMessage( resource={"fermat_scan": "fermat_scan"} ) @@ -146,8 +146,8 @@ def test_check_valid_scan_accepts_known_scan(scan_guard_mock): def test_check_valid_scan_device_rpc(scan_guard_mock): sg = scan_guard_mock - sg.producer = mock.MagicMock() - sg.producer.get.return_value = messages.AvailableResourceMessage( + sg.connector = mock.MagicMock() + sg.connector.get.return_value = messages.AvailableResourceMessage( resource={"device_rpc": "device_rpc"} ) request = messages.ScanQueueMessage( @@ -162,8 +162,8 @@ def test_check_valid_scan_device_rpc(scan_guard_mock): def test_check_valid_scan_device_rpc_raises(scan_guard_mock): sg = scan_guard_mock - sg.producer = mock.MagicMock() - sg.producer.get.return_value = messages.AvailableResourceMessage( + sg.connector = mock.MagicMock() + sg.connector.get.return_value = messages.AvailableResourceMessage( resource={"device_rpc": "device_rpc"} ) request = messages.ScanQueueMessage( @@ -184,7 +184,7 @@ def test_handle_scan_modification_request(scan_guard_mock): msg = messages.ScanQueueModificationMessage( scanID="scanID", action="abort", parameter={}, metadata={"RID": "RID"} ) - with mock.patch.object(sg.device_manager.producer, "send") as send: + with mock.patch.object(sg.device_manager.connector, "send") as send: sg._handle_scan_modification_request(msg) send.assert_called_once_with(MessageEndpoints.scan_queue_modification(), msg) @@ -207,7 +207,7 @@ def test_append_to_scan_queue(scan_guard_mock): parameter={"args": {"samx": (-5, 5), "samy": (-5, 5)}, "kwargs": {"step": 3}}, queue="primary", ) - with mock.patch.object(sg.device_manager.producer, "send") as send: + with mock.patch.object(sg.device_manager.connector, "send") as send: sg._append_to_scan_queue(msg) send.assert_called_once_with(MessageEndpoints.scan_queue_insert(), msg) @@ -251,7 +251,7 @@ def test_scan_queue_modification_request_callback(scan_guard_mock): def test_send_scan_request_response(scan_guard_mock): sg = scan_guard_mock - with mock.patch.object(sg.device_manager.producer, "send") as send: + with mock.patch.object(sg.device_manager.connector, "send") as send: sg._send_scan_request_response(ScanStatus(), {"RID": "RID"}) send.assert_called_once_with( MessageEndpoints.scan_queue_request_response(), diff --git a/scan_server/tests/test_scan_server_queue.py b/scan_server/tests/test_scan_server_queue.py index d8d9d6fc..e81c0856 100644 --- a/scan_server/tests/test_scan_server_queue.py +++ b/scan_server/tests/test_scan_server_queue.py @@ -166,40 +166,40 @@ def test_set_halt_disables_return_to_start(queuemanager_mock): def test_set_pause(queuemanager_mock): queue_manager = queuemanager_mock() - queue_manager.producer.message_sent = [] + queue_manager.connector.message_sent = [] queue_manager.set_pause(queue="primary") assert queue_manager.queues["primary"].status == ScanQueueStatus.PAUSED - assert len(queue_manager.producer.message_sent) == 1 + assert len(queue_manager.connector.message_sent) == 1 assert ( - queue_manager.producer.message_sent[0].get("queue") == MessageEndpoints.scan_queue_status() + queue_manager.connector.message_sent[0].get("queue") == MessageEndpoints.scan_queue_status() ) def test_set_deferred_pause(queuemanager_mock): queue_manager = queuemanager_mock() - queue_manager.producer.message_sent = [] + queue_manager.connector.message_sent = [] queue_manager.set_deferred_pause(queue="primary") assert queue_manager.queues["primary"].status == ScanQueueStatus.PAUSED - assert len(queue_manager.producer.message_sent) == 1 + assert len(queue_manager.connector.message_sent) == 1 assert ( - queue_manager.producer.message_sent[0].get("queue") == MessageEndpoints.scan_queue_status() + queue_manager.connector.message_sent[0].get("queue") == MessageEndpoints.scan_queue_status() ) def test_set_continue(queuemanager_mock): queue_manager = queuemanager_mock() - queue_manager.producer.message_sent = [] + queue_manager.connector.message_sent = [] queue_manager.set_continue(queue="primary") assert queue_manager.queues["primary"].status == ScanQueueStatus.RUNNING - assert len(queue_manager.producer.message_sent) == 1 + assert len(queue_manager.connector.message_sent) == 1 assert ( - queue_manager.producer.message_sent[0].get("queue") == MessageEndpoints.scan_queue_status() + queue_manager.connector.message_sent[0].get("queue") == MessageEndpoints.scan_queue_status() ) def test_set_abort(queuemanager_mock): queue_manager = queuemanager_mock() - queue_manager.producer.message_sent = [] + queue_manager.connector.message_sent = [] msg = messages.ScanQueueMessage( scan_type="mv", parameter={"args": {"samx": (1,)}, "kwargs": {}}, @@ -210,23 +210,23 @@ def test_set_abort(queuemanager_mock): queue_manager.add_to_queue(scan_queue="primary", msg=msg) queue_manager.set_abort(queue="primary") assert queue_manager.queues["primary"].status == ScanQueueStatus.PAUSED - assert len(queue_manager.producer.message_sent) == 2 + assert len(queue_manager.connector.message_sent) == 2 assert ( - queue_manager.producer.message_sent[0].get("queue") == MessageEndpoints.scan_queue_status() + queue_manager.connector.message_sent[0].get("queue") == MessageEndpoints.scan_queue_status() ) def test_set_abort_with_empty_queue(queuemanager_mock): queue_manager = queuemanager_mock() - queue_manager.producer.message_sent = [] + queue_manager.connector.message_sent = [] queue_manager.set_abort(queue="primary") assert queue_manager.queues["primary"].status == ScanQueueStatus.RUNNING - assert len(queue_manager.producer.message_sent) == 0 + assert len(queue_manager.connector.message_sent) == 0 def test_set_clear_sends_message(queuemanager_mock): queue_manager = queuemanager_mock() - queue_manager.producer.message_sent = [] + queue_manager.connector.message_sent = [] setter_mock = mock.Mock(wraps=ScanQueue.worker_status.fset) # pylint: disable=assignment-from-no-return # pylint: disable=too-many-function-args @@ -238,9 +238,9 @@ def test_set_clear_sends_message(queuemanager_mock): mock_property.fset.assert_called_once_with( queue_manager.queues["primary"], InstructionQueueStatus.STOPPED ) - assert len(queue_manager.producer.message_sent) == 1 + assert len(queue_manager.connector.message_sent) == 1 assert ( - queue_manager.producer.message_sent[0].get("queue") + queue_manager.connector.message_sent[0].get("queue") == MessageEndpoints.scan_queue_status() ) diff --git a/scan_server/tests/test_scan_stubs.py b/scan_server/tests/test_scan_stubs.py index a356ef3e..b3c540b1 100644 --- a/scan_server/tests/test_scan_stubs.py +++ b/scan_server/tests/test_scan_stubs.py @@ -11,7 +11,7 @@ from scan_server.scan_stubs import ScanAbortion, ScanStubs @pytest.fixture def stubs(): connector = ConnectorMock("") - yield ScanStubs(connector.producer()) + yield ScanStubs(connector) @pytest.mark.parametrize( @@ -36,7 +36,11 @@ def stubs(): device="rtx", action="kickoff", parameter={ - "configure": {"num_pos": 5, "positions": [1, 2, 3, 4, 5], "exp_time": 2}, + "configure": { + "num_pos": 5, + "positions": [1, 2, 3, 4, 5], + "exp_time": 2, + }, "wait_group": "kickoff", }, metadata={}, @@ -45,6 +49,8 @@ def stubs(): ], ) def test_kickoff(stubs, device, parameter, metadata, reference_msg): + connector = ConnectorMock("") + stubs = ScanStubs(connector) msg = list(stubs.kickoff(device=device, parameter=parameter, metadata=metadata)) assert msg[0] == reference_msg @@ -52,12 +58,19 @@ def test_kickoff(stubs, device, parameter, metadata, reference_msg): @pytest.mark.parametrize( "msg,raised_error", [ - (messages.DeviceRPCMessage(device="samx", return_val="", out="", success=True), None), + ( + messages.DeviceRPCMessage(device="samx", return_val="", out="", success=True), + None, + ), ( messages.DeviceRPCMessage( device="samx", return_val="", - out={"error": "TypeError", "msg": "some weird error", "traceback": "traceback"}, + out={ + "error": "TypeError", + "msg": "some weird error", + "traceback": "traceback", + }, success=False, ), ScanAbortion, @@ -69,8 +82,7 @@ def test_kickoff(stubs, device, parameter, metadata, reference_msg): ], ) def test_rpc_raises_scan_abortion(stubs, msg, raised_error): - msg = msg - with mock.patch.object(stubs.producer, "get", return_value=msg) as prod_get: + with mock.patch.object(stubs.connector, "get", return_value=msg) as prod_get: if raised_error is None: stubs._get_from_rpc("rpc-id") else: @@ -106,8 +118,8 @@ def test_rpc_raises_scan_abortion(stubs, msg, raised_error): def test_device_progress(stubs, msg, ret_value, raised_error): if raised_error: with pytest.raises(DeviceMessageError): - with mock.patch.object(stubs.producer, "get", return_value=msg): + with mock.patch.object(stubs.connector, "get", return_value=msg): assert stubs.get_device_progress(device="samx", RID="rid") == ret_value return - with mock.patch.object(stubs.producer, "get", return_value=msg): + with mock.patch.object(stubs.connector, "get", return_value=msg): assert stubs.get_device_progress(device="samx", RID="rid") == ret_value diff --git a/scan_server/tests/test_scan_worker.py b/scan_server/tests/test_scan_worker.py index 47c90b4c..bb19147c 100644 --- a/scan_server/tests/test_scan_worker.py +++ b/scan_server/tests/test_scan_worker.py @@ -4,7 +4,7 @@ from unittest import mock import pytest from bec_lib import MessageEndpoints, messages -from bec_lib.tests.utils import ProducerMock, dm, dm_with_devices +from bec_lib.tests.utils import ConnectorMock, dm, dm_with_devices from utils import scan_server_mock from scan_server.errors import DeviceMessageError, ScanAbortion @@ -22,7 +22,7 @@ from scan_server.scan_worker import ScanWorker @pytest.fixture def scan_worker_mock(scan_server_mock) -> ScanWorker: - scan_server_mock.device_manager.producer = mock.MagicMock() + scan_server_mock.device_manager.connector = mock.MagicMock() scan_worker = ScanWorker(parent=scan_server_mock) yield scan_worker @@ -295,7 +295,7 @@ def test_wait_for_devices(scan_worker_mock, instructions, wait_type): def test_complete_devices(scan_worker_mock, instructions): worker = scan_worker_mock with mock.patch.object(worker, "_wait_for_status") as wait_for_status_mock: - with mock.patch.object(worker.device_manager.producer, "send") as send_mock: + with mock.patch.object(worker.device_manager.connector, "send") as send_mock: worker.complete_devices(instructions) if instructions.content["device"]: devices = instructions.content["device"] @@ -328,7 +328,7 @@ def test_complete_devices(scan_worker_mock, instructions): ) def test_pre_scan(scan_worker_mock, instructions): worker = scan_worker_mock - with mock.patch.object(worker.device_manager.producer, "send") as send_mock: + with mock.patch.object(worker.device_manager.connector, "send") as send_mock: with mock.patch.object(worker, "_wait_for_status") as wait_for_status_mock: worker.pre_scan(instructions) devices = [dev.name for dev in worker.device_manager.devices.enabled_devices] @@ -457,12 +457,12 @@ def test_pre_scan(scan_worker_mock, instructions): ) def test_check_for_failed_movements(scan_worker_mock, device_status, devices, instr, abort): worker = scan_worker_mock - worker.device_manager.producer = ProducerMock() + worker.device_manager.connector = ConnectorMock() if abort: with pytest.raises(ScanAbortion): - worker.device_manager.producer._get_buffer[MessageEndpoints.device_readback("samx")] = ( - messages.DeviceMessage(signals={"samx": {"value": 4}}, metadata={}) - ) + worker.device_manager.connector._get_buffer[ + MessageEndpoints.device_readback("samx") + ] = messages.DeviceMessage(signals={"samx": {"value": 4}}, metadata={}) worker._check_for_failed_movements(device_status, devices, instr) else: worker._check_for_failed_movements(device_status, devices, instr) @@ -577,12 +577,12 @@ def test_check_for_failed_movements(scan_worker_mock, device_status, devices, in ) def test_wait_for_idle(scan_worker_mock, msg1, msg2, req_msg: messages.DeviceReqStatusMessage): worker = scan_worker_mock - worker.device_manager.producer = ProducerMock() + worker.device_manager.connector = ConnectorMock() with mock.patch.object( worker.validate, "get_device_status", return_value=[req_msg] ) as device_status: - worker.device_manager.producer._get_buffer[MessageEndpoints.device_readback("samx")] = ( + worker.device_manager.connector._get_buffer[MessageEndpoints.device_readback("samx")] = ( messages.DeviceMessage(signals={"samx": {"value": 4}}, metadata={}) ) @@ -635,7 +635,7 @@ def test_wait_for_idle(scan_worker_mock, msg1, msg2, req_msg: messages.DeviceReq ) def test_wait_for_read(scan_worker_mock, msg1, msg2, req_msg: messages.DeviceReqStatusMessage): worker = scan_worker_mock - worker.device_manager.producer = ProducerMock() + worker.device_manager.connector = ConnectorMock() with mock.patch.object( worker.validate, "get_device_status", return_value=[req_msg] @@ -643,9 +643,9 @@ def test_wait_for_read(scan_worker_mock, msg1, msg2, req_msg: messages.DeviceReq with mock.patch.object(worker, "_check_for_interruption") as interruption_mock: assert worker._groups == {} worker._groups["scan_motor"] = {"samx": 3, "samy": 4} - worker.device_manager.producer._get_buffer[MessageEndpoints.device_readback("samx")] = ( - messages.DeviceMessage(signals={"samx": {"value": 4}}, metadata={}) - ) + worker.device_manager.connector._get_buffer[ + MessageEndpoints.device_readback("samx") + ] = messages.DeviceMessage(signals={"samx": {"value": 4}}, metadata={}) worker._add_wait_group(msg1) worker._wait_for_read(msg2) assert worker._groups == {"scan_motor": {"samy": 4}} @@ -730,7 +730,7 @@ def test_wait_for_device_server(scan_worker_mock): ) def test_set_devices(scan_worker_mock, instr): worker = scan_worker_mock - with mock.patch.object(worker.device_manager.producer, "send") as send_mock: + with mock.patch.object(worker.device_manager.connector, "send") as send_mock: worker.set_devices(instr) send_mock.assert_called_once_with(MessageEndpoints.device_instructions(), instr) @@ -755,7 +755,7 @@ def test_set_devices(scan_worker_mock, instr): ) def test_trigger_devices(scan_worker_mock, instr): worker = scan_worker_mock - with mock.patch.object(worker.device_manager.producer, "send") as send_mock: + with mock.patch.object(worker.device_manager.connector, "send") as send_mock: worker.trigger_devices(instr) devices = [ dev.name for dev in worker.device_manager.devices.get_software_triggered_devices() @@ -797,7 +797,7 @@ def test_trigger_devices(scan_worker_mock, instr): ) def test_send_rpc(scan_worker_mock, instr): worker = scan_worker_mock - with mock.patch.object(worker.device_manager.producer, "send") as send_mock: + with mock.patch.object(worker.device_manager.connector, "send") as send_mock: worker.send_rpc(instr) send_mock.assert_called_once_with(MessageEndpoints.device_instructions(), instr) @@ -840,7 +840,7 @@ def test_read_devices(scan_worker_mock, instr): instr_devices = [] worker.readout_priority.update({"monitored": instr_devices}) devices = [dev.name for dev in worker._get_devices_from_instruction(instr)] - with mock.patch.object(worker.device_manager.producer, "send") as send_mock: + with mock.patch.object(worker.device_manager.connector, "send") as send_mock: worker.read_devices(instr) if instr.content.get("device"): @@ -888,7 +888,7 @@ def test_read_devices(scan_worker_mock, instr): ) def test_kickoff_devices(scan_worker_mock, instr, devices, parameter, metadata): worker = scan_worker_mock - with mock.patch.object(worker.device_manager.producer, "send") as send_mock: + with mock.patch.object(worker.device_manager.connector, "send") as send_mock: worker.kickoff_devices(instr) send_mock.assert_called_once_with( MessageEndpoints.device_instructions(), @@ -920,29 +920,27 @@ def test_kickoff_devices(scan_worker_mock, instr, devices, parameter, metadata): def test_publish_readback(scan_worker_mock, instr, devices): worker = scan_worker_mock with mock.patch.object(worker, "_get_readback", return_value=[{}]) as get_readback: - with mock.patch.object(worker.device_manager, "producer") as producer_mock: + with mock.patch.object(worker.device_manager, "connector") as connector_mock: worker._publish_readback(instr) get_readback.assert_called_once_with(["samx"]) - pipe = producer_mock.pipeline() + pipe = connector_mock.pipeline() msg = messages.DeviceMessage(signals={}, metadata=instr.metadata) - - producer_mock.set_and_publish.assert_called_once_with( + connector_mock.set_and_publish.assert_called_once_with( MessageEndpoints.device_read("samx"), msg, pipe ) - pipe.execute.assert_called_once() def test_get_readback(scan_worker_mock): worker = scan_worker_mock devices = ["samx"] - with mock.patch.object(worker.device_manager, "producer") as producer_mock: + with mock.patch.object(worker.device_manager, "connector") as connector_mock: worker._get_readback(devices) - pipe = producer_mock.pipeline() - producer_mock.get.assert_called_once_with( + pipe = connector_mock.pipeline() + connector_mock.get.assert_called_once_with( MessageEndpoints.device_readback("samx"), pipe=pipe ) - producer_mock.execute_pipeline.assert_called_once() + connector_mock.execute_pipeline.assert_called_once() def test_publish_data_as_read(scan_worker_mock): @@ -958,12 +956,12 @@ def test_publish_data_as_read(scan_worker_mock): "RID": "requestID", }, ) - with mock.patch.object(worker.device_manager, "producer") as producer_mock: + with mock.patch.object(worker.device_manager, "connector") as connector_mock: worker.publish_data_as_read(instr) msg = messages.DeviceMessage( signals=instr.content["parameter"]["data"], metadata=instr.metadata ) - producer_mock.set_and_publish.assert_called_once_with( + connector_mock.set_and_publish.assert_called_once_with( MessageEndpoints.device_read("samx"), msg ) @@ -983,13 +981,13 @@ def test_publish_data_as_read_multiple(scan_worker_mock): "RID": "requestID", }, ) - with mock.patch.object(worker.device_manager, "producer") as producer_mock: + with mock.patch.object(worker.device_manager, "connector") as connector_mock: worker.publish_data_as_read(instr) mock_calls = [] for device, dev_data in zip(devices, data): msg = messages.DeviceMessage(signals=dev_data, metadata=instr.metadata) mock_calls.append(mock.call(MessageEndpoints.device_read(device), msg)) - assert producer_mock.set_and_publish.mock_calls == mock_calls + assert connector_mock.set_and_publish.mock_calls == mock_calls def test_check_for_interruption(scan_worker_mock): @@ -1048,7 +1046,7 @@ def test_open_scan(scan_worker_mock, instr, corr_num_points, scan_id): if "pointID" in instr.metadata: worker.max_point_id = instr.metadata["pointID"] - assert worker.parent.producer.get(MessageEndpoints.scan_number()) == None + assert worker.parent.connector.get(MessageEndpoints.scan_number()) == None with mock.patch.object(worker, "current_instruction_queue_item") as queue_mock: with mock.patch.object(worker, "_initialize_scan_info") as init_mock: @@ -1181,7 +1179,7 @@ def test_stage_device(scan_worker_mock, msg): worker.device_manager.devices["eiger"]._config["readoutPriority"] = "async" with mock.patch.object(worker, "_wait_for_stage") as wait_mock: - with mock.patch.object(worker.device_manager.producer, "send") as send_mock: + with mock.patch.object(worker.device_manager.connector, "send") as send_mock: worker.stage_devices(msg) async_devices = [dev.name for dev in worker.device_manager.devices.async_devices()] devices = [ @@ -1251,7 +1249,7 @@ def test_unstage_device(scan_worker_mock, msg, devices, parameter, metadata, cle if not devices: devices = [dev.name for dev in worker.device_manager.devices.enabled_devices] - with mock.patch.object(worker.device_manager.producer, "send") as send_mock: + with mock.patch.object(worker.device_manager.connector, "send") as send_mock: with mock.patch.object(worker, "_wait_for_stage") as wait_mock: worker.unstage_devices(msg, devices, cleanup) @@ -1270,12 +1268,12 @@ def test_unstage_device(scan_worker_mock, msg, devices, parameter, metadata, cle @pytest.mark.parametrize("status,expire", [("open", None), ("closed", 1800), ("aborted", 1800)]) def test_send_scan_status(scan_worker_mock, status, expire): worker = scan_worker_mock - worker.device_manager.producer = ProducerMock() + worker.device_manager.connector = ConnectorMock() worker.current_scanID = str(uuid.uuid4()) worker._send_scan_status(status) scan_info_msgs = [ msg - for msg in worker.device_manager.producer.message_sent + for msg in worker.device_manager.connector.message_sent if msg["queue"] == MessageEndpoints.public_scan_info(scanID=worker.current_scanID) ] assert len(scan_info_msgs) == 1 diff --git a/scan_server/tests/test_scans.py b/scan_server/tests/test_scans.py index 0c8880b6..e824ee2f 100644 --- a/scan_server/tests/test_scans.py +++ b/scan_server/tests/test_scans.py @@ -6,7 +6,7 @@ import numpy as np import pytest from bec_lib import messages from bec_lib.devicemanager import DeviceContainer -from bec_lib.tests.utils import ProducerMock +from bec_lib.tests.utils import ConnectorMock from scan_plugins.LamNIFermatScan import LamNIFermatScan from scan_plugins.otf_scan import OTFScan @@ -80,7 +80,7 @@ class DeviceMock: class DMMock: devices = DeviceContainer() - producer = ProducerMock() + connector = ConnectorMock() def add_device(self, name): self.devices[name] = DeviceMock(name) @@ -1099,7 +1099,7 @@ def test_pre_scan_macro(): device_manager=device_manager, parameter=scan_msg.content["parameter"] ) with mock.patch.object( - request.device_manager.producer, + request.device_manager.connector, "lrange", new_callable=mock.PropertyMock, return_value=macros, diff --git a/scihub/scihub/scibec/config_handler.py b/scihub/scihub/scibec/config_handler.py index 3ab7143b..1776d6a4 100644 --- a/scihub/scihub/scibec/config_handler.py +++ b/scihub/scihub/scibec/config_handler.py @@ -26,9 +26,9 @@ dir_path = os.path.abspath(os.path.join(os.path.dirname(bec_lib.__file__), "./co class ConfigHandler: def __init__(self, scibec_connector: SciBecConnector, connector: ConnectorBase) -> None: self.scibec_connector = scibec_connector + self.connector = connector self.device_manager = DeviceManager(self.scibec_connector.scihub) self.device_manager.initialize(scibec_connector.config.redis) - self.producer = connector.producer() self.validator = SciBecValidator(os.path.join(dir_path, "openapi_schema.json")) def parse_config_request(self, msg: messages.DeviceConfigMessage) -> None: @@ -53,7 +53,7 @@ class ConfigHandler: def send_config(self, msg: messages.DeviceConfigMessage) -> None: """broadcast a new config""" - self.producer.send(MessageEndpoints.device_config_update(), msg) + self.connector.send(MessageEndpoints.device_config_update(), msg) def send_config_request_reply(self, accepted, error_msg, metadata): """send a config request reply""" @@ -61,7 +61,7 @@ class ConfigHandler: accepted=accepted, message=error_msg, metadata=metadata ) RID = metadata.get("RID") - self.producer.set(MessageEndpoints.device_config_request_response(RID), msg, expire=60) + self.connector.set(MessageEndpoints.device_config_request_response(RID), msg, expire=60) def _set_config(self, msg: messages.DeviceConfigMessage): config = msg.content["config"] @@ -127,14 +127,14 @@ class ConfigHandler: def _update_device_server(self, RID: str, config: dict, action="update") -> None: msg = messages.DeviceConfigMessage(action=action, config=config, metadata={"RID": RID}) - self.producer.send(MessageEndpoints.device_server_config_request(), msg) + self.connector.send(MessageEndpoints.device_server_config_request(), msg) def _wait_for_device_server_update(self, RID: str, timeout_time=10) -> bool: timeout = timeout_time time_step = 0.05 elapsed_time = 0 while True: - msg = self.producer.get(MessageEndpoints.device_config_request_response(RID)) + msg = self.connector.get(MessageEndpoints.device_config_request_response(RID)) if msg: return msg.content["accepted"], msg @@ -188,11 +188,11 @@ class ConfigHandler: self.validator.validate_device_patch(update) def update_config_in_redis(self, device): - config = self.device_manager.producer.get(MessageEndpoints.device_config()) + config = self.device_manager.connector.get(MessageEndpoints.device_config()) config = config.content["resource"] index = next( index for index, dev_conf in enumerate(config) if dev_conf["name"] == device.name ) config[index] = device._config msg = messages.AvailableResourceMessage(resource=config) - self.device_manager.producer.set(MessageEndpoints.device_config(), msg) + self.device_manager.connector.set(MessageEndpoints.device_config(), msg) diff --git a/scihub/scihub/scibec/scibec_connector.py b/scihub/scihub/scibec/scibec_connector.py index 56730c55..2e3455bf 100644 --- a/scihub/scihub/scibec/scibec_connector.py +++ b/scihub/scihub/scibec/scibec_connector.py @@ -31,7 +31,6 @@ class SciBecConnector: def __init__(self, scihub: SciHub, connector: ConnectorBase) -> None: self.scihub = scihub self.connector = connector - self.producer = connector.producer() self.scibec = None self.host = None self.target_bl = None @@ -132,25 +131,24 @@ class SciBecConnector: """ Set the scibec account in redis """ - self.producer.set( + self.connector.set( MessageEndpoints.scibec(), messages.CredentialsMessage(credentials={"url": self.host, "token": f"Bearer {token}"}), ) def set_redis_config(self, config): msg = messages.AvailableResourceMessage(resource=config) - self.producer.set(MessageEndpoints.device_config(), msg) + self.connector.set(MessageEndpoints.device_config(), msg) def _start_metadata_handler(self) -> None: self._metadata_handler = SciBecMetadataHandler(self) def _start_config_request_handler(self) -> None: - self._config_request_handler = self.connector.consumer( + self._config_request_handler = self.connector.register( MessageEndpoints.device_config_request(), cb=self._device_config_request_callback, parent=self, ) - self._config_request_handler.start() @staticmethod def _device_config_request_callback(msg, *, parent, **_kwargs) -> None: @@ -159,7 +157,7 @@ class SciBecConnector: def connect_to_scibec(self): """ - Connect to SciBec and set the producer to the write account + Connect to SciBec and set the connector to the write account """ self._load_environment() if not self._env_configured: @@ -205,7 +203,7 @@ class SciBecConnector: write_account = self.scibec_info["activeExperiment"]["writeAccount"] if write_account[0] == "p": write_account = write_account.replace("p", "e") - self.producer.set(MessageEndpoints.account(), write_account.encode()) + self.connector.set(MessageEndpoints.account(), write_account.encode()) def shutdown(self): """ diff --git a/scihub/scihub/scibec/scibec_metadata_handler.py b/scihub/scihub/scibec/scibec_metadata_handler.py index 2a07b4b3..4dd68767 100644 --- a/scihub/scihub/scibec/scibec_metadata_handler.py +++ b/scihub/scihub/scibec/scibec_metadata_handler.py @@ -15,22 +15,20 @@ if TYPE_CHECKING: class SciBecMetadataHandler: def __init__(self, scibec_connector: SciBecConnector) -> None: self.scibec_connector = scibec_connector - self._scan_status_consumer = None + self._scan_status_register = None self._start_scan_subscription() self._file_subscription = None self._start_file_subscription() def _start_scan_subscription(self): - self._scan_status_consumer = self.scibec_connector.connector.consumer( + self._scan_status_register = self.scibec_connector.connector.register( MessageEndpoints.scan_status(), cb=self._handle_scan_status, parent=self ) - self._scan_status_consumer.start() def _start_file_subscription(self): - self._file_subscription = self.scibec_connector.connector.consumer( + self._file_subscription = self.scibec_connector.connector.register( MessageEndpoints.file_content(), cb=self._handle_file_content, parent=self ) - self._file_subscription.start() @staticmethod def _handle_scan_status(msg, *, parent, **_kwargs) -> None: @@ -171,7 +169,7 @@ class SciBecMetadataHandler: """ Shutdown the metadata handler """ - if self._scan_status_consumer: - self._scan_status_consumer.shutdown() + if self._scan_status_register: + self._scan_status_register.shutdown() if self._file_subscription: self._file_subscription.shutdown() diff --git a/scihub/scihub/scilog/scilog.py b/scihub/scihub/scilog/scilog.py index 97c20796..0829c37d 100644 --- a/scihub/scihub/scilog/scilog.py +++ b/scihub/scihub/scilog/scilog.py @@ -22,7 +22,6 @@ class SciLogConnector: def __init__(self, scihub: SciHub, connector: RedisConnector) -> None: self.scihub = scihub self.connector = connector - self.producer = self.connector.producer() self.host = None self.user = None self.user_secret = None @@ -44,7 +43,7 @@ class SciLogConnector: def set_bec_token(self, token: str) -> None: """set the scilog token in redis""" - self.producer.set( + self.connector.set( MessageEndpoints.logbook(), msgpack.dumps({"url": self.host, "user": self.user, "token": f"Bearer {token}"}), ) diff --git a/scihub/tests/test_scibec_config_handler.py b/scihub/tests/test_scibec_config_handler.py index e601a27f..8c8b1605 100644 --- a/scihub/tests/test_scibec_config_handler.py +++ b/scihub/tests/test_scibec_config_handler.py @@ -334,7 +334,7 @@ def test_config_handler_update_device_config_available_keys(config_handler, avai def test_config_handler_wait_for_device_server_update(config_handler): RID = "12345" - with mock.patch.object(config_handler.producer, "get") as mock_get: + with mock.patch.object(config_handler.connector, "get") as mock_get: mock_get.side_effect = [ None, None, @@ -346,7 +346,7 @@ def test_config_handler_wait_for_device_server_update(config_handler): def test_config_handler_wait_for_device_server_update_timeout(config_handler): RID = "12345" - with mock.patch.object(config_handler.producer, "get", return_value=None) as mock_get: + with mock.patch.object(config_handler.connector, "get", return_value=None) as mock_get: with pytest.raises(TimeoutError): config_handler._wait_for_device_server_update(RID, timeout_time=0.1) mock_get.assert_called() diff --git a/scihub/tests/test_scibec_connector.py b/scihub/tests/test_scibec_connector.py index cdceb743..a2db5bbd 100644 --- a/scihub/tests/test_scibec_connector.py +++ b/scihub/tests/test_scibec_connector.py @@ -138,6 +138,6 @@ def test_scibec_update_experiment_info(SciBecMock): def test_update_eaccount_in_redis(SciBecMock): SciBecMock.scibec_info = {"activeExperiment": {"writeAccount": "p12345"}} - with mock.patch.object(SciBecMock, "producer") as mock_producer: + with mock.patch.object(SciBecMock, "connector") as mock_connector: SciBecMock._update_eaccount_in_redis() - mock_producer.set.assert_called_once_with(MessageEndpoints.account(), b"e12345") + mock_connector.set.assert_called_once_with(MessageEndpoints.account(), b"e12345")