refactor!(connector): unify connector/redis_connector in one class

This commit is contained in:
guijar_m 2024-01-31 13:43:01 +01:00
parent 4edc5d02fe
commit b92a79b0c0
86 changed files with 1212 additions and 1745 deletions

View File

@ -38,16 +38,16 @@ class ReadbackDataMixin:
def get_request_done_msgs(self): def get_request_done_msgs(self):
"""get all request-done messages""" """get all request-done messages"""
pipe = self.device_manager.producer.pipeline() pipe = self.device_manager.connector.pipeline()
for dev in self.devices: for dev in self.devices:
self.device_manager.producer.get(MessageEndpoints.device_req_status(dev), pipe) self.device_manager.connector.get(MessageEndpoints.device_req_status(dev), pipe)
return self.device_manager.producer.execute_pipeline(pipe) return self.device_manager.connector.execute_pipeline(pipe)
def wait_for_RID(self, request): def wait_for_RID(self, request):
"""wait for the readback's metadata to match the request ID""" """wait for the readback's metadata to match the request ID"""
while True: while True:
msgs = [ msgs = [
self.device_manager.producer.get(MessageEndpoints.device_readback(dev)) self.device_manager.connector.get(MessageEndpoints.device_readback(dev))
for dev in self.devices for dev in self.devices
] ]
if all(msg.metadata.get("RID") == request.metadata["RID"] for msg in msgs if msg): if all(msg.metadata.get("RID") == request.metadata["RID"] for msg in msgs if msg):

View File

@ -27,7 +27,7 @@ class LiveUpdatesScanProgress(LiveUpdatesTable):
Update the progressbar based on the device status message. Returns True if the scan is finished. Update the progressbar based on the device status message. Returns True if the scan is finished.
""" """
self.check_alarms() self.check_alarms()
status = self.bec.producer.get(MessageEndpoints.device_progress(device_names[0])) status = self.bec.connector.get(MessageEndpoints.device_progress(device_names[0]))
if not status: if not status:
logger.debug("waiting for new data point") logger.debug("waiting for new data point")
await asyncio.sleep(0.1) await asyncio.sleep(0.1)

View File

@ -13,7 +13,7 @@ from bec_client.callbacks.move_device import (
@pytest.fixture @pytest.fixture
def readback_data_mixin(bec_client): def readback_data_mixin(bec_client):
with mock.patch.object(bec_client.device_manager, "producer"): with mock.patch.object(bec_client.device_manager, "connector"):
yield ReadbackDataMixin(bec_client.device_manager, ["samx", "samy"]) yield ReadbackDataMixin(bec_client.device_manager, ["samx", "samy"])
@ -102,7 +102,7 @@ async def test_move_callback_with_report_instruction(bec_client):
def test_readback_data_mixin(readback_data_mixin): def test_readback_data_mixin(readback_data_mixin):
readback_data_mixin.device_manager.producer.get.side_effect = [ readback_data_mixin.device_manager.connector.get.side_effect = [
messages.DeviceMessage( messages.DeviceMessage(
signals={"samx": {"value": 10}, "samx_setpoint": {"value": 20}}, signals={"samx": {"value": 10}, "samx_setpoint": {"value": 20}},
metadata={"device": "samx"}, metadata={"device": "samx"},
@ -121,7 +121,7 @@ def test_readback_data_mixin_multiple_hints(readback_data_mixin):
"samx_setpoint", "samx_setpoint",
"samx", "samx",
] ]
readback_data_mixin.device_manager.producer.get.side_effect = [ readback_data_mixin.device_manager.connector.get.side_effect = [
messages.DeviceMessage( messages.DeviceMessage(
signals={"samx": {"value": 10}, "samx_setpoint": {"value": 20}}, signals={"samx": {"value": 10}, "samx_setpoint": {"value": 20}},
metadata={"device": "samx"}, metadata={"device": "samx"},
@ -137,7 +137,7 @@ def test_readback_data_mixin_multiple_hints(readback_data_mixin):
def test_readback_data_mixin_multiple_no_hints(readback_data_mixin): def test_readback_data_mixin_multiple_no_hints(readback_data_mixin):
readback_data_mixin.device_manager.devices.samx._info["hints"]["fields"] = [] readback_data_mixin.device_manager.devices.samx._info["hints"]["fields"] = []
readback_data_mixin.device_manager.producer.get.side_effect = [ readback_data_mixin.device_manager.connector.get.side_effect = [
messages.DeviceMessage( messages.DeviceMessage(
signals={"samx": {"value": 10}, "samx_setpoint": {"value": 20}}, signals={"samx": {"value": 10}, "samx_setpoint": {"value": 20}},
metadata={"device": "samx"}, metadata={"device": "samx"},
@ -153,18 +153,18 @@ def test_readback_data_mixin_multiple_no_hints(readback_data_mixin):
def test_get_request_done_msgs(readback_data_mixin): def test_get_request_done_msgs(readback_data_mixin):
res = readback_data_mixin.get_request_done_msgs() res = readback_data_mixin.get_request_done_msgs()
readback_data_mixin.device_manager.producer.pipeline.assert_called_once() readback_data_mixin.device_manager.connector.pipeline.assert_called_once()
assert ( assert (
mock.call( mock.call(
MessageEndpoints.device_req_status("samx"), MessageEndpoints.device_req_status("samx"),
readback_data_mixin.device_manager.producer.pipeline.return_value, readback_data_mixin.device_manager.connector.pipeline.return_value,
) )
in readback_data_mixin.device_manager.producer.get.call_args_list in readback_data_mixin.device_manager.connector.get.call_args_list
) )
assert ( assert (
mock.call( mock.call(
MessageEndpoints.device_req_status("samy"), MessageEndpoints.device_req_status("samy"),
readback_data_mixin.device_manager.producer.pipeline.return_value, readback_data_mixin.device_manager.connector.pipeline.return_value,
) )
in readback_data_mixin.device_manager.producer.get.call_args_list in readback_data_mixin.device_manager.connector.get.call_args_list
) )

View File

@ -15,7 +15,7 @@ async def test_update_progressbar_continues_without_device_data():
live_update = LiveUpdatesScanProgress(bec=bec, report_instruction={}, request=request) live_update = LiveUpdatesScanProgress(bec=bec, report_instruction={}, request=request)
progressbar = mock.MagicMock() progressbar = mock.MagicMock()
bec.producer.get.return_value = None bec.connector.get.return_value = None
res = await live_update._update_progressbar(progressbar, "async_dev1") res = await live_update._update_progressbar(progressbar, "async_dev1")
assert res is False assert res is False
@ -29,7 +29,7 @@ async def test_update_progressbar_continues_when_scanID_doesnt_match():
live_update.scan_item = mock.MagicMock() live_update.scan_item = mock.MagicMock()
live_update.scan_item.scanID = "scanID2" live_update.scan_item.scanID = "scanID2"
bec.producer.get.return_value = messages.ProgressMessage( bec.connector.get.return_value = messages.ProgressMessage(
value=1, max_value=10, done=False, metadata={"scanID": "scanID"} value=1, max_value=10, done=False, metadata={"scanID": "scanID"}
) )
res = await live_update._update_progressbar(progressbar, "async_dev1") res = await live_update._update_progressbar(progressbar, "async_dev1")
@ -45,7 +45,7 @@ async def test_update_progressbar_continues_when_msg_specifies_no_value():
live_update.scan_item = mock.MagicMock() live_update.scan_item = mock.MagicMock()
live_update.scan_item.scanID = "scanID" live_update.scan_item.scanID = "scanID"
bec.producer.get.return_value = messages.ProgressMessage( bec.connector.get.return_value = messages.ProgressMessage(
value=None, max_value=None, done=None, metadata={"scanID": "scanID"} value=None, max_value=None, done=None, metadata={"scanID": "scanID"}
) )
res = await live_update._update_progressbar(progressbar, "async_dev1") res = await live_update._update_progressbar(progressbar, "async_dev1")
@ -61,7 +61,7 @@ async def test_update_progressbar_updates_max_value():
live_update.scan_item = mock.MagicMock() live_update.scan_item = mock.MagicMock()
live_update.scan_item.scanID = "scanID" live_update.scan_item.scanID = "scanID"
bec.producer.get.return_value = messages.ProgressMessage( bec.connector.get.return_value = messages.ProgressMessage(
value=10, max_value=20, done=False, metadata={"scanID": "scanID"} value=10, max_value=20, done=False, metadata={"scanID": "scanID"}
) )
res = await live_update._update_progressbar(progressbar, "async_dev1") res = await live_update._update_progressbar(progressbar, "async_dev1")
@ -79,7 +79,7 @@ async def test_update_progressbar_returns_true_when_max_value_is_reached():
live_update.scan_item = mock.MagicMock() live_update.scan_item = mock.MagicMock()
live_update.scan_item.scanID = "scanID" live_update.scan_item.scanID = "scanID"
bec.producer.get.return_value = messages.ProgressMessage( bec.connector.get.return_value = messages.ProgressMessage(
value=10, max_value=10, done=True, metadata={"scanID": "scanID"} value=10, max_value=10, done=True, metadata={"scanID": "scanID"}
) )
res = await live_update._update_progressbar(progressbar, "async_dev1") res = await live_update._update_progressbar(progressbar, "async_dev1")

View File

@ -463,11 +463,11 @@ def test_file_writer(client):
md={"datasetID": 325}, md={"datasetID": 325},
) )
assert len(scan.scan.data) == 100 assert len(scan.scan.data) == 100
msg = bec.device_manager.producer.get(MessageEndpoints.public_file(scan.scan.scanID, "master")) msg = bec.device_manager.connector.get(MessageEndpoints.public_file(scan.scan.scanID, "master"))
while True: while True:
if msg: if msg:
break break
msg = bec.device_manager.producer.get( msg = bec.device_manager.connector.get(
MessageEndpoints.public_file(scan.scan.scanID, "master") MessageEndpoints.public_file(scan.scan.scanID, "master")
) )

View File

@ -3,7 +3,6 @@ from bec_lib.bec_service import BECService
from bec_lib.channel_monitor import channel_monitor_launch from bec_lib.channel_monitor import channel_monitor_launch
from bec_lib.client import BECClient from bec_lib.client import BECClient
from bec_lib.config_helper import ConfigHelper from bec_lib.config_helper import ConfigHelper
from bec_lib.connector import ProducerConnector
from bec_lib.device import DeviceBase, DeviceStatus, Status from bec_lib.device import DeviceBase, DeviceStatus, Status
from bec_lib.devicemanager import DeviceConfigError, DeviceContainer, DeviceManagerBase from bec_lib.devicemanager import DeviceConfigError, DeviceContainer, DeviceManagerBase
from bec_lib.endpoints import MessageEndpoints from bec_lib.endpoints import MessageEndpoints

View File

@ -48,23 +48,21 @@ class AlarmBase(Exception):
class AlarmHandler: class AlarmHandler:
def __init__(self, connector: RedisConnector) -> None: def __init__(self, connector: RedisConnector) -> None:
self.connector = connector self.connector = connector
self.alarm_consumer = None
self.alarms_stack = deque(maxlen=100) self.alarms_stack = deque(maxlen=100)
self._raised_alarms = deque(maxlen=100) self._raised_alarms = deque(maxlen=100)
self._lock = threading.RLock() self._lock = threading.RLock()
def start(self): def start(self):
"""start the alarm handler and its subscriptions""" """start the alarm handler and its subscriptions"""
self.alarm_consumer = self.connector.consumer( self.connector.register(
topics=MessageEndpoints.alarm(), topics=MessageEndpoints.alarm(),
name="AlarmHandler", name="AlarmHandler",
cb=self._alarm_consumer_callback, cb=self._alarm_register_callback,
parent=self, parent=self,
) )
self.alarm_consumer.start()
@staticmethod @staticmethod
def _alarm_consumer_callback(msg, *, parent, **_kwargs): def _alarm_register_callback(msg, *, parent, **_kwargs):
parent.add_alarm(msg.value) parent.add_alarm(msg.value)
@threadlocked @threadlocked
@ -136,4 +134,4 @@ class AlarmHandler:
def shutdown(self): def shutdown(self):
"""shutdown the alarm handler""" """shutdown the alarm handler"""
self.alarm_consumer.shutdown() self.connector.shutdown()

View File

@ -8,12 +8,12 @@ from bec_lib.endpoints import MessageEndpoints
if TYPE_CHECKING: if TYPE_CHECKING:
from bec_lib import messages from bec_lib import messages
from bec_lib.redis_connector import RedisProducer from bec_lib.connector import ConnectorBase
class AsyncDataHandler: class AsyncDataHandler:
def __init__(self, producer: RedisProducer): def __init__(self, connector: ConnectorBase):
self.producer = producer self.connector = connector
def get_async_data_for_scan(self, scan_id: str) -> dict[list]: def get_async_data_for_scan(self, scan_id: str) -> dict[list]:
""" """
@ -25,7 +25,9 @@ class AsyncDataHandler:
Returns: Returns:
dict[list]: the async data for the scan sorted by device name dict[list]: the async data for the scan sorted by device name
""" """
async_device_keys = self.producer.keys(MessageEndpoints.device_async_readback(scan_id, "*")) async_device_keys = self.connector.keys(
MessageEndpoints.device_async_readback(scan_id, "*")
)
async_data = {} async_data = {}
for device_key in async_device_keys: for device_key in async_device_keys:
key = device_key.decode() key = device_key.decode()
@ -50,7 +52,7 @@ class AsyncDataHandler:
list: the async data for the device list: the async data for the device
""" """
key = MessageEndpoints.device_async_readback(scan_id, device_name) key = MessageEndpoints.device_async_readback(scan_id, device_name)
msgs = self.producer.xrange(key, min="-", max="+") msgs = self.connector.xrange(key, min="-", max="+")
if not msgs: if not msgs:
return [] return []
return self.process_async_data(msgs) return self.process_async_data(msgs)

View File

@ -54,14 +54,14 @@ class BECWidgetsConnector:
def __init__(self, gui_id: str, bec_client: BECClient = None) -> None: def __init__(self, gui_id: str, bec_client: BECClient = None) -> None:
self._client = bec_client self._client = bec_client
self.gui_id = gui_id self.gui_id = gui_id
# TODO replace with a global producer # TODO replace with a global connector
if self._client is None: if self._client is None:
if "bec" in builtins.__dict__: if "bec" in builtins.__dict__:
self._client = builtins.bec self._client = builtins.bec
else: else:
self._client = BECClient() self._client = BECClient()
self._client.start() self._client.start()
self._producer = self._client.connector.producer() self._connector = self._client.connector
def set_plot_config(self, plot_id: str, config: dict) -> None: def set_plot_config(self, plot_id: str, config: dict) -> None:
""" """
@ -72,7 +72,7 @@ class BECWidgetsConnector:
config (dict): The config to set. config (dict): The config to set.
""" """
msg = messages.GUIConfigMessage(config=config) msg = messages.GUIConfigMessage(config=config)
self._producer.set_and_publish(MessageEndpoints.gui_config(plot_id), msg) self._connector.set_and_publish(MessageEndpoints.gui_config(plot_id), msg)
def close(self, plot_id: str) -> None: def close(self, plot_id: str) -> None:
""" """
@ -82,7 +82,7 @@ class BECWidgetsConnector:
plot_id (str): The id of the plot. plot_id (str): The id of the plot.
""" """
msg = messages.GUIInstructionMessage(action="close", parameter={}) msg = messages.GUIInstructionMessage(action="close", parameter={})
self._producer.set_and_publish(MessageEndpoints.gui_instructions(plot_id), msg) self._connector.set_and_publish(MessageEndpoints.gui_instructions(plot_id), msg)
def config_dialog(self, plot_id: str) -> None: def config_dialog(self, plot_id: str) -> None:
""" """
@ -92,7 +92,7 @@ class BECWidgetsConnector:
plot_id (str): The id of the plot. plot_id (str): The id of the plot.
""" """
msg = messages.GUIInstructionMessage(action="config_dialog", parameter={}) msg = messages.GUIInstructionMessage(action="config_dialog", parameter={})
self._producer.set_and_publish(MessageEndpoints.gui_instructions(plot_id), msg) self._connector.set_and_publish(MessageEndpoints.gui_instructions(plot_id), msg)
def send_data(self, plot_id: str, data: dict) -> None: def send_data(self, plot_id: str, data: dict) -> None:
""" """
@ -103,9 +103,9 @@ class BECWidgetsConnector:
data (dict): The data to send. data (dict): The data to send.
""" """
msg = messages.GUIDataMessage(data=data) msg = messages.GUIDataMessage(data=data)
self._producer.set_and_publish(topic=MessageEndpoints.gui_data(plot_id), msg=msg) self._connector.set_and_publish(topic=MessageEndpoints.gui_data(plot_id), msg=msg)
# TODO bec_dispatcher can only handle set_and_publish ATM # TODO bec_dispatcher can only handle set_and_publish ATM
# self._producer.xadd(topic=MessageEndpoints.gui_data(plot_id),msg= {"data": msg}) # self._connector.xadd(topic=MessageEndpoints.gui_data(plot_id),msg= {"data": msg})
def clear(self, plot_id: str) -> None: def clear(self, plot_id: str) -> None:
""" """
@ -115,7 +115,7 @@ class BECWidgetsConnector:
plot_id (str): The id of the plot. plot_id (str): The id of the plot.
""" """
msg = messages.GUIInstructionMessage(action="clear", parameter={}) msg = messages.GUIInstructionMessage(action="clear", parameter={})
self._producer.set_and_publish(MessageEndpoints.gui_instructions(plot_id), msg) self._connector.set_and_publish(MessageEndpoints.gui_instructions(plot_id), msg)
class BECPlotter: class BECPlotter:

View File

@ -41,7 +41,6 @@ class BECService:
self.connector = connector_cls(self.bootstrap_server) self.connector = connector_cls(self.bootstrap_server)
self._unique_service = unique_service self._unique_service = unique_service
self.wait_for_server = wait_for_server self.wait_for_server = wait_for_server
self.producer = self.connector.producer()
self.__service_id = str(uuid.uuid4()) self.__service_id = str(uuid.uuid4())
self._user = getpass.getuser() self._user = getpass.getuser()
self._hostname = socket.gethostname() self._hostname = socket.gethostname()
@ -110,11 +109,11 @@ class BECService:
) )
def _update_existing_services(self): def _update_existing_services(self):
service_keys = self.producer.keys(MessageEndpoints.service_status("*")) service_keys = self.connector.keys(MessageEndpoints.service_status("*"))
if not service_keys: if not service_keys:
return return
services = [service.decode().split(":", maxsplit=1)[0] for service in service_keys] services = [service.decode().split(":", maxsplit=1)[0] for service in service_keys]
msgs = [self.producer.get(service) for service in services] msgs = [self.connector.get(service) for service in services]
self._services_info = {msg.content["name"]: msg for msg in msgs if msg is not None} self._services_info = {msg.content["name"]: msg for msg in msgs if msg is not None}
def _update_service_info(self): def _update_service_info(self):
@ -124,7 +123,7 @@ class BECService:
self._service_info_event.wait(timeout=3) self._service_info_event.wait(timeout=3)
def _send_service_status(self): def _send_service_status(self):
self.producer.set_and_publish( self.connector.set_and_publish(
topic=MessageEndpoints.service_status(self._service_id), topic=MessageEndpoints.service_status(self._service_id),
msg=messages.StatusMessage( msg=messages.StatusMessage(
name=self._service_name, name=self._service_name,
@ -189,7 +188,7 @@ class BECService:
) )
) )
msg = messages.ServiceMetricMessage(name=self.__class__.__name__, metrics=data) msg = messages.ServiceMetricMessage(name=self.__class__.__name__, metrics=data)
self.producer.send(MessageEndpoints.metrics(self._service_id), msg) self.connector.send(MessageEndpoints.metrics(self._service_id), msg)
self._metrics_emitter_event.wait(timeout=1) self._metrics_emitter_event.wait(timeout=1)
def set_global_var(self, name: str, val: Any) -> None: def set_global_var(self, name: str, val: Any) -> None:
@ -200,7 +199,7 @@ class BECService:
val (Any): Value of the variable val (Any): Value of the variable
""" """
self.producer.set(MessageEndpoints.global_vars(name), messages.VariableMessage(value=val)) self.connector.set(MessageEndpoints.global_vars(name), messages.VariableMessage(value=val))
def get_global_var(self, name: str) -> Any: def get_global_var(self, name: str) -> Any:
"""Get a global variable from Redis """Get a global variable from Redis
@ -211,7 +210,7 @@ class BECService:
Returns: Returns:
Any: Value of the variable Any: Value of the variable
""" """
msg = self.producer.get(MessageEndpoints.global_vars(name)) msg = self.connector.get(MessageEndpoints.global_vars(name))
if msg: if msg:
return msg.content.get("value") return msg.content.get("value")
return None return None
@ -223,12 +222,12 @@ class BECService:
name (str): Name of the variable name (str): Name of the variable
""" """
self.producer.delete(MessageEndpoints.global_vars(name)) self.connector.delete(MessageEndpoints.global_vars(name))
def global_vars(self) -> str: def global_vars(self) -> str:
"""Get all available global variables""" """Get all available global variables"""
# sadly, this cannot be a property as it causes side effects with IPython's tab completion # sadly, this cannot be a property as it causes side effects with IPython's tab completion
available_keys = self.producer.keys(MessageEndpoints.global_vars("*")) available_keys = self.connector.keys(MessageEndpoints.global_vars("*"))
def get_endpoint_from_topic(topic: str) -> str: def get_endpoint_from_topic(topic: str) -> str:
return topic.decode().split(MessageEndpoints.global_vars(""))[-1] return topic.decode().split(MessageEndpoints.global_vars(""))[-1]
@ -252,6 +251,7 @@ class BECService:
def shutdown(self): def shutdown(self):
"""shutdown the BECService""" """shutdown the BECService"""
self.connector.shutdown()
self._service_info_event.set() self._service_info_event.set()
if self._service_info_thread: if self._service_info_thread:
self._service_info_thread.join() self._service_info_thread.join()

View File

@ -16,11 +16,11 @@ def channel_callback(msg, **_kwargs):
print(json.dumps(out, indent=4, default=lambda o: "<not serializable object>")) print(json.dumps(out, indent=4, default=lambda o: "<not serializable object>"))
def _start_consumer(config_path, topic): def _start_register(config_path, topic):
config = ServiceConfig(config_path) config = ServiceConfig(config_path)
connector = RedisConnector(config.redis) connector = RedisConnector(config.redis)
consumer = connector.consumer(topics=topic, cb=channel_callback) register = connector.register(topics=topic, cb=channel_callback)
consumer.start() register.start()
event = threading.Event() event = threading.Event()
event.wait() event.wait()
@ -38,4 +38,4 @@ def channel_monitor_launch():
config_path = clargs.config config_path = clargs.config
topic = clargs.channel topic = clargs.channel
_start_consumer(config_path, topic) _start_register(config_path, topic)

View File

@ -91,7 +91,7 @@ class BECClient(BECService, UserScriptsMixin):
@property @property
def active_account(self) -> str: def active_account(self) -> str:
"""get the currently active target (e)account""" """get the currently active target (e)account"""
return self.producer.get(MessageEndpoints.account()) return self.connector.get(MessageEndpoints.account())
def start(self): def start(self):
"""start the client""" """start the client"""
@ -133,13 +133,13 @@ class BECClient(BECService, UserScriptsMixin):
@property @property
def pre_scan_hooks(self): def pre_scan_hooks(self):
"""currently stored pre-scan hooks""" """currently stored pre-scan hooks"""
return self.producer.lrange(MessageEndpoints.pre_scan_macros(), 0, -1) return self.connector.lrange(MessageEndpoints.pre_scan_macros(), 0, -1)
@pre_scan_hooks.setter @pre_scan_hooks.setter
def pre_scan_hooks(self, hooks: list): def pre_scan_hooks(self, hooks: list):
self.producer.delete(MessageEndpoints.pre_scan_macros()) self.connector.delete(MessageEndpoints.pre_scan_macros())
for hook in hooks: for hook in hooks:
self.producer.lpush(MessageEndpoints.pre_scan_macros(), hook) self.connector.lpush(MessageEndpoints.pre_scan_macros(), hook)
def _load_scans(self): def _load_scans(self):
self.scans = Scans(self) self.scans = Scans(self)

View File

@ -26,7 +26,6 @@ logger = bec_logger.logger
class ConfigHelper: class ConfigHelper:
def __init__(self, connector: RedisConnector, service_name: str = None) -> None: def __init__(self, connector: RedisConnector, service_name: str = None) -> None:
self.connector = connector self.connector = connector
self.producer = connector.producer()
self._service_name = service_name self._service_name = service_name
def update_session_with_file(self, file_path: str, save_recovery: bool = True) -> None: def update_session_with_file(self, file_path: str, save_recovery: bool = True) -> None:
@ -71,7 +70,7 @@ class ConfigHelper:
print(f"Config was written to {file_path}.") print(f"Config was written to {file_path}.")
def _save_config_to_file(self, file_path: str, raise_on_error: bool = True) -> bool: def _save_config_to_file(self, file_path: str, raise_on_error: bool = True) -> bool:
config = self.producer.get(MessageEndpoints.device_config()) config = self.connector.get(MessageEndpoints.device_config())
if not config: if not config:
if raise_on_error: if raise_on_error:
raise DeviceConfigError("No config found in the session.") raise DeviceConfigError("No config found in the session.")
@ -99,7 +98,7 @@ class ConfigHelper:
if action in ["update", "add", "set"] and not config: if action in ["update", "add", "set"] and not config:
raise DeviceConfigError(f"Config cannot be empty for an {action} request.") raise DeviceConfigError(f"Config cannot be empty for an {action} request.")
RID = str(uuid.uuid4()) RID = str(uuid.uuid4())
self.producer.send( self.connector.send(
MessageEndpoints.device_config_request(), MessageEndpoints.device_config_request(),
DeviceConfigMessage(action=action, config=config, metadata={"RID": RID}), DeviceConfigMessage(action=action, config=config, metadata={"RID": RID}),
) )
@ -145,7 +144,7 @@ class ConfigHelper:
elapsed_time = 0 elapsed_time = 0
max_time = timeout max_time = timeout
while True: while True:
service_messages = self.producer.lrange(MessageEndpoints.service_response(RID), 0, -1) service_messages = self.connector.lrange(MessageEndpoints.service_response(RID), 0, -1)
if not service_messages: if not service_messages:
time.sleep(0.005) time.sleep(0.005)
elapsed_time += 0.005 elapsed_time += 0.005
@ -185,7 +184,7 @@ class ConfigHelper:
""" """
start = 0 start = 0
while True: while True:
msg = self.producer.get(MessageEndpoints.device_config_request_response(RID)) msg = self.connector.get(MessageEndpoints.device_config_request_response(RID))
if msg is None: if msg is None:
time.sleep(0.01) time.sleep(0.01)
start += 0.01 start += 0.01

View File

@ -6,7 +6,8 @@ import threading
import traceback import traceback
from bec_lib.logger import bec_logger from bec_lib.logger import bec_logger
from bec_lib.messages import BECMessage from bec_lib.messages import BECMessage, LogMessage
from bec_lib.endpoints import MessageEndpoints
logger = bec_logger.logger logger = bec_logger.logger
@ -33,154 +34,98 @@ class MessageObject:
return f"MessageObject(topic={self.topic}, value={self._value})" return f"MessageObject(topic={self.topic}, value={self._value})"
class ConnectorBase(abc.ABC): class StoreInterface(abc.ABC):
""" """StoreBase defines the interface for storing data"""
ConnectorBase implements producer and consumer clients for communicating with a broker.
One ought to inherit from this base class and provide at least customized producer and consumer methods.
""" def __init__(self, store):
def __init__(self, bootstrap_server: list):
self.bootstrap = bootstrap_server
self._threads = []
def producer(self, **kwargs) -> ProducerConnector:
raise NotImplementedError
def consumer(self, **kwargs) -> ConsumerConnectorThreaded:
raise NotImplementedError
def shutdown(self):
for t in self._threads:
t.signal_event.set()
t.join()
def raise_warning(self, msg):
raise NotImplementedError
def send_log(self, msg):
raise NotImplementedError
def poll_messages(self):
"""Poll for new messages, receive them and execute callbacks"""
pass pass
def pipeline(self):
pass
class ProducerConnector(abc.ABC): def execute_pipeline(self):
pass
def lpush(
self, topic: str, msg: str, pipe=None, max_size: int = None, expire: int = None
) -> None:
raise NotImplementedError
def lset(self, topic: str, index: int, msg: str, pipe=None) -> None:
raise NotImplementedError
def rpush(self, topic: str, msg: str, pipe=None) -> int:
raise NotImplementedError
def lrange(self, topic: str, start: int, end: int, pipe=None):
raise NotImplementedError
def set(self, topic: str, msg, pipe=None, expire: int = None) -> None:
raise NotImplementedError
def keys(self, pattern: str) -> list:
raise NotImplementedError
def delete(self, topic, pipe=None):
raise NotImplementedError
def get(self, topic: str, pipe=None):
raise NotImplementedError
def xadd(self, topic: str, msg: dict, max_size=None, pipe=None, expire: int = None):
raise NotImplementedError
def xread(
self,
topic: str,
id: str = None,
count: int = None,
block: int = None,
pipe=None,
from_start=False,
) -> list:
raise NotImplementedError
def xrange(self, topic: str, min: str, max: str, count: int = None, pipe=None):
raise NotImplementedError
class PubSubInterface(abc.ABC):
def raw_send(self, topic: str, msg: bytes) -> None: def raw_send(self, topic: str, msg: bytes) -> None:
raise NotImplementedError raise NotImplementedError
def send(self, topic: str, msg: BECMessage) -> None: def send(self, topic: str, msg: BECMessage) -> None:
raise NotImplementedError raise NotImplementedError
def register(self, topics=None, pattern=None, cb=None, start_thread=True, **kwargs):
class ConsumerConnector(abc.ABC):
def __init__(
self, bootstrap_server, cb, topics=None, pattern=None, group_id=None, event=None, **kwargs
):
"""
ConsumerConnector class defines the communication with the broker for consuming messages.
An implementation ought to inherit from this class and implement the initialize_connector and poll_messages methods.
Args:
bootstrap_server: list of bootstrap servers, e.g. ["localhost:9092", "localhost:9093"]
topics: the topic(s) to which the connector should attach
event: external event to trigger start and stop of the connector
cb: callback function; will be triggered from within poll_messages
kwargs: additional keyword arguments
"""
self.bootstrap = bootstrap_server
self.topics = topics
self.pattern = pattern
self.group_id = group_id
self.connector = None
self.cb = cb
self.kwargs = kwargs
if not self.topics and not self.pattern:
raise ConsumerConnectorError("Either a topic or a patter must be specified.")
def initialize_connector(self) -> None:
"""
initialize the connector instance self.connector
The connector will be initialized once the thread is started
"""
raise NotImplementedError raise NotImplementedError
def poll_messages(self) -> None: def poll_messages(self, timeout=None):
""" """Poll for new messages, receive them and execute callbacks"""
Poll messages from self.connector and call the callback function self.cb raise NotImplementedError
""" def run_messages_loop(self):
raise NotImplementedError() raise NotImplementedError
class ConsumerConnectorThreaded(ConsumerConnector, threading.Thread):
def __init__(
self,
bootstrap_server,
cb,
topics=None,
pattern=None,
group_id=None,
event=None,
name=None,
**kwargs,
):
"""
ConsumerConnectorThreaded class defines the threaded communication with the broker for consuming messages.
An implementation ought to inherit from this class and implement the initialize_connector and poll_messages methods.
Once started, the connector is expected to poll new messages until the signal_event is set.
Args:
bootstrap_server: list of bootstrap servers, e.g. ["localhost:9092", "localhost:9093"]
topics: the topic(s) to which the connector should attach
event: external event to trigger start and stop of the connector
cb: callback function; will be triggered from within poll_messages
kwargs: additional keyword arguments
"""
super().__init__(
bootstrap_server=bootstrap_server,
topics=topics,
pattern=pattern,
group_id=group_id,
event=event,
cb=cb,
**kwargs,
)
if name is not None:
thread_kwargs = {"name": name, "daemon": True}
else:
thread_kwargs = {"daemon": True}
super(ConsumerConnector, self).__init__(**thread_kwargs)
self.signal_event = event if event is not None else threading.Event()
def run(self):
self.initialize_connector()
while True:
try:
self.poll_messages()
except Exception as e:
logger.error(traceback.format_exc())
_thread.interrupt_main()
raise e
finally:
if self.signal_event.is_set():
self.shutdown()
break
def shutdown(self): def shutdown(self):
self.signal_event.set() raise NotImplementedError
# def stop(self) -> None:
# """
# Stop consumer
# Returns:
# """ class ConnectorBase(PubSubInterface, StoreInterface):
# self.signal_event.set() def raise_warning(self, msg):
# self.connector.close() raise NotImplementedError
# self.join()
def log_warning(self, msg):
"""send a warning"""
self.send(MessageEndpoints.log(), LogMessage(log_type="warning", log_msg=msg))
def log_message(self, msg):
"""send a log message"""
self.send(MessageEndpoints.log(), LogMessage(log_type="log", log_msg=msg))
def log_error(self, msg):
"""send an error as log"""
self.send(MessageEndpoints.log(), LogMessage(log_type="error", log_msg=msg))
def set_and_publish(self, topic: str, msg, pipe=None, expire: int = None) -> None:
raise NotImplementedError

View File

@ -78,7 +78,7 @@ class DAPPluginObjectBase:
converted_kwargs[key] = val converted_kwargs[key] = val
kwargs = converted_kwargs kwargs = converted_kwargs
request_id = str(uuid.uuid4()) request_id = str(uuid.uuid4())
self._client.producer.set_and_publish( self._client.connector.set_and_publish(
MessageEndpoints.dap_request(), MessageEndpoints.dap_request(),
messages.DAPRequestMessage( messages.DAPRequestMessage(
dap_cls=self._plugin_info["class"], dap_cls=self._plugin_info["class"],
@ -110,7 +110,7 @@ class DAPPluginObjectBase:
while True: while True:
if time.time() - start_time > timeout: if time.time() - start_time > timeout:
raise TimeoutError("Timeout waiting for DAP response.") raise TimeoutError("Timeout waiting for DAP response.")
response = self._client.producer.get(MessageEndpoints.dap_response(request_id)) response = self._client.connector.get(MessageEndpoints.dap_response(request_id))
if not response: if not response:
time.sleep(0.005) time.sleep(0.005)
continue continue
@ -128,7 +128,7 @@ class DAPPluginObjectBase:
return return
self._plugin_config["class_args"] = self._plugin_info.get("class_args") self._plugin_config["class_args"] = self._plugin_info.get("class_args")
self._plugin_config["class_kwargs"] = self._plugin_info.get("class_kwargs") self._plugin_config["class_kwargs"] = self._plugin_info.get("class_kwargs")
self._client.producer.set_and_publish( self._client.connector.set_and_publish(
MessageEndpoints.dap_request(), MessageEndpoints.dap_request(),
messages.DAPRequestMessage( messages.DAPRequestMessage(
dap_cls=self._plugin_info["class"], dap_cls=self._plugin_info["class"],
@ -149,7 +149,7 @@ class DAPPluginObject(DAPPluginObjectBase):
""" """
Get the data from last run. Get the data from last run.
""" """
msg = self._client.producer.get_last(MessageEndpoints.processed_data(self._service_name)) msg = self._client.connector.get_last(MessageEndpoints.processed_data(self._service_name))
if not msg: if not msg:
return None return None
return self._convert_result(msg) return self._convert_result(msg)

View File

@ -38,7 +38,7 @@ class DAPPlugins:
service for service in available_services if service.startswith("DAPServer/") service for service in available_services if service.startswith("DAPServer/")
] ]
for service in dap_services: for service in dap_services:
available_plugins = self._parent.producer.get( available_plugins = self._parent.connector.get(
MessageEndpoints.dap_available_plugins(service) MessageEndpoints.dap_available_plugins(service)
) )
if available_plugins is None: if available_plugins is None:

View File

@ -10,7 +10,7 @@ from typeguard import typechecked
from bec_lib import messages from bec_lib import messages
from bec_lib.endpoints import MessageEndpoints from bec_lib.endpoints import MessageEndpoints
from bec_lib.logger import bec_logger from bec_lib.logger import bec_logger
from bec_lib.redis_connector import RedisProducer from bec_lib.redis_connector import RedisConnector
logger = bec_logger.logger logger = bec_logger.logger
@ -61,15 +61,15 @@ class ReadoutPriority(str, enum.Enum):
class Status: class Status:
def __init__(self, producer: RedisProducer, RID: str) -> None: def __init__(self, connector: RedisConnector, RID: str) -> None:
""" """
Status object for RPC calls Status object for RPC calls
Args: Args:
producer (RedisProducer): Redis producer connector (RedisConnector): Redis connector
RID (str): Request ID RID (str): Request ID
""" """
self._producer = producer self._connector = connector
self._RID = RID self._RID = RID
def __eq__(self, __value: object) -> bool: def __eq__(self, __value: object) -> bool:
@ -91,7 +91,7 @@ class Status:
raise TimeoutError() raise TimeoutError()
while True: while True:
request_status = self._producer.lrange( request_status = self._connector.lrange(
MessageEndpoints.device_req_status(self._RID), 0, -1 MessageEndpoints.device_req_status(self._RID), 0, -1
) )
if request_status: if request_status:
@ -251,7 +251,7 @@ class DeviceBase:
if not isinstance(return_val, dict): if not isinstance(return_val, dict):
return return_val return return_val
if return_val.get("type") == "status" and return_val.get("RID"): if return_val.get("type") == "status" and return_val.get("RID"):
return Status(self.root.parent.producer, return_val.get("RID")) return Status(self.root.parent.connector, return_val.get("RID"))
return return_val return return_val
def _get_rpc_response(self, request_id, rpc_id) -> Any: def _get_rpc_response(self, request_id, rpc_id) -> Any:
@ -267,7 +267,7 @@ class DeviceBase:
f" {scan_queue_request.response.content['message']}" f" {scan_queue_request.response.content['message']}"
) )
while True: while True:
msg = self.root.parent.producer.get(MessageEndpoints.device_rpc(rpc_id)) msg = self.root.parent.connector.get(MessageEndpoints.device_rpc(rpc_id))
if msg: if msg:
break break
time.sleep(0.01) time.sleep(0.01)
@ -296,7 +296,7 @@ class DeviceBase:
msg = self._prepare_rpc_msg(rpc_id, request_id, device, func_call, *args, **kwargs) msg = self._prepare_rpc_msg(rpc_id, request_id, device, func_call, *args, **kwargs)
# send RPC message # send RPC message
self.root.parent.producer.send(MessageEndpoints.scan_queue_request(), msg) self.root.parent.connector.send(MessageEndpoints.scan_queue_request(), msg)
# wait for RPC response # wait for RPC response
if not wait_for_rpc_response: if not wait_for_rpc_response:
@ -496,7 +496,7 @@ class DeviceBase:
# def read(self, cached, filter_readback=True): # def read(self, cached, filter_readback=True):
# """get the last reading from a device""" # """get the last reading from a device"""
# val = self.parent.producer.get(MessageEndpoints.device_read(self.name)) # val = self.parent.connector.get(MessageEndpoints.device_read(self.name))
# if not val: # if not val:
# return None # return None
# if filter_readback: # if filter_readback:
@ -505,7 +505,7 @@ class DeviceBase:
# #
# def readback(self, filter_readback=True): # def readback(self, filter_readback=True):
# """get the last readback value from a device""" # """get the last readback value from a device"""
# val = self.parent.producer.get(MessageEndpoints.device_readback(self.name)) # val = self.parent.connector.get(MessageEndpoints.device_readback(self.name))
# if not val: # if not val:
# return None # return None
# if filter_readback: # if filter_readback:
@ -515,7 +515,7 @@ class DeviceBase:
# @property # @property
# def device_status(self): # def device_status(self):
# """get the current status of the device""" # """get the current status of the device"""
# val = self.parent.producer.get(MessageEndpoints.device_status(self.name)) # val = self.parent.connector.get(MessageEndpoints.device_status(self.name))
# if val is None: # if val is None:
# return val # return val
# val = DeviceStatusMessage.loads(val) # val = DeviceStatusMessage.loads(val)
@ -524,7 +524,7 @@ class DeviceBase:
# @property # @property
# def signals(self): # def signals(self):
# """get the last signals from a device""" # """get the last signals from a device"""
# val = self.parent.producer.get(MessageEndpoints.device_read(self.name)) # val = self.parent.connector.get(MessageEndpoints.device_read(self.name))
# if val is None: # if val is None:
# return None # return None
# self._signals = DeviceMessage.loads(val).content["signals"] # self._signals = DeviceMessage.loads(val).content["signals"]
@ -593,11 +593,11 @@ class OphydInterfaceBase(DeviceBase):
if is_config_signal: if is_config_signal:
return self.read_configuration(cached=cached) return self.read_configuration(cached=cached)
if use_readback: if use_readback:
val = self.root.parent.producer.get( val = self.root.parent.connector.get(
MessageEndpoints.device_readback(self.root.name) MessageEndpoints.device_readback(self.root.name)
) )
else: else:
val = self.root.parent.producer.get(MessageEndpoints.device_read(self.root.name)) val = self.root.parent.connector.get(MessageEndpoints.device_read(self.root.name))
if not val: if not val:
return None return None
@ -623,7 +623,7 @@ class OphydInterfaceBase(DeviceBase):
if is_signal and not is_config_signal: if is_signal and not is_config_signal:
return self.read(cached=True) return self.read(cached=True)
val = self.root.parent.producer.get( val = self.root.parent.connector.get(
MessageEndpoints.device_read_configuration(self.root.name) MessageEndpoints.device_read_configuration(self.root.name)
) )
if not val: if not val:
@ -766,7 +766,7 @@ class AdjustableMixin:
""" """
Returns the device limits. Returns the device limits.
""" """
limit_msg = self.root.parent.producer.get(MessageEndpoints.device_limits(self.root.name)) limit_msg = self.root.parent.connector.get(MessageEndpoints.device_limits(self.root.name))
if not limit_msg: if not limit_msg:
return [0, 0] return [0, 0]
limits = [ limits = [

View File

@ -370,8 +370,7 @@ class DeviceManagerBase:
_request_config_parsed = None # parsed config request _request_config_parsed = None # parsed config request
_response = None # response message _response = None # response message
_connector_base_consumer = {} _connector_base_register = {}
producer = None
config_helper = None config_helper = None
_device_cls = DeviceBase _device_cls = DeviceBase
_status_cb = [] _status_cb = []
@ -464,7 +463,7 @@ class DeviceManagerBase:
""" """
if not msg.metadata.get("RID"): if not msg.metadata.get("RID"):
return return
self.producer.lpush( self.connector.lpush(
MessageEndpoints.service_response(msg.metadata["RID"]), MessageEndpoints.service_response(msg.metadata["RID"]),
messages.ServiceResponseMessage( messages.ServiceResponseMessage(
# pylint: disable=no-member # pylint: disable=no-member
@ -487,25 +486,20 @@ class DeviceManagerBase:
self._remove_device(dev) self._remove_device(dev)
def _start_connectors(self, bootstrap_server) -> None: def _start_connectors(self, bootstrap_server) -> None:
self._start_base_consumer() self._start_base_register()
self.producer = self.connector.producer()
self._start_custom_connectors(bootstrap_server)
def _start_base_consumer(self) -> None: def _start_base_register(self) -> None:
""" """
Start consuming messages for all base topics. This method will be called upon startup. Start consuming messages for all base topics. This method will be called upon startup.
Returns: Returns:
""" """
self._connector_base_consumer["device_config_update"] = self.connector.consumer( self.connector.register(
MessageEndpoints.device_config_update(), MessageEndpoints.device_config_update(),
cb=self._device_config_update_callback, cb=self._device_config_update_callback,
parent=self, parent=self,
) )
# self._connector_base_consumer["log"].start()
self._connector_base_consumer["device_config_update"].start()
@staticmethod @staticmethod
def _log_callback(msg, *, parent, **kwargs) -> None: def _log_callback(msg, *, parent, **kwargs) -> None:
""" """
@ -541,48 +535,11 @@ class DeviceManagerBase:
self._load_session() self._load_session()
def _get_redis_device_config(self) -> list: def _get_redis_device_config(self) -> list:
devices = self.producer.get(MessageEndpoints.device_config()) devices = self.connector.get(MessageEndpoints.device_config())
if not devices: if not devices:
return [] return []
return devices.content["resource"] return devices.content["resource"]
def _stop_base_consumer(self):
"""
Stop all base consumers by setting the corresponding event
Returns:
"""
if self.connector is not None:
for _, con in self._connector_base_consumer.items():
con.signal_event.set()
con.join()
def _stop_consumer(self):
"""
Stop all consumers
Returns:
"""
self._stop_base_consumer()
self._stop_custom_consumer()
def _start_custom_connectors(self, bootstrap_server) -> None:
"""
Override this method in a derived class to start custom connectors upon initialization.
Args:
bootstrap_server: Kafka bootstrap server
Returns:
"""
def _stop_custom_consumer(self) -> None:
"""
Stop all custom consumers. Override this method in a derived class.
Returns:
"""
def _add_device(self, dev: dict, msg: messages.DeviceInfoMessage): def _add_device(self, dev: dict, msg: messages.DeviceInfoMessage):
name = msg.content["device"] name = msg.content["device"]
info = msg.content["info"] info = msg.content["info"]
@ -621,8 +578,7 @@ class DeviceManagerBase:
logger.error(f"Failed to load device {dev}: {content}") logger.error(f"Failed to load device {dev}: {content}")
def _get_device_info(self, device_name) -> DeviceInfoMessage: def _get_device_info(self, device_name) -> DeviceInfoMessage:
msg = self.producer.get(MessageEndpoints.device_info(device_name)) return self.connector.get(MessageEndpoints.device_info(device_name))
return msg
def check_request_validity(self, msg: DeviceConfigMessage) -> None: def check_request_validity(self, msg: DeviceConfigMessage) -> None:
""" """
@ -663,10 +619,7 @@ class DeviceManagerBase:
""" """
Shutdown all connectors. Shutdown all connectors.
""" """
try: self.connector.shutdown()
self.connector.shutdown()
except RuntimeError as runtime_error:
logger.error(f"Failed to shutdown connector. {runtime_error}")
def __del__(self): def __del__(self):
self.shutdown() self.shutdown()

View File

@ -24,7 +24,6 @@ if TYPE_CHECKING:
class LogbookConnector: class LogbookConnector:
def __init__(self, connector: RedisConnector) -> None: def __init__(self, connector: RedisConnector) -> None:
self.connector = connector self.connector = connector
self.producer = connector.producer()
self.connected = False self.connected = False
self._scilog_module = None self._scilog_module = None
self._connect() self._connect()
@ -34,12 +33,12 @@ class LogbookConnector:
if "scilog" not in sys.modules: if "scilog" not in sys.modules:
return return
msg = self.producer.get(MessageEndpoints.logbook()) msg = self.connector.get(MessageEndpoints.logbook())
if not msg: if not msg:
return return
msg = msgpack.loads(msg) msg = msgpack.loads(msg)
account = self.producer.get(MessageEndpoints.account()) account = self.connector.get(MessageEndpoints.account())
if not account: if not account:
return return
account = account.decode() account = account.decode()
@ -54,7 +53,7 @@ class LogbookConnector:
try: try:
logbooks = self.log.get_logbooks(readACL={"inq": [account]}) logbooks = self.log.get_logbooks(readACL={"inq": [account]})
except HTTPError: except HTTPError:
self.producer.set(MessageEndpoints.logbook(), b"") self.connector.set(MessageEndpoints.logbook(), b"")
return return
if len(logbooks) > 1: if len(logbooks) > 1:
logger.warning("Found two logbooks. Taking the first one.") logger.warning("Found two logbooks. Taking the first one.")

View File

@ -45,7 +45,6 @@ class BECLogger:
self.bootstrap_server = None self.bootstrap_server = None
self.connector = None self.connector = None
self.service_name = None self.service_name = None
self.producer = None
self.logger = loguru_logger self.logger = loguru_logger
self._log_level = LogLevel.INFO self._log_level = LogLevel.INFO
self.level = self._log_level self.level = self._log_level
@ -73,7 +72,6 @@ class BECLogger:
self.bootstrap_server = bootstrap_server self.bootstrap_server = bootstrap_server
self.connector = connector_cls(bootstrap_server) self.connector = connector_cls(bootstrap_server)
self.service_name = service_name self.service_name = service_name
self.producer = self.connector.producer()
self._configured = True self._configured = True
self._update_sinks() self._update_sinks()
@ -82,7 +80,7 @@ class BECLogger:
return return
msg = json.loads(msg) msg = json.loads(msg)
msg["service_name"] = self.service_name msg["service_name"] = self.service_name
self.producer.send( self.connector.send(
topic=MessageEndpoints.log(), topic=MessageEndpoints.log(),
msg=bec_lib.messages.LogMessage(log_type=msg["record"]["level"]["name"], log_msg=msg), msg=bec_lib.messages.LogMessage(log_type=msg["record"]["level"]["name"], log_msg=msg),
) )

View File

@ -152,7 +152,7 @@ class ObserverManager:
def _get_installed_observer(self): def _get_installed_observer(self):
# get current observer list from Redis # get current observer list from Redis
observer_msg = self.device_manager.producer.get(MessageEndpoints.observer()) observer_msg = self.device_manager.connector.get(MessageEndpoints.observer())
if observer_msg is None: if observer_msg is None:
return [] return []
return [Observer.from_dict(obs) for obs in observer_msg.content["observer"]] return [Observer.from_dict(obs) for obs in observer_msg.content["observer"]]

View File

@ -122,12 +122,16 @@ class QueueStorage:
if history < 0: if history < 0:
history *= -1 history *= -1
return self.scan_manager.producer.lrange(MessageEndpoints.scan_queue_history(), 0, history) return self.scan_manager.connector.lrange(
MessageEndpoints.scan_queue_history(),
0,
history,
)
@property @property
def current_scan_queue(self) -> dict: def current_scan_queue(self) -> dict:
"""get the current scan queue from redis""" """get the current scan queue from redis"""
msg = self.scan_manager.producer.get(MessageEndpoints.scan_queue_status()) msg = self.scan_manager.connector.get(MessageEndpoints.scan_queue_status())
if msg: if msg:
self._current_scan_queue = msg.content["queue"] self._current_scan_queue = msg.content["queue"]
return self._current_scan_queue return self._current_scan_queue

View File

@ -1,19 +1,18 @@
from __future__ import annotations from __future__ import annotations
import time import collections
import queue
import sys
import threading
import warnings import warnings
from functools import wraps from functools import wraps
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import louie
import redis import redis
import redis.client
from bec_lib.connector import ( from bec_lib.connector import ConnectorBase, MessageObject
ConnectorBase,
ConsumerConnector,
ConsumerConnectorThreaded,
MessageObject,
ProducerConnector,
)
from bec_lib.endpoints import MessageEndpoints from bec_lib.endpoints import MessageEndpoints
from bec_lib.messages import AlarmMessage, BECMessage, LogMessage from bec_lib.messages import AlarmMessage, BECMessage, LogMessage
from bec_lib.serialization import MsgpackSerialization from bec_lib.serialization import MsgpackSerialization
@ -31,7 +30,6 @@ def catch_connection_error(func):
return func(*args, **kwargs) return func(*args, **kwargs)
except redis.exceptions.ConnectionError: except redis.exceptions.ConnectionError:
warnings.warn("Failed to connect to redis. Is the server running?") warnings.warn("Failed to connect to redis. Is the server running?")
time.sleep(0.1)
return None return None
return wrapper return wrapper
@ -40,149 +38,80 @@ def catch_connection_error(func):
class RedisConnector(ConnectorBase): class RedisConnector(ConnectorBase):
def __init__(self, bootstrap: list, redis_cls=None): def __init__(self, bootstrap: list, redis_cls=None):
super().__init__(bootstrap) super().__init__(bootstrap)
self.redis_cls = redis_cls
self.host, self.port = ( self.host, self.port = (
bootstrap[0].split(":") if isinstance(bootstrap, list) else bootstrap.split(":") bootstrap[0].split(":") if isinstance(bootstrap, list) else bootstrap.split(":")
) )
self._notifications_producer = RedisProducer(
host=self.host, port=self.port, redis_cls=self.redis_cls
)
def producer(self, **kwargs): if redis_cls:
return RedisProducer(host=self.host, port=self.port, redis_cls=self.redis_cls) self._redis_conn = redis_cls(host=self.host, port=self.port)
else:
self._redis_conn = redis.Redis(host=self.host, port=self.port)
# pylint: disable=too-many-arguments # main pubsub connection
def consumer( self._pubsub_conn = self._redis_conn.pubsub()
self, self._pubsub_conn.ignore_subscribe_messages = True
topics=None, # keep track of topics and callbacks
pattern=None, self._topics_cb = collections.defaultdict(list)
group_id=None,
event=None,
cb=None,
threaded=True,
name=None,
**kwargs,
):
if cb is None:
raise ValueError("The callback function must be specified.")
if threaded: self._events_listener_thread = None
if topics is None and pattern is None: self._events_dispatcher_thread = None
raise ValueError("Topics must be set for threaded consumer") self._messages_queue = queue.Queue()
listener = RedisConsumerThreaded( self._stop_events_listener_thread = threading.Event()
self.host,
self.port,
topics,
pattern,
group_id,
event,
cb,
redis_cls=self.redis_cls,
name=name,
**kwargs,
)
self._threads.append(listener)
return listener
return RedisConsumer(
self.host,
self.port,
topics,
pattern,
group_id,
event,
cb,
redis_cls=self.redis_cls,
**kwargs,
)
def stream_consumer( self.stream_keys = {}
self,
topics=None,
pattern=None,
group_id=None,
event=None,
cb=None,
from_start=False,
newest_only=False,
**kwargs,
):
"""
Threaded stream consumer for redis streams.
Args: def shutdown(self):
topics (str, list): topics to subscribe to if self._events_listener_thread:
pattern (str, list): pattern to subscribe to self._stop_events_listener_thread.set()
group_id (str): group id self._events_listener_thread.join()
event (threading.Event): event to stop the consumer self._events_listener_thread = None
cb (function): callback function if self._events_dispatcher_thread:
from_start (bool): read from start. Defaults to False. self._messages_queue.put(StopIteration)
newest_only (bool): read only the newest message. Defaults to False. self._events_dispatcher_thread.join()
""" self._events_dispatcher_thread = None
if cb is None: # release all connections
raise ValueError("The callback function must be specified.") self._pubsub_conn.close()
self._redis_conn.close()
if pattern:
raise ValueError("Pattern is currently not supported for stream consumer.")
if topics is None and pattern is None:
raise ValueError("Topics must be set for stream consumer.")
listener = RedisStreamConsumerThreaded(
self.host,
self.port,
topics,
pattern,
group_id,
event,
cb,
redis_cls=self.redis_cls,
from_start=from_start,
newest_only=newest_only,
**kwargs,
)
self._threads.append(listener)
return listener
@catch_connection_error @catch_connection_error
def log_warning(self, msg): def log_warning(self, msg):
"""send a warning""" """send a warning"""
self._notifications_producer.send( self.send(MessageEndpoints.log(), LogMessage(log_type="warning", log_msg=msg))
MessageEndpoints.log(), LogMessage(log_type="warning", log_msg=msg)
)
@catch_connection_error @catch_connection_error
def log_message(self, msg): def log_message(self, msg):
"""send a log message""" """send a log message"""
self._notifications_producer.send( self.send(MessageEndpoints.log(), LogMessage(log_type="log", log_msg=msg))
MessageEndpoints.log(), LogMessage(log_type="log", log_msg=msg)
)
@catch_connection_error @catch_connection_error
def log_error(self, msg): def log_error(self, msg):
"""send an error as log""" """send an error as log"""
self._notifications_producer.send( self.send(MessageEndpoints.log(), LogMessage(log_type="error", log_msg=msg))
MessageEndpoints.log(), LogMessage(log_type="error", log_msg=msg)
)
@catch_connection_error @catch_connection_error
def raise_alarm(self, severity: Alarms, alarm_type: str, source: str, msg: str, metadata: dict): def raise_alarm(
self,
severity: Alarms,
alarm_type: str,
source: str,
msg: str,
metadata: dict,
):
"""raise an alarm""" """raise an alarm"""
self._notifications_producer.set_and_publish( alarm_msg = AlarmMessage(
MessageEndpoints.alarm(), severity=severity,
AlarmMessage( alarm_type=alarm_type,
severity=severity, alarm_type=alarm_type, source=source, msg=msg, metadata=metadata source=source,
), msg=msg,
metadata=metadata,
) )
self.set_and_publish(MessageEndpoints.alarm(), alarm_msg)
def pipeline(self):
"""Create a new pipeline"""
return self._redis_conn.pipeline()
class RedisProducer(ProducerConnector): @catch_connection_error
def __init__(self, host: str, port: int, redis_cls=None) -> None:
# pylint: disable=invalid-name
if redis_cls:
self.r = redis_cls(host=host, port=port)
return
self.r = redis.Redis(host=host, port=port)
self.stream_keys = {}
def execute_pipeline(self, pipeline): def execute_pipeline(self, pipeline):
"""Execute the pipeline and returns the results with decoded BECMessages""" """Execute the pipeline and returns the results with decoded BECMessages"""
ret = [] ret = []
@ -197,7 +126,7 @@ class RedisProducer(ProducerConnector):
@catch_connection_error @catch_connection_error
def raw_send(self, topic: str, msg: bytes, pipe=None): def raw_send(self, topic: str, msg: bytes, pipe=None):
"""send to redis without any check on message type""" """send to redis without any check on message type"""
client = pipe if pipe is not None else self.r client = pipe if pipe is not None else self._redis_conn
client.publish(topic, msg) client.publish(topic, msg)
def send(self, topic: str, msg: BECMessage, pipe=None) -> None: def send(self, topic: str, msg: BECMessage, pipe=None) -> None:
@ -206,6 +135,95 @@ class RedisProducer(ProducerConnector):
raise TypeError(f"Message {msg} is not a BECMessage") raise TypeError(f"Message {msg} is not a BECMessage")
self.raw_send(topic, MsgpackSerialization.dumps(msg), pipe) self.raw_send(topic, MsgpackSerialization.dumps(msg), pipe)
def register(self, topics=None, patterns=None, cb=None, start_thread=True, **kwargs):
if self._events_listener_thread is None:
# create the thread that will get all messages for this connector;
# under the hood, it uses asyncio - this lets the possibility to stop
# the loop on demand
self._events_listener_thread = threading.Thread(
target=self._get_messages_loop,
args=(self._pubsub_conn,),
)
self._events_listener_thread.start()
# make a weakref from the callable, using louie;
# it can create safe refs for simple functions as well as methods
cb_ref = louie.saferef.safe_ref(cb)
if patterns is not None:
if isinstance(patterns, str):
patterns = [patterns]
self._pubsub_conn.psubscribe(patterns)
for pattern in patterns:
self._topics_cb[pattern].append((cb_ref, kwargs))
else:
if isinstance(topics, str):
topics = [topics]
self._pubsub_conn.subscribe(topics)
for topic in topics:
self._topics_cb[topic].append((cb_ref, kwargs))
if start_thread and self._events_dispatcher_thread is None:
# start dispatcher thread
self._events_dispatcher_thread = threading.Thread(target=self.dispatch_events)
self._events_dispatcher_thread.start()
def _get_messages_loop(self, pubsub) -> None:
"""
Start a listening coroutine to deal with redis events and wait for completion
"""
while not self._stop_events_listener_thread.is_set():
try:
msg = pubsub.get_message(timeout=1)
except Exception:
sys.excepthook(*sys.exc_info())
else:
if msg is not None:
self._messages_queue.put(msg)
def _handle_message(self, msg):
if msg["type"].endswith("subscribe"):
# ignore subscribe messages
return False
channel = msg["channel"].decode()
if msg["pattern"] is not None:
callbacks = self._topics_cb[msg["pattern"].decode()]
else:
callbacks = self._topics_cb[channel]
msg = MessageObject(
topic=channel,
value=MsgpackSerialization.loads(msg["data"]),
)
for cb_ref, kwargs in callbacks:
cb = cb_ref()
if cb:
try:
cb(msg, **kwargs)
except Exception:
sys.excepthook(*sys.exc_info())
return True
def poll_messages(self, timeout=None) -> None:
while True:
try:
msg = self._messages_queue.get(timeout=timeout)
except queue.Empty:
raise TimeoutError(
f"{self}: poll_messages: did not receive a message within {timeout} seconds"
)
else:
if msg is StopIteration:
return False
if self._handle_message(msg):
return True
else:
continue
def dispatch_events(self):
while self.poll_messages():
...
@catch_connection_error @catch_connection_error
def lpush( def lpush(
self, topic: str, msg: str, pipe=None, max_size: int = None, expire: int = None self, topic: str, msg: str, pipe=None, max_size: int = None, expire: int = None
@ -229,7 +247,7 @@ class RedisProducer(ProducerConnector):
@catch_connection_error @catch_connection_error
def lset(self, topic: str, index: int, msg: str, pipe=None) -> None: def lset(self, topic: str, index: int, msg: str, pipe=None) -> None:
client = pipe if pipe is not None else self.r client = pipe if pipe is not None else self._redis_conn
if isinstance(msg, BECMessage): if isinstance(msg, BECMessage):
msg = MsgpackSerialization.dumps(msg) msg = MsgpackSerialization.dumps(msg)
return client.lset(topic, index, msg) return client.lset(topic, index, msg)
@ -241,7 +259,7 @@ class RedisProducer(ProducerConnector):
values at the tail of the list stored at key. If key does not exist, values at the tail of the list stored at key. If key does not exist,
it is created as empty list before performing the push operation. When it is created as empty list before performing the push operation. When
key holds a value that is not a list, an error is returned.""" key holds a value that is not a list, an error is returned."""
client = pipe if pipe is not None else self.r client = pipe if pipe is not None else self._redis_conn
if isinstance(msg, BECMessage): if isinstance(msg, BECMessage):
msg = MsgpackSerialization.dumps(msg) msg = MsgpackSerialization.dumps(msg)
return client.rpush(topic, msg) return client.rpush(topic, msg)
@ -254,7 +272,7 @@ class RedisProducer(ProducerConnector):
of the list stored at key. The offsets start and stop are zero-based indexes, of the list stored at key. The offsets start and stop are zero-based indexes,
with 0 being the first element of the list (the head of the list), 1 being with 0 being the first element of the list (the head of the list), 1 being
the next element and so on.""" the next element and so on."""
client = pipe if pipe is not None else self.r client = pipe if pipe is not None else self._redis_conn
cmd_result = client.lrange(topic, start, end) cmd_result = client.lrange(topic, start, end)
if pipe: if pipe:
return cmd_result return cmd_result
@ -268,23 +286,21 @@ class RedisProducer(ProducerConnector):
ret.append(msg) ret.append(msg)
return ret return ret
@catch_connection_error
def set_and_publish(self, topic: str, msg, pipe=None, expire: int = None) -> None: def set_and_publish(self, topic: str, msg, pipe=None, expire: int = None) -> None:
"""piped combination of self.publish and self.set""" """piped combination of self.publish and self.set"""
client = pipe if pipe is not None else self.pipeline() client = pipe if pipe is not None else self.pipeline()
if isinstance(msg, BECMessage): if not isinstance(msg, BECMessage):
msg = MsgpackSerialization.dumps(msg) raise TypeError(f"Message {msg} is not a BECMessage")
client.publish(topic, msg) msg = MsgpackSerialization.dumps(msg)
client.set(topic, msg) self.set(topic, msg, pipe=client, expire=expire)
if expire: self.raw_send(topic, msg, pipe=client)
client.expire(topic, expire)
if not pipe: if not pipe:
client.execute() client.execute()
@catch_connection_error @catch_connection_error
def set(self, topic: str, msg, pipe=None, expire: int = None) -> None: def set(self, topic: str, msg, pipe=None, expire: int = None) -> None:
"""set redis value""" """set redis value"""
client = pipe if pipe is not None else self.r client = pipe if pipe is not None else self._redis_conn
if isinstance(msg, BECMessage): if isinstance(msg, BECMessage):
msg = MsgpackSerialization.dumps(msg) msg = MsgpackSerialization.dumps(msg)
client.set(topic, msg, ex=expire) client.set(topic, msg, ex=expire)
@ -292,23 +308,18 @@ class RedisProducer(ProducerConnector):
@catch_connection_error @catch_connection_error
def keys(self, pattern: str) -> list: def keys(self, pattern: str) -> list:
"""returns all keys matching a pattern""" """returns all keys matching a pattern"""
return self.r.keys(pattern) return self._redis_conn.keys(pattern)
@catch_connection_error
def pipeline(self):
"""create a new pipeline"""
return self.r.pipeline()
@catch_connection_error @catch_connection_error
def delete(self, topic, pipe=None): def delete(self, topic, pipe=None):
"""delete topic""" """delete topic"""
client = pipe if pipe is not None else self.r client = pipe if pipe is not None else self._redis_conn
client.delete(topic) client.delete(topic)
@catch_connection_error @catch_connection_error
def get(self, topic: str, pipe=None): def get(self, topic: str, pipe=None):
"""retrieve entry, either via hgetall or get""" """retrieve entry, either via hgetall or get"""
client = pipe if pipe is not None else self.r client = pipe if pipe is not None else self._redis_conn
data = client.get(topic) data = client.get(topic)
if pipe: if pipe:
return data return data
@ -339,7 +350,7 @@ class RedisProducer(ProducerConnector):
elif expire: elif expire:
client = self.pipeline() client = self.pipeline()
else: else:
client = self.r client = self._redis_conn
for key, msg in msg_dict.items(): for key, msg in msg_dict.items():
msg_dict[key] = MsgpackSerialization.dumps(msg) msg_dict[key] = MsgpackSerialization.dumps(msg)
@ -356,7 +367,7 @@ class RedisProducer(ProducerConnector):
@catch_connection_error @catch_connection_error
def get_last(self, topic: str, key="data"): def get_last(self, topic: str, key="data"):
"""retrieve last entry from stream""" """retrieve last entry from stream"""
client = self.r client = self._redis_conn
try: try:
_, msg_dict = client.xrevrange(topic, "+", "-", count=1)[0] _, msg_dict = client.xrevrange(topic, "+", "-", count=1)[0]
except TypeError: except TypeError:
@ -370,7 +381,12 @@ class RedisProducer(ProducerConnector):
@catch_connection_error @catch_connection_error
def xread( def xread(
self, topic: str, id: str = None, count: int = None, block: int = None, from_start=False self,
topic: str,
id: str = None,
count: int = None,
block: int = None,
from_start=False,
) -> list: ) -> list:
""" """
read from stream read from stream
@ -395,13 +411,13 @@ class RedisProducer(ProducerConnector):
>>> key = msg[0][1][0][0] >>> key = msg[0][1][0][0]
>>> next_msg = redis.xread("test", key, count=1) >>> next_msg = redis.xread("test", key, count=1)
""" """
client = self.r client = self._redis_conn
if from_start: if from_start:
self.stream_keys[topic] = "0-0" self.stream_keys[topic] = "0-0"
if topic not in self.stream_keys: if topic not in self.stream_keys:
if id is None: if id is None:
try: try:
msg = self.r.xrevrange(topic, "+", "-", count=1) msg = client.xrevrange(topic, "+", "-", count=1)
if msg: if msg:
self.stream_keys[topic] = msg[0][0] self.stream_keys[topic] = msg[0][0]
out = {} out = {}
@ -438,7 +454,7 @@ class RedisProducer(ProducerConnector):
max (str): max id. Use "+" to read to end max (str): max id. Use "+" to read to end
count (int, optional): number of messages to read. Defaults to None. count (int, optional): number of messages to read. Defaults to None.
""" """
client = self.r client = self._redis_conn
msgs = [] msgs = []
for reading in client.xrange(topic, min, max, count=count): for reading in client.xrange(topic, min, max, count=count):
index, msg_dict = reading index, msg_dict = reading
@ -446,270 +462,3 @@ class RedisProducer(ProducerConnector):
{k.decode(): MsgpackSerialization.loads(msg) for k, msg in msg_dict.items()} {k.decode(): MsgpackSerialization.loads(msg) for k, msg in msg_dict.items()}
) )
return msgs return msgs
class RedisConsumerMixin:
def _init_topics_and_pattern(self, topics, pattern):
if topics:
if not isinstance(topics, list):
topics = [topics]
if pattern:
if not isinstance(pattern, list):
pattern = [pattern]
return topics, pattern
def _init_redis_cls(self, redis_cls):
# pylint: disable=invalid-name
if redis_cls:
self.r = redis_cls(host=self.host, port=self.port)
else:
self.r = redis.Redis(host=self.host, port=self.port)
@catch_connection_error
def initialize_connector(self) -> None:
if self.pattern is not None:
self.pubsub.psubscribe(self.pattern)
else:
self.pubsub.subscribe(self.topics)
class RedisConsumer(RedisConsumerMixin, ConsumerConnector):
# pylint: disable=too-many-arguments
def __init__(
self,
host,
port,
topics=None,
pattern=None,
group_id=None,
event=None,
cb=None,
redis_cls=None,
**kwargs,
):
self.host = host
self.port = port
bootstrap_server = "".join([host, ":", port])
topics, pattern = self._init_topics_and_pattern(topics, pattern)
super().__init__(
bootstrap_server=bootstrap_server,
topics=topics,
pattern=pattern,
group_id=group_id,
event=event,
cb=cb,
**kwargs,
)
self.error_message_sent = False
self._init_redis_cls(redis_cls)
self.pubsub = self.r.pubsub()
self.initialize_connector()
@catch_connection_error
def poll_messages(self) -> None:
"""
Poll messages from self.connector and call the callback function self.cb
"""
message = self.pubsub.get_message(ignore_subscribe_messages=True)
if message is not None:
msg = MessageObject(
topic=message["channel"], value=MsgpackSerialization.loads(message["data"])
)
return self.cb(msg, **self.kwargs)
time.sleep(0.01)
return None
def shutdown(self):
"""shutdown the consumer"""
self.pubsub.close()
class RedisStreamConsumerThreaded(RedisConsumerMixin, ConsumerConnectorThreaded):
# pylint: disable=too-many-arguments
def __init__(
self,
host,
port,
topics=None,
pattern=None,
group_id=None,
event=None,
cb=None,
redis_cls=None,
from_start=False,
newest_only=False,
**kwargs,
):
self.host = host
self.port = port
self.from_start = from_start
self.newest_only = newest_only
bootstrap_server = "".join([host, ":", port])
topics, pattern = self._init_topics_and_pattern(topics, pattern)
super().__init__(
bootstrap_server=bootstrap_server,
topics=topics,
pattern=pattern,
group_id=group_id,
event=event,
cb=cb,
**kwargs,
)
self._init_redis_cls(redis_cls)
self.sleep_times = [0.005, 0.1]
self.last_received_msg = 0
self.idle_time = 30
self.error_message_sent = False
self.stream_keys = {}
def initialize_connector(self) -> None:
pass
def _init_topics_and_pattern(self, topics, pattern):
if topics:
if not isinstance(topics, list):
topics = [topics]
if pattern:
if not isinstance(pattern, list):
pattern = [pattern]
return topics, pattern
def get_id(self, topic: str) -> str:
"""
Get the stream key for the given topic.
Args:
topic (str): topic to get the stream key for
"""
if topic not in self.stream_keys:
return "0-0"
return self.stream_keys.get(topic)
def get_newest_message(self, container: list, append=True) -> None:
"""
Get the newest message from the stream and update the stream key. If
append is True, append the message to the container.
Args:
container (list): container to append the message to
append (bool, optional): append to container. Defaults to True.
"""
for topic in self.topics:
msg = self.r.xrevrange(topic, "+", "-", count=1)
if msg:
if append:
container.append((topic, msg[0][1]))
self.stream_keys[topic] = msg[0][0]
else:
self.stream_keys[topic] = "0-0"
@catch_connection_error
def poll_messages(self) -> None:
"""
Poll messages from self.connector and call the callback function self.cb
"""
if self.pattern is not None:
topics = [key.decode() for key in self.r.scan_iter(match=self.pattern, _type="stream")]
else:
topics = self.topics
messages = []
if self.newest_only:
self.get_newest_message(messages)
elif not self.from_start and not self.stream_keys:
self.get_newest_message(messages, append=False)
else:
streams = {topic: self.get_id(topic) for topic in topics}
read_msgs = self.r.xread(streams, count=1)
if read_msgs:
for msg in read_msgs:
topic = msg[0].decode()
messages.append((topic, msg[1][0][1]))
self.stream_keys[topic] = msg[1][-1][0]
if messages:
if MessageEndpoints.log() not in topics:
# no need to update the update frequency just for logs
self.last_received_msg = time.time()
for topic, msg in messages:
try:
msg = MsgpackSerialization.loads(msg[b"data"])
except RuntimeError:
msg = msg[b"data"]
msg_obj = MessageObject(topic=topic, value=msg)
self.cb(msg_obj, **self.kwargs)
else:
sleep_time = int(bool(time.time() - self.last_received_msg > self.idle_time))
if self.sleep_times[sleep_time]:
time.sleep(self.sleep_times[sleep_time])
class RedisConsumerThreaded(RedisConsumerMixin, ConsumerConnectorThreaded):
# pylint: disable=too-many-arguments
def __init__(
self,
host,
port,
topics=None,
pattern=None,
group_id=None,
event=None,
cb=None,
redis_cls=None,
name=None,
**kwargs,
):
self.host = host
self.port = port
bootstrap_server = "".join([host, ":", port])
topics, pattern = self._init_topics_and_pattern(topics, pattern)
super().__init__(
bootstrap_server=bootstrap_server,
topics=topics,
pattern=pattern,
group_id=group_id,
event=event,
cb=cb,
name=name,
**kwargs,
)
self._init_redis_cls(redis_cls)
self.pubsub = self.r.pubsub()
self.sleep_times = [0.005, 0.1]
self.last_received_msg = 0
self.idle_time = 30
self.error_message_sent = False
@catch_connection_error
def poll_messages(self) -> None:
"""
Poll messages from self.connector and call the callback function self.cb
Note: pubsub messages are supposed to be BECMessage objects only
"""
messages = self.pubsub.get_message(ignore_subscribe_messages=True)
if messages is not None:
if f"{MessageEndpoints.log()}".encode() not in messages["channel"]:
# no need to update the update frequency just for logs
self.last_received_msg = time.time()
msg = MessageObject(
topic=messages["channel"].decode(),
value=MsgpackSerialization.loads(messages["data"]),
)
self.cb(msg, **self.kwargs)
else:
sleep_time = int(bool(time.time() - self.last_received_msg > self.idle_time))
if self.sleep_times[sleep_time]:
time.sleep(self.sleep_times[sleep_time])
def shutdown(self):
super().shutdown()
self.pubsub.close()

View File

@ -46,7 +46,7 @@ class ScanItem:
self.data = ScanData() self.data = ScanData()
self.async_data = {} self.async_data = {}
self.baseline = ScanData() self.baseline = ScanData()
self._async_data_handler = AsyncDataHandler(scan_manager.producer) self._async_data_handler = AsyncDataHandler(scan_manager.connector)
self.open_scan_defs = set() self.open_scan_defs = set()
self.open_queue_group = None self.open_queue_group = None
self.num_points = None self.num_points = None

View File

@ -25,44 +25,31 @@ class ScanManager:
connector (BECConnector): BECConnector instance connector (BECConnector): BECConnector instance
""" """
self.connector = connector self.connector = connector
self.producer = self.connector.producer()
self.queue_storage = QueueStorage(scan_manager=self) self.queue_storage = QueueStorage(scan_manager=self)
self.request_storage = RequestStorage(scan_manager=self) self.request_storage = RequestStorage(scan_manager=self)
self.scan_storage = ScanStorage(scan_manager=self) self.scan_storage = ScanStorage(scan_manager=self)
self._scan_queue_consumer = self.connector.consumer( self.connector.register(
topics=MessageEndpoints.scan_queue_status(), topics=MessageEndpoints.scan_queue_status(),
cb=self._scan_queue_status_callback, cb=self._scan_queue_status_callback,
parent=self,
) )
self._scan_queue_request_consumer = self.connector.consumer( self.connector.register(
topics=MessageEndpoints.scan_queue_request(), topics=MessageEndpoints.scan_queue_request(),
cb=self._scan_queue_request_callback, cb=self._scan_queue_request_callback,
parent=self,
) )
self._scan_queue_request_response_consumer = self.connector.consumer( self.connector.register(
topics=MessageEndpoints.scan_queue_request_response(), topics=MessageEndpoints.scan_queue_request_response(),
cb=self._scan_queue_request_response_callback, cb=self._scan_queue_request_response_callback,
parent=self,
) )
self._scan_status_consumer = self.connector.consumer( self.connector.register(
topics=MessageEndpoints.scan_status(), cb=self._scan_status_callback, parent=self topics=MessageEndpoints.scan_status(), cb=self._scan_status_callback
) )
self._scan_segment_consumer = self.connector.consumer( self.connector.register(
topics=MessageEndpoints.scan_segment(), cb=self._scan_segment_callback, parent=self topics=MessageEndpoints.scan_segment(), cb=self._scan_segment_callback
) )
self._baseline_consumer = self.connector.consumer( self.connector.register(topics=MessageEndpoints.scan_baseline(), cb=self._baseline_callback)
topics=MessageEndpoints.scan_baseline(), cb=self._baseline_callback, parent=self
)
self._scan_queue_consumer.start()
self._scan_queue_request_consumer.start()
self._scan_queue_request_response_consumer.start()
self._scan_status_consumer.start()
self._scan_segment_consumer.start()
self._baseline_consumer.start()
def update_with_queue_status(self, queue: messages.ScanQueueStatusMessage) -> None: def update_with_queue_status(self, queue: messages.ScanQueueStatusMessage) -> None:
"""update storage with a new queue status message""" """update storage with a new queue status message"""
@ -84,7 +71,7 @@ class ScanManager:
action = "deferred_pause" if deferred_pause else "pause" action = "deferred_pause" if deferred_pause else "pause"
logger.info(f"Requesting {action}") logger.info(f"Requesting {action}")
return self.producer.send( return self.connector.send(
MessageEndpoints.scan_queue_modification_request(), MessageEndpoints.scan_queue_modification_request(),
messages.ScanQueueModificationMessage(scanID=scanID, action=action, parameter={}), messages.ScanQueueModificationMessage(scanID=scanID, action=action, parameter={}),
) )
@ -99,7 +86,7 @@ class ScanManager:
if scanID is None: if scanID is None:
scanID = self.scan_storage.current_scanID scanID = self.scan_storage.current_scanID
logger.info("Requesting scan abortion") logger.info("Requesting scan abortion")
self.producer.send( self.connector.send(
MessageEndpoints.scan_queue_modification_request(), MessageEndpoints.scan_queue_modification_request(),
messages.ScanQueueModificationMessage(scanID=scanID, action="abort", parameter={}), messages.ScanQueueModificationMessage(scanID=scanID, action="abort", parameter={}),
) )
@ -114,7 +101,7 @@ class ScanManager:
if scanID is None: if scanID is None:
scanID = self.scan_storage.current_scanID scanID = self.scan_storage.current_scanID
logger.info("Requesting scan halt") logger.info("Requesting scan halt")
self.producer.send( self.connector.send(
MessageEndpoints.scan_queue_modification_request(), MessageEndpoints.scan_queue_modification_request(),
messages.ScanQueueModificationMessage(scanID=scanID, action="halt", parameter={}), messages.ScanQueueModificationMessage(scanID=scanID, action="halt", parameter={}),
) )
@ -129,7 +116,7 @@ class ScanManager:
if scanID is None: if scanID is None:
scanID = self.scan_storage.current_scanID scanID = self.scan_storage.current_scanID
logger.info("Requesting scan continuation") logger.info("Requesting scan continuation")
self.producer.send( self.connector.send(
MessageEndpoints.scan_queue_modification_request(), MessageEndpoints.scan_queue_modification_request(),
messages.ScanQueueModificationMessage(scanID=scanID, action="continue", parameter={}), messages.ScanQueueModificationMessage(scanID=scanID, action="continue", parameter={}),
) )
@ -137,7 +124,7 @@ class ScanManager:
def request_queue_reset(self): def request_queue_reset(self):
"""request a scan queue reset""" """request a scan queue reset"""
logger.info("Requesting a queue reset") logger.info("Requesting a queue reset")
self.producer.send( self.connector.send(
MessageEndpoints.scan_queue_modification_request(), MessageEndpoints.scan_queue_modification_request(),
messages.ScanQueueModificationMessage(scanID=None, action="clear", parameter={}), messages.ScanQueueModificationMessage(scanID=None, action="clear", parameter={}),
) )
@ -151,7 +138,7 @@ class ScanManager:
logger.info("Requesting to abort and repeat a scan") logger.info("Requesting to abort and repeat a scan")
position = "replace" if replace else "append" position = "replace" if replace else "append"
self.producer.send( self.connector.send(
MessageEndpoints.scan_queue_modification_request(), MessageEndpoints.scan_queue_modification_request(),
messages.ScanQueueModificationMessage( messages.ScanQueueModificationMessage(
scanID=scanID, action="restart", parameter={"position": position, "RID": requestID} scanID=scanID, action="restart", parameter={"position": position, "RID": requestID}
@ -162,7 +149,7 @@ class ScanManager:
@property @property
def next_scan_number(self): def next_scan_number(self):
"""get the next scan number from redis""" """get the next scan number from redis"""
num = self.producer.get(MessageEndpoints.scan_number()) num = self.connector.get(MessageEndpoints.scan_number())
if num is None: if num is None:
logger.warning("Failed to retrieve scan number from redis.") logger.warning("Failed to retrieve scan number from redis.")
return -1 return -1
@ -172,63 +159,51 @@ class ScanManager:
@typechecked @typechecked
def next_scan_number(self, val: int): def next_scan_number(self, val: int):
"""set the next scan number in redis""" """set the next scan number in redis"""
return self.producer.set(MessageEndpoints.scan_number(), val) return self.connector.set(MessageEndpoints.scan_number(), val)
@property @property
def next_dataset_number(self): def next_dataset_number(self):
"""get the next dataset number from redis""" """get the next dataset number from redis"""
return int(self.producer.get(MessageEndpoints.dataset_number())) return int(self.connector.get(MessageEndpoints.dataset_number()))
@next_dataset_number.setter @next_dataset_number.setter
@typechecked @typechecked
def next_dataset_number(self, val: int): def next_dataset_number(self, val: int):
"""set the next dataset number in redis""" """set the next dataset number in redis"""
return self.producer.set(MessageEndpoints.dataset_number(), val) return self.connector.set(MessageEndpoints.dataset_number(), val)
@staticmethod def _scan_queue_status_callback(self, msg, **_kwargs) -> None:
def _scan_queue_status_callback(msg, *, parent: ScanManager, **_kwargs) -> None:
queue_status = msg.value queue_status = msg.value
if not queue_status: if not queue_status:
return return
parent.update_with_queue_status(queue_status) self.update_with_queue_status(queue_status)
@staticmethod def _scan_queue_request_callback(self, msg, **_kwargs) -> None:
def _scan_queue_request_callback(msg, *, parent: ScanManager, **_kwargs) -> None:
request = msg.value request = msg.value
parent.request_storage.update_with_request(request) self.request_storage.update_with_request(request)
@staticmethod def _scan_queue_request_response_callback(self, msg, **_kwargs) -> None:
def _scan_queue_request_response_callback(msg, *, parent: ScanManager, **_kwargs) -> None:
response = msg.value response = msg.value
logger.debug(response) logger.debug(response)
parent.request_storage.update_with_response(response) self.request_storage.update_with_response(response)
@staticmethod def _scan_status_callback(self, msg, **_kwargs) -> None:
def _scan_status_callback(msg, *, parent: ScanManager, **_kwargs) -> None:
scan = msg.value scan = msg.value
parent.scan_storage.update_with_scan_status(scan) self.scan_storage.update_with_scan_status(scan)
@staticmethod def _scan_segment_callback(self, msg, **_kwargs) -> None:
def _scan_segment_callback(msg, *, parent: ScanManager, **_kwargs) -> None:
scan_msgs = msg.value scan_msgs = msg.value
if not isinstance(scan_msgs, list): if not isinstance(scan_msgs, list):
scan_msgs = [scan_msgs] scan_msgs = [scan_msgs]
for scan_msg in scan_msgs: for scan_msg in scan_msgs:
parent.scan_storage.add_scan_segment(scan_msg) self.scan_storage.add_scan_segment(scan_msg)
@staticmethod def _baseline_callback(self, msg, **_kwargs) -> None:
def _baseline_callback(msg, *, parent: ScanManager, **_kwargs) -> None:
msg = msg.value msg = msg.value
parent.scan_storage.add_scan_baseline(msg) self.scan_storage.add_scan_baseline(msg)
def __str__(self) -> str: def __str__(self) -> str:
return "\n".join(self.queue_storage.describe_queue()) return "\n".join(self.queue_storage.describe_queue())
def shutdown(self): def shutdown(self):
"""stop the scan manager's threads""" pass
self._scan_queue_consumer.shutdown()
self._scan_queue_request_consumer.shutdown()
self._scan_queue_request_response_consumer.shutdown()
self._scan_status_consumer.shutdown()
self._scan_segment_consumer.shutdown()
self._baseline_consumer.shutdown()

View File

@ -89,7 +89,7 @@ class ScanReport:
def _get_mv_status(self) -> bool: def _get_mv_status(self) -> bool:
"""get the status of a move request""" """get the status of a move request"""
motors = list(self.request.request.content["parameter"]["args"].keys()) motors = list(self.request.request.content["parameter"]["args"].keys())
request_status = self._client.device_manager.producer.lrange( request_status = self._client.device_manager.connector.lrange(
MessageEndpoints.device_req_status(self.request.requestID), 0, -1 MessageEndpoints.device_req_status(self.request.requestID), 0, -1
) )
if len(request_status) == len(motors): if len(request_status) == len(motors):

View File

@ -100,9 +100,9 @@ class ScanObject:
return None return None
return self.scan_info.get("scan_report_hint") return self.scan_info.get("scan_report_hint")
def _start_consumer(self, request: messages.ScanQueueMessage) -> ConsumerConnector: def _start_register(self, request: messages.ScanQueueMessage) -> ConsumerConnector:
"""Start a consumer for the given request""" """Start a register for the given request"""
consumer = self.client.device_manager.connector.consumer( register = self.client.device_manager.connector.register(
[ [
MessageEndpoints.device_readback(dev) MessageEndpoints.device_readback(dev)
for dev in request.content["parameter"]["args"].keys() for dev in request.content["parameter"]["args"].keys()
@ -110,11 +110,11 @@ class ScanObject:
threaded=False, threaded=False,
cb=(lambda msg: msg), cb=(lambda msg: msg),
) )
return consumer return register
def _send_scan_request(self, request: messages.ScanQueueMessage) -> None: def _send_scan_request(self, request: messages.ScanQueueMessage) -> None:
"""Send a scan request to the scan server""" """Send a scan request to the scan server"""
self.client.device_manager.producer.send(MessageEndpoints.scan_queue_request(), request) self.client.device_manager.connector.send(MessageEndpoints.scan_queue_request(), request)
class Scans: class Scans:
@ -136,7 +136,7 @@ class Scans:
def _import_scans(self): def _import_scans(self):
"""Import scans from the scan server""" """Import scans from the scan server"""
available_scans = self.parent.producer.get(MessageEndpoints.available_scans()) available_scans = self.parent.connector.get(MessageEndpoints.available_scans())
if available_scans is None: if available_scans is None:
logger.warning("No scans available. Are redis and the BEC server running?") logger.warning("No scans available. Are redis and the BEC server running?")
return return

View File

@ -14,7 +14,9 @@ logger = bec_logger.logger
DEFAULT_SERVICE_CONFIG = { DEFAULT_SERVICE_CONFIG = {
"redis": {"host": "localhost", "port": 6379}, "redis": {"host": "localhost", "port": 6379},
"service_config": {"file_writer": {"plugin": "default_NeXus_format", "base_path": "./"}}, "service_config": {
"file_writer": {"plugin": "default_NeXus_format", "base_path": os.path.dirname(__file__)}
},
} }
@ -32,7 +34,7 @@ class ServiceConfig:
self._update_config(service_config=config, redis=redis) self._update_config(service_config=config, redis=redis)
self.service_config = self.config.get( self.service_config = self.config.get(
"service_config", {"file_writer": {"plugin": "default_NeXus_format", "base_path": "./"}} "service_config", DEFAULT_SERVICE_CONFIG["service_config"]
) )
def _update_config(self, **kwargs): def _update_config(self, **kwargs):

View File

@ -56,7 +56,7 @@ def queue_is_empty(queue) -> bool: # pragma: no cover
def get_queue(bec): # pragma: no cover def get_queue(bec): # pragma: no cover
return bec.queue.producer.get(MessageEndpoints.scan_queue_status()) return bec.queue.connector.get(MessageEndpoints.scan_queue_status())
def wait_for_empty_queue(bec): # pragma: no cover def wait_for_empty_queue(bec): # pragma: no cover
@ -484,7 +484,6 @@ def bec_client():
with open(f"{dir_path}/tests/test_config.yaml", "r", encoding="utf-8") as f: with open(f"{dir_path}/tests/test_config.yaml", "r", encoding="utf-8") as f:
builtins.__dict__["test_session"] = create_session_from_config(yaml.safe_load(f)) builtins.__dict__["test_session"] = create_session_from_config(yaml.safe_load(f))
device_manager._session = builtins.__dict__["test_session"] device_manager._session = builtins.__dict__["test_session"]
device_manager.producer = device_manager.connector.producer()
client.wait_for_service = lambda service_name: None client.wait_for_service = lambda service_name: None
device_manager._load_session() device_manager._load_session()
for name, dev in device_manager.devices.items(): for name, dev in device_manager.devices.items():
@ -497,37 +496,23 @@ def bec_client():
class PipelineMock: # pragma: no cover class PipelineMock: # pragma: no cover
_pipe_buffer = [] _pipe_buffer = []
_producer = None _connector = None
def __init__(self, producer) -> None: def __init__(self, connector) -> None:
self._producer = producer self._connector = connector
def execute(self): def execute(self):
if not self._producer.store_data: if not self._connector.store_data:
self._pipe_buffer = [] self._pipe_buffer = []
return [] return []
res = [ res = [
getattr(self._producer, method)(*args, **kwargs) getattr(self._connector, method)(*args, **kwargs)
for method, args, kwargs in self._pipe_buffer for method, args, kwargs in self._pipe_buffer
] ]
self._pipe_buffer = [] self._pipe_buffer = []
return res return res
class ConsumerMock: # pragma: no cover
def __init__(self) -> None:
self.signal_event = SignalMock()
def start(self):
pass
def join(self):
pass
def shutdown(self):
pass
class SignalMock: # pragma: no cover class SignalMock: # pragma: no cover
def __init__(self) -> None: def __init__(self) -> None:
self.is_set = False self.is_set = False
@ -536,12 +521,36 @@ class SignalMock: # pragma: no cover
self.is_set = True self.is_set = True
class ProducerMock: # pragma: no cover class ConnectorMock(ConnectorBase): # pragma: no cover
def __init__(self, store_data=True) -> None: def __init__(self, bootstrap_server="localhost:0000", store_data=True):
super().__init__(bootstrap_server)
self.message_sent = [] self.message_sent = []
self._get_buffer = {} self._get_buffer = {}
self.store_data = store_data self.store_data = store_data
def raise_alarm(
self, severity: Alarms, alarm_type: str, source: str, msg: dict, metadata: dict
):
pass
def log_error(self, *args, **kwargs):
pass
def shutdown(self):
pass
def register(self, *args, **kwargs):
pass
def set(self, *args, **kwargs):
pass
def set_and_publish(self, *args, **kwargs):
pass
def keys(self, *args, **kwargs):
return []
def set(self, topic, msg, pipe=None, expire: int = None): def set(self, topic, msg, pipe=None, expire: int = None):
if pipe: if pipe:
pipe._pipe_buffer.append(("set", (topic, msg), {"expire": expire})) pipe._pipe_buffer.append(("set", (topic, msg), {"expire": expire}))
@ -592,9 +601,6 @@ class ProducerMock: # pragma: no cover
self._get_buffer.pop(topic, None) self._get_buffer.pop(topic, None)
return val return val
def keys(self, pattern: str) -> list:
return []
def pipeline(self): def pipeline(self):
return PipelineMock(self) return PipelineMock(self)
@ -609,29 +615,6 @@ class ProducerMock: # pragma: no cover
return return
class ConnectorMock(ConnectorBase): # pragma: no cover
def __init__(self, bootstrap_server: list, store_data=True):
super().__init__(bootstrap_server)
self.store_data = store_data
def consumer(self, *args, **kwargs) -> ConsumerMock:
return ConsumerMock()
def producer(self, *args, **kwargs):
return ProducerMock(self.store_data)
def raise_alarm(
self, severity: Alarms, alarm_type: str, source: str, msg: dict, metadata: dict
):
pass
def log_error(self, *args, **kwargs):
pass
def shutdown(self):
pass
def create_session_from_config(config: dict) -> dict: def create_session_from_config(config: dict) -> dict:
device_configs = [] device_configs = []
session_id = str(uuid.uuid4()) session_id = str(uuid.uuid4())

View File

@ -5,6 +5,8 @@ __version__ = "1.12.1"
if __name__ == "__main__": if __name__ == "__main__":
setup( setup(
install_requires=[ install_requires=[
"hiredis",
"louie",
"numpy", "numpy",
"scipy", "scipy",
"msgpack", "msgpack",
@ -22,7 +24,16 @@ if __name__ == "__main__":
"lmfit", "lmfit",
], ],
extras_require={ extras_require={
"dev": ["pytest", "pytest-random-order", "coverage", "pandas", "black", "pylint"] "dev": [
"pytest",
"pytest-random-order",
"pytest-redis",
"pytest-timeout",
"coverage",
"pandas",
"black",
"pylint",
]
}, },
entry_points={"console_scripts": ["bec-channel-monitor = bec_lib:channel_monitor_launch"]}, entry_points={"console_scripts": ["bec-channel-monitor = bec_lib:channel_monitor_launch"]},
package_data={"bec_lib.tests": ["*.yaml"], "bec_lib.configs": ["*.yaml", "*.json"]}, package_data={"bec_lib.tests": ["*.yaml"], "bec_lib.configs": ["*.yaml", "*.json"]},

View File

@ -17,7 +17,7 @@ def test_bec_widgets_connector_set_plot_config(bec_client):
config = {"x": "test", "y": "test", "color": "test", "size": "test", "shape": "test"} config = {"x": "test", "y": "test", "color": "test", "size": "test", "shape": "test"}
connector.set_plot_config(plot_id="plot_id", config=config) connector.set_plot_config(plot_id="plot_id", config=config)
msg = messages.GUIConfigMessage(config=config) msg = messages.GUIConfigMessage(config=config)
bec_client.connector.producer().set_and_publish.assert_called_once_with( bec_client.connector.set_and_publish.assert_called_once_with(
MessageEndpoints.gui_config("plot_id"), msg MessageEndpoints.gui_config("plot_id"), msg
) is None ) is None
@ -26,7 +26,7 @@ def test_bec_widgets_connector_close(bec_client):
connector = BECWidgetsConnector(gui_id="gui_id", bec_client=bec_client) connector = BECWidgetsConnector(gui_id="gui_id", bec_client=bec_client)
connector.close("plot_id") connector.close("plot_id")
msg = messages.GUIInstructionMessage(action="close", parameter={}) msg = messages.GUIInstructionMessage(action="close", parameter={})
bec_client.connector.producer().set_and_publish.assert_called_once_with( bec_client.connector.set_and_publish.assert_called_once_with(
MessageEndpoints.gui_instructions("plot_id"), msg MessageEndpoints.gui_instructions("plot_id"), msg
) )
@ -36,7 +36,7 @@ def test_bec_widgets_connector_send_data(bec_client):
data = {"x": [1, 2, 3], "y": [1, 2, 3]} data = {"x": [1, 2, 3], "y": [1, 2, 3]}
connector.send_data("plot_id", data) connector.send_data("plot_id", data)
msg = messages.GUIDataMessage(data=data) msg = messages.GUIDataMessage(data=data)
bec_client.connector.producer().set_and_publish.assert_called_once_with( bec_client.connector.set_and_publish.assert_called_once_with(
topic=MessageEndpoints.gui_data("plot_id"), msg=msg topic=MessageEndpoints.gui_data("plot_id"), msg=msg
) )
@ -45,7 +45,7 @@ def test_bec_widgets_connector_clear(bec_client):
connector = BECWidgetsConnector(gui_id="gui_id", bec_client=bec_client) connector = BECWidgetsConnector(gui_id="gui_id", bec_client=bec_client)
connector.clear("plot_id") connector.clear("plot_id")
msg = messages.GUIInstructionMessage(action="clear", parameter={}) msg = messages.GUIInstructionMessage(action="clear", parameter={})
bec_client.connector.producer().set_and_publish.assert_called_once_with( bec_client.connector.set_and_publish.assert_called_once_with(
MessageEndpoints.gui_instructions("plot_id"), msg MessageEndpoints.gui_instructions("plot_id"), msg
) )

View File

@ -124,8 +124,8 @@ def test_bec_service_update_existing_services():
messages.StatusMessage(name="service2", status=BECStatus.IDLE, info={}, metadata={}), messages.StatusMessage(name="service2", status=BECStatus.IDLE, info={}, metadata={}),
] ]
connector_cls = mock.MagicMock() connector_cls = mock.MagicMock()
connector_cls().producer().keys.return_value = service_keys connector_cls().keys.return_value = service_keys
connector_cls().producer().get.side_effect = [msg for msg in service_msgs] connector_cls().get.side_effect = [msg for msg in service_msgs]
service = BECService( service = BECService(
config=f"{os.path.dirname(bec_lib.__file__)}/tests/test_service_config.yaml", config=f"{os.path.dirname(bec_lib.__file__)}/tests/test_service_config.yaml",
connector_cls=connector_cls, connector_cls=connector_cls,
@ -144,8 +144,8 @@ def test_bec_service_update_existing_services_ignores_wrong_msgs():
None, None,
] ]
connector_cls = mock.MagicMock() connector_cls = mock.MagicMock()
connector_cls().producer().keys.return_value = service_keys connector_cls().keys.return_value = service_keys
connector_cls().producer().get.side_effect = [service_msgs[0], None] connector_cls().get.side_effect = [service_msgs[0], None]
service = BECService( service = BECService(
config=f"{os.path.dirname(bec_lib.__file__)}/tests/test_service_config.yaml", config=f"{os.path.dirname(bec_lib.__file__)}/tests/test_service_config.yaml",
connector_cls=connector_cls, connector_cls=connector_cls,

View File

@ -13,7 +13,7 @@ def test_channel_monitor_callback():
mock_print.assert_called_once() mock_print.assert_called_once()
def test_channel_monitor_start_consumer(): def test_channel_monitor_start_register():
with mock.patch("bec_lib.channel_monitor.argparse") as mock_argparse: with mock.patch("bec_lib.channel_monitor.argparse") as mock_argparse:
with mock.patch("bec_lib.channel_monitor.ServiceConfig") as mock_config: with mock.patch("bec_lib.channel_monitor.ServiceConfig") as mock_config:
with mock.patch("bec_lib.channel_monitor.RedisConnector") as mock_connector: with mock.patch("bec_lib.channel_monitor.RedisConnector") as mock_connector:
@ -26,6 +26,6 @@ def test_channel_monitor_start_consumer():
mock_config.return_value = mock.MagicMock() mock_config.return_value = mock.MagicMock()
mock_connector.return_value = mock.MagicMock() mock_connector.return_value = mock.MagicMock()
channel_monitor_launch() channel_monitor_launch()
mock_connector().consumer.assert_called_once() mock_connector().register.assert_called_once()
mock_connector().consumer.return_value.start.assert_called_once() mock_connector().register.return_value.start.assert_called_once()
mock_threading.Event().wait.assert_called_once() mock_threading.Event().wait.assert_called_once()

View File

@ -49,7 +49,7 @@ def test_config_helper_save_current_session():
connector = mock.MagicMock() connector = mock.MagicMock()
config_helper = ConfigHelper(connector) config_helper = ConfigHelper(connector)
connector.producer().get.return_value = messages.AvailableResourceMessage( connector.get.return_value = messages.AvailableResourceMessage(
resource=[ resource=[
{ {
"id": "648c817f67d3c7cd6a354e8e", "id": "648c817f67d3c7cd6a354e8e",
@ -158,9 +158,7 @@ def test_send_config_request_raises_for_rejected_update(config_helper):
def test_wait_for_config_reply(): def test_wait_for_config_reply():
connector = mock.MagicMock() connector = mock.MagicMock()
config_helper = ConfigHelper(connector) config_helper = ConfigHelper(connector)
connector.producer().get.return_value = messages.RequestResponseMessage( connector.get.return_value = messages.RequestResponseMessage(accepted=True, message="test")
accepted=True, message="test"
)
res = config_helper.wait_for_config_reply("test") res = config_helper.wait_for_config_reply("test")
assert res == messages.RequestResponseMessage(accepted=True, message="test") assert res == messages.RequestResponseMessage(accepted=True, message="test")
@ -169,7 +167,7 @@ def test_wait_for_config_reply():
def test_wait_for_config_raises_timeout(): def test_wait_for_config_raises_timeout():
connector = mock.MagicMock() connector = mock.MagicMock()
config_helper = ConfigHelper(connector) config_helper = ConfigHelper(connector)
connector.producer().get.return_value = None connector.get.return_value = None
with pytest.raises(DeviceConfigError): with pytest.raises(DeviceConfigError):
config_helper.wait_for_config_reply("test", timeout=0.3) config_helper.wait_for_config_reply("test", timeout=0.3)
@ -178,7 +176,7 @@ def test_wait_for_config_raises_timeout():
def test_wait_for_service_response(): def test_wait_for_service_response():
connector = mock.MagicMock() connector = mock.MagicMock()
config_helper = ConfigHelper(connector) config_helper = ConfigHelper(connector)
connector.producer().lrange.side_effect = [ connector.lrange.side_effect = [
[], [],
[ [
messages.ServiceResponseMessage( messages.ServiceResponseMessage(
@ -196,7 +194,7 @@ def test_wait_for_service_response():
def test_wait_for_service_response_raises_timeout(): def test_wait_for_service_response_raises_timeout():
connector = mock.MagicMock() connector = mock.MagicMock()
config_helper = ConfigHelper(connector) config_helper = ConfigHelper(connector)
connector.producer().lrange.return_value = [] connector.lrange.return_value = []
with pytest.raises(DeviceConfigError): with pytest.raises(DeviceConfigError):
config_helper.wait_for_service_response("test", timeout=0.3) config_helper.wait_for_service_response("test", timeout=0.3)

View File

@ -349,7 +349,7 @@ def dap(dap_plugin_message):
} }
client = mock.MagicMock() client = mock.MagicMock()
client.service_status = dap_services client.service_status = dap_services
client.producer.get.return_value = dap_plugin_message client.connector.get.return_value = dap_plugin_message
dap_plugins = DAPPlugins(client) dap_plugins = DAPPlugins(client)
yield dap_plugins yield dap_plugins
@ -367,7 +367,7 @@ def test_dap_plugins_construction(dap):
def test_dap_plugin_fit(dap): def test_dap_plugin_fit(dap):
with mock.patch.object(dap.GaussianModel, "_wait_for_dap_response") as mock_wait: with mock.patch.object(dap.GaussianModel, "_wait_for_dap_response") as mock_wait:
dap.GaussianModel.fit() dap.GaussianModel.fit()
dap._parent.producer.set_and_publish.assert_called_once() dap._parent.connector.set_and_publish.assert_called_once()
mock_wait.assert_called_once() mock_wait.assert_called_once()
@ -380,7 +380,7 @@ def test_dap_auto_run(dap):
def test_dap_wait_for_dap_response_waits_for_RID(dap): def test_dap_wait_for_dap_response_waits_for_RID(dap):
dap._parent.producer.get.return_value = messages.DAPResponseMessage( dap._parent.connector.get.return_value = messages.DAPResponseMessage(
success=True, data={}, metadata={"RID": "wrong_ID"} success=True, data={}, metadata={"RID": "wrong_ID"}
) )
with pytest.raises(TimeoutError): with pytest.raises(TimeoutError):
@ -388,7 +388,7 @@ def test_dap_wait_for_dap_response_waits_for_RID(dap):
def test_dap_wait_for_dap_respnse_returns(dap): def test_dap_wait_for_dap_respnse_returns(dap):
dap._parent.producer.get.return_value = messages.DAPResponseMessage( dap._parent.connector.get.return_value = messages.DAPResponseMessage(
success=True, data={}, metadata={"RID": "1234"} success=True, data={}, metadata={"RID": "1234"}
) )
val = dap.GaussianModel._wait_for_dap_response(request_id="1234", timeout=0.1) val = dap.GaussianModel._wait_for_dap_response(request_id="1234", timeout=0.1)
@ -429,11 +429,11 @@ def test_dap_select_raises_on_wrong_device(dap):
def test_dap_get_data(dap): def test_dap_get_data(dap):
dap._parent.producer.get_last.return_value = messages.ProcessedDataMessage( dap._parent.connector.get_last.return_value = messages.ProcessedDataMessage(
data=[{"x": [1, 2, 3], "y": [4, 5, 6]}, {"fit_parameters": {"amplitude": 1}}] data=[{"x": [1, 2, 3], "y": [4, 5, 6]}, {"fit_parameters": {"amplitude": 1}}]
) )
data = dap.GaussianModel.get_data() data = dap.GaussianModel.get_data()
dap._parent.producer.get_last.assert_called_once_with( dap._parent.connector.get_last.assert_called_once_with(
MessageEndpoints.processed_data("GaussianModel") MessageEndpoints.processed_data("GaussianModel")
) )
@ -443,13 +443,13 @@ def test_dap_get_data(dap):
def test_dap_update_dap_config_not_called_without_device(dap): def test_dap_update_dap_config_not_called_without_device(dap):
dap.GaussianModel._update_dap_config(request_id="1234") dap.GaussianModel._update_dap_config(request_id="1234")
dap._parent.producer.set_and_publish.assert_not_called() dap._parent.connector.set_and_publish.assert_not_called()
def test_dap_update_dap_config(dap): def test_dap_update_dap_config(dap):
dap.GaussianModel._plugin_config["selected_device"] = ["samx", "samx"] dap.GaussianModel._plugin_config["selected_device"] = ["samx", "samx"]
dap.GaussianModel._update_dap_config(request_id="1234") dap.GaussianModel._update_dap_config(request_id="1234")
dap._parent.producer.set_and_publish.assert_called_with( dap._parent.connector.set_and_publish.assert_called_with(
MessageEndpoints.dap_request(), MessageEndpoints.dap_request(),
messages.DAPRequestMessage( messages.DAPRequestMessage(
dap_cls="LmfitService1D", dap_cls="LmfitService1D",

View File

@ -91,15 +91,14 @@ def test_get_config_calls_load(dm):
dm, "_get_redis_device_config", return_value={"devices": [{}]} dm, "_get_redis_device_config", return_value={"devices": [{}]}
) as get_redis_config: ) as get_redis_config:
with mock.patch.object(dm, "_load_session") as load_session: with mock.patch.object(dm, "_load_session") as load_session:
with mock.patch.object(dm, "producer") as producer: dm._get_config()
dm._get_config() get_redis_config.assert_called_once()
get_redis_config.assert_called_once() load_session.assert_called_once()
load_session.assert_called_once()
def test_get_redis_device_config(dm): def test_get_redis_device_config(dm):
with mock.patch.object(dm, "producer") as producer: with mock.patch.object(dm, "connector") as connector:
producer.get.return_value = messages.AvailableResourceMessage(resource={"devices": [{}]}) connector.get.return_value = messages.AvailableResourceMessage(resource={"devices": [{}]})
assert dm._get_redis_device_config() == {"devices": [{}]} assert dm._get_redis_device_config() == {"devices": [{}]}

View File

@ -23,7 +23,7 @@ def test_nested_device_root(dev):
def test_read(dev): def test_read(dev):
with mock.patch.object(dev.samx.root.parent.producer, "get") as mock_get: with mock.patch.object(dev.samx.root.parent.connector, "get") as mock_get:
mock_get.return_value = messages.DeviceMessage( mock_get.return_value = messages.DeviceMessage(
signals={ signals={
"samx": {"value": 0, "timestamp": 1701105880.1711318}, "samx": {"value": 0, "timestamp": 1701105880.1711318},
@ -42,7 +42,7 @@ def test_read(dev):
def test_read_filtered_hints(dev): def test_read_filtered_hints(dev):
with mock.patch.object(dev.samx.root.parent.producer, "get") as mock_get: with mock.patch.object(dev.samx.root.parent.connector, "get") as mock_get:
mock_get.return_value = messages.DeviceMessage( mock_get.return_value = messages.DeviceMessage(
signals={ signals={
"samx": {"value": 0, "timestamp": 1701105880.1711318}, "samx": {"value": 0, "timestamp": 1701105880.1711318},
@ -57,7 +57,7 @@ def test_read_filtered_hints(dev):
def test_read_use_read(dev): def test_read_use_read(dev):
with mock.patch.object(dev.samx.root.parent.producer, "get") as mock_get: with mock.patch.object(dev.samx.root.parent.connector, "get") as mock_get:
data = { data = {
"samx": {"value": 0, "timestamp": 1701105880.1711318}, "samx": {"value": 0, "timestamp": 1701105880.1711318},
"samx_setpoint": {"value": 0, "timestamp": 1701105880.1693492}, "samx_setpoint": {"value": 0, "timestamp": 1701105880.1693492},
@ -72,7 +72,7 @@ def test_read_use_read(dev):
def test_read_nested_device(dev): def test_read_nested_device(dev):
with mock.patch.object(dev.dyn_signals.root.parent.producer, "get") as mock_get: with mock.patch.object(dev.dyn_signals.root.parent.connector, "get") as mock_get:
data = { data = {
"dyn_signals_messages_message1": {"value": 0, "timestamp": 1701105880.0716832}, "dyn_signals_messages_message1": {"value": 0, "timestamp": 1701105880.0716832},
"dyn_signals_messages_message2": {"value": 0, "timestamp": 1701105880.071722}, "dyn_signals_messages_message2": {"value": 0, "timestamp": 1701105880.071722},
@ -93,7 +93,7 @@ def test_read_nested_device(dev):
) )
def test_read_kind_hinted(dev, kind, cached): def test_read_kind_hinted(dev, kind, cached):
with mock.patch.object(dev.samx.readback, "_run") as mock_run: with mock.patch.object(dev.samx.readback, "_run") as mock_run:
with mock.patch.object(dev.samx.root.parent.producer, "get") as mock_get: with mock.patch.object(dev.samx.root.parent.connector, "get") as mock_get:
data = { data = {
"samx": {"value": 0, "timestamp": 1701105880.1711318}, "samx": {"value": 0, "timestamp": 1701105880.1711318},
"samx_setpoint": {"value": 0, "timestamp": 1701105880.1693492}, "samx_setpoint": {"value": 0, "timestamp": 1701105880.1693492},
@ -138,7 +138,7 @@ def test_read_configuration_cached(dev, is_signal, is_config_signal, method):
with mock.patch.object( with mock.patch.object(
dev.samx.readback, "_get_rpc_signal_info", return_value=(is_signal, is_config_signal, True) dev.samx.readback, "_get_rpc_signal_info", return_value=(is_signal, is_config_signal, True)
): ):
with mock.patch.object(dev.samx.root.parent.producer, "get") as mock_get: with mock.patch.object(dev.samx.root.parent.connector, "get") as mock_get:
mock_get.return_value = messages.DeviceMessage( mock_get.return_value = messages.DeviceMessage(
signals={ signals={
"samx": {"value": 0, "timestamp": 1701105880.1711318}, "samx": {"value": 0, "timestamp": 1701105880.1711318},
@ -182,7 +182,7 @@ def test_get_rpc_func_name_read(dev):
) )
def test_get_rpc_func_name_readback_get(dev, kind, cached): def test_get_rpc_func_name_readback_get(dev, kind, cached):
with mock.patch.object(dev.samx.readback, "_run") as mock_rpc: with mock.patch.object(dev.samx.readback, "_run") as mock_rpc:
with mock.patch.object(dev.samx.root.parent.producer, "get") as mock_get: with mock.patch.object(dev.samx.root.parent.connector, "get") as mock_get:
mock_get.return_value = messages.DeviceMessage( mock_get.return_value = messages.DeviceMessage(
signals={ signals={
"samx": {"value": 0, "timestamp": 1701105880.1711318}, "samx": {"value": 0, "timestamp": 1701105880.1711318},
@ -219,9 +219,7 @@ def test_handle_rpc_response_returns_status(dev, bec_client):
msg = messages.DeviceRPCMessage( msg = messages.DeviceRPCMessage(
device="samx", return_val={"type": "status", "RID": "request_id"}, out="done", success=True device="samx", return_val={"type": "status", "RID": "request_id"}, out="done", success=True
) )
assert dev.samx._handle_rpc_response(msg) == Status( assert dev.samx._handle_rpc_response(msg) == Status(bec_client.device_manager, "request_id")
bec_client.device_manager.producer, "request_id"
)
def test_handle_rpc_response_raises(dev): def test_handle_rpc_response_raises(dev):
@ -348,7 +346,7 @@ def test_device_update_user_parameter(device_obj, user_param, val, out, raised_e
def test_status_wait(): def test_status_wait():
producer = mock.MagicMock() connector = mock.MagicMock()
def lrange_mock(*args, **kwargs): def lrange_mock(*args, **kwargs):
yield False yield False
@ -358,8 +356,8 @@ def test_status_wait():
return next(lmock) return next(lmock)
lmock = lrange_mock() lmock = lrange_mock()
producer.lrange = get_lrange connector.lrange = get_lrange
status = Status(producer, "test") status = Status(connector, "test")
status.wait() status.wait()
@ -561,7 +559,7 @@ def test_show_all():
def test_adjustable_mixin_limits(): def test_adjustable_mixin_limits():
adj = AdjustableMixin() adj = AdjustableMixin()
adj.root = mock.MagicMock() adj.root = mock.MagicMock()
adj.root.parent.producer.get.return_value = messages.DeviceMessage( adj.root.parent.connector.get.return_value = messages.DeviceMessage(
signals={"low": -12, "high": 12}, metadata={} signals={"low": -12, "high": 12}, metadata={}
) )
assert adj.limits == [-12, 12] assert adj.limits == [-12, 12]
@ -570,7 +568,7 @@ def test_adjustable_mixin_limits():
def test_adjustable_mixin_limits_missing(): def test_adjustable_mixin_limits_missing():
adj = AdjustableMixin() adj = AdjustableMixin()
adj.root = mock.MagicMock() adj.root = mock.MagicMock()
adj.root.parent.producer.get.return_value = None adj.root.parent.connector.get.return_value = None
assert adj.limits == [0, 0] assert adj.limits == [0, 0]
@ -585,7 +583,7 @@ def test_adjustable_mixin_set_low_limit():
adj = AdjustableMixin() adj = AdjustableMixin()
adj.update_config = mock.MagicMock() adj.update_config = mock.MagicMock()
adj.root = mock.MagicMock() adj.root = mock.MagicMock()
adj.root.parent.producer.get.return_value = messages.DeviceMessage( adj.root.parent.connector.get.return_value = messages.DeviceMessage(
signals={"low": -12, "high": 12}, metadata={} signals={"low": -12, "high": 12}, metadata={}
) )
adj.low_limit = -20 adj.low_limit = -20
@ -596,7 +594,7 @@ def test_adjustable_mixin_set_high_limit():
adj = AdjustableMixin() adj = AdjustableMixin()
adj.update_config = mock.MagicMock() adj.update_config = mock.MagicMock()
adj.root = mock.MagicMock() adj.root = mock.MagicMock()
adj.root.parent.producer.get.return_value = messages.DeviceMessage( adj.root.parent.connector.get.return_value = messages.DeviceMessage(
signals={"low": -12, "high": 12}, metadata={} signals={"low": -12, "high": 12}, metadata={}
) )
adj.high_limit = 20 adj.high_limit = 20

View File

@ -97,9 +97,9 @@ def device_manager(dm_with_devices):
def test_observer_manager_None(device_manager): def test_observer_manager_None(device_manager):
with mock.patch.object(device_manager.producer, "get", return_value=None) as producer_get: with mock.patch.object(device_manager.connector, "get", return_value=None) as connector_get:
observer_manager = ObserverManager(device_manager=device_manager) observer_manager = ObserverManager(device_manager=device_manager)
producer_get.assert_called_once_with(MessageEndpoints.observer()) connector_get.assert_called_once_with(MessageEndpoints.observer())
assert len(observer_manager._observer) == 0 assert len(observer_manager._observer) == 0
@ -115,9 +115,9 @@ def test_observer_manager_msg(device_manager):
} }
] ]
) )
with mock.patch.object(device_manager.producer, "get", return_value=msg) as producer_get: with mock.patch.object(device_manager.connector, "get", return_value=msg) as connector_get:
observer_manager = ObserverManager(device_manager=device_manager) observer_manager = ObserverManager(device_manager=device_manager)
producer_get.assert_called_once_with(MessageEndpoints.observer()) connector_get.assert_called_once_with(MessageEndpoints.observer())
assert len(observer_manager._observer) == 1 assert len(observer_manager._observer) == 1
@ -139,7 +139,7 @@ def test_observer_manager_msg(device_manager):
], ],
) )
def test_add_observer(device_manager, observer, raises_error): def test_add_observer(device_manager, observer, raises_error):
with mock.patch.object(device_manager.producer, "get", return_value=None) as producer_get: with mock.patch.object(device_manager.connector, "get", return_value=None) as connector_get:
observer_manager = ObserverManager(device_manager=device_manager) observer_manager = ObserverManager(device_manager=device_manager)
observer_manager.add_observer(observer) observer_manager.add_observer(observer)
with pytest.raises(AttributeError): with pytest.raises(AttributeError):
@ -185,7 +185,7 @@ def test_add_observer_existing_device(device_manager, observer, raises_error):
"limits": [380, None], "limits": [380, None],
} }
) )
with mock.patch.object(device_manager.producer, "get", return_value=None) as producer_get: with mock.patch.object(device_manager.connector, "get", return_value=None) as connector_get:
observer_manager = ObserverManager(device_manager=device_manager) observer_manager = ObserverManager(device_manager=device_manager)
observer_manager.add_observer(default_observer) observer_manager.add_observer(default_observer)
if raises_error: if raises_error:

View File

@ -12,113 +12,71 @@ from bec_lib.messages import AlarmMessage, BECMessage, LogMessage
from bec_lib.redis_connector import ( from bec_lib.redis_connector import (
MessageObject, MessageObject,
RedisConnector, RedisConnector,
RedisConsumer,
RedisConsumerMixin,
RedisConsumerThreaded,
RedisProducer,
RedisStreamConsumerThreaded,
) )
from bec_lib.serialization import MsgpackSerialization from bec_lib.serialization import MsgpackSerialization
@pytest.fixture
def producer():
with mock.patch("bec_lib.redis_connector.redis.Redis"):
prod = RedisProducer("localhost", 1)
yield prod
@pytest.fixture @pytest.fixture
def connector(): def connector():
with mock.patch("bec_lib.redis_connector.redis.Redis"): with mock.patch("bec_lib.redis_connector.redis.Redis"):
connector = RedisConnector("localhost:1") connector = RedisConnector("localhost:1")
try:
yield connector
finally:
connector.shutdown()
@pytest.fixture
def connected_connector(redis_proc):
connector = RedisConnector(f"localhost:{redis_proc.port}")
try:
yield connector yield connector
finally:
connector.shutdown()
@pytest.fixture
def consumer():
with mock.patch("bec_lib.redis_connector.redis.Redis"):
consumer = RedisConsumer("localhost", "1", topics="topics")
yield consumer
@pytest.fixture
def consumer_threaded():
with mock.patch("bec_lib.redis_connector.redis.Redis"):
consumer_threaded = RedisConsumerThreaded("localhost", "1", topics="topics")
yield consumer_threaded
@pytest.fixture
def mixin():
with mock.patch("bec_lib.redis_connector.redis.Redis"):
mixin = RedisConsumerMixin
yield mixin
def test_redis_connector_producer(connector):
ret = connector.producer()
assert isinstance(ret, RedisProducer)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"topics, threaded", [["topics", True], ["topics", False], [None, True], [None, False]] "topics, threaded",
[["topics", True], ["topics", False], [None, True], [None, False]],
) )
def test_redis_connector_consumer(connector, threaded, topics): def test_redis_connector_register(connected_connector, threaded, topics):
pattern = None breakpoint()
len_of_threads = len(connector._threads) connector = connected_connector
if topics is None:
if threaded: with pytest.raises(TypeError):
if topics is None and pattern is None: ret = connector.register(
with pytest.raises(ValueError) as exc_info: topics=topics, cb=lambda *args, **kwargs: ..., start_thread=threaded
ret = connector.consumer(
topics=topics, threaded=threaded, cb=lambda *args, **kwargs: ...
)
assert exc_info.value.args[0] == "Topics must be set for threaded consumer"
else:
ret = connector.consumer(
topics=topics, threaded=threaded, cb=lambda *args, **kwargs: ...
) )
assert len(connector._threads) == len_of_threads + 1
assert isinstance(ret, RedisConsumerThreaded)
else: else:
if not topics: ret = connector.register(
with pytest.raises(ConsumerConnectorError): topics=topics, cb=lambda *args, **kwargs: ..., start_thread=threaded
ret = connector.consumer( )
topics=topics, threaded=threaded, cb=lambda *args, **kwargs: ... if threaded:
) assert connector._events_listener_thread is not None
return
ret = connector.consumer(topics=topics, threaded=threaded, cb=lambda *args, **kwargs: ...)
assert isinstance(ret, RedisConsumer)
def test_redis_connector_log_warning(connector): def test_redis_connector_log_warning(connector):
connector._notifications_producer.send = mock.MagicMock() with mock.patch.object(connector, "send", return_value=None):
connector.log_warning("msg")
connector.log_warning("msg") connector.send.assert_called_once_with(
connector._notifications_producer.send.assert_called_once_with( MessageEndpoints.log(), LogMessage(log_type="warning", log_msg="msg")
MessageEndpoints.log(), LogMessage(log_type="warning", log_msg="msg") )
)
def test_redis_connector_log_message(connector): def test_redis_connector_log_message(connector):
connector._notifications_producer.send = mock.MagicMock() with mock.patch.object(connector, "send", return_value=None):
connector.log_message("msg")
connector.log_message("msg") connector.send.assert_called_once_with(
connector._notifications_producer.send.assert_called_once_with( MessageEndpoints.log(), LogMessage(log_type="log", log_msg="msg")
MessageEndpoints.log(), LogMessage(log_type="log", log_msg="msg") )
)
def test_redis_connector_log_error(connector): def test_redis_connector_log_error(connector):
connector._notifications_producer.send = mock.MagicMock() with mock.patch.object(connector, "send", return_value=None):
connector.log_error("msg")
connector.log_error("msg") connector.send.assert_called_once_with(
connector._notifications_producer.send.assert_called_once_with( MessageEndpoints.log(), LogMessage(log_type="error", log_msg="msg")
MessageEndpoints.log(), LogMessage(log_type="error", log_msg="msg") )
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -130,20 +88,24 @@ def test_redis_connector_log_error(connector):
], ],
) )
def test_redis_connector_raise_alarm(connector, severity, alarm_type, source, msg, metadata): def test_redis_connector_raise_alarm(connector, severity, alarm_type, source, msg, metadata):
connector._notifications_producer.set_and_publish = mock.MagicMock() with mock.patch.object(connector, "set_and_publish", return_value=None):
connector.raise_alarm(severity, alarm_type, source, msg, metadata)
connector.raise_alarm(severity, alarm_type, source, msg, metadata) connector.set_and_publish.assert_called_once_with(
MessageEndpoints.alarm(),
connector._notifications_producer.set_and_publish.assert_called_once_with( AlarmMessage(
MessageEndpoints.alarm(), severity=severity,
AlarmMessage( alarm_type=alarm_type,
severity=severity, alarm_type=alarm_type, source=source, msg=msg, metadata=metadata source=source,
), msg=msg,
) metadata=metadata,
),
)
@dataclass(eq=False) @dataclass(eq=False)
class TestMessage(BECMessage): class TestMessage(BECMessage):
__test__ = False # just for pytest to ignore this class
msg_type = "test_message" msg_type = "test_message"
msg: str msg: str
# have to add this field here, # have to add this field here,
@ -160,30 +122,36 @@ bec_messages.TestMessage = TestMessage
@pytest.mark.parametrize( @pytest.mark.parametrize(
"topic , msg", [["topic1", TestMessage("msg1")], ["topic2", TestMessage("msg2")]] "topic , msg", [["topic1", TestMessage("msg1")], ["topic2", TestMessage("msg2")]]
) )
def test_redis_producer_send(producer, topic, msg): def test_redis_connector_send(connector, topic, msg):
producer.send(topic, msg) connector.send(topic, msg)
producer.r.publish.assert_called_once_with(topic, MsgpackSerialization.dumps(msg)) connector._redis_conn.publish.assert_called_once_with(topic, MsgpackSerialization.dumps(msg))
producer.send(topic, msg, pipe=producer.pipeline()) connector.send(topic, msg, pipe=connector.pipeline())
producer.r.pipeline().publish.assert_called_once_with(topic, MsgpackSerialization.dumps(msg)) connector._redis_conn.pipeline().publish.assert_called_once_with(
topic, MsgpackSerialization.dumps(msg)
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"topic, msgs, max_size, expire", "topic, msgs, max_size, expire",
[["topic1", "msgs", None, None], ["topic1", "msgs", 10, None], ["topic1", "msgs", None, 100]], [
["topic1", "msgs", None, None],
["topic1", "msgs", 10, None],
["topic1", "msgs", None, 100],
],
) )
def test_redis_producer_lpush(producer, topic, msgs, max_size, expire): def test_redis_connector_lpush(connector, topic, msgs, max_size, expire):
pipe = None pipe = None
producer.lpush(topic, msgs, pipe, max_size, expire) connector.lpush(topic, msgs, pipe, max_size, expire)
producer.r.pipeline().lpush.assert_called_once_with(topic, msgs) connector._redis_conn.pipeline().lpush.assert_called_once_with(topic, msgs)
if max_size: if max_size:
producer.r.pipeline().ltrim.assert_called_once_with(topic, 0, max_size) connector._redis_conn.pipeline().ltrim.assert_called_once_with(topic, 0, max_size)
if expire: if expire:
producer.r.pipeline().expire.assert_called_once_with(topic, expire) connector._redis_conn.pipeline().expire.assert_called_once_with(topic, expire)
if not pipe: if not pipe:
producer.r.pipeline().execute.assert_called_once() connector._redis_conn.pipeline().execute.assert_called_once()
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -194,68 +162,76 @@ def test_redis_producer_lpush(producer, topic, msgs, max_size, expire):
["topic1", TestMessage("msgs"), None, 100], ["topic1", TestMessage("msgs"), None, 100],
], ],
) )
def test_redis_producer_lpush_BECMessage(producer, topic, msgs, max_size, expire): def test_redis_connector_lpush_BECMessage(connector, topic, msgs, max_size, expire):
pipe = None pipe = None
producer.lpush(topic, msgs, pipe, max_size, expire) connector.lpush(topic, msgs, pipe, max_size, expire)
producer.r.pipeline().lpush.assert_called_once_with(topic, MsgpackSerialization.dumps(msgs)) connector._redis_conn.pipeline().lpush.assert_called_once_with(
topic, MsgpackSerialization.dumps(msgs)
)
if max_size: if max_size:
producer.r.pipeline().ltrim.assert_called_once_with(topic, 0, max_size) connector._redis_conn.pipeline().ltrim.assert_called_once_with(topic, 0, max_size)
if expire: if expire:
producer.r.pipeline().expire.assert_called_once_with(topic, expire) connector._redis_conn.pipeline().expire.assert_called_once_with(topic, expire)
if not pipe: if not pipe:
producer.r.pipeline().execute.assert_called_once() connector._redis_conn.pipeline().execute.assert_called_once()
@pytest.mark.parametrize( @pytest.mark.parametrize(
"topic , index , msgs, use_pipe", [["topic1", 1, "msg1", True], ["topic2", 4, "msg2", False]] "topic , index , msgs, use_pipe",
[["topic1", 1, "msg1", True], ["topic2", 4, "msg2", False]],
) )
def test_redis_producer_lset(producer, topic, index, msgs, use_pipe): def test_redis_connector_lset(connector, topic, index, msgs, use_pipe):
pipe = use_pipe_fcn(producer, use_pipe) pipe = use_pipe_fcn(connector, use_pipe)
ret = producer.lset(topic, index, msgs, pipe) ret = connector.lset(topic, index, msgs, pipe)
if pipe: if pipe:
producer.r.pipeline().lset.assert_called_once_with(topic, index, msgs) connector._redis_conn.pipeline().lset.assert_called_once_with(topic, index, msgs)
assert ret == redis.Redis().pipeline().lset() assert ret == redis.Redis().pipeline().lset()
else: else:
producer.r.lset.assert_called_once_with(topic, index, msgs) connector._redis_conn.lset.assert_called_once_with(topic, index, msgs)
assert ret == redis.Redis().lset() assert ret == redis.Redis().lset()
@pytest.mark.parametrize( @pytest.mark.parametrize(
"topic , index , msgs, use_pipe", "topic , index , msgs, use_pipe",
[["topic1", 1, TestMessage("msg1"), True], ["topic2", 4, TestMessage("msg2"), False]], [
["topic1", 1, TestMessage("msg1"), True],
["topic2", 4, TestMessage("msg2"), False],
],
) )
def test_redis_producer_lset_BECMessage(producer, topic, index, msgs, use_pipe): def test_redis_connector_lset_BECMessage(connector, topic, index, msgs, use_pipe):
pipe = use_pipe_fcn(producer, use_pipe) pipe = use_pipe_fcn(connector, use_pipe)
ret = producer.lset(topic, index, msgs, pipe) ret = connector.lset(topic, index, msgs, pipe)
if pipe: if pipe:
producer.r.pipeline().lset.assert_called_once_with( connector._redis_conn.pipeline().lset.assert_called_once_with(
topic, index, MsgpackSerialization.dumps(msgs) topic, index, MsgpackSerialization.dumps(msgs)
) )
assert ret == redis.Redis().pipeline().lset() assert ret == redis.Redis().pipeline().lset()
else: else:
producer.r.lset.assert_called_once_with(topic, index, MsgpackSerialization.dumps(msgs)) connector._redis_conn.lset.assert_called_once_with(
topic, index, MsgpackSerialization.dumps(msgs)
)
assert ret == redis.Redis().lset() assert ret == redis.Redis().lset()
@pytest.mark.parametrize( @pytest.mark.parametrize(
"topic, msgs, use_pipe", [["topic1", "msg1", True], ["topic2", "msg2", False]] "topic, msgs, use_pipe", [["topic1", "msg1", True], ["topic2", "msg2", False]]
) )
def test_redis_producer_rpush(producer, topic, msgs, use_pipe): def test_redis_connector_rpush(connector, topic, msgs, use_pipe):
pipe = use_pipe_fcn(producer, use_pipe) pipe = use_pipe_fcn(connector, use_pipe)
ret = producer.rpush(topic, msgs, pipe) ret = connector.rpush(topic, msgs, pipe)
if pipe: if pipe:
producer.r.pipeline().rpush.assert_called_once_with(topic, msgs) connector._redis_conn.pipeline().rpush.assert_called_once_with(topic, msgs)
assert ret == redis.Redis().pipeline().rpush() assert ret == redis.Redis().pipeline().rpush()
else: else:
producer.r.rpush.assert_called_once_with(topic, msgs) connector._redis_conn.rpush.assert_called_once_with(topic, msgs)
assert ret == redis.Redis().rpush() assert ret == redis.Redis().rpush()
@ -263,421 +239,349 @@ def test_redis_producer_rpush(producer, topic, msgs, use_pipe):
"topic, msgs, use_pipe", "topic, msgs, use_pipe",
[["topic1", TestMessage("msg1"), True], ["topic2", TestMessage("msg2"), False]], [["topic1", TestMessage("msg1"), True], ["topic2", TestMessage("msg2"), False]],
) )
def test_redis_producer_rpush_BECMessage(producer, topic, msgs, use_pipe): def test_redis_connector_rpush_BECMessage(connector, topic, msgs, use_pipe):
pipe = use_pipe_fcn(producer, use_pipe) pipe = use_pipe_fcn(connector, use_pipe)
ret = producer.rpush(topic, msgs, pipe) ret = connector.rpush(topic, msgs, pipe)
if pipe: if pipe:
producer.r.pipeline().rpush.assert_called_once_with(topic, MsgpackSerialization.dumps(msgs)) connector._redis_conn.pipeline().rpush.assert_called_once_with(
topic, MsgpackSerialization.dumps(msgs)
)
assert ret == redis.Redis().pipeline().rpush() assert ret == redis.Redis().pipeline().rpush()
else: else:
producer.r.rpush.assert_called_once_with(topic, MsgpackSerialization.dumps(msgs)) connector._redis_conn.rpush.assert_called_once_with(topic, MsgpackSerialization.dumps(msgs))
assert ret == redis.Redis().rpush() assert ret == redis.Redis().rpush()
@pytest.mark.parametrize( @pytest.mark.parametrize(
"topic, start, end, use_pipe", [["topic1", 0, 4, True], ["topic2", 3, 7, False]] "topic, start, end, use_pipe", [["topic1", 0, 4, True], ["topic2", 3, 7, False]]
) )
def test_redis_producer_lrange(producer, topic, start, end, use_pipe): def test_redis_connector_lrange(connector, topic, start, end, use_pipe):
pipe = use_pipe_fcn(producer, use_pipe) pipe = use_pipe_fcn(connector, use_pipe)
ret = producer.lrange(topic, start, end, pipe) ret = connector.lrange(topic, start, end, pipe)
if pipe: if pipe:
producer.r.pipeline().lrange.assert_called_once_with(topic, start, end) connector._redis_conn.pipeline().lrange.assert_called_once_with(topic, start, end)
assert ret == redis.Redis().pipeline().lrange() assert ret == redis.Redis().pipeline().lrange()
else: else:
producer.r.lrange.assert_called_once_with(topic, start, end) connector._redis_conn.lrange.assert_called_once_with(topic, start, end)
assert ret == [] assert ret == []
@pytest.mark.parametrize( @pytest.mark.parametrize(
"topic, msg, pipe, expire", [["topic1", "msg1", None, 400], ["topic2", "msg2", None, None]] "topic, msg, pipe, expire",
[
["topic1", TestMessage("msg1"), None, 400],
["topic2", TestMessage("msg2"), None, None],
["topic3", "msg3", None, None],
],
) )
def test_redis_producer_set_and_publish(producer, topic, msg, pipe, expire): def test_redis_connector_set_and_publish(connector, topic, msg, pipe, expire):
producer.set_and_publish(topic, msg, pipe, expire) if not isinstance(msg, BECMessage):
with pytest.raises(TypeError):
connector.set_and_publish(topic, msg, pipe, expire)
else:
connector.set_and_publish(topic, msg, pipe, expire)
producer.r.pipeline().publish.assert_called_once_with(topic, msg) connector._redis_conn.pipeline().publish.assert_called_once_with(
producer.r.pipeline().set.assert_called_once_with(topic, msg) topic, MsgpackSerialization.dumps(msg)
if expire: )
producer.r.pipeline().expire.assert_called_once_with(topic, expire) connector._redis_conn.pipeline().set.assert_called_once_with(
if not pipe: topic, MsgpackSerialization.dumps(msg), ex=expire
producer.r.pipeline().execute.assert_called_once() )
if not pipe:
connector._redis_conn.pipeline().execute.assert_called_once()
@pytest.mark.parametrize("topic, msg, expire", [["topic1", "msg1", None], ["topic2", "msg2", 400]]) @pytest.mark.parametrize("topic, msg, expire", [["topic1", "msg1", None], ["topic2", "msg2", 400]])
def test_redis_producer_set(producer, topic, msg, expire): def test_redis_connector_set(connector, topic, msg, expire):
pipe = None pipe = None
producer.set(topic, msg, pipe, expire) connector.set(topic, msg, pipe, expire)
if pipe: if pipe:
producer.r.pipeline().set.assert_called_once_with(topic, msg, ex=expire) connector._redis_conn.pipeline().set.assert_called_once_with(topic, msg, ex=expire)
else: else:
producer.r.set.assert_called_once_with(topic, msg, ex=expire) connector._redis_conn.set.assert_called_once_with(topic, msg, ex=expire)
@pytest.mark.parametrize("pattern", ["samx", "samy"]) @pytest.mark.parametrize("pattern", ["samx", "samy"])
def test_redis_producer_keys(producer, pattern): def test_redis_connector_keys(connector, pattern):
ret = producer.keys(pattern) ret = connector.keys(pattern)
producer.r.keys.assert_called_once_with(pattern) connector._redis_conn.keys.assert_called_once_with(pattern)
assert ret == redis.Redis().keys() assert ret == redis.Redis().keys()
def test_redis_producer_pipeline(producer): def test_redis_connector_pipeline(connector):
ret = producer.pipeline() ret = connector.pipeline()
producer.r.pipeline.assert_called_once() connector._redis_conn.pipeline.assert_called_once()
assert ret == redis.Redis().pipeline() assert ret == redis.Redis().pipeline()
@pytest.mark.parametrize("topic,use_pipe", [["topic1", True], ["topic2", False]]) def use_pipe_fcn(connector, use_pipe):
def test_redis_producer_delete(producer, topic, use_pipe):
pipe = use_pipe_fcn(producer, use_pipe)
producer.delete(topic, pipe)
if pipe:
producer.pipeline().delete.assert_called_once_with(topic)
else:
producer.r.delete.assert_called_once_with(topic)
@pytest.mark.parametrize("topic, use_pipe", [["topic1", True], ["topic2", False]])
def test_redis_producer_get(producer, topic, use_pipe):
pipe = use_pipe_fcn(producer, use_pipe)
ret = producer.get(topic, pipe)
if pipe:
producer.pipeline().get.assert_called_once_with(topic)
assert ret == redis.Redis().pipeline().get()
else:
producer.r.get.assert_called_once_with(topic)
assert ret == redis.Redis().get()
def use_pipe_fcn(producer, use_pipe):
if use_pipe: if use_pipe:
return producer.pipeline() return connector.pipeline()
return None return None
@pytest.mark.parametrize("topic,use_pipe", [["topic1", True], ["topic2", False]])
def test_redis_connector_delete(connector, topic, use_pipe):
pipe = use_pipe_fcn(connector, use_pipe)
connector.delete(topic, pipe)
if pipe:
connector.pipeline().delete.assert_called_once_with(topic)
else:
connector._redis_conn.delete.assert_called_once_with(topic)
@pytest.mark.parametrize("topic, use_pipe", [["topic1", True], ["topic2", False]])
def test_redis_connector_get(connector, topic, use_pipe):
pipe = use_pipe_fcn(connector, use_pipe)
ret = connector.get(topic, pipe)
if pipe:
connector.pipeline().get.assert_called_once_with(topic)
assert ret == redis.Redis().pipeline().get()
else:
connector._redis_conn.get.assert_called_once_with(topic)
assert ret == redis.Redis().get()
@pytest.mark.parametrize( @pytest.mark.parametrize(
"topics, pattern", "subscribed_topics, subscribed_patterns, msgs",
[ [
["topics1", None], ["topics1", None, ["topics1"]],
[["topics1", "topics2"], None], [["topics1", "topics2"], None, ["topics1", "topics2"]],
[None, "pattern1"], [None, "pattern1", ["pattern1"]],
[None, ["pattern1", "pattern2"]], [None, ["patt*", "top*"], ["pattern1", "topics1"]],
], ],
) )
def test_redis_consumer_init(consumer, topics, pattern): def test_redis_connector_register(
with mock.patch("bec_lib.redis_connector.redis.Redis"): redisdb, connected_connector, subscribed_topics, subscribed_patterns, msgs
consumer = RedisConsumer( ):
"localhost", "1", topics, pattern, redis_cls=redis.Redis, cb=lambda *args, **kwargs: ... connector = connected_connector
test_msg = TestMessage("test")
cb_mock = mock.Mock(spec=[]) # spec is here to remove all attributes
if subscribed_topics:
connector.register(
subscribed_topics, subscribed_patterns, cb=cb_mock, start_thread=False, a=1
) )
for msg in msgs:
if topics: connector.send(msg, TestMessage(msg))
if isinstance(topics, list): connector.poll_messages()
assert consumer.topics == topics msg_object = MessageObject(msg, TestMessage(msg))
else: cb_mock.assert_called_with(msg_object, a=1)
assert consumer.topics == [topics] if subscribed_patterns:
if pattern: connector.register(
if isinstance(pattern, list): subscribed_topics, subscribed_patterns, cb=cb_mock, start_thread=False, a=1
assert consumer.pattern == pattern )
else: for msg in msgs:
assert consumer.pattern == [pattern] connector.send(msg, TestMessage(msg))
connector.poll_messages()
assert consumer.r == redis.Redis() msg_object = MessageObject(msg, TestMessage(msg))
assert consumer.pubsub == consumer.r.pubsub() cb_mock.assert_called_with(msg_object, a=1)
assert consumer.host == "localhost"
assert consumer.port == "1"
@pytest.mark.parametrize("pattern, topics", [["pattern", "topics1"], [None, "topics2"]]) def test_redis_register_poll_messages(redisdb, connected_connector):
def test_redis_consumer_initialize_connector(consumer, pattern, topics): connector = connected_connector
consumer.pattern = pattern
consumer.topics = topics
consumer.initialize_connector()
if consumer.pattern is not None:
consumer.pubsub.psubscribe.assert_called_once_with(consumer.pattern)
else:
consumer.pubsub.subscribe.assert_called_with(consumer.topics)
def test_redis_consumer_poll_messages(consumer):
cb_fcn_has_been_called = False cb_fcn_has_been_called = False
def cb_fcn(msg, **kwargs): def cb_fcn(msg, **kwargs):
nonlocal cb_fcn_has_been_called nonlocal cb_fcn_has_been_called
cb_fcn_has_been_called = True cb_fcn_has_been_called = True
print(msg) assert kwargs["a"] == 1
consumer.cb = cb_fcn
test_msg = TestMessage("test") test_msg = TestMessage("test")
consumer.pubsub.get_message.return_value = { connector.register("test", cb=cb_fcn, a=1, start_thread=False)
"channel": "", redisdb.publish("test", MsgpackSerialization.dumps(test_msg))
"data": MsgpackSerialization.dumps(test_msg),
} connector.poll_messages(timeout=1)
ret = consumer.poll_messages()
consumer.pubsub.get_message.assert_called_once_with(ignore_subscribe_messages=True)
assert cb_fcn_has_been_called assert cb_fcn_has_been_called
with pytest.raises(TimeoutError):
def test_redis_consumer_shutdown(consumer): connector.poll_messages(timeout=0.1)
consumer.shutdown()
consumer.pubsub.close.assert_called_once()
def test_redis_consumer_additional_kwargs(connector): def test_redis_connector_xadd(connector):
cons = connector.consumer(topics="topic1", parent="here", cb=lambda *args, **kwargs: ...) connector.xadd("topic1", {"key": "value"})
assert "parent" in cons.kwargs connector._redis_conn.xadd.assert_called_once_with("topic1", {"key": "value"})
@pytest.mark.parametrize( def test_redis_connector_xadd_with_maxlen(connector):
"topics, pattern", connector.xadd("topic1", {"key": "value"}, max_size=100)
[ connector._redis_conn.xadd.assert_called_once_with("topic1", {"key": "value"}, maxlen=100)
["topics1", None],
[["topics1", "topics2"], None],
[None, "pattern1"],
[None, ["pattern1", "pattern2"]],
],
)
def test_mixin_init_topics_and_pattern(mixin, topics, pattern):
ret_topics, ret_pattern = mixin._init_topics_and_pattern(mixin, topics, pattern)
if topics:
if isinstance(topics, list):
assert ret_topics == topics
else:
assert ret_topics == [topics]
if pattern:
if isinstance(pattern, list):
assert ret_pattern == pattern
else:
assert ret_pattern == [pattern]
def test_mixin_init_redis_cls(mixin, consumer): def test_redis_connector_xadd_with_expire(connector):
mixin._init_redis_cls(consumer, None) connector.xadd("topic1", {"key": "value"}, expire=100)
assert consumer.r == redis.Redis(host="localhost", port=1) connector._redis_conn.pipeline().xadd.assert_called_once_with("topic1", {"key": "value"})
connector._redis_conn.pipeline().expire.assert_called_once_with("topic1", 100)
connector._redis_conn.pipeline().execute.assert_called_once()
@pytest.mark.parametrize( def test_redis_connector_xread(connector):
"topics, pattern", connector.xread("topic1", "id")
[ connector._redis_conn.xread.assert_called_once_with({"topic1": "id"}, count=None, block=None)
["topics1", None],
[["topics1", "topics2"], None],
[None, "pattern1"],
[None, ["pattern1", "pattern2"]],
],
)
def test_redis_consumer_threaded_init(consumer_threaded, topics, pattern):
with mock.patch("bec_lib.redis_connector.redis.Redis"):
consumer_threaded = RedisConsumerThreaded(
"localhost", "1", topics, pattern, redis_cls=redis.Redis, cb=lambda *args, **kwargs: ...
)
if topics:
if isinstance(topics, list):
assert consumer_threaded.topics == topics
else:
assert consumer_threaded.topics == [topics]
if pattern:
if isinstance(pattern, list):
assert consumer_threaded.pattern == pattern
else:
assert consumer_threaded.pattern == [pattern]
assert consumer_threaded.r == redis.Redis()
assert consumer_threaded.pubsub == consumer_threaded.r.pubsub()
assert consumer_threaded.host == "localhost"
assert consumer_threaded.port == "1"
assert consumer_threaded.sleep_times == [0.005, 0.1]
assert consumer_threaded.last_received_msg == 0
assert consumer_threaded.idle_time == 30
def test_redis_connector_xadd(producer): def test_redis_connector_xadd(connector):
producer.xadd("topic1", {"key": "value"}) connector.xadd("topic1", {"key": "value"})
producer.r.xadd.assert_called_once_with("topic1", {"key": MsgpackSerialization.dumps("value")}) connector._redis_conn.xadd.assert_called_once_with(
"topic1", {"key": MsgpackSerialization.dumps("value")}
)
test_msg = TestMessage("test") test_msg = TestMessage("test")
producer.xadd("topic1", {"data": test_msg}) connector.xadd("topic1", {"data": test_msg})
producer.r.xadd.assert_called_with("topic1", {"data": MsgpackSerialization.dumps(test_msg)}) connector._redis_conn.xadd.assert_called_with(
producer.r.xrevrange.return_value = [ "topic1", {"data": MsgpackSerialization.dumps(test_msg)}
)
connector._redis_conn.xrevrange.return_value = [
(b"1707391599960-0", {b"data": MsgpackSerialization.dumps(test_msg)}) (b"1707391599960-0", {b"data": MsgpackSerialization.dumps(test_msg)})
] ]
msg = producer.get_last("topic1") msg = connector.get_last("topic1")
assert msg == test_msg assert msg == test_msg
def test_redis_connector_xadd_with_maxlen(producer): def test_redis_connector_xadd_with_maxlen(connector):
producer.xadd("topic1", {"key": "value"}, max_size=100) connector.xadd("topic1", {"key": "value"}, max_size=100)
producer.r.xadd.assert_called_once_with( connector._redis_conn.xadd.assert_called_once_with(
"topic1", {"key": MsgpackSerialization.dumps("value")}, maxlen=100 "topic1", {"key": MsgpackSerialization.dumps("value")}, maxlen=100
) )
def test_redis_connector_xadd_with_expire(producer): def test_redis_connector_xadd_with_expire(connector):
producer.xadd("topic1", {"key": "value"}, expire=100) connector.xadd("topic1", {"key": "value"}, expire=100)
producer.r.pipeline().xadd.assert_called_once_with( connector._redis_conn.pipeline().xadd.assert_called_once_with(
"topic1", {"key": MsgpackSerialization.dumps("value")} "topic1", {"key": MsgpackSerialization.dumps("value")}
) )
producer.r.pipeline().expire.assert_called_once_with("topic1", 100) connector._redis_conn.pipeline().expire.assert_called_once_with("topic1", 100)
producer.r.pipeline().execute.assert_called_once() connector._redis_conn.pipeline().execute.assert_called_once()
def test_redis_connector_xread(producer): def test_redis_connector_xread(connector):
producer.xread("topic1", "id") connector.xread("topic1", "id")
producer.r.xread.assert_called_once_with({"topic1": "id"}, count=None, block=None) connector._redis_conn.xread.assert_called_once_with({"topic1": "id"}, count=None, block=None)
def test_redis_connector_xread_without_id(producer): def test_redis_connector_xread_without_id(connector):
producer.xread("topic1", from_start=True) connector.xread("topic1", from_start=True)
producer.r.xread.assert_called_once_with({"topic1": "0-0"}, count=None, block=None) connector._redis_conn.xread.assert_called_once_with({"topic1": "0-0"}, count=None, block=None)
producer.r.xread.reset_mock() connector._redis_conn.xread.reset_mock()
producer.stream_keys["topic1"] = "id" connector.stream_keys["topic1"] = "id"
producer.xread("topic1") connector.xread("topic1")
producer.r.xread.assert_called_once_with({"topic1": "id"}, count=None, block=None) connector._redis_conn.xread.assert_called_once_with({"topic1": "id"}, count=None, block=None)
def test_redis_connector_xread_from_end(producer): def test_redis_connector_xread_from_end(connector):
producer.xread("topic1", from_start=False) connector.xread("topic1", from_start=False)
producer.r.xrevrange.assert_called_once_with("topic1", "+", "-", count=1) connector._redis_conn.xrevrange.assert_called_once_with("topic1", "+", "-", count=1)
def test_redis_connector_get_last(producer): def test_redis_connector_get_last(connector):
producer.r.xrevrange.return_value = [ connector._redis_conn.xrevrange.return_value = [
(b"1707391599960-0", {b"key": MsgpackSerialization.dumps("value")}) (b"1707391599960-0", {b"key": MsgpackSerialization.dumps("value")})
] ]
msg = producer.get_last("topic1") msg = connector.get_last("topic1")
producer.r.xrevrange.assert_called_once_with("topic1", "+", "-", count=1) connector._redis_conn.xrevrange.assert_called_once_with("topic1", "+", "-", count=1)
assert msg is None # no key given, default is b'data' assert msg is None # no key given, default is b'data'
assert producer.get_last("topic1", "key") == "value" assert connector.get_last("topic1", "key") == "value"
assert producer.get_last("topic1", None) == {"key": "value"} assert connector.get_last("topic1", None) == {"key": "value"}
def test_redis_xrange(producer): def test_redis_connector_xread_without_id(connector):
producer.xrange("topic1", "start", "end") connector.xread("topic1", from_start=True)
producer.r.xrange.assert_called_once_with("topic1", "start", "end", count=None) connector._redis_conn.xread.assert_called_once_with({"topic1": "0-0"}, count=None, block=None)
connector._redis_conn.xread.reset_mock()
connector.stream_keys["topic1"] = "id"
connector.xread("topic1")
connector._redis_conn.xread.assert_called_once_with({"topic1": "id"}, count=None, block=None)
def test_redis_xrange_topic_with_suffix(producer): def test_redis_xrange(connector):
producer.xrange("topic1", "start", "end") connector.xrange("topic1", "start", "end")
producer.r.xrange.assert_called_once_with("topic1", "start", "end", count=None) connector._redis_conn.xrange.assert_called_once_with("topic1", "start", "end", count=None)
def test_redis_consumer_threaded_no_cb_without_messages(consumer_threaded): def test_redis_xrange_topic_with_suffix(connector):
with mock.patch.object(consumer_threaded.pubsub, "get_message", return_value=None): connector.xrange("topic1", "start", "end")
consumer_threaded.cb = mock.MagicMock() connector._redis_conn.xrange.assert_called_once_with("topic1", "start", "end", count=None)
consumer_threaded.poll_messages()
consumer_threaded.cb.assert_not_called()
def test_redis_consumer_threaded_cb_called_with_messages(consumer_threaded): # def test_redis_stream_register_threaded_get_id():
message = {"channel": b"topic1", "data": MsgpackSerialization.dumps(TestMessage("test"))} # register = RedisStreamConsumerThreaded(
# "localhost", "1", topics="topic1", cb=mock.MagicMock(), redis_cls=mock.MagicMock()
with mock.patch.object(consumer_threaded.pubsub, "get_message", return_value=message): # )
consumer_threaded.cb = mock.MagicMock() # register.stream_keys["topic1"] = b"1691610882756-0"
consumer_threaded.poll_messages() # assert register.get_id("topic1") == b"1691610882756-0"
msg_object = MessageObject("topic1", TestMessage("test")) # assert register.get_id("doesnt_exist") == "0-0"
consumer_threaded.cb.assert_called_once_with(msg_object)
def test_redis_consumer_threaded_shutdown(consumer_threaded): # def test_redis_stream_register_threaded_poll_messages():
consumer_threaded.shutdown() # register = RedisStreamConsumerThreaded(
consumer_threaded.pubsub.close.assert_called_once() # "localhost", "1", topics="topic1", cb=mock.MagicMock(), redis_cls=mock.MagicMock()
# )
# with mock.patch.object(
# register, "get_newest_message", return_value=None
# ) as mock_get_newest_message:
# register.poll_messages()
# mock_get_newest_message.assert_called_once()
# register._redis_conn.xread.assert_not_called()
def test_redis_stream_consumer_threaded_get_newest_message(): # def test_redis_stream_register_threaded_poll_messages_newest_only():
consumer = RedisStreamConsumerThreaded( # register = RedisStreamConsumerThreaded(
"localhost", "1", topics="topic1", cb=mock.MagicMock(), redis_cls=mock.MagicMock() # "localhost",
) # "1",
consumer.r.xrevrange.return_value = [(b"1691610882756-0", {b"data": b"msg"})] # topics="topic1",
msgs = [] # cb=mock.MagicMock(),
consumer.get_newest_message(msgs) # redis_cls=mock.MagicMock(),
assert "topic1" in consumer.stream_keys # newest_only=True,
assert consumer.stream_keys["topic1"] == b"1691610882756-0" # )
#
# register._redis_conn.xrevrange.return_value = [(b"1691610882756-0", {b"data": b"msg"})]
# register.poll_messages()
# register._redis_conn.xread.assert_not_called()
# register.cb.assert_called_once_with(MessageObject(topic="topic1", value=b"msg"))
def test_redis_stream_consumer_threaded_get_newest_message_no_msg(): # def test_redis_stream_register_threaded_poll_messages_read():
consumer = RedisStreamConsumerThreaded( # register = RedisStreamConsumerThreaded(
"localhost", "1", topics="topic1", cb=mock.MagicMock(), redis_cls=mock.MagicMock() # "localhost",
) # "1",
consumer.r.xrevrange.return_value = [] # topics="topic1",
msgs = [] # cb=mock.MagicMock(),
consumer.get_newest_message(msgs) # redis_cls=mock.MagicMock(),
assert "topic1" in consumer.stream_keys # )
assert consumer.stream_keys["topic1"] == "0-0" # register.stream_keys["topic1"] = "0-0"
#
# msg = [[b"topic1", [(b"1691610714612-0", {b"data": b"msg"})]]]
#
# register._redis_conn.xread.return_value = msg
# register.poll_messages()
# register._redis_conn.xread.assert_called_once_with({"topic1": "0-0"}, count=1)
# register.cb.assert_called_once_with(MessageObject(topic="topic1", value=b"msg"))
# @pytest.mark.parametrize(
def test_redis_stream_consumer_threaded_get_id(): # "topics,expected",
consumer = RedisStreamConsumerThreaded( # [
"localhost", "1", topics="topic1", cb=mock.MagicMock(), redis_cls=mock.MagicMock() # ("topic1", ["topic1"]),
) # (["topic1"], ["topic1"]),
consumer.stream_keys["topic1"] = b"1691610882756-0" # (["topic1", "topic2"], ["topic1", "topic2"]),
assert consumer.get_id("topic1") == b"1691610882756-0" # ],
assert consumer.get_id("doesnt_exist") == "0-0" # )
# def test_redis_stream_register_threaded_init_topics(topics, expected):
# register = RedisStreamConsumerThreaded(
def test_redis_stream_consumer_threaded_poll_messages(): # "localhost",
consumer = RedisStreamConsumerThreaded( # "1",
"localhost", "1", topics="topic1", cb=mock.MagicMock(), redis_cls=mock.MagicMock() # topics=topics,
) # cb=mock.MagicMock(),
with mock.patch.object( # redis_cls=mock.MagicMock(),
consumer, "get_newest_message", return_value=None # )
) as mock_get_newest_message: # assert register.topics == expected
consumer.poll_messages()
mock_get_newest_message.assert_called_once()
consumer.r.xread.assert_not_called()
def test_redis_stream_consumer_threaded_poll_messages_newest_only():
consumer = RedisStreamConsumerThreaded(
"localhost",
"1",
topics="topic1",
cb=mock.MagicMock(),
redis_cls=mock.MagicMock(),
newest_only=True,
)
consumer.r.xrevrange.return_value = [(b"1691610882756-0", {b"data": b"msg"})]
consumer.poll_messages()
consumer.r.xread.assert_not_called()
consumer.cb.assert_called_once_with(MessageObject(topic="topic1", value=b"msg"))
def test_redis_stream_consumer_threaded_poll_messages_read():
consumer = RedisStreamConsumerThreaded(
"localhost", "1", topics="topic1", cb=mock.MagicMock(), redis_cls=mock.MagicMock()
)
consumer.stream_keys["topic1"] = "0-0"
msg = [[b"topic1", [(b"1691610714612-0", {b"data": b"msg"})]]]
consumer.r.xread.return_value = msg
consumer.poll_messages()
consumer.r.xread.assert_called_once_with({"topic1": "0-0"}, count=1)
consumer.cb.assert_called_once_with(MessageObject(topic="topic1", value=b"msg"))
@pytest.mark.parametrize(
"topics,expected",
[
("topic1", ["topic1"]),
(["topic1"], ["topic1"]),
(["topic1", "topic2"], ["topic1", "topic2"]),
],
)
def test_redis_stream_consumer_threaded_init_topics(topics, expected):
consumer = RedisStreamConsumerThreaded(
"localhost", "1", topics=topics, cb=mock.MagicMock(), redis_cls=mock.MagicMock()
)
assert consumer.topics == expected

View File

@ -63,7 +63,7 @@ from bec_lib.tests.utils import ConnectorMock
) )
def test_update_with_queue_status(queue_msg): def test_update_with_queue_status(queue_msg):
scan_manager = ScanManager(ConnectorMock("")) scan_manager = ScanManager(ConnectorMock(""))
scan_manager.producer._get_buffer[MessageEndpoints.scan_queue_status()] = queue_msg scan_manager.connector._get_buffer[MessageEndpoints.scan_queue_status()] = queue_msg
scan_manager.update_with_queue_status(queue_msg) scan_manager.update_with_queue_status(queue_msg)
assert ( assert (
scan_manager.scan_storage.find_scan_by_ID("bfa582aa-f9cd-4258-ab5d-3e5d54d3dde5") scan_manager.scan_storage.find_scan_by_ID("bfa582aa-f9cd-4258-ab5d-3e5d54d3dde5")

View File

@ -105,6 +105,6 @@ def test_scan_report_get_mv_status(scan_report, lrange_return, expected):
scan_report.request.request = messages.ScanQueueMessage( scan_report.request.request = messages.ScanQueueMessage(
scan_type="mv", parameter={"args": {"samx": [5], "samy": [5]}} scan_type="mv", parameter={"args": {"samx": [5], "samy": [5]}}
) )
with mock.patch.object(scan_report._client.device_manager.producer, "lrange") as mock_lrange: with mock.patch.object(scan_report._client.device_manager.connector, "lrange") as mock_lrange:
mock_lrange.return_value = lrange_return mock_lrange.return_value = lrange_return
assert scan_report._get_mv_status() == expected assert scan_report._get_mv_status() == expected

View File

@ -14,7 +14,6 @@ parser.add_argument("--redis", default="localhost:6379", help="redis host and po
clargs = parser.parse_args() clargs = parser.parse_args()
connector = RedisConnector(clargs.redis) connector = RedisConnector(clargs.redis)
producer = connector.producer()
with open(clargs.config, "r", encoding="utf-8") as stream: with open(clargs.config, "r", encoding="utf-8") as stream:
data = yaml.safe_load(stream) data = yaml.safe_load(stream)
@ -22,4 +21,4 @@ for name, device in data.items():
device["name"] = name device["name"] = name
config_data = list(data.values()) config_data = list(data.values())
msg = messages.AvailableResourceMessage(resource=config_data) msg = messages.AvailableResourceMessage(resource=config_data)
producer.set(MessageEndpoints.device_config(), msg) connector.set(MessageEndpoints.device_config(), msg)

View File

@ -11,7 +11,6 @@ class DAPServiceManager:
def __init__(self, services: list) -> None: def __init__(self, services: list) -> None:
self.connector = None self.connector = None
self.producer = None
self._started = False self._started = False
self.client = None self.client = None
self._dap_request_thread = None self._dap_request_thread = None
@ -24,13 +23,11 @@ class DAPServiceManager:
""" """
Start the dap request consumer. Start the dap request consumer.
""" """
self._dap_request_thread = self.connector.consumer( self.connector.register(
topics=MessageEndpoints.dap_request(), cb=self._dap_request_callback, parent=self topics=MessageEndpoints.dap_request(), cb=self._dap_request_callback
) )
self._dap_request_thread.start()
@staticmethod def _dap_request_callback(self, msg: MessageObject) -> None:
def _dap_request_callback(msg: MessageObject, *, parent: DAPServiceManager) -> None:
""" """
Callback function for dap request consumer. Callback function for dap request consumer.
@ -41,7 +38,7 @@ class DAPServiceManager:
dap_request_msg = messages.DAPRequestMessage.loads(msg.value) dap_request_msg = messages.DAPRequestMessage.loads(msg.value)
if not dap_request_msg: if not dap_request_msg:
return return
parent.process_dap_request(dap_request_msg) self.process_dap_request(dap_request_msg)
def process_dap_request(self, dap_request_msg: messages.DAPRequestMessage) -> None: def process_dap_request(self, dap_request_msg: messages.DAPRequestMessage) -> None:
""" """
@ -153,7 +150,7 @@ class DAPServiceManager:
dap_response_msg = messages.DAPResponseMessage( dap_response_msg = messages.DAPResponseMessage(
success=success, data=data, error=error, dap_request=dap_request_msg, metadata=metadata success=success, data=data, error=error, dap_request=dap_request_msg, metadata=metadata
) )
self.producer.set_and_publish( self.connector.set_and_publish(
MessageEndpoints.dap_response(metadata.get("RID")), dap_response_msg, expire=60 MessageEndpoints.dap_response(metadata.get("RID")), dap_response_msg, expire=60
) )
@ -168,7 +165,6 @@ class DAPServiceManager:
return return
self.client = client self.client = client
self.connector = client.connector self.connector = client.connector
self.producer = self.connector.producer()
self._start_dap_request_consumer() self._start_dap_request_consumer()
self.update_available_dap_services() self.update_available_dap_services()
self.publish_available_services() self.publish_available_services()
@ -264,12 +260,12 @@ class DAPServiceManager:
"""send all available dap services to the broker""" """send all available dap services to the broker"""
msg = messages.AvailableResourceMessage(resource=self.available_dap_services) msg = messages.AvailableResourceMessage(resource=self.available_dap_services)
# pylint: disable=protected-access # pylint: disable=protected-access
self.producer.set( self.connector.set(
MessageEndpoints.dap_available_plugins(f"DAPServer/{self.client._service_id}"), msg MessageEndpoints.dap_available_plugins(f"DAPServer/{self.client._service_id}"), msg
) )
def shutdown(self) -> None: def shutdown(self) -> None:
if not self._started: if not self._started:
return return
self._dap_request_thread.stop() self.connector.shutdown()
self._started = False self._started = False

View File

@ -148,7 +148,7 @@ class LmfitService1D(DAPServiceBase):
out = self.process() out = self.process()
if out: if out:
stream_output, metadata = out stream_output, metadata = out
self.client.producer.xadd( self.client.connector.xadd(
MessageEndpoints.processed_data(self.model.__class__.__name__), MessageEndpoints.processed_data(self.model.__class__.__name__),
msg={ msg={
"data": MsgpackSerialization.dumps( "data": MsgpackSerialization.dumps(

View File

@ -72,7 +72,7 @@ def test_DAPServiceManager_init(service_manager):
def test_DAPServiceManager_request_callback(service_manager, msg, process_called): def test_DAPServiceManager_request_callback(service_manager, msg, process_called):
msg_obj = MessageObject(value=msg, topic="topic") msg_obj = MessageObject(value=msg, topic="topic")
with mock.patch.object(service_manager, "process_dap_request") as mock_process_dap_request: with mock.patch.object(service_manager, "process_dap_request") as mock_process_dap_request:
service_manager._dap_request_callback(msg_obj, parent=service_manager) service_manager._dap_request_callback(msg_obj)
if process_called: if process_called:
mock_process_dap_request.assert_called_once_with(msg) mock_process_dap_request.assert_called_once_with(msg)

View File

@ -134,7 +134,7 @@ def test_LmfitService1D_process_until_finished(lmfit_service):
lmfit_service.process_until_finished(event) lmfit_service.process_until_finished(event)
assert get_data.call_count == 2 assert get_data.call_count == 2
assert process.call_count == 2 assert process.call_count == 2
assert lmfit_service.client.producer.xadd.call_count == 2 assert lmfit_service.client.connector.xadd.call_count == 2
def test_LmfitService1D_configure(lmfit_service): def test_LmfitService1D_configure(lmfit_service):

View File

@ -18,7 +18,7 @@ from device_server.rpc_mixin import RPCMixin
logger = bec_logger.logger logger = bec_logger.logger
consumer_stop = threading.Event() register_stop = threading.Event()
class DisabledDeviceError(Exception): class DisabledDeviceError(Exception):
@ -38,14 +38,10 @@ class DeviceServer(RPCMixin, BECService):
super().__init__(config, connector_cls, unique_service=True) super().__init__(config, connector_cls, unique_service=True)
self._tasks = [] self._tasks = []
self.device_manager = None self.device_manager = None
self.threads = [] self.connector.register(
self.sig_thread = None
self.sig_thread = self.connector.consumer(
MessageEndpoints.scan_queue_modification(), MessageEndpoints.scan_queue_modification(),
cb=self.consumer_interception_callback, cb=self.register_interception_callback,
parent=self,
) )
self.sig_thread.start()
self.executor = ThreadPoolExecutor(max_workers=4) self.executor = ThreadPoolExecutor(max_workers=4)
self._start_device_manager() self._start_device_manager()
@ -55,19 +51,16 @@ class DeviceServer(RPCMixin, BECService):
def start(self) -> None: def start(self) -> None:
"""start the device server""" """start the device server"""
if consumer_stop.is_set(): if register_stop.is_set():
consumer_stop.clear() register_stop.clear()
self.connector.register(
MessageEndpoints.device_instructions(),
event=register_stop,
cb=self.instructions_callback,
parent=self,
)
self.threads = [
self.connector.consumer(
MessageEndpoints.device_instructions(),
event=consumer_stop,
cb=self.instructions_callback,
parent=self,
)
]
for thread in self.threads:
thread.start()
self.status = BECStatus.RUNNING self.status = BECStatus.RUNNING
def update_status(self, status: BECStatus): def update_status(self, status: BECStatus):
@ -76,17 +69,13 @@ class DeviceServer(RPCMixin, BECService):
def stop(self) -> None: def stop(self) -> None:
"""stop the device server""" """stop the device server"""
consumer_stop.set() register_stop.set()
for thread in self.threads:
thread.join()
self.status = BECStatus.IDLE self.status = BECStatus.IDLE
def shutdown(self) -> None: def shutdown(self) -> None:
"""shutdown the device server""" """shutdown the device server"""
super().shutdown() super().shutdown()
self.stop() self.stop()
self.sig_thread.signal_event.set()
self.sig_thread.join()
self.device_manager.shutdown() self.device_manager.shutdown()
def _update_device_metadata(self, instr) -> None: def _update_device_metadata(self, instr) -> None:
@ -97,8 +86,7 @@ class DeviceServer(RPCMixin, BECService):
device_root = dev.split(".")[0] device_root = dev.split(".")[0]
self.device_manager.devices.get(device_root).metadata = instr.metadata self.device_manager.devices.get(device_root).metadata = instr.metadata
@staticmethod def register_interception_callback(self, msg, **_kwargs) -> None:
def consumer_interception_callback(msg, *, parent, **_kwargs) -> None:
"""callback for receiving scan modifications / interceptions""" """callback for receiving scan modifications / interceptions"""
mvalue = msg.value mvalue = msg.value
if mvalue is None: if mvalue is None:
@ -106,7 +94,7 @@ class DeviceServer(RPCMixin, BECService):
return return
logger.info(f"Receiving: {mvalue.content}") logger.info(f"Receiving: {mvalue.content}")
if mvalue.content.get("action") in ["pause", "abort", "halt"]: if mvalue.content.get("action") in ["pause", "abort", "halt"]:
parent.stop_devices() self.stop_devices()
def stop_devices(self) -> None: def stop_devices(self) -> None:
"""stop all enabled devices""" """stop all enabled devices"""
@ -279,7 +267,7 @@ class DeviceServer(RPCMixin, BECService):
devices = instr.content["device"] devices = instr.content["device"]
if not isinstance(devices, list): if not isinstance(devices, list):
devices = [devices] devices = [devices]
pipe = self.producer.pipeline() pipe = self.connector.pipeline()
for dev in devices: for dev in devices:
obj = self.device_manager.devices.get(dev) obj = self.device_manager.devices.get(dev)
obj.metadata = instr.metadata obj.metadata = instr.metadata
@ -288,11 +276,11 @@ class DeviceServer(RPCMixin, BECService):
dev_msg = messages.DeviceReqStatusMessage( dev_msg = messages.DeviceReqStatusMessage(
device=dev, success=True, metadata=instr.metadata device=dev, success=True, metadata=instr.metadata
) )
self.producer.set_and_publish(MessageEndpoints.device_req_status(dev), dev_msg, pipe) self.connector.set_and_publish(MessageEndpoints.device_req_status(dev), dev_msg, pipe)
pipe.execute() pipe.execute()
def _status_callback(self, status): def _status_callback(self, status):
pipe = self.producer.pipeline() pipe = self.connector.pipeline()
if hasattr(status, "device"): if hasattr(status, "device"):
obj = status.device obj = status.device
else: else:
@ -302,12 +290,12 @@ class DeviceServer(RPCMixin, BECService):
device=device_name, success=status.success, metadata=status.instruction.metadata device=device_name, success=status.success, metadata=status.instruction.metadata
) )
logger.debug(f"req status for device {device_name}: {status.success}") logger.debug(f"req status for device {device_name}: {status.success}")
self.producer.set_and_publish( self.connector.set_and_publish(
MessageEndpoints.device_req_status(device_name), dev_msg, pipe MessageEndpoints.device_req_status(device_name), dev_msg, pipe
) )
response = status.instruction.metadata.get("response") response = status.instruction.metadata.get("response")
if response: if response:
self.producer.lpush( self.connector.lpush(
MessageEndpoints.device_req_status(status.instruction.metadata["RID"]), MessageEndpoints.device_req_status(status.instruction.metadata["RID"]),
dev_msg, dev_msg,
pipe, pipe,
@ -328,7 +316,7 @@ class DeviceServer(RPCMixin, BECService):
dev_config_msg = messages.DeviceMessage( dev_config_msg = messages.DeviceMessage(
signals=obj.root.read_configuration(), metadata=metadata signals=obj.root.read_configuration(), metadata=metadata
) )
self.producer.set_and_publish( self.connector.set_and_publish(
MessageEndpoints.device_read_configuration(obj.root.name), dev_config_msg, pipe MessageEndpoints.device_read_configuration(obj.root.name), dev_config_msg, pipe
) )
@ -342,7 +330,7 @@ class DeviceServer(RPCMixin, BECService):
def _read_and_update_devices(self, devices: list[str], metadata: dict) -> list: def _read_and_update_devices(self, devices: list[str], metadata: dict) -> list:
start = time.time() start = time.time()
pipe = self.producer.pipeline() pipe = self.connector.pipeline()
signal_container = [] signal_container = []
for dev in devices: for dev in devices:
device_root = dev.split(".")[0] device_root = dev.split(".")[0]
@ -354,17 +342,17 @@ class DeviceServer(RPCMixin, BECService):
except Exception as exc: except Exception as exc:
signals = self._retry_obj_method(dev, obj, "read", exc) signals = self._retry_obj_method(dev, obj, "read", exc)
self.producer.set_and_publish( self.connector.set_and_publish(
MessageEndpoints.device_read(device_root), MessageEndpoints.device_read(device_root),
messages.DeviceMessage(signals=signals, metadata=metadata), messages.DeviceMessage(signals=signals, metadata=metadata),
pipe, pipe,
) )
self.producer.set_and_publish( self.connector.set_and_publish(
MessageEndpoints.device_readback(device_root), MessageEndpoints.device_readback(device_root),
messages.DeviceMessage(signals=signals, metadata=metadata), messages.DeviceMessage(signals=signals, metadata=metadata),
pipe, pipe,
) )
self.producer.set( self.connector.set(
MessageEndpoints.device_status(device_root), MessageEndpoints.device_status(device_root),
messages.DeviceStatusMessage(device=device_root, status=0, metadata=metadata), messages.DeviceStatusMessage(device=device_root, status=0, metadata=metadata),
pipe, pipe,
@ -377,7 +365,7 @@ class DeviceServer(RPCMixin, BECService):
def _read_config_and_update_devices(self, devices: list[str], metadata: dict) -> list: def _read_config_and_update_devices(self, devices: list[str], metadata: dict) -> list:
start = time.time() start = time.time()
pipe = self.producer.pipeline() pipe = self.connector.pipeline()
signal_container = [] signal_container = []
for dev in devices: for dev in devices:
self.device_manager.devices.get(dev).metadata = metadata self.device_manager.devices.get(dev).metadata = metadata
@ -387,7 +375,7 @@ class DeviceServer(RPCMixin, BECService):
signal_container.append(signals) signal_container.append(signals)
except Exception as exc: except Exception as exc:
signals = self._retry_obj_method(dev, obj, "read_configuration", exc) signals = self._retry_obj_method(dev, obj, "read_configuration", exc)
self.producer.set_and_publish( self.connector.set_and_publish(
MessageEndpoints.device_read_configuration(dev), MessageEndpoints.device_read_configuration(dev),
messages.DeviceMessage(signals=signals, metadata=metadata), messages.DeviceMessage(signals=signals, metadata=metadata),
pipe, pipe,
@ -420,9 +408,11 @@ class DeviceServer(RPCMixin, BECService):
f"Failed to run {method} on device {device_root}. Trying to load an old value." f"Failed to run {method} on device {device_root}. Trying to load an old value."
) )
if method == "read": if method == "read":
old_msg = self.producer.get(MessageEndpoints.device_read(device_root)) old_msg = self.connector.get(MessageEndpoints.device_read(device_root))
elif method == "read_configuration": elif method == "read_configuration":
old_msg = self.producer.get(MessageEndpoints.device_read_configuration(device_root)) old_msg = self.connector.get(
MessageEndpoints.device_read_configuration(device_root)
)
else: else:
raise ValueError(f"Unknown method {method}.") raise ValueError(f"Unknown method {method}.")
if not old_msg: if not old_msg:
@ -435,7 +425,7 @@ class DeviceServer(RPCMixin, BECService):
if not isinstance(devices, list): if not isinstance(devices, list):
devices = [devices] devices = [devices]
pipe = self.producer.pipeline() pipe = self.connector.pipeline()
for dev in devices: for dev in devices:
obj = self.device_manager.devices[dev].obj obj = self.device_manager.devices[dev].obj
if hasattr(obj, "_staged"): if hasattr(obj, "_staged"):
@ -444,7 +434,7 @@ class DeviceServer(RPCMixin, BECService):
logger.info(f"Device {obj.name} was already staged and will be first unstaged.") logger.info(f"Device {obj.name} was already staged and will be first unstaged.")
self.device_manager.devices[dev].obj.unstage() self.device_manager.devices[dev].obj.unstage()
self.device_manager.devices[dev].obj.stage() self.device_manager.devices[dev].obj.stage()
self.producer.set( self.connector.set(
MessageEndpoints.device_staged(dev), MessageEndpoints.device_staged(dev),
messages.DeviceStatusMessage(device=dev, status=1, metadata=instr.metadata), messages.DeviceStatusMessage(device=dev, status=1, metadata=instr.metadata),
pipe, pipe,
@ -456,7 +446,7 @@ class DeviceServer(RPCMixin, BECService):
if not isinstance(devices, list): if not isinstance(devices, list):
devices = [devices] devices = [devices]
pipe = self.producer.pipeline() pipe = self.connector.pipeline()
for dev in devices: for dev in devices:
obj = self.device_manager.devices[dev].obj obj = self.device_manager.devices[dev].obj
if hasattr(obj, "_staged"): if hasattr(obj, "_staged"):
@ -465,7 +455,7 @@ class DeviceServer(RPCMixin, BECService):
self.device_manager.devices[dev].obj.unstage() self.device_manager.devices[dev].obj.unstage()
else: else:
logger.debug(f"Device {obj.name} was already unstaged.") logger.debug(f"Device {obj.name} was already unstaged.")
self.producer.set( self.connector.set(
MessageEndpoints.device_staged(dev), MessageEndpoints.device_staged(dev),
messages.DeviceStatusMessage(device=dev, status=0, metadata=instr.metadata), messages.DeviceStatusMessage(device=dev, status=0, metadata=instr.metadata),
pipe, pipe,

View File

@ -16,17 +16,11 @@ class ConfigUpdateHandler:
def __init__(self, device_manager: DeviceManagerDS) -> None: def __init__(self, device_manager: DeviceManagerDS) -> None:
self.device_manager = device_manager self.device_manager = device_manager
self.connector = self.device_manager.connector self.connector = self.device_manager.connector
self._config_request_handler = None self.connector.register(
self._start_config_handler()
def _start_config_handler(self) -> None:
self._config_request_handler = self.connector.consumer(
MessageEndpoints.device_server_config_request(), MessageEndpoints.device_server_config_request(),
cb=self._device_config_callback, cb=self._device_config_callback,
parent=self, parent=self,
) )
self._config_request_handler.start()
@staticmethod @staticmethod
def _device_config_callback(msg, *, parent, **_kwargs) -> None: def _device_config_callback(msg, *, parent, **_kwargs) -> None:
@ -74,7 +68,7 @@ class ConfigUpdateHandler:
accepted=accepted, message=error_msg, metadata=metadata accepted=accepted, message=error_msg, metadata=metadata
) )
RID = metadata.get("RID") RID = metadata.get("RID")
self.device_manager.producer.set( self.device_manager.connector.set(
MessageEndpoints.device_config_request_response(RID), msg, expire=60 MessageEndpoints.device_config_request_response(RID), msg, expire=60
) )
@ -97,7 +91,7 @@ class ConfigUpdateHandler:
"low": device.obj.low_limit_travel.get(), "low": device.obj.low_limit_travel.get(),
"high": device.obj.high_limit_travel.get(), "high": device.obj.high_limit_travel.get(),
} }
self.device_manager.producer.set_and_publish( self.device_manager.connector.set_and_publish(
MessageEndpoints.device_limits(device.name), MessageEndpoints.device_limits(device.name),
messages.DeviceMessage(signals=limits), messages.DeviceMessage(signals=limits),
) )

View File

@ -51,7 +51,7 @@ class DSDevice(DeviceBase):
self.metadata = {} self.metadata = {}
self.initialized = False self.initialized = False
def initialize_device_buffer(self, producer): def initialize_device_buffer(self, connector):
"""initialize the device read and readback buffer on redis with a new reading""" """initialize the device read and readback buffer on redis with a new reading"""
dev_msg = messages.DeviceMessage(signals=self.obj.read(), metadata={}) dev_msg = messages.DeviceMessage(signals=self.obj.read(), metadata={})
dev_config_msg = messages.DeviceMessage(signals=self.obj.read_configuration(), metadata={}) dev_config_msg = messages.DeviceMessage(signals=self.obj.read_configuration(), metadata={})
@ -62,14 +62,14 @@ class DSDevice(DeviceBase):
} }
else: else:
limits = None limits = None
pipe = producer.pipeline() pipe = connector.pipeline()
producer.set_and_publish(MessageEndpoints.device_readback(self.name), dev_msg, pipe=pipe) connector.set_and_publish(MessageEndpoints.device_readback(self.name), dev_msg, pipe=pipe)
producer.set(topic=MessageEndpoints.device_read(self.name), msg=dev_msg, pipe=pipe) connector.set(topic=MessageEndpoints.device_read(self.name), msg=dev_msg, pipe=pipe)
producer.set_and_publish( connector.set_and_publish(
MessageEndpoints.device_read_configuration(self.name), dev_config_msg, pipe=pipe MessageEndpoints.device_read_configuration(self.name), dev_config_msg, pipe=pipe
) )
if limits is not None: if limits is not None:
producer.set_and_publish( connector.set_and_publish(
MessageEndpoints.device_limits(self.name), MessageEndpoints.device_limits(self.name),
messages.DeviceMessage(signals=limits), messages.DeviceMessage(signals=limits),
pipe=pipe, pipe=pipe,
@ -318,7 +318,7 @@ class DeviceManagerDS(DeviceManagerBase):
self.update_config(obj, config) self.update_config(obj, config)
# refresh the device info # refresh the device info
pipe = self.producer.pipeline() pipe = self.connector.pipeline()
self.reset_device_data(obj, pipe) self.reset_device_data(obj, pipe)
self.publish_device_info(obj, pipe) self.publish_device_info(obj, pipe)
pipe.execute() pipe.execute()
@ -369,7 +369,7 @@ class DeviceManagerDS(DeviceManagerBase):
def initialize_enabled_device(self, opaas_obj): def initialize_enabled_device(self, opaas_obj):
"""connect to an enabled device and initialize the device buffer""" """connect to an enabled device and initialize the device buffer"""
self.connect_device(opaas_obj.obj) self.connect_device(opaas_obj.obj)
opaas_obj.initialize_device_buffer(self.producer) opaas_obj.initialize_device_buffer(self.connector)
@staticmethod @staticmethod
def disconnect_device(obj): def disconnect_device(obj):
@ -420,7 +420,7 @@ class DeviceManagerDS(DeviceManagerBase):
""" """
interface = get_device_info(obj, {}) interface = get_device_info(obj, {})
self.producer.set( self.connector.set(
MessageEndpoints.device_info(obj.name), MessageEndpoints.device_info(obj.name),
messages.DeviceInfoMessage(device=obj.name, info=interface), messages.DeviceInfoMessage(device=obj.name, info=interface),
pipe, pipe,
@ -428,9 +428,9 @@ class DeviceManagerDS(DeviceManagerBase):
def reset_device_data(self, obj: OphydObject, pipe=None) -> None: def reset_device_data(self, obj: OphydObject, pipe=None) -> None:
"""delete all device data and device info""" """delete all device data and device info"""
self.producer.delete(MessageEndpoints.device_status(obj.name), pipe) self.connector.delete(MessageEndpoints.device_status(obj.name), pipe)
self.producer.delete(MessageEndpoints.device_read(obj.name), pipe) self.connector.delete(MessageEndpoints.device_read(obj.name), pipe)
self.producer.delete(MessageEndpoints.device_info(obj.name), pipe) self.connector.delete(MessageEndpoints.device_info(obj.name), pipe)
def _obj_callback_readback(self, *_args, obj: OphydObject, **kwargs): def _obj_callback_readback(self, *_args, obj: OphydObject, **kwargs):
if obj.connected: if obj.connected:
@ -438,8 +438,8 @@ class DeviceManagerDS(DeviceManagerBase):
signals = obj.read() signals = obj.read()
metadata = self.devices.get(obj.root.name).metadata metadata = self.devices.get(obj.root.name).metadata
dev_msg = messages.DeviceMessage(signals=signals, metadata=metadata) dev_msg = messages.DeviceMessage(signals=signals, metadata=metadata)
pipe = self.producer.pipeline() pipe = self.connector.pipeline()
self.producer.set_and_publish(MessageEndpoints.device_readback(name), dev_msg, pipe) self.connector.set_and_publish(MessageEndpoints.device_readback(name), dev_msg, pipe)
pipe.execute() pipe.execute()
@typechecked @typechecked
@ -466,7 +466,7 @@ class DeviceManagerDS(DeviceManagerBase):
metadata = self.devices[name].metadata metadata = self.devices[name].metadata
msg = messages.DeviceMonitorMessage(device=name, data=value, metadata=metadata) msg = messages.DeviceMonitorMessage(device=name, data=value, metadata=metadata)
stream_msg = {"data": msg} stream_msg = {"data": msg}
self.producer.xadd( self.connector.xadd(
MessageEndpoints.device_monitor(name), MessageEndpoints.device_monitor(name),
stream_msg, stream_msg,
max_size=min(100, int(max_size // dsize)), max_size=min(100, int(max_size // dsize)),
@ -476,7 +476,7 @@ class DeviceManagerDS(DeviceManagerBase):
device = kwargs["obj"].root.name device = kwargs["obj"].root.name
status = 0 status = 0
metadata = self.devices[device].metadata metadata = self.devices[device].metadata
self.producer.send( self.connector.send(
MessageEndpoints.device_status(device), MessageEndpoints.device_status(device),
messages.DeviceStatusMessage(device=device, status=status, metadata=metadata), messages.DeviceStatusMessage(device=device, status=status, metadata=metadata),
) )
@ -489,8 +489,8 @@ class DeviceManagerDS(DeviceManagerBase):
device = kwargs["obj"].root.name device = kwargs["obj"].root.name
status = int(kwargs.get("value")) status = int(kwargs.get("value"))
metadata = self.devices[device].metadata metadata = self.devices[device].metadata
self.producer.set( self.connector.set(
MessageEndpoints.device_status(kwargs["obj"].root.name), MessageEndpoints.device_status(device),
messages.DeviceStatusMessage(device=device, status=status, metadata=metadata), messages.DeviceStatusMessage(device=device, status=status, metadata=metadata),
) )
@ -521,12 +521,12 @@ class DeviceManagerDS(DeviceManagerBase):
) )
) )
ds_obj.emitted_points[metadata["scanID"]] = max_points ds_obj.emitted_points[metadata["scanID"]] = max_points
pipe = self.producer.pipeline() pipe = self.connector.pipeline()
self.producer.send(MessageEndpoints.device_read(obj.root.name), bundle, pipe=pipe) self.connector.send(MessageEndpoints.device_read(obj.root.name), bundle, pipe=pipe)
msg = messages.DeviceStatusMessage( msg = messages.DeviceStatusMessage(
device=obj.root.name, status=max_points, metadata=metadata device=obj.root.name, status=max_points, metadata=metadata
) )
self.producer.set_and_publish( self.connector.set_and_publish(
MessageEndpoints.device_progress(obj.root.name), msg, pipe=pipe MessageEndpoints.device_progress(obj.root.name), msg, pipe=pipe
) )
pipe.execute() pipe.execute()
@ -536,4 +536,4 @@ class DeviceManagerDS(DeviceManagerBase):
msg = messages.ProgressMessage( msg = messages.ProgressMessage(
value=value, max_value=max_value, done=done, metadata=metadata value=value, max_value=max_value, done=done, metadata=metadata
) )
self.producer.set_and_publish(MessageEndpoints.device_progress(obj.root.name), msg) self.connector.set_and_publish(MessageEndpoints.device_progress(obj.root.name), msg)

View File

@ -70,7 +70,7 @@ class RPCMixin:
def _send_rpc_result_to_client( def _send_rpc_result_to_client(
self, device: str, instr_params: dict, res: Any, result: StringIO self, device: str, instr_params: dict, res: Any, result: StringIO
): ):
self.producer.set( self.connector.set(
MessageEndpoints.device_rpc(instr_params.get("rpc_id")), MessageEndpoints.device_rpc(instr_params.get("rpc_id")),
messages.DeviceRPCMessage( messages.DeviceRPCMessage(
device=device, return_val=res, out=result.getvalue(), success=True device=device, return_val=res, out=result.getvalue(), success=True
@ -175,7 +175,7 @@ class RPCMixin:
} }
logger.info(f"Received exception: {exc_formatted}, {exc}") logger.info(f"Received exception: {exc_formatted}, {exc}")
instr_params = instr.content.get("parameter") instr_params = instr.content.get("parameter")
self.producer.set( self.connector.set(
MessageEndpoints.device_rpc(instr_params.get("rpc_id")), MessageEndpoints.device_rpc(instr_params.get("rpc_id")),
messages.DeviceRPCMessage( messages.DeviceRPCMessage(
device=instr.content["device"], return_val=None, out=exc_formatted, success=False device=instr.content["device"], return_val=None, out=exc_formatted, success=False

View File

@ -6,7 +6,7 @@ import numpy as np
import pytest import pytest
import yaml import yaml
from bec_lib import MessageEndpoints, messages from bec_lib import MessageEndpoints, messages
from bec_lib.tests.utils import ConnectorMock, ProducerMock, create_session_from_config from bec_lib.tests.utils import ConnectorMock, create_session_from_config
from device_server.devices.devicemanager import DeviceManagerDS from device_server.devices.devicemanager import DeviceManagerDS
@ -52,7 +52,7 @@ def load_device_manager():
service_mock = mock.MagicMock() service_mock = mock.MagicMock()
service_mock.connector = ConnectorMock("", store_data=False) service_mock.connector = ConnectorMock("", store_data=False)
device_manager = DeviceManagerDS(service_mock, "") device_manager = DeviceManagerDS(service_mock, "")
device_manager.producer = service_mock.connector.producer() device_manager.connector = service_mock.connector
device_manager.config_update_handler = mock.MagicMock() device_manager.config_update_handler = mock.MagicMock()
with open(f"{dir_path}/tests/test_config.yaml", "r") as session_file: with open(f"{dir_path}/tests/test_config.yaml", "r") as session_file:
device_manager._session = create_session_from_config(yaml.safe_load(session_file)) device_manager._session = create_session_from_config(yaml.safe_load(session_file))
@ -133,10 +133,10 @@ def test_flyer_event_callback():
device_manager._obj_flyer_callback( device_manager._obj_flyer_callback(
obj=samx.obj, value={"data": {"idata": np.random.rand(20), "edata": np.random.rand(20)}} obj=samx.obj, value={"data": {"idata": np.random.rand(20), "edata": np.random.rand(20)}}
) )
pipe = device_manager.producer.pipeline() pipe = device_manager.connector.pipeline()
bundle, progress = pipe._pipe_buffer[-2:] bundle, progress = pipe._pipe_buffer[-2:]
# check producer method # check connector method
assert bundle[0] == "send" assert bundle[0] == "send"
assert progress[0] == "set_and_publish" assert progress[0] == "set_and_publish"
@ -157,9 +157,9 @@ def test_obj_progress_callback():
samx = device_manager.devices.samx samx = device_manager.devices.samx
samx.metadata = {"scanID": "12345"} samx.metadata = {"scanID": "12345"}
with mock.patch.object(device_manager, "producer") as mock_producer: with mock.patch.object(device_manager, "connector") as mock_connector:
device_manager._obj_progress_callback(obj=samx.obj, value=1, max_value=2, done=False) device_manager._obj_progress_callback(obj=samx.obj, value=1, max_value=2, done=False)
mock_producer.set_and_publish.assert_called_once_with( mock_connector.set_and_publish.assert_called_once_with(
MessageEndpoints.device_progress("samx"), MessageEndpoints.device_progress("samx"),
messages.ProgressMessage( messages.ProgressMessage(
value=1, max_value=2, done=False, metadata={"scanID": "12345"} value=1, max_value=2, done=False, metadata={"scanID": "12345"}
@ -176,9 +176,9 @@ def test_obj_monitor_callback(value):
eiger.metadata = {"scanID": "12345"} eiger.metadata = {"scanID": "12345"}
value_size = len(value.tobytes()) / 1e6 # MB value_size = len(value.tobytes()) / 1e6 # MB
max_size = 100 max_size = 100
with mock.patch.object(device_manager, "producer") as mock_producer: with mock.patch.object(device_manager, "connector") as mock_connector:
device_manager._obj_callback_monitor(obj=eiger.obj, value=value) device_manager._obj_callback_monitor(obj=eiger.obj, value=value)
mock_producer.xadd.assert_called_once_with( mock_connector.xadd.assert_called_once_with(
MessageEndpoints.device_monitor(eiger.name), MessageEndpoints.device_monitor(eiger.name),
{ {
"data": messages.DeviceMonitorMessage( "data": messages.DeviceMonitorMessage(

View File

@ -7,7 +7,7 @@ from bec_lib import Alarms, MessageEndpoints, ServiceConfig, messages
from bec_lib.device import OnFailure from bec_lib.device import OnFailure
from bec_lib.messages import BECStatus from bec_lib.messages import BECStatus
from bec_lib.redis_connector import MessageObject from bec_lib.redis_connector import MessageObject
from bec_lib.tests.utils import ConnectorMock, ConsumerMock from bec_lib.tests.utils import ConnectorMock
from ophyd import Staged from ophyd import Staged
from ophyd.utils import errors as ophyd_errors from ophyd.utils import errors as ophyd_errors
from test_device_manager_ds import device_manager, load_device_manager from test_device_manager_ds import device_manager, load_device_manager
@ -54,8 +54,6 @@ def test_start(device_server_mock):
device_server.start() device_server.start()
assert device_server.threads
assert isinstance(device_server.threads[0], ConsumerMock)
assert device_server.status == BECStatus.RUNNING assert device_server.status == BECStatus.RUNNING
@ -187,10 +185,10 @@ def test_stop_devices(device_server_mock):
), ),
], ],
) )
def test_consumer_interception_callback(device_server_mock, msg, stop_called): def test_register_interception_callback(device_server_mock, msg, stop_called):
device_server = device_server_mock device_server = device_server_mock
with mock.patch.object(device_server, "stop_devices") as stop: with mock.patch.object(device_server, "stop_devices") as stop:
device_server.consumer_interception_callback(msg, parent=device_server) device_server.register_interception_callback(msg, parent=device_server)
if stop_called: if stop_called:
stop.assert_called_once() stop.assert_called_once()
else: else:
@ -640,7 +638,7 @@ def test_set_device(device_server_mock, instr):
while True: while True:
res = [ res = [
msg msg
for msg in device_server.producer.message_sent for msg in device_server.connector.message_sent
if msg["queue"] == MessageEndpoints.device_req_status("samx") if msg["queue"] == MessageEndpoints.device_req_status("samx")
] ]
if res: if res:
@ -676,7 +674,7 @@ def test_read_device(device_server_mock, instr):
for device in devices: for device in devices:
res = [ res = [
msg msg
for msg in device_server.producer.message_sent for msg in device_server.connector.message_sent
if msg["queue"] == MessageEndpoints.device_read(device) if msg["queue"] == MessageEndpoints.device_read(device)
] ]
assert res[-1]["msg"].metadata["RID"] == instr.metadata["RID"] assert res[-1]["msg"].metadata["RID"] == instr.metadata["RID"]
@ -690,7 +688,7 @@ def test_read_config_and_update_devices(device_server_mock, devices):
for device in devices: for device in devices:
res = [ res = [
msg msg
for msg in device_server.producer.message_sent for msg in device_server.connector.message_sent
if msg["queue"] == MessageEndpoints.device_read_configuration(device) if msg["queue"] == MessageEndpoints.device_read_configuration(device)
] ]
config = device_server.device_manager.devices[device].obj.read_configuration() config = device_server.device_manager.devices[device].obj.read_configuration()
@ -755,8 +753,8 @@ def test_retry_obj_method_buffer(device_server_mock, instr):
return return
signals_before = getattr(samx.obj, instr)() signals_before = getattr(samx.obj, instr)()
device_server.producer = mock.MagicMock() device_server.connector = mock.MagicMock()
device_server.producer.get.return_value = messages.DeviceMessage( device_server.connector.get.return_value = messages.DeviceMessage(
signals=signals_before, metadata={"RID": "test", "stream": "primary"} signals=signals_before, metadata={"RID": "test", "stream": "primary"}
) )

View File

@ -13,7 +13,7 @@ from device_server.rpc_mixin import RPCMixin
def rpc_cls(): def rpc_cls():
rpc_mixin = RPCMixin() rpc_mixin = RPCMixin()
rpc_mixin.connector = mock.MagicMock() rpc_mixin.connector = mock.MagicMock()
rpc_mixin.producer = mock.MagicMock() rpc_mixin.connector = mock.MagicMock()
rpc_mixin.device_manager = mock.MagicMock() rpc_mixin.device_manager = mock.MagicMock()
yield rpc_mixin yield rpc_mixin
@ -93,7 +93,7 @@ def test_get_result_from_rpc_list_from_stage(rpc_cls):
def test_send_rpc_exception(rpc_cls, instr): def test_send_rpc_exception(rpc_cls, instr):
rpc_cls._send_rpc_exception(Exception(), instr) rpc_cls._send_rpc_exception(Exception(), instr)
rpc_cls.producer.set.assert_called_once_with( rpc_cls.connector.set.assert_called_once_with(
MessageEndpoints.device_rpc("rpc_id"), MessageEndpoints.device_rpc("rpc_id"),
messages.DeviceRPCMessage( messages.DeviceRPCMessage(
device="device", device="device",
@ -108,7 +108,7 @@ def test_send_rpc_result_to_client(rpc_cls):
result = mock.MagicMock() result = mock.MagicMock()
result.getvalue.return_value = "result" result.getvalue.return_value = "result"
rpc_cls._send_rpc_result_to_client("device", {"rpc_id": "rpc_id"}, 1, result) rpc_cls._send_rpc_result_to_client("device", {"rpc_id": "rpc_id"}, 1, result)
rpc_cls.producer.set.assert_called_once_with( rpc_cls.connector.set.assert_called_once_with(
MessageEndpoints.device_rpc("rpc_id"), MessageEndpoints.device_rpc("rpc_id"),
messages.DeviceRPCMessage(device="device", return_val=1, out="result", success=True), messages.DeviceRPCMessage(device="device", return_val=1, out="result", success=True),
expire=1800, expire=1800,

View File

@ -325,7 +325,7 @@ class NexusFileWriter(FileWriter):
file_data[key] = val if not isinstance(val, list) else merge_dicts(val) file_data[key] = val if not isinstance(val, list) else merge_dicts(val)
msg_data = {"file_path": file_path, "data": file_data} msg_data = {"file_path": file_path, "data": file_data}
msg = messages.FileContentMessage(**msg_data) msg = messages.FileContentMessage(**msg_data)
self.file_writer_manager.producer.set_and_publish(MessageEndpoints.file_content(), msg) self.file_writer_manager.connector.set_and_publish(MessageEndpoints.file_content(), msg)
with h5py.File(file_path, "w") as file: with h5py.File(file_path, "w") as file:
HDF5StorageWriter.write(writer_storage._storage, device_storage, file) HDF5StorageWriter.write(writer_storage._storage, device_storage, file)

View File

@ -80,10 +80,13 @@ class FileWriterManager(BECService):
self._lock = threading.RLock() self._lock = threading.RLock()
self.file_writer_config = self._service_config.service_config.get("file_writer") self.file_writer_config = self._service_config.service_config.get("file_writer")
self.writer_mixin = FileWriterMixin(self.file_writer_config) self.writer_mixin = FileWriterMixin(self.file_writer_config)
self.producer = self.connector.producer()
self._start_device_manager() self._start_device_manager()
self._start_scan_segment_consumer() self.connector.register(
self._start_scan_status_consumer() patterns=MessageEndpoints.scan_segment(), cb=self._scan_segment_callback, parent=self
)
self.connector.register(
MessageEndpoints.scan_status(), cb=self._scan_status_callback, parent=self
)
self.scan_storage = {} self.scan_storage = {}
self.file_writer = NexusFileWriter(self) self.file_writer = NexusFileWriter(self)
@ -92,20 +95,7 @@ class FileWriterManager(BECService):
self.device_manager = DeviceManagerBase(self) self.device_manager = DeviceManagerBase(self)
self.device_manager.initialize([self.bootstrap_server]) self.device_manager.initialize([self.bootstrap_server])
def _start_scan_segment_consumer(self): def _scan_segment_callback(self, msg: MessageObject, *, parent: FileWriterManager):
self._scan_segment_consumer = self.connector.consumer(
pattern=MessageEndpoints.scan_segment(), cb=self._scan_segment_callback, parent=self
)
self._scan_segment_consumer.start()
def _start_scan_status_consumer(self):
self._scan_status_consumer = self.connector.consumer(
MessageEndpoints.scan_status(), cb=self._scan_status_callback, parent=self
)
self._scan_status_consumer.start()
@staticmethod
def _scan_segment_callback(msg: MessageObject, *, parent: FileWriterManager):
msgs = msg.value msgs = msg.value
for scan_msg in msgs: for scan_msg in msgs:
parent.insert_to_scan_storage(scan_msg) parent.insert_to_scan_storage(scan_msg)
@ -188,7 +178,7 @@ class FileWriterManager(BECService):
return return
if self.scan_storage[scanID].baseline: if self.scan_storage[scanID].baseline:
return return
baseline = self.producer.get(MessageEndpoints.public_scan_baseline(scanID)) baseline = self.connector.get(MessageEndpoints.public_scan_baseline(scanID))
if not baseline: if not baseline:
return return
self.scan_storage[scanID].baseline = baseline.content["data"] self.scan_storage[scanID].baseline = baseline.content["data"]
@ -205,13 +195,13 @@ class FileWriterManager(BECService):
""" """
if not self.scan_storage.get(scanID): if not self.scan_storage.get(scanID):
return return
msgs = self.producer.keys(MessageEndpoints.public_file(scanID, "*")) msgs = self.connector.keys(MessageEndpoints.public_file(scanID, "*"))
if not msgs: if not msgs:
return return
# extract name from 'public/<scanID>/file/<name>' # extract name from 'public/<scanID>/file/<name>'
names = [msg.decode().split("/")[-1] for msg in msgs] names = [msg.decode().split("/")[-1] for msg in msgs]
file_msgs = [self.producer.get(msg.decode()) for msg in msgs] file_msgs = [self.connector.get(msg.decode()) for msg in msgs]
if not file_msgs: if not file_msgs:
return return
for name, file_msg in zip(names, file_msgs): for name, file_msg in zip(names, file_msgs):
@ -236,7 +226,7 @@ class FileWriterManager(BECService):
if not self.scan_storage.get(scanID): if not self.scan_storage.get(scanID):
return return
# get all async devices # get all async devices
async_device_keys = self.producer.keys(MessageEndpoints.device_async_readback(scanID, "*")) async_device_keys = self.connector.keys(MessageEndpoints.device_async_readback(scanID, "*"))
if not async_device_keys: if not async_device_keys:
return return
for device_key in async_device_keys: for device_key in async_device_keys:
@ -244,7 +234,7 @@ class FileWriterManager(BECService):
device_name = key.split(MessageEndpoints.device_async_readback(scanID, ""))[-1].split( device_name = key.split(MessageEndpoints.device_async_readback(scanID, ""))[-1].split(
":" ":"
)[0] )[0]
msgs = self.producer.xrange(key, min="-", max="+") msgs = self.connector.xrange(key, min="-", max="+")
if not msgs: if not msgs:
continue continue
self._process_async_data(msgs, scanID, device_name) self._process_async_data(msgs, scanID, device_name)
@ -298,7 +288,7 @@ class FileWriterManager(BECService):
try: try:
file_path = self.writer_mixin.compile_full_filename(scan, file_suffix) file_path = self.writer_mixin.compile_full_filename(scan, file_suffix)
self.producer.set_and_publish( self.connector.set_and_publish(
MessageEndpoints.public_file(scanID, "master"), MessageEndpoints.public_file(scanID, "master"),
messages.FileMessage(file_path=file_path, done=False), messages.FileMessage(file_path=file_path, done=False),
) )
@ -319,7 +309,7 @@ class FileWriterManager(BECService):
) )
successful = False successful = False
self.scan_storage.pop(scanID) self.scan_storage.pop(scanID)
self.producer.set_and_publish( self.connector.set_and_publish(
MessageEndpoints.public_file(scanID, "master"), MessageEndpoints.public_file(scanID, "master"),
messages.FileMessage(file_path=file_path, successful=successful), messages.FileMessage(file_path=file_path, successful=successful),
) )

View File

@ -25,7 +25,7 @@ def load_FileWriter():
service_mock = mock.MagicMock() service_mock = mock.MagicMock()
service_mock.connector = ConnectorMock("") service_mock.connector = ConnectorMock("")
device_manager = DeviceManagerBase(service_mock, "") device_manager = DeviceManagerBase(service_mock, "")
device_manager.producer = service_mock.connector.producer() device_manager.connector = service_mock.connector
with open(f"{dir_path}/tests/test_config.yaml", "r") as session_file: with open(f"{dir_path}/tests/test_config.yaml", "r") as session_file:
device_manager._session = create_session_from_config(yaml.safe_load(session_file)) device_manager._session = create_session_from_config(yaml.safe_load(session_file))
device_manager._load_session() device_manager._load_session()
@ -152,13 +152,13 @@ def test_write_file_raises_alarm_on_error():
def test_update_baseline_reading(): def test_update_baseline_reading():
file_manager = load_FileWriter() file_manager = load_FileWriter()
file_manager.scan_storage["scanID"] = ScanStorage(10, "scanID") file_manager.scan_storage["scanID"] = ScanStorage(10, "scanID")
with mock.patch.object(file_manager, "producer") as mock_producer: with mock.patch.object(file_manager, "connector") as mock_connector:
mock_producer.get.return_value = messages.ScanBaselineMessage( mock_connector.get.return_value = messages.ScanBaselineMessage(
scanID="scanID", data={"data": "data"} scanID="scanID", data={"data": "data"}
) )
file_manager.update_baseline_reading("scanID") file_manager.update_baseline_reading("scanID")
assert file_manager.scan_storage["scanID"].baseline == {"data": "data"} assert file_manager.scan_storage["scanID"].baseline == {"data": "data"}
mock_producer.get.assert_called_once_with(MessageEndpoints.public_scan_baseline("scanID")) mock_connector.get.assert_called_once_with(MessageEndpoints.public_scan_baseline("scanID"))
def test_scan_storage_append(): def test_scan_storage_append():
@ -178,30 +178,30 @@ def test_scan_storage_ready_to_write():
def test_update_file_references(): def test_update_file_references():
file_manager = load_FileWriter() file_manager = load_FileWriter()
with mock.patch.object(file_manager, "producer") as mock_producer: with mock.patch.object(file_manager, "connector") as mock_connector:
file_manager.update_file_references("scanID") file_manager.update_file_references("scanID")
mock_producer.keys.assert_not_called() mock_connector.keys.assert_not_called()
def test_update_file_references_gets_keys(): def test_update_file_references_gets_keys():
file_manager = load_FileWriter() file_manager = load_FileWriter()
file_manager.scan_storage["scanID"] = ScanStorage(10, "scanID") file_manager.scan_storage["scanID"] = ScanStorage(10, "scanID")
with mock.patch.object(file_manager, "producer") as mock_producer: with mock.patch.object(file_manager, "connector") as mock_connector:
file_manager.update_file_references("scanID") file_manager.update_file_references("scanID")
mock_producer.keys.assert_called_once_with(MessageEndpoints.public_file("scanID", "*")) mock_connector.keys.assert_called_once_with(MessageEndpoints.public_file("scanID", "*"))
def test_update_async_data(): def test_update_async_data():
file_manager = load_FileWriter() file_manager = load_FileWriter()
file_manager.scan_storage["scanID"] = ScanStorage(10, "scanID") file_manager.scan_storage["scanID"] = ScanStorage(10, "scanID")
with mock.patch.object(file_manager, "producer") as mock_producer: with mock.patch.object(file_manager, "connector") as mock_connector:
with mock.patch.object(file_manager, "_process_async_data") as mock_process: with mock.patch.object(file_manager, "_process_async_data") as mock_process:
key = MessageEndpoints.device_async_readback("scanID", "dev1") key = MessageEndpoints.device_async_readback("scanID", "dev1")
mock_producer.keys.return_value = [key.encode()] mock_connector.keys.return_value = [key.encode()]
data = [(b"0-0", b'{"data": "data"}')] data = [(b"0-0", b'{"data": "data"}')]
mock_producer.xrange.return_value = data mock_connector.xrange.return_value = data
file_manager.update_async_data("scanID") file_manager.update_async_data("scanID")
mock_producer.xrange.assert_called_once_with(key, min="-", max="+") mock_connector.xrange.assert_called_once_with(key, min="-", max="+")
mock_process.assert_called_once_with(data, "scanID", "dev1") mock_process.assert_called_once_with(data, "scanID", "dev1")

View File

@ -14,7 +14,7 @@ if TYPE_CHECKING:
class BECEmitter(EmitterBase): class BECEmitter(EmitterBase):
def __init__(self, scan_bundler: ScanBundler) -> None: def __init__(self, scan_bundler: ScanBundler) -> None:
super().__init__(scan_bundler.producer) super().__init__(scan_bundler.connector)
self.scan_bundler = scan_bundler self.scan_bundler = scan_bundler
def on_scan_point_emit(self, scanID: str, pointID: int): def on_scan_point_emit(self, scanID: str, pointID: int):
@ -46,9 +46,16 @@ class BECEmitter(EmitterBase):
data=sb.sync_storage[scanID]["baseline"], data=sb.sync_storage[scanID]["baseline"],
metadata=sb.sync_storage[scanID]["info"], metadata=sb.sync_storage[scanID]["info"],
) )
pipe = sb.producer.pipeline() pipe = sb.connector.pipeline()
sb.producer.set( sb.connector.set(
MessageEndpoints.public_scan_baseline(scanID=scanID), msg, expire=1800, pipe=pipe MessageEndpoints.public_scan_baseline(scanID=scanID),
msg,
expire=1800,
pipe=pipe,
)
sb.connector.set_and_publish(
MessageEndpoints.scan_baseline(),
msg,
pipe=pipe,
) )
sb.producer.set_and_publish(MessageEndpoints.scan_baseline(), msg, pipe=pipe)
pipe.execute() pipe.execute()

View File

@ -17,7 +17,7 @@ if TYPE_CHECKING:
class BlueskyEmitter(EmitterBase): class BlueskyEmitter(EmitterBase):
def __init__(self, scan_bundler: ScanBundler) -> None: def __init__(self, scan_bundler: ScanBundler) -> None:
super().__init__(scan_bundler.producer) super().__init__(scan_bundler.connector)
self.scan_bundler = scan_bundler self.scan_bundler = scan_bundler
self.bluesky_metadata = {} self.bluesky_metadata = {}
@ -27,7 +27,7 @@ class BlueskyEmitter(EmitterBase):
self.bluesky_metadata[scanID] = {} self.bluesky_metadata[scanID] = {}
doc = self._get_run_start_document(scanID) doc = self._get_run_start_document(scanID)
self.bluesky_metadata[scanID]["start"] = doc self.bluesky_metadata[scanID]["start"] = doc
self.producer.raw_send(MessageEndpoints.bluesky_events(), msgpack.dumps(("start", doc))) self.connector.raw_send(MessageEndpoints.bluesky_events(), msgpack.dumps(("start", doc)))
self.send_descriptor_document(scanID) self.send_descriptor_document(scanID)
def _get_run_start_document(self, scanID) -> dict: def _get_run_start_document(self, scanID) -> dict:
@ -71,7 +71,7 @@ class BlueskyEmitter(EmitterBase):
"""Bluesky only: send descriptor document""" """Bluesky only: send descriptor document"""
doc = self._get_descriptor_document(scanID) doc = self._get_descriptor_document(scanID)
self.bluesky_metadata[scanID]["descriptor"] = doc self.bluesky_metadata[scanID]["descriptor"] = doc
self.producer.raw_send( self.connector.raw_send(
MessageEndpoints.bluesky_events(), msgpack.dumps(("descriptor", doc)) MessageEndpoints.bluesky_events(), msgpack.dumps(("descriptor", doc))
) )
@ -85,7 +85,7 @@ class BlueskyEmitter(EmitterBase):
logger.warning(f"Failed to remove {scanID} from {storage}.") logger.warning(f"Failed to remove {scanID} from {storage}.")
def send_bluesky_scan_point(self, scanID, pointID) -> None: def send_bluesky_scan_point(self, scanID, pointID) -> None:
self.producer.raw_send( self.connector.raw_send(
MessageEndpoints.bluesky_events(), MessageEndpoints.bluesky_events(),
msgpack.dumps(("event", self._prepare_bluesky_event_data(scanID, pointID))), msgpack.dumps(("event", self._prepare_bluesky_event_data(scanID, pointID))),
) )

View File

@ -6,16 +6,16 @@ from bec_lib import messages
class EmitterBase: class EmitterBase:
def __init__(self, producer) -> None: def __init__(self, connector) -> None:
self._send_buffer = Queue() self._send_buffer = Queue()
self.producer = producer self.connector = connector
self._start_buffered_producer() self._start_buffered_connector()
def _start_buffered_producer(self): def _start_buffered_connector(self):
self._buffered_producer_thread = threading.Thread( self._buffered_connector_thread = threading.Thread(
target=self._buffered_publish, daemon=True, name="buffered_publisher" target=self._buffered_publish, daemon=True, name="buffered_publisher"
) )
self._buffered_producer_thread.start() self._buffered_connector_thread.start()
def add_message(self, msg: messages.BECMessage, endpoint: str, public: str = None): def add_message(self, msg: messages.BECMessage, endpoint: str, public: str = None):
self._send_buffer.put((msg, endpoint, public)) self._send_buffer.put((msg, endpoint, public))
@ -37,20 +37,20 @@ class EmitterBase:
time.sleep(0.1) time.sleep(0.1)
return return
pipe = self.producer.pipeline() pipe = self.connector.pipeline()
msgs = messages.BundleMessage() msgs = messages.BundleMessage()
_, endpoint, _ = msgs_to_send[0] _, endpoint, _ = msgs_to_send[0]
for msg, endpoint, public in msgs_to_send: for msg, endpoint, public in msgs_to_send:
msg_dump = msg msg_dump = msg
msgs.append(msg_dump) msgs.append(msg_dump)
if public: if public:
self.producer.set( self.connector.set(
public, public,
msg_dump, msg_dump,
pipe=pipe, pipe=pipe,
expire=1800, expire=1800,
) )
self.producer.send(endpoint, msgs, pipe=pipe) self.connector.send(endpoint, msgs, pipe=pipe)
pipe.execute() pipe.execute()
def on_init(self, scanID: str): def on_init(self, scanID: str):

View File

@ -20,9 +20,23 @@ class ScanBundler(BECService):
self.device_manager = None self.device_manager = None
self._start_device_manager() self._start_device_manager()
self._start_device_read_consumer() self.connector.register(
self._start_scan_queue_consumer() patterns=MessageEndpoints.device_read("*"),
self._start_scan_status_consumer() cb=self._device_read_callback,
name="device_read_register",
)
self.connector.register(
MessageEndpoints.scan_queue_status(),
cb=self._scan_queue_callback,
group_id="scan_bundler",
name="scan_queue_register",
)
self.connector.register(
MessageEndpoints.scan_status(),
cb=self._scan_status_callback,
group_id="scan_bundler",
name="scan_status_register",
)
self.sync_storage = {} self.sync_storage = {}
self.monitored_devices = {} self.monitored_devices = {}
@ -56,56 +70,24 @@ class ScanBundler(BECService):
self.device_manager = DeviceManagerBase(self) self.device_manager = DeviceManagerBase(self)
self.device_manager.initialize(self.bootstrap_server) self.device_manager.initialize(self.bootstrap_server)
def _start_device_read_consumer(self): def _device_read_callback(self, msg, **_kwargs):
self._device_read_consumer = self.connector.consumer(
pattern=MessageEndpoints.device_read("*"),
cb=self._device_read_callback,
parent=self,
name="device_read_consumer",
)
self._device_read_consumer.start()
def _start_scan_queue_consumer(self):
self._scan_queue_consumer = self.connector.consumer(
MessageEndpoints.scan_queue_status(),
cb=self._scan_queue_callback,
group_id="scan_bundler",
parent=self,
name="scan_queue_consumer",
)
self._scan_queue_consumer.start()
def _start_scan_status_consumer(self):
self._scan_status_consumer = self.connector.consumer(
MessageEndpoints.scan_status(),
cb=self._scan_status_callback,
group_id="scan_bundler",
parent=self,
name="scan_status_consumer",
)
self._scan_status_consumer.start()
@staticmethod
def _device_read_callback(msg, parent, **_kwargs):
# pylint: disable=protected-access # pylint: disable=protected-access
dev = msg.topic.split(MessageEndpoints._device_read + "/")[-1] dev = msg.topic.split(MessageEndpoints._device_read + "/")[-1]
msgs = msg.value msgs = msg.value
logger.debug(f"Received reading from device {dev}") logger.debug(f"Received reading from device {dev}")
if not isinstance(msgs, list): if not isinstance(msgs, list):
msgs = [msgs] msgs = [msgs]
task = parent.executor.submit(parent._add_device_to_storage, msgs, dev) task = self.executor.submit(self._add_device_to_storage, msgs, dev)
parent.executor_tasks.append(task) self.executor_tasks.append(task)
@staticmethod def _scan_queue_callback(self, msg, **_kwargs):
def _scan_queue_callback(msg, parent, **_kwargs):
msg = msg.value msg = msg.value
logger.trace(msg) logger.trace(msg)
parent.current_queue = msg.content["queue"]["primary"].get("info") self.current_queue = msg.content["queue"]["primary"].get("info")
@staticmethod def _scan_status_callback(self, msg, **_kwargs):
def _scan_status_callback(msg, parent, **_kwargs):
msg = msg.value msg = msg.value
parent.handle_scan_status_message(msg) self.handle_scan_status_message(msg)
def handle_scan_status_message(self, msg: messages.ScanStatusMessage) -> None: def handle_scan_status_message(self, msg: messages.ScanStatusMessage) -> None:
"""handle scan status messages""" """handle scan status messages"""
@ -270,7 +252,7 @@ class ScanBundler(BECService):
} }
def _get_scan_status_history(self, length): def _get_scan_status_history(self, length):
return self.producer.lrange(MessageEndpoints.scan_status() + "_list", length * -1, -1) return self.connector.lrange(MessageEndpoints.scan_status() + "_list", length * -1, -1)
def _wait_for_scanID(self, scanID, timeout_time=10): def _wait_for_scanID(self, scanID, timeout_time=10):
elapsed_time = 0 elapsed_time = 0
@ -344,10 +326,10 @@ class ScanBundler(BECService):
self.sync_storage[scanID][pointID][dev.name] = read self.sync_storage[scanID][pointID][dev.name] = read
def _get_last_device_readback(self, devices: list) -> list: def _get_last_device_readback(self, devices: list) -> list:
pipe = self.producer.pipeline() pipe = self.connector.pipeline()
for dev in devices: for dev in devices:
self.producer.get(MessageEndpoints.device_readback(dev.name), pipe) self.connector.get(MessageEndpoints.device_readback(dev.name), pipe)
return [msg.content["signals"] for msg in self.producer.execute_pipeline(pipe)] return [msg.content["signals"] for msg in self.connector.execute_pipeline(pipe)]
def cleanup_storage(self): def cleanup_storage(self):
"""remove old scanIDs to free memory""" """remove old scanIDs to free memory"""

View File

@ -57,16 +57,16 @@ def test_send_baseline_BEC():
sb.sync_storage[scanID] = {"info": {}, "status": "open", "sent": set()} sb.sync_storage[scanID] = {"info": {}, "status": "open", "sent": set()}
sb.sync_storage[scanID]["baseline"] = {} sb.sync_storage[scanID]["baseline"] = {}
msg = messages.ScanBaselineMessage(scanID=scanID, data=sb.sync_storage[scanID]["baseline"]) msg = messages.ScanBaselineMessage(scanID=scanID, data=sb.sync_storage[scanID]["baseline"])
with mock.patch.object(sb, "producer") as producer: with mock.patch.object(sb, "connector") as connector:
bec_emitter._send_baseline(scanID) bec_emitter._send_baseline(scanID)
pipe = producer.pipeline() pipe = connector.pipeline()
producer.set.assert_called_once_with( connector.set.assert_called_once_with(
MessageEndpoints.public_scan_baseline(scanID), MessageEndpoints.public_scan_baseline(scanID),
msg, msg,
expire=1800, expire=1800,
pipe=pipe, pipe=pipe,
) )
producer.set_and_publish.assert_called_once_with( connector.set_and_publish.assert_called_once_with(
MessageEndpoints.scan_baseline(), MessageEndpoints.scan_baseline(),
msg, msg,
pipe=pipe, pipe=pipe,

View File

@ -13,7 +13,7 @@ from scan_bundler.bluesky_emitter import BlueskyEmitter
def test_run_start_document(scanID): def test_run_start_document(scanID):
sb = load_ScanBundlerMock() sb = load_ScanBundlerMock()
bls_emitter = BlueskyEmitter(sb) bls_emitter = BlueskyEmitter(sb)
with mock.patch.object(bls_emitter.producer, "raw_send") as send: with mock.patch.object(bls_emitter.connector, "raw_send") as send:
with mock.patch.object(bls_emitter, "send_descriptor_document") as send_descr: with mock.patch.object(bls_emitter, "send_descriptor_document") as send_descr:
with mock.patch.object( with mock.patch.object(
bls_emitter, "_get_run_start_document", return_value={} bls_emitter, "_get_run_start_document", return_value={}
@ -45,7 +45,7 @@ def test_send_descriptor_document():
bls_emitter = BlueskyEmitter(sb) bls_emitter = BlueskyEmitter(sb)
scanID = "lkajsdl" scanID = "lkajsdl"
bls_emitter.bluesky_metadata[scanID] = {} bls_emitter.bluesky_metadata[scanID] = {}
with mock.patch.object(bls_emitter.producer, "raw_send") as send: with mock.patch.object(bls_emitter.connector, "raw_send") as send:
with mock.patch.object( with mock.patch.object(
bls_emitter, "_get_descriptor_document", return_value={} bls_emitter, "_get_descriptor_document", return_value={}
) as get_descr: ) as get_descr:

View File

@ -51,30 +51,30 @@ from scan_bundler.emitter import EmitterBase
], ],
) )
def test_publish_data(msgs): def test_publish_data(msgs):
producer = mock.MagicMock() connector = mock.MagicMock()
with mock.patch.object(EmitterBase, "_start_buffered_producer") as start: with mock.patch.object(EmitterBase, "_start_buffered_connector") as start:
emitter = EmitterBase(producer) emitter = EmitterBase(connector)
start.assert_called_once() start.assert_called_once()
with mock.patch.object(emitter, "_get_messages_from_buffer", return_value=msgs) as get_msgs: with mock.patch.object(emitter, "_get_messages_from_buffer", return_value=msgs) as get_msgs:
emitter._publish_data() emitter._publish_data()
get_msgs.assert_called_once() get_msgs.assert_called_once()
if not msgs: if not msgs:
producer.send.assert_not_called() connector.send.assert_not_called()
return return
pipe = producer.pipeline() pipe = connector.pipeline()
msgs_bundle = messages.BundleMessage() msgs_bundle = messages.BundleMessage()
_, endpoint, _ = msgs[0] _, endpoint, _ = msgs[0]
for msg, endpoint, public in msgs: for msg, endpoint, public in msgs:
msg_dump = msg msg_dump = msg
msgs_bundle.append(msg_dump) msgs_bundle.append(msg_dump)
if public: if public:
producer.set.assert_has_calls( connector.set.assert_has_calls(
producer.set(public, msg_dump, pipe=pipe, expire=1800) connector.set(public, msg_dump, pipe=pipe, expire=1800)
) )
producer.send.assert_called_with(endpoint, msgs_bundle, pipe=pipe) connector.send.assert_called_with(endpoint, msgs_bundle, pipe=pipe)
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -93,8 +93,8 @@ def test_publish_data(msgs):
], ],
) )
def test_add_message(msg, endpoint, public): def test_add_message(msg, endpoint, public):
producer = mock.MagicMock() connector = mock.MagicMock()
emitter = EmitterBase(producer) emitter = EmitterBase(connector)
emitter.add_message(msg, endpoint, public) emitter.add_message(msg, endpoint, public)
msgs = emitter._get_messages_from_buffer() msgs = emitter._get_messages_from_buffer()
out_msg, out_endpoint, out_public = msgs[0] out_msg, out_endpoint, out_public = msgs[0]

View File

@ -36,7 +36,7 @@ def load_ScanBundlerMock():
service_mock = mock.MagicMock() service_mock = mock.MagicMock()
service_mock.connector = ConnectorMock("") service_mock.connector = ConnectorMock("")
device_manager = ScanBundlerDeviceManagerMock(service_mock, "") device_manager = ScanBundlerDeviceManagerMock(service_mock, "")
device_manager.producer = service_mock.connector.producer() device_manager.connector = service_mock.connector
with open(f"{dir_path}/tests/test_config.yaml", "r") as session_file: with open(f"{dir_path}/tests/test_config.yaml", "r") as session_file:
device_manager._session = create_session_from_config(yaml.safe_load(session_file)) device_manager._session = create_session_from_config(yaml.safe_load(session_file))
device_manager._load_session() device_manager._load_session()
@ -74,7 +74,7 @@ def test_device_read_callback():
msg.topic = MessageEndpoints.device_read("samx") msg.topic = MessageEndpoints.device_read("samx")
with mock.patch.object(scan_bundler, "_add_device_to_storage") as add_dev: with mock.patch.object(scan_bundler, "_add_device_to_storage") as add_dev:
scan_bundler._device_read_callback(msg, scan_bundler) scan_bundler._device_read_callback(msg)
add_dev.assert_called_once_with([dev_msg], "samx") add_dev.assert_called_once_with([dev_msg], "samx")
@ -157,7 +157,7 @@ def test_wait_for_scanID(scanID, storageID, scan_msg):
) )
def test_get_scan_status_history(msgs): def test_get_scan_status_history(msgs):
sb = load_ScanBundlerMock() sb = load_ScanBundlerMock()
with mock.patch.object(sb.producer, "lrange", return_value=[msg for msg in msgs]) as lrange: with mock.patch.object(sb.connector, "lrange", return_value=[msg for msg in msgs]) as lrange:
res = sb._get_scan_status_history(5) res = sb._get_scan_status_history(5)
lrange.assert_called_once_with(MessageEndpoints.scan_status() + "_list", -5, -1) lrange.assert_called_once_with(MessageEndpoints.scan_status() + "_list", -5, -1)
assert res == msgs assert res == msgs
@ -371,7 +371,7 @@ def test_scan_queue_callback(queue_msg):
sb = load_ScanBundlerMock() sb = load_ScanBundlerMock()
msg = MessageMock() msg = MessageMock()
msg.value = queue_msg msg.value = queue_msg
sb._scan_queue_callback(msg, sb) sb._scan_queue_callback(msg)
assert sb.current_queue == queue_msg.content["queue"]["primary"].get("info") assert sb.current_queue == queue_msg.content["queue"]["primary"].get("info")
@ -399,7 +399,7 @@ def test_scan_status_callback(scan_msg):
msg.value = scan_msg msg.value = scan_msg
with mock.patch.object(sb, "handle_scan_status_message") as handle_scan_status_message_mock: with mock.patch.object(sb, "handle_scan_status_message") as handle_scan_status_message_mock:
sb._scan_status_callback(msg, sb) sb._scan_status_callback(msg)
handle_scan_status_message_mock.assert_called_once_with(scan_msg) handle_scan_status_message_mock.assert_called_once_with(scan_msg)
@ -744,10 +744,10 @@ def test_get_last_device_readback():
signals={"samx": {"samx": 0.51, "setpoint": 0.5, "motor_is_moving": 0}}, signals={"samx": {"samx": 0.51, "setpoint": 0.5, "motor_is_moving": 0}},
metadata={"scanID": "laksjd", "readout_priority": "monitored"}, metadata={"scanID": "laksjd", "readout_priority": "monitored"},
) )
with mock.patch.object(sb, "producer") as producer_mock: with mock.patch.object(sb, "connector") as connector_mock:
producer_mock.execute_pipeline.return_value = [dev_msg] connector_mock.execute_pipeline.return_value = [dev_msg]
ret = sb._get_last_device_readback([sb.device_manager.devices.samx]) ret = sb._get_last_device_readback([sb.device_manager.devices.samx])
assert producer_mock.get.mock_calls == [ assert connector_mock.get.mock_calls == [
mock.call(MessageEndpoints.device_readback("samx"), producer_mock.pipeline()) mock.call(MessageEndpoints.device_readback("samx"), connector_mock.pipeline())
] ]
assert ret == [dev_msg.content["signals"]] assert ret == [dev_msg.content["signals"]]

View File

@ -458,7 +458,7 @@ class LamNIFermatScan(ScanBase, LamNIMixin):
yield from self.stubs.kickoff(device="rtx") yield from self.stubs.kickoff(device="rtx")
while True: while True:
yield from self.stubs.read_and_wait(group="primary", wait_group="readout_primary") yield from self.stubs.read_and_wait(group="primary", wait_group="readout_primary")
msg = self.device_manager.producer.get(MessageEndpoints.device_status("rt_scan")) msg = self.device_manager.connector.get(MessageEndpoints.device_status("rt_scan"))
if msg: if msg:
status = msg status = msg
status_id = status.content.get("status", 1) status_id = status.content.get("status", 1)

View File

@ -163,7 +163,7 @@ class OwisGrid(AsyncFlyScanBase):
def scan_progress(self) -> int: def scan_progress(self) -> int:
"""Timeout of the progress bar. This gets updated in the frequency of scan segments""" """Timeout of the progress bar. This gets updated in the frequency of scan segments"""
msg = self.device_manager.producer.get(MessageEndpoints.device_progress("mcs")) msg = self.device_manager.connector.get(MessageEndpoints.device_progress("mcs"))
if not msg: if not msg:
self.timeout_progress += 1 self.timeout_progress += 1
return self.timeout_progress return self.timeout_progress

View File

@ -106,7 +106,7 @@ class SgalilGrid(AsyncFlyScanBase):
def scan_progress(self) -> int: def scan_progress(self) -> int:
"""Timeout of the progress bar. This gets updated in the frequency of scan segments""" """Timeout of the progress bar. This gets updated in the frequency of scan segments"""
msg = self.device_manager.producer.get(MessageEndpoints.device_progress("mcs")) msg = self.device_manager.connector.get(MessageEndpoints.device_progress("mcs"))
if not msg: if not msg:
self.timeout_progress += 1 self.timeout_progress += 1
return self.timeout_progress return self.timeout_progress

View File

@ -10,8 +10,8 @@ class DeviceValidation:
Mixin class for validation methods Mixin class for validation methods
""" """
def __init__(self, producer, worker): def __init__(self, connector, worker):
self.producer = producer self.connector = connector
self.worker = worker self.worker = worker
def get_device_status(self, endpoint: MessageEndpoints, devices: list) -> list: def get_device_status(self, endpoint: MessageEndpoints, devices: list) -> list:
@ -25,10 +25,10 @@ class DeviceValidation:
Returns: Returns:
list: List of BECMessage objects list: List of BECMessage objects
""" """
pipe = self.producer.pipeline() pipe = self.connector.pipeline()
for dev in devices: for dev in devices:
self.producer.get(endpoint(dev), pipe) self.connector.get(endpoint(dev), pipe)
return self.producer.execute_pipeline(pipe) return self.connector.execute_pipeline(pipe)
def devices_are_ready( def devices_are_ready(
self, self,

View File

@ -27,25 +27,19 @@ class ScanGuard:
self.parent = parent self.parent = parent
self.device_manager = self.parent.device_manager self.device_manager = self.parent.device_manager
self.connector = self.parent.connector self.connector = self.parent.connector
self.producer = self.connector.producer()
self._start_scan_queue_request_consumer()
def _start_scan_queue_request_consumer(self): self.connector.register(
self._scan_queue_request_consumer = self.connector.consumer(
MessageEndpoints.scan_queue_request(), MessageEndpoints.scan_queue_request(),
cb=self._scan_queue_request_callback, cb=self._scan_queue_request_callback,
parent=self, parent=self,
) )
self._scan_queue_modification_request_consumer = self.connector.consumer( self.connector.register(
MessageEndpoints.scan_queue_modification_request(), MessageEndpoints.scan_queue_modification_request(),
cb=self._scan_queue_modification_request_callback, cb=self._scan_queue_modification_request_callback,
parent=self, parent=self,
) )
self._scan_queue_request_consumer.start()
self._scan_queue_modification_request_consumer.start()
def _is_valid_scan_request(self, request) -> ScanStatus: def _is_valid_scan_request(self, request) -> ScanStatus:
try: try:
self._check_valid_request(request) self._check_valid_request(request)
@ -63,7 +57,7 @@ class ScanGuard:
raise ScanRejection("Invalid request.") raise ScanRejection("Invalid request.")
def _check_valid_scan(self, request) -> None: def _check_valid_scan(self, request) -> None:
avail_scans = self.producer.get(MessageEndpoints.available_scans()) avail_scans = self.connector.get(MessageEndpoints.available_scans())
scan_type = request.content.get("scan_type") scan_type = request.content.get("scan_type")
if scan_type not in avail_scans.resource: if scan_type not in avail_scans.resource:
raise ScanRejection(f"Unknown scan type {scan_type}.") raise ScanRejection(f"Unknown scan type {scan_type}.")
@ -140,7 +134,7 @@ class ScanGuard:
message=scan_status.message, message=scan_status.message,
metadata=metadata, metadata=metadata,
) )
self.device_manager.producer.send(sqrr, rrm) self.device_manager.connector.send(sqrr, rrm)
def _handle_scan_request(self, msg): def _handle_scan_request(self, msg):
""" """
@ -181,10 +175,10 @@ class ScanGuard:
self._send_scan_request_response(ScanStatus(), mod_msg.metadata) self._send_scan_request_response(ScanStatus(), mod_msg.metadata)
sqm = MessageEndpoints.scan_queue_modification() sqm = MessageEndpoints.scan_queue_modification()
self.device_manager.producer.send(sqm, mod_msg) self.device_manager.connector.send(sqm, mod_msg)
def _append_to_scan_queue(self, msg): def _append_to_scan_queue(self, msg):
logger.info("Appending new scan to queue") logger.info("Appending new scan to queue")
msg = msg msg = msg
sqi = MessageEndpoints.scan_queue_insert() sqi = MessageEndpoints.scan_queue_insert()
self.device_manager.producer.send(sqi, msg) self.device_manager.connector.send(sqi, msg)

View File

@ -101,7 +101,7 @@ class ScanManager:
def publish_available_scans(self): def publish_available_scans(self):
"""send all available scans to the broker""" """send all available scans to the broker"""
self.parent.producer.set( self.parent.connector.set(
MessageEndpoints.available_scans(), MessageEndpoints.available_scans(),
AvailableResourceMessage(resource=self.available_scans), AvailableResourceMessage(resource=self.available_scans),
) )

View File

@ -51,11 +51,10 @@ class QueueManager:
def __init__(self, parent) -> None: def __init__(self, parent) -> None:
self.parent = parent self.parent = parent
self.connector = parent.connector self.connector = parent.connector
self.producer = parent.producer
self.num_queues = 1 self.num_queues = 1
self.key = "" self.key = ""
self.queues = {} self.queues = {}
self._start_scan_queue_consumer() self._start_scan_queue_register()
self._lock = threading.RLock() self._lock = threading.RLock()
def add_to_queue(self, scan_queue: str, msg: messages.ScanQueueMessage, position=-1) -> None: def add_to_queue(self, scan_queue: str, msg: messages.ScanQueueMessage, position=-1) -> None:
@ -91,17 +90,15 @@ class QueueManager:
self.queues[queue_name] = ScanQueue(self, queue_name=queue_name) self.queues[queue_name] = ScanQueue(self, queue_name=queue_name)
self.queues[queue_name].start_worker() self.queues[queue_name].start_worker()
def _start_scan_queue_consumer(self) -> None: def _start_scan_queue_register(self) -> None:
self._scan_queue_consumer = self.connector.consumer( self.connector.register(
MessageEndpoints.scan_queue_insert(), cb=self._scan_queue_callback, parent=self MessageEndpoints.scan_queue_insert(), cb=self._scan_queue_callback, parent=self
) )
self._scan_queue_modification_consumer = self.connector.consumer( self.connector.register(
MessageEndpoints.scan_queue_modification(), MessageEndpoints.scan_queue_modification(),
cb=self._scan_queue_modification_callback, cb=self._scan_queue_modification_callback,
parent=self, parent=self,
) )
self._scan_queue_consumer.start()
self._scan_queue_modification_consumer.start()
@staticmethod @staticmethod
def _scan_queue_callback(msg, parent, **_kwargs) -> None: def _scan_queue_callback(msg, parent, **_kwargs) -> None:
@ -233,7 +230,7 @@ class QueueManager:
logger.info("New scan queue:") logger.info("New scan queue:")
for queue in self.describe_queue(): for queue in self.describe_queue():
logger.info(f"\n {queue}") logger.info(f"\n {queue}")
self.producer.set_and_publish( self.connector.set_and_publish(
MessageEndpoints.scan_queue_status(), MessageEndpoints.scan_queue_status(),
messages.ScanQueueStatusMessage(queue=queue_export), messages.ScanQueueStatusMessage(queue=queue_export),
) )
@ -685,7 +682,7 @@ class InstructionQueueItem:
self.instructions = [] self.instructions = []
self.parent = parent self.parent = parent
self.queue = RequestBlockQueue(instruction_queue=self, assembler=assembler) self.queue = RequestBlockQueue(instruction_queue=self, assembler=assembler)
self.producer = self.parent.queue_manager.producer self.connector = self.parent.queue_manager.connector
self._is_scan = False self._is_scan = False
self.is_active = False # set to true while a worker is processing the instructions self.is_active = False # set to true while a worker is processing the instructions
self.completed = False self.completed = False
@ -790,7 +787,7 @@ class InstructionQueueItem:
msg = messages.ScanQueueHistoryMessage( msg = messages.ScanQueueHistoryMessage(
status=self.status.name, queueID=self.queue_id, info=self.describe() status=self.status.name, queueID=self.queue_id, info=self.describe()
) )
self.parent.queue_manager.producer.lpush( self.parent.queue_manager.connector.lpush(
MessageEndpoints.scan_queue_history(), msg, max_size=100 MessageEndpoints.scan_queue_history(), msg, max_size=100
) )

View File

@ -23,7 +23,6 @@ class ScanServer(BECService):
def __init__(self, config: ServiceConfig, connector_cls: ConnectorBase): def __init__(self, config: ServiceConfig, connector_cls: ConnectorBase):
super().__init__(config, connector_cls, unique_service=True) super().__init__(config, connector_cls, unique_service=True)
self.producer = self.connector.producer()
self._start_scan_manager() self._start_scan_manager()
self._start_queue_manager() self._start_queue_manager()
self._start_device_manager() self._start_device_manager()
@ -52,15 +51,12 @@ class ScanServer(BECService):
self.scan_guard = ScanGuard(parent=self) self.scan_guard = ScanGuard(parent=self)
def _start_alarm_handler(self): def _start_alarm_handler(self):
self._alarm_consumer = self.connector.consumer( self.connector.register(MessageEndpoints.alarm(), cb=self._alarm_callback, parent=self)
MessageEndpoints.alarm(), cb=self._alarm_callback, parent=self
)
self._alarm_consumer.start()
def _reset_scan_number(self): def _reset_scan_number(self):
if self.producer.get(MessageEndpoints.scan_number()) is None: if self.connector.get(MessageEndpoints.scan_number()) is None:
self.scan_number = 1 self.scan_number = 1
if self.producer.get(MessageEndpoints.dataset_number()) is None: if self.connector.get(MessageEndpoints.dataset_number()) is None:
self.dataset_number = 1 self.dataset_number = 1
@staticmethod @staticmethod
@ -74,25 +70,24 @@ class ScanServer(BECService):
@property @property
def scan_number(self) -> int: def scan_number(self) -> int:
"""get the current scan number""" """get the current scan number"""
return int(self.producer.get(MessageEndpoints.scan_number())) return int(self.connector.get(MessageEndpoints.scan_number()))
@scan_number.setter @scan_number.setter
def scan_number(self, val: int): def scan_number(self, val: int):
"""set the current scan number""" """set the current scan number"""
self.producer.set(MessageEndpoints.scan_number(), val) self.connector.set(MessageEndpoints.scan_number(), val)
@property @property
def dataset_number(self) -> int: def dataset_number(self) -> int:
"""get the current dataset number""" """get the current dataset number"""
return int(self.producer.get(MessageEndpoints.dataset_number())) return int(self.connector.get(MessageEndpoints.dataset_number()))
@dataset_number.setter @dataset_number.setter
def dataset_number(self, val: int): def dataset_number(self, val: int):
"""set the current dataset number""" """set the current dataset number"""
self.producer.set(MessageEndpoints.dataset_number(), val) self.connector.set(MessageEndpoints.dataset_number(), val)
def shutdown(self) -> None: def shutdown(self) -> None:
"""shutdown the scan server""" """shutdown the scan server"""
self.device_manager.shutdown() self.device_manager.shutdown()
self.queue_manager.shutdown() self.queue_manager.shutdown()

View File

@ -5,7 +5,9 @@ import uuid
from collections.abc import Callable from collections.abc import Callable
import numpy as np import numpy as np
from bec_lib import MessageEndpoints, ProducerConnector, Status, bec_logger, messages
from bec_lib import MessageEndpoints, Status, bec_logger, messages
from bec_lib.connector import ConnectorBase
from .errors import DeviceMessageError, ScanAbortion from .errors import DeviceMessageError, ScanAbortion
@ -13,8 +15,8 @@ logger = bec_logger.logger
class ScanStubs: class ScanStubs:
def __init__(self, producer: ProducerConnector, device_msg_callback: Callable = None) -> None: def __init__(self, connector: ConnectorBase, device_msg_callback: Callable = None) -> None:
self.producer = producer self.connector = connector
self.device_msg_metadata = ( self.device_msg_metadata = (
device_msg_callback if device_msg_callback is not None else lambda: {} device_msg_callback if device_msg_callback is not None else lambda: {}
) )
@ -62,7 +64,7 @@ class ScanStubs:
def _get_from_rpc(self, rpc_id): def _get_from_rpc(self, rpc_id):
while True: while True:
msg = self.producer.get(MessageEndpoints.device_rpc(rpc_id)) msg = self.connector.get(MessageEndpoints.device_rpc(rpc_id))
if msg: if msg:
break break
time.sleep(0.001) time.sleep(0.001)
@ -81,7 +83,7 @@ class ScanStubs:
if not isinstance(return_val, dict): if not isinstance(return_val, dict):
return return_val return return_val
if return_val.get("type") == "status" and return_val.get("RID"): if return_val.get("type") == "status" and return_val.get("RID"):
return Status(self.producer, return_val.get("RID")) return Status(self.connector, return_val.get("RID"))
return return_val return return_val
def set_and_wait(self, *, device: list[str], positions: list | np.ndarray): def set_and_wait(self, *, device: list[str], positions: list | np.ndarray):
@ -182,7 +184,7 @@ class ScanStubs:
DIID (int): device instruction ID DIID (int): device instruction ID
""" """
msg = self.producer.get(MessageEndpoints.device_req_status(device)) msg = self.connector.get(MessageEndpoints.device_req_status(device))
if not msg: if not msg:
return 0 return 0
matching_RID = msg.metadata.get("RID") == RID matching_RID = msg.metadata.get("RID") == RID
@ -199,7 +201,7 @@ class ScanStubs:
RID (str): request ID RID (str): request ID
""" """
msg = self.producer.get(MessageEndpoints.device_progress(device)) msg = self.connector.get(MessageEndpoints.device_progress(device))
if not msg: if not msg:
return None return None
matching_RID = msg.metadata.get("RID") == RID matching_RID = msg.metadata.get("RID") == RID

View File

@ -41,7 +41,7 @@ class ScanWorker(threading.Thread):
self._groups = {} self._groups = {}
self.interception_msg = None self.interception_msg = None
self.reset() self.reset()
self.validate = DeviceValidation(self.device_manager.producer, self) self.validate = DeviceValidation(self.device_manager.connector, self)
def open_scan(self, instr: messages.DeviceInstructionMessage) -> None: def open_scan(self, instr: messages.DeviceInstructionMessage) -> None:
""" """
@ -138,7 +138,7 @@ class ScanWorker(threading.Thread):
""" """
devices = [dev.name for dev in self.device_manager.devices.get_software_triggered_devices()] devices = [dev.name for dev in self.device_manager.devices.get_software_triggered_devices()]
self._last_trigger = instr self._last_trigger = instr
self.device_manager.producer.send( self.device_manager.connector.send(
MessageEndpoints.device_instructions(), MessageEndpoints.device_instructions(),
messages.DeviceInstructionMessage( messages.DeviceInstructionMessage(
device=devices, device=devices,
@ -157,7 +157,7 @@ class ScanWorker(threading.Thread):
""" """
# send instruction # send instruction
self.device_manager.producer.send(MessageEndpoints.device_instructions(), instr) self.device_manager.connector.send(MessageEndpoints.device_instructions(), instr)
def read_devices(self, instr: messages.DeviceInstructionMessage) -> None: def read_devices(self, instr: messages.DeviceInstructionMessage) -> None:
""" """
@ -171,7 +171,7 @@ class ScanWorker(threading.Thread):
self._publish_readback(instr) self._publish_readback(instr)
return return
producer = self.device_manager.producer connector = self.device_manager.connector
devices = instr.content.get("device") devices = instr.content.get("device")
if devices is None: if devices is None:
@ -181,7 +181,7 @@ class ScanWorker(threading.Thread):
readout_priority=self.readout_priority readout_priority=self.readout_priority
) )
] ]
producer.send( connector.send(
MessageEndpoints.device_instructions(), MessageEndpoints.device_instructions(),
messages.DeviceInstructionMessage( messages.DeviceInstructionMessage(
device=devices, device=devices,
@ -201,7 +201,7 @@ class ScanWorker(threading.Thread):
""" """
# logger.info("kickoff") # logger.info("kickoff")
self.device_manager.producer.send( self.device_manager.connector.send(
MessageEndpoints.device_instructions(), MessageEndpoints.device_instructions(),
messages.DeviceInstructionMessage( messages.DeviceInstructionMessage(
device=instr.content.get("device"), device=instr.content.get("device"),
@ -225,7 +225,7 @@ class ScanWorker(threading.Thread):
devices = instr.content.get("device") devices = instr.content.get("device")
if not isinstance(devices, list): if not isinstance(devices, list):
devices = [devices] devices = [devices]
self.device_manager.producer.send( self.device_manager.connector.send(
MessageEndpoints.device_instructions(), MessageEndpoints.device_instructions(),
messages.DeviceInstructionMessage( messages.DeviceInstructionMessage(
device=devices, device=devices,
@ -251,7 +251,7 @@ class ScanWorker(threading.Thread):
) )
] ]
params = instr.content["parameter"] params = instr.content["parameter"]
self.device_manager.producer.send( self.device_manager.connector.send(
MessageEndpoints.device_instructions(), MessageEndpoints.device_instructions(),
messages.DeviceInstructionMessage( messages.DeviceInstructionMessage(
device=baseline_devices, action="read", parameter=params, metadata=instr.metadata device=baseline_devices, action="read", parameter=params, metadata=instr.metadata
@ -266,7 +266,7 @@ class ScanWorker(threading.Thread):
instr (DeviceInstructionMessage): Device instruction received from the scan assembler instr (DeviceInstructionMessage): Device instruction received from the scan assembler
""" """
devices = [dev.name for dev in self.device_manager.devices.enabled_devices] devices = [dev.name for dev in self.device_manager.devices.enabled_devices]
self.device_manager.producer.send( self.device_manager.connector.send(
MessageEndpoints.device_instructions(), MessageEndpoints.device_instructions(),
messages.DeviceInstructionMessage( messages.DeviceInstructionMessage(
device=devices, device=devices,
@ -285,7 +285,7 @@ class ScanWorker(threading.Thread):
Args: Args:
instr (DeviceInstructionMessage): Device instruction received from the scan assembler instr (DeviceInstructionMessage): Device instruction received from the scan assembler
""" """
producer = self.device_manager.producer connector = self.device_manager.connector
data = instr.content["parameter"]["data"] data = instr.content["parameter"]["data"]
devices = instr.content["device"] devices = instr.content["device"]
if not isinstance(devices, list): if not isinstance(devices, list):
@ -294,7 +294,7 @@ class ScanWorker(threading.Thread):
data = [data] data = [data]
for device, dev_data in zip(devices, data): for device, dev_data in zip(devices, data):
msg = messages.DeviceMessage(signals=dev_data, metadata=instr.metadata) msg = messages.DeviceMessage(signals=dev_data, metadata=instr.metadata)
producer.set_and_publish(MessageEndpoints.device_read(device), msg) connector.set_and_publish(MessageEndpoints.device_read(device), msg)
def send_rpc(self, instr: messages.DeviceInstructionMessage) -> None: def send_rpc(self, instr: messages.DeviceInstructionMessage) -> None:
""" """
@ -304,7 +304,7 @@ class ScanWorker(threading.Thread):
instr (DeviceInstructionMessage): Device instruction received from the scan assembler instr (DeviceInstructionMessage): Device instruction received from the scan assembler
""" """
self.device_manager.producer.send(MessageEndpoints.device_instructions(), instr) self.device_manager.connector.send(MessageEndpoints.device_instructions(), instr)
def process_scan_report_instruction(self, instr): def process_scan_report_instruction(self, instr):
""" """
@ -333,7 +333,7 @@ class ScanWorker(threading.Thread):
if dev.name not in async_devices if dev.name not in async_devices
] ]
for det in async_devices: for det in async_devices:
self.device_manager.producer.send( self.device_manager.connector.send(
MessageEndpoints.device_instructions(), MessageEndpoints.device_instructions(),
messages.DeviceInstructionMessage( messages.DeviceInstructionMessage(
device=det, device=det,
@ -344,7 +344,7 @@ class ScanWorker(threading.Thread):
) )
self._staged_devices.update(async_devices) self._staged_devices.update(async_devices)
self.device_manager.producer.send( self.device_manager.connector.send(
MessageEndpoints.device_instructions(), MessageEndpoints.device_instructions(),
messages.DeviceInstructionMessage( messages.DeviceInstructionMessage(
device=devices, device=devices,
@ -375,7 +375,7 @@ class ScanWorker(threading.Thread):
parameter = {} if not instr else instr.content["parameter"] parameter = {} if not instr else instr.content["parameter"]
metadata = {} if not instr else instr.metadata metadata = {} if not instr else instr.metadata
self._staged_devices.difference_update(devices) self._staged_devices.difference_update(devices)
self.device_manager.producer.send( self.device_manager.connector.send(
MessageEndpoints.device_instructions(), MessageEndpoints.device_instructions(),
messages.DeviceInstructionMessage( messages.DeviceInstructionMessage(
device=devices, action="unstage", parameter=parameter, metadata=metadata device=devices, action="unstage", parameter=parameter, metadata=metadata
@ -459,7 +459,7 @@ class ScanWorker(threading.Thread):
matching_DIID = device_status[ind].metadata.get("DIID") >= devices[ind][1] matching_DIID = device_status[ind].metadata.get("DIID") >= devices[ind][1]
matching_RID = device_status[ind].metadata.get("RID") == instr.metadata["RID"] matching_RID = device_status[ind].metadata.get("RID") == instr.metadata["RID"]
if matching_DIID and matching_RID: if matching_DIID and matching_RID:
last_pos_msg = self.device_manager.producer.get( last_pos_msg = self.device_manager.connector.get(
MessageEndpoints.device_readback(failed_device[0]) MessageEndpoints.device_readback(failed_device[0])
) )
last_pos = last_pos_msg.content["signals"][failed_device[0]]["value"] last_pos = last_pos_msg.content["signals"][failed_device[0]]["value"]
@ -603,25 +603,25 @@ class ScanWorker(threading.Thread):
def _publish_readback( def _publish_readback(
self, instr: messages.DeviceInstructionMessage, devices: list = None self, instr: messages.DeviceInstructionMessage, devices: list = None
) -> None: ) -> None:
producer = self.device_manager.producer connector = self.device_manager.connector
if not devices: if not devices:
devices = instr.content.get("device") devices = instr.content.get("device")
# cached readout # cached readout
readouts = self._get_readback(devices) readouts = self._get_readback(devices)
pipe = producer.pipeline() pipe = connector.pipeline()
for readout, device in zip(readouts, devices): for readout, device in zip(readouts, devices):
msg = messages.DeviceMessage(signals=readout, metadata=instr.metadata) msg = messages.DeviceMessage(signals=readout, metadata=instr.metadata)
producer.set_and_publish(MessageEndpoints.device_read(device), msg, pipe) connector.set_and_publish(MessageEndpoints.device_read(device), msg, pipe)
return pipe.execute() return pipe.execute()
def _get_readback(self, devices: list) -> list: def _get_readback(self, devices: list) -> list:
producer = self.device_manager.producer connector = self.device_manager.connector
# cached readout # cached readout
pipe = producer.pipeline() pipe = connector.pipeline()
for dev in devices: for dev in devices:
producer.get(MessageEndpoints.device_readback(dev), pipe=pipe) connector.get(MessageEndpoints.device_readback(dev), pipe=pipe)
return producer.execute_pipeline(pipe) return connector.execute_pipeline(pipe)
def _check_for_interruption(self) -> None: def _check_for_interruption(self) -> None:
if self.status == InstructionQueueStatus.PAUSED: if self.status == InstructionQueueStatus.PAUSED:
@ -700,11 +700,13 @@ class ScanWorker(threading.Thread):
scanID=self.current_scanID, status=status, info=self.current_scan_info scanID=self.current_scanID, status=status, info=self.current_scan_info
) )
expire = None if status in ["open", "paused"] else 1800 expire = None if status in ["open", "paused"] else 1800
pipe = self.device_manager.producer.pipeline() pipe = self.device_manager.connector.pipeline()
self.device_manager.producer.set( self.device_manager.connector.set(
MessageEndpoints.public_scan_info(self.current_scanID), msg, pipe=pipe, expire=expire MessageEndpoints.public_scan_info(self.current_scanID), msg, pipe=pipe, expire=expire
) )
self.device_manager.producer.set_and_publish(MessageEndpoints.scan_status(), msg, pipe=pipe) self.device_manager.connector.set_and_publish(
MessageEndpoints.scan_status(), msg, pipe=pipe
)
pipe.execute() pipe.execute()
def _process_instructions(self, queue: InstructionQueueItem) -> None: def _process_instructions(self, queue: InstructionQueueItem) -> None:

View File

@ -213,7 +213,7 @@ class RequestBase(ABC):
if metadata is None: if metadata is None:
self.metadata = {} self.metadata = {}
self.stubs = ScanStubs( self.stubs = ScanStubs(
producer=self.device_manager.producer, device_msg_callback=self.device_msg_metadata connector=self.device_manager.connector, device_msg_callback=self.device_msg_metadata
) )
@property @property
@ -239,7 +239,7 @@ class RequestBase(ABC):
def run_pre_scan_macros(self): def run_pre_scan_macros(self):
"""run pre scan macros if any""" """run pre scan macros if any"""
macros = self.device_manager.producer.lrange(MessageEndpoints.pre_scan_macros(), 0, -1) macros = self.device_manager.connector.lrange(MessageEndpoints.pre_scan_macros(), 0, -1)
for macro in macros: for macro in macros:
macro = macro.decode().strip() macro = macro.decode().strip()
func_name = self._get_func_name_from_macro(macro) func_name = self._get_func_name_from_macro(macro)
@ -558,12 +558,12 @@ class SyncFlyScanBase(ScanBase, ABC):
def _get_flyer_status(self) -> list: def _get_flyer_status(self) -> list:
flyer = self.scan_motors[0] flyer = self.scan_motors[0]
producer = self.device_manager.producer connector = self.device_manager.connector
pipe = producer.pipeline() pipe = connector.pipeline()
producer.lrange(MessageEndpoints.device_req_status(self.metadata["RID"]), 0, -1, pipe) connector.lrange(MessageEndpoints.device_req_status(self.metadata["RID"]), 0, -1, pipe)
producer.get(MessageEndpoints.device_readback(flyer), pipe) connector.get(MessageEndpoints.device_readback(flyer), pipe)
return producer.execute_pipeline(pipe) return connector.execute_pipeline(pipe)
@abstractmethod @abstractmethod
def scan_core(self): def scan_core(self):
@ -1098,7 +1098,7 @@ class RoundScanFlySim(SyncFlyScanBase):
while True: while True:
yield from self.stubs.read_and_wait(group="primary", wait_group="readout_primary") yield from self.stubs.read_and_wait(group="primary", wait_group="readout_primary")
status = self.device_manager.producer.get(MessageEndpoints.device_status(self.flyer)) status = self.device_manager.connector.get(MessageEndpoints.device_status(self.flyer))
if status: if status:
device_is_idle = status.content.get("status", 1) == 0 device_is_idle = status.content.get("status", 1) == 0
matching_RID = self.metadata.get("RID") == status.metadata.get("RID") matching_RID = self.metadata.get("RID") == status.metadata.get("RID")
@ -1318,12 +1318,12 @@ class MonitorScan(ScanBase):
self._check_limits() self._check_limits()
def _get_flyer_status(self) -> list: def _get_flyer_status(self) -> list:
producer = self.device_manager.producer connector = self.device_manager.connector
pipe = producer.pipeline() pipe = connector.pipeline()
producer.lrange(MessageEndpoints.device_req_status(self.metadata["RID"]), 0, -1, pipe) connector.lrange(MessageEndpoints.device_req_status(self.metadata["RID"]), 0, -1, pipe)
producer.get(MessageEndpoints.device_readback(self.flyer), pipe) connector.get(MessageEndpoints.device_readback(self.flyer), pipe)
return producer.execute_pipeline(pipe) return connector.execute_pipeline(pipe)
def scan_core(self): def scan_core(self):
yield from self.stubs.set( yield from self.stubs.set(

View File

@ -12,7 +12,7 @@ from scan_server.scan_guard import ScanGuard, ScanRejection, ScanStatus
@pytest.fixture @pytest.fixture
def scan_guard_mock(scan_server_mock): def scan_guard_mock(scan_server_mock):
sg = ScanGuard(parent=scan_server_mock) sg = ScanGuard(parent=scan_server_mock)
sg.device_manager.producer = mock.MagicMock() sg.device_manager.connector = mock.MagicMock()
yield sg yield sg
@ -113,8 +113,8 @@ def test_valid_request(scan_server_mock, scan_queue_msg, valid):
def test_check_valid_scan_raises_for_unknown_scan(scan_guard_mock): def test_check_valid_scan_raises_for_unknown_scan(scan_guard_mock):
sg = scan_guard_mock sg = scan_guard_mock
sg.producer = mock.MagicMock() sg.connector = mock.MagicMock()
sg.producer.get.return_value = messages.AvailableResourceMessage( sg.connector.get.return_value = messages.AvailableResourceMessage(
resource={"fermat_scan": "fermat_scan"} resource={"fermat_scan": "fermat_scan"}
) )
@ -130,8 +130,8 @@ def test_check_valid_scan_raises_for_unknown_scan(scan_guard_mock):
def test_check_valid_scan_accepts_known_scan(scan_guard_mock): def test_check_valid_scan_accepts_known_scan(scan_guard_mock):
sg = scan_guard_mock sg = scan_guard_mock
sg.producer = mock.MagicMock() sg.connector = mock.MagicMock()
sg.producer.get.return_value = messages.AvailableResourceMessage( sg.connector.get.return_value = messages.AvailableResourceMessage(
resource={"fermat_scan": "fermat_scan"} resource={"fermat_scan": "fermat_scan"}
) )
@ -146,8 +146,8 @@ def test_check_valid_scan_accepts_known_scan(scan_guard_mock):
def test_check_valid_scan_device_rpc(scan_guard_mock): def test_check_valid_scan_device_rpc(scan_guard_mock):
sg = scan_guard_mock sg = scan_guard_mock
sg.producer = mock.MagicMock() sg.connector = mock.MagicMock()
sg.producer.get.return_value = messages.AvailableResourceMessage( sg.connector.get.return_value = messages.AvailableResourceMessage(
resource={"device_rpc": "device_rpc"} resource={"device_rpc": "device_rpc"}
) )
request = messages.ScanQueueMessage( request = messages.ScanQueueMessage(
@ -162,8 +162,8 @@ def test_check_valid_scan_device_rpc(scan_guard_mock):
def test_check_valid_scan_device_rpc_raises(scan_guard_mock): def test_check_valid_scan_device_rpc_raises(scan_guard_mock):
sg = scan_guard_mock sg = scan_guard_mock
sg.producer = mock.MagicMock() sg.connector = mock.MagicMock()
sg.producer.get.return_value = messages.AvailableResourceMessage( sg.connector.get.return_value = messages.AvailableResourceMessage(
resource={"device_rpc": "device_rpc"} resource={"device_rpc": "device_rpc"}
) )
request = messages.ScanQueueMessage( request = messages.ScanQueueMessage(
@ -184,7 +184,7 @@ def test_handle_scan_modification_request(scan_guard_mock):
msg = messages.ScanQueueModificationMessage( msg = messages.ScanQueueModificationMessage(
scanID="scanID", action="abort", parameter={}, metadata={"RID": "RID"} scanID="scanID", action="abort", parameter={}, metadata={"RID": "RID"}
) )
with mock.patch.object(sg.device_manager.producer, "send") as send: with mock.patch.object(sg.device_manager.connector, "send") as send:
sg._handle_scan_modification_request(msg) sg._handle_scan_modification_request(msg)
send.assert_called_once_with(MessageEndpoints.scan_queue_modification(), msg) send.assert_called_once_with(MessageEndpoints.scan_queue_modification(), msg)
@ -207,7 +207,7 @@ def test_append_to_scan_queue(scan_guard_mock):
parameter={"args": {"samx": (-5, 5), "samy": (-5, 5)}, "kwargs": {"step": 3}}, parameter={"args": {"samx": (-5, 5), "samy": (-5, 5)}, "kwargs": {"step": 3}},
queue="primary", queue="primary",
) )
with mock.patch.object(sg.device_manager.producer, "send") as send: with mock.patch.object(sg.device_manager.connector, "send") as send:
sg._append_to_scan_queue(msg) sg._append_to_scan_queue(msg)
send.assert_called_once_with(MessageEndpoints.scan_queue_insert(), msg) send.assert_called_once_with(MessageEndpoints.scan_queue_insert(), msg)
@ -251,7 +251,7 @@ def test_scan_queue_modification_request_callback(scan_guard_mock):
def test_send_scan_request_response(scan_guard_mock): def test_send_scan_request_response(scan_guard_mock):
sg = scan_guard_mock sg = scan_guard_mock
with mock.patch.object(sg.device_manager.producer, "send") as send: with mock.patch.object(sg.device_manager.connector, "send") as send:
sg._send_scan_request_response(ScanStatus(), {"RID": "RID"}) sg._send_scan_request_response(ScanStatus(), {"RID": "RID"})
send.assert_called_once_with( send.assert_called_once_with(
MessageEndpoints.scan_queue_request_response(), MessageEndpoints.scan_queue_request_response(),

View File

@ -166,40 +166,40 @@ def test_set_halt_disables_return_to_start(queuemanager_mock):
def test_set_pause(queuemanager_mock): def test_set_pause(queuemanager_mock):
queue_manager = queuemanager_mock() queue_manager = queuemanager_mock()
queue_manager.producer.message_sent = [] queue_manager.connector.message_sent = []
queue_manager.set_pause(queue="primary") queue_manager.set_pause(queue="primary")
assert queue_manager.queues["primary"].status == ScanQueueStatus.PAUSED assert queue_manager.queues["primary"].status == ScanQueueStatus.PAUSED
assert len(queue_manager.producer.message_sent) == 1 assert len(queue_manager.connector.message_sent) == 1
assert ( assert (
queue_manager.producer.message_sent[0].get("queue") == MessageEndpoints.scan_queue_status() queue_manager.connector.message_sent[0].get("queue") == MessageEndpoints.scan_queue_status()
) )
def test_set_deferred_pause(queuemanager_mock): def test_set_deferred_pause(queuemanager_mock):
queue_manager = queuemanager_mock() queue_manager = queuemanager_mock()
queue_manager.producer.message_sent = [] queue_manager.connector.message_sent = []
queue_manager.set_deferred_pause(queue="primary") queue_manager.set_deferred_pause(queue="primary")
assert queue_manager.queues["primary"].status == ScanQueueStatus.PAUSED assert queue_manager.queues["primary"].status == ScanQueueStatus.PAUSED
assert len(queue_manager.producer.message_sent) == 1 assert len(queue_manager.connector.message_sent) == 1
assert ( assert (
queue_manager.producer.message_sent[0].get("queue") == MessageEndpoints.scan_queue_status() queue_manager.connector.message_sent[0].get("queue") == MessageEndpoints.scan_queue_status()
) )
def test_set_continue(queuemanager_mock): def test_set_continue(queuemanager_mock):
queue_manager = queuemanager_mock() queue_manager = queuemanager_mock()
queue_manager.producer.message_sent = [] queue_manager.connector.message_sent = []
queue_manager.set_continue(queue="primary") queue_manager.set_continue(queue="primary")
assert queue_manager.queues["primary"].status == ScanQueueStatus.RUNNING assert queue_manager.queues["primary"].status == ScanQueueStatus.RUNNING
assert len(queue_manager.producer.message_sent) == 1 assert len(queue_manager.connector.message_sent) == 1
assert ( assert (
queue_manager.producer.message_sent[0].get("queue") == MessageEndpoints.scan_queue_status() queue_manager.connector.message_sent[0].get("queue") == MessageEndpoints.scan_queue_status()
) )
def test_set_abort(queuemanager_mock): def test_set_abort(queuemanager_mock):
queue_manager = queuemanager_mock() queue_manager = queuemanager_mock()
queue_manager.producer.message_sent = [] queue_manager.connector.message_sent = []
msg = messages.ScanQueueMessage( msg = messages.ScanQueueMessage(
scan_type="mv", scan_type="mv",
parameter={"args": {"samx": (1,)}, "kwargs": {}}, parameter={"args": {"samx": (1,)}, "kwargs": {}},
@ -210,23 +210,23 @@ def test_set_abort(queuemanager_mock):
queue_manager.add_to_queue(scan_queue="primary", msg=msg) queue_manager.add_to_queue(scan_queue="primary", msg=msg)
queue_manager.set_abort(queue="primary") queue_manager.set_abort(queue="primary")
assert queue_manager.queues["primary"].status == ScanQueueStatus.PAUSED assert queue_manager.queues["primary"].status == ScanQueueStatus.PAUSED
assert len(queue_manager.producer.message_sent) == 2 assert len(queue_manager.connector.message_sent) == 2
assert ( assert (
queue_manager.producer.message_sent[0].get("queue") == MessageEndpoints.scan_queue_status() queue_manager.connector.message_sent[0].get("queue") == MessageEndpoints.scan_queue_status()
) )
def test_set_abort_with_empty_queue(queuemanager_mock): def test_set_abort_with_empty_queue(queuemanager_mock):
queue_manager = queuemanager_mock() queue_manager = queuemanager_mock()
queue_manager.producer.message_sent = [] queue_manager.connector.message_sent = []
queue_manager.set_abort(queue="primary") queue_manager.set_abort(queue="primary")
assert queue_manager.queues["primary"].status == ScanQueueStatus.RUNNING assert queue_manager.queues["primary"].status == ScanQueueStatus.RUNNING
assert len(queue_manager.producer.message_sent) == 0 assert len(queue_manager.connector.message_sent) == 0
def test_set_clear_sends_message(queuemanager_mock): def test_set_clear_sends_message(queuemanager_mock):
queue_manager = queuemanager_mock() queue_manager = queuemanager_mock()
queue_manager.producer.message_sent = [] queue_manager.connector.message_sent = []
setter_mock = mock.Mock(wraps=ScanQueue.worker_status.fset) setter_mock = mock.Mock(wraps=ScanQueue.worker_status.fset)
# pylint: disable=assignment-from-no-return # pylint: disable=assignment-from-no-return
# pylint: disable=too-many-function-args # pylint: disable=too-many-function-args
@ -238,9 +238,9 @@ def test_set_clear_sends_message(queuemanager_mock):
mock_property.fset.assert_called_once_with( mock_property.fset.assert_called_once_with(
queue_manager.queues["primary"], InstructionQueueStatus.STOPPED queue_manager.queues["primary"], InstructionQueueStatus.STOPPED
) )
assert len(queue_manager.producer.message_sent) == 1 assert len(queue_manager.connector.message_sent) == 1
assert ( assert (
queue_manager.producer.message_sent[0].get("queue") queue_manager.connector.message_sent[0].get("queue")
== MessageEndpoints.scan_queue_status() == MessageEndpoints.scan_queue_status()
) )

View File

@ -11,7 +11,7 @@ from scan_server.scan_stubs import ScanAbortion, ScanStubs
@pytest.fixture @pytest.fixture
def stubs(): def stubs():
connector = ConnectorMock("") connector = ConnectorMock("")
yield ScanStubs(connector.producer()) yield ScanStubs(connector)
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -36,7 +36,11 @@ def stubs():
device="rtx", device="rtx",
action="kickoff", action="kickoff",
parameter={ parameter={
"configure": {"num_pos": 5, "positions": [1, 2, 3, 4, 5], "exp_time": 2}, "configure": {
"num_pos": 5,
"positions": [1, 2, 3, 4, 5],
"exp_time": 2,
},
"wait_group": "kickoff", "wait_group": "kickoff",
}, },
metadata={}, metadata={},
@ -45,6 +49,8 @@ def stubs():
], ],
) )
def test_kickoff(stubs, device, parameter, metadata, reference_msg): def test_kickoff(stubs, device, parameter, metadata, reference_msg):
connector = ConnectorMock("")
stubs = ScanStubs(connector)
msg = list(stubs.kickoff(device=device, parameter=parameter, metadata=metadata)) msg = list(stubs.kickoff(device=device, parameter=parameter, metadata=metadata))
assert msg[0] == reference_msg assert msg[0] == reference_msg
@ -52,12 +58,19 @@ def test_kickoff(stubs, device, parameter, metadata, reference_msg):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"msg,raised_error", "msg,raised_error",
[ [
(messages.DeviceRPCMessage(device="samx", return_val="", out="", success=True), None), (
messages.DeviceRPCMessage(device="samx", return_val="", out="", success=True),
None,
),
( (
messages.DeviceRPCMessage( messages.DeviceRPCMessage(
device="samx", device="samx",
return_val="", return_val="",
out={"error": "TypeError", "msg": "some weird error", "traceback": "traceback"}, out={
"error": "TypeError",
"msg": "some weird error",
"traceback": "traceback",
},
success=False, success=False,
), ),
ScanAbortion, ScanAbortion,
@ -69,8 +82,7 @@ def test_kickoff(stubs, device, parameter, metadata, reference_msg):
], ],
) )
def test_rpc_raises_scan_abortion(stubs, msg, raised_error): def test_rpc_raises_scan_abortion(stubs, msg, raised_error):
msg = msg with mock.patch.object(stubs.connector, "get", return_value=msg) as prod_get:
with mock.patch.object(stubs.producer, "get", return_value=msg) as prod_get:
if raised_error is None: if raised_error is None:
stubs._get_from_rpc("rpc-id") stubs._get_from_rpc("rpc-id")
else: else:
@ -106,8 +118,8 @@ def test_rpc_raises_scan_abortion(stubs, msg, raised_error):
def test_device_progress(stubs, msg, ret_value, raised_error): def test_device_progress(stubs, msg, ret_value, raised_error):
if raised_error: if raised_error:
with pytest.raises(DeviceMessageError): with pytest.raises(DeviceMessageError):
with mock.patch.object(stubs.producer, "get", return_value=msg): with mock.patch.object(stubs.connector, "get", return_value=msg):
assert stubs.get_device_progress(device="samx", RID="rid") == ret_value assert stubs.get_device_progress(device="samx", RID="rid") == ret_value
return return
with mock.patch.object(stubs.producer, "get", return_value=msg): with mock.patch.object(stubs.connector, "get", return_value=msg):
assert stubs.get_device_progress(device="samx", RID="rid") == ret_value assert stubs.get_device_progress(device="samx", RID="rid") == ret_value

View File

@ -4,7 +4,7 @@ from unittest import mock
import pytest import pytest
from bec_lib import MessageEndpoints, messages from bec_lib import MessageEndpoints, messages
from bec_lib.tests.utils import ProducerMock, dm, dm_with_devices from bec_lib.tests.utils import ConnectorMock, dm, dm_with_devices
from utils import scan_server_mock from utils import scan_server_mock
from scan_server.errors import DeviceMessageError, ScanAbortion from scan_server.errors import DeviceMessageError, ScanAbortion
@ -22,7 +22,7 @@ from scan_server.scan_worker import ScanWorker
@pytest.fixture @pytest.fixture
def scan_worker_mock(scan_server_mock) -> ScanWorker: def scan_worker_mock(scan_server_mock) -> ScanWorker:
scan_server_mock.device_manager.producer = mock.MagicMock() scan_server_mock.device_manager.connector = mock.MagicMock()
scan_worker = ScanWorker(parent=scan_server_mock) scan_worker = ScanWorker(parent=scan_server_mock)
yield scan_worker yield scan_worker
@ -295,7 +295,7 @@ def test_wait_for_devices(scan_worker_mock, instructions, wait_type):
def test_complete_devices(scan_worker_mock, instructions): def test_complete_devices(scan_worker_mock, instructions):
worker = scan_worker_mock worker = scan_worker_mock
with mock.patch.object(worker, "_wait_for_status") as wait_for_status_mock: with mock.patch.object(worker, "_wait_for_status") as wait_for_status_mock:
with mock.patch.object(worker.device_manager.producer, "send") as send_mock: with mock.patch.object(worker.device_manager.connector, "send") as send_mock:
worker.complete_devices(instructions) worker.complete_devices(instructions)
if instructions.content["device"]: if instructions.content["device"]:
devices = instructions.content["device"] devices = instructions.content["device"]
@ -328,7 +328,7 @@ def test_complete_devices(scan_worker_mock, instructions):
) )
def test_pre_scan(scan_worker_mock, instructions): def test_pre_scan(scan_worker_mock, instructions):
worker = scan_worker_mock worker = scan_worker_mock
with mock.patch.object(worker.device_manager.producer, "send") as send_mock: with mock.patch.object(worker.device_manager.connector, "send") as send_mock:
with mock.patch.object(worker, "_wait_for_status") as wait_for_status_mock: with mock.patch.object(worker, "_wait_for_status") as wait_for_status_mock:
worker.pre_scan(instructions) worker.pre_scan(instructions)
devices = [dev.name for dev in worker.device_manager.devices.enabled_devices] devices = [dev.name for dev in worker.device_manager.devices.enabled_devices]
@ -457,12 +457,12 @@ def test_pre_scan(scan_worker_mock, instructions):
) )
def test_check_for_failed_movements(scan_worker_mock, device_status, devices, instr, abort): def test_check_for_failed_movements(scan_worker_mock, device_status, devices, instr, abort):
worker = scan_worker_mock worker = scan_worker_mock
worker.device_manager.producer = ProducerMock() worker.device_manager.connector = ConnectorMock()
if abort: if abort:
with pytest.raises(ScanAbortion): with pytest.raises(ScanAbortion):
worker.device_manager.producer._get_buffer[MessageEndpoints.device_readback("samx")] = ( worker.device_manager.connector._get_buffer[
messages.DeviceMessage(signals={"samx": {"value": 4}}, metadata={}) MessageEndpoints.device_readback("samx")
) ] = messages.DeviceMessage(signals={"samx": {"value": 4}}, metadata={})
worker._check_for_failed_movements(device_status, devices, instr) worker._check_for_failed_movements(device_status, devices, instr)
else: else:
worker._check_for_failed_movements(device_status, devices, instr) worker._check_for_failed_movements(device_status, devices, instr)
@ -577,12 +577,12 @@ def test_check_for_failed_movements(scan_worker_mock, device_status, devices, in
) )
def test_wait_for_idle(scan_worker_mock, msg1, msg2, req_msg: messages.DeviceReqStatusMessage): def test_wait_for_idle(scan_worker_mock, msg1, msg2, req_msg: messages.DeviceReqStatusMessage):
worker = scan_worker_mock worker = scan_worker_mock
worker.device_manager.producer = ProducerMock() worker.device_manager.connector = ConnectorMock()
with mock.patch.object( with mock.patch.object(
worker.validate, "get_device_status", return_value=[req_msg] worker.validate, "get_device_status", return_value=[req_msg]
) as device_status: ) as device_status:
worker.device_manager.producer._get_buffer[MessageEndpoints.device_readback("samx")] = ( worker.device_manager.connector._get_buffer[MessageEndpoints.device_readback("samx")] = (
messages.DeviceMessage(signals={"samx": {"value": 4}}, metadata={}) messages.DeviceMessage(signals={"samx": {"value": 4}}, metadata={})
) )
@ -635,7 +635,7 @@ def test_wait_for_idle(scan_worker_mock, msg1, msg2, req_msg: messages.DeviceReq
) )
def test_wait_for_read(scan_worker_mock, msg1, msg2, req_msg: messages.DeviceReqStatusMessage): def test_wait_for_read(scan_worker_mock, msg1, msg2, req_msg: messages.DeviceReqStatusMessage):
worker = scan_worker_mock worker = scan_worker_mock
worker.device_manager.producer = ProducerMock() worker.device_manager.connector = ConnectorMock()
with mock.patch.object( with mock.patch.object(
worker.validate, "get_device_status", return_value=[req_msg] worker.validate, "get_device_status", return_value=[req_msg]
@ -643,9 +643,9 @@ def test_wait_for_read(scan_worker_mock, msg1, msg2, req_msg: messages.DeviceReq
with mock.patch.object(worker, "_check_for_interruption") as interruption_mock: with mock.patch.object(worker, "_check_for_interruption") as interruption_mock:
assert worker._groups == {} assert worker._groups == {}
worker._groups["scan_motor"] = {"samx": 3, "samy": 4} worker._groups["scan_motor"] = {"samx": 3, "samy": 4}
worker.device_manager.producer._get_buffer[MessageEndpoints.device_readback("samx")] = ( worker.device_manager.connector._get_buffer[
messages.DeviceMessage(signals={"samx": {"value": 4}}, metadata={}) MessageEndpoints.device_readback("samx")
) ] = messages.DeviceMessage(signals={"samx": {"value": 4}}, metadata={})
worker._add_wait_group(msg1) worker._add_wait_group(msg1)
worker._wait_for_read(msg2) worker._wait_for_read(msg2)
assert worker._groups == {"scan_motor": {"samy": 4}} assert worker._groups == {"scan_motor": {"samy": 4}}
@ -730,7 +730,7 @@ def test_wait_for_device_server(scan_worker_mock):
) )
def test_set_devices(scan_worker_mock, instr): def test_set_devices(scan_worker_mock, instr):
worker = scan_worker_mock worker = scan_worker_mock
with mock.patch.object(worker.device_manager.producer, "send") as send_mock: with mock.patch.object(worker.device_manager.connector, "send") as send_mock:
worker.set_devices(instr) worker.set_devices(instr)
send_mock.assert_called_once_with(MessageEndpoints.device_instructions(), instr) send_mock.assert_called_once_with(MessageEndpoints.device_instructions(), instr)
@ -755,7 +755,7 @@ def test_set_devices(scan_worker_mock, instr):
) )
def test_trigger_devices(scan_worker_mock, instr): def test_trigger_devices(scan_worker_mock, instr):
worker = scan_worker_mock worker = scan_worker_mock
with mock.patch.object(worker.device_manager.producer, "send") as send_mock: with mock.patch.object(worker.device_manager.connector, "send") as send_mock:
worker.trigger_devices(instr) worker.trigger_devices(instr)
devices = [ devices = [
dev.name for dev in worker.device_manager.devices.get_software_triggered_devices() dev.name for dev in worker.device_manager.devices.get_software_triggered_devices()
@ -797,7 +797,7 @@ def test_trigger_devices(scan_worker_mock, instr):
) )
def test_send_rpc(scan_worker_mock, instr): def test_send_rpc(scan_worker_mock, instr):
worker = scan_worker_mock worker = scan_worker_mock
with mock.patch.object(worker.device_manager.producer, "send") as send_mock: with mock.patch.object(worker.device_manager.connector, "send") as send_mock:
worker.send_rpc(instr) worker.send_rpc(instr)
send_mock.assert_called_once_with(MessageEndpoints.device_instructions(), instr) send_mock.assert_called_once_with(MessageEndpoints.device_instructions(), instr)
@ -840,7 +840,7 @@ def test_read_devices(scan_worker_mock, instr):
instr_devices = [] instr_devices = []
worker.readout_priority.update({"monitored": instr_devices}) worker.readout_priority.update({"monitored": instr_devices})
devices = [dev.name for dev in worker._get_devices_from_instruction(instr)] devices = [dev.name for dev in worker._get_devices_from_instruction(instr)]
with mock.patch.object(worker.device_manager.producer, "send") as send_mock: with mock.patch.object(worker.device_manager.connector, "send") as send_mock:
worker.read_devices(instr) worker.read_devices(instr)
if instr.content.get("device"): if instr.content.get("device"):
@ -888,7 +888,7 @@ def test_read_devices(scan_worker_mock, instr):
) )
def test_kickoff_devices(scan_worker_mock, instr, devices, parameter, metadata): def test_kickoff_devices(scan_worker_mock, instr, devices, parameter, metadata):
worker = scan_worker_mock worker = scan_worker_mock
with mock.patch.object(worker.device_manager.producer, "send") as send_mock: with mock.patch.object(worker.device_manager.connector, "send") as send_mock:
worker.kickoff_devices(instr) worker.kickoff_devices(instr)
send_mock.assert_called_once_with( send_mock.assert_called_once_with(
MessageEndpoints.device_instructions(), MessageEndpoints.device_instructions(),
@ -920,29 +920,27 @@ def test_kickoff_devices(scan_worker_mock, instr, devices, parameter, metadata):
def test_publish_readback(scan_worker_mock, instr, devices): def test_publish_readback(scan_worker_mock, instr, devices):
worker = scan_worker_mock worker = scan_worker_mock
with mock.patch.object(worker, "_get_readback", return_value=[{}]) as get_readback: with mock.patch.object(worker, "_get_readback", return_value=[{}]) as get_readback:
with mock.patch.object(worker.device_manager, "producer") as producer_mock: with mock.patch.object(worker.device_manager, "connector") as connector_mock:
worker._publish_readback(instr) worker._publish_readback(instr)
get_readback.assert_called_once_with(["samx"]) get_readback.assert_called_once_with(["samx"])
pipe = producer_mock.pipeline() pipe = connector_mock.pipeline()
msg = messages.DeviceMessage(signals={}, metadata=instr.metadata) msg = messages.DeviceMessage(signals={}, metadata=instr.metadata)
connector_mock.set_and_publish.assert_called_once_with(
producer_mock.set_and_publish.assert_called_once_with(
MessageEndpoints.device_read("samx"), msg, pipe MessageEndpoints.device_read("samx"), msg, pipe
) )
pipe.execute.assert_called_once()
def test_get_readback(scan_worker_mock): def test_get_readback(scan_worker_mock):
worker = scan_worker_mock worker = scan_worker_mock
devices = ["samx"] devices = ["samx"]
with mock.patch.object(worker.device_manager, "producer") as producer_mock: with mock.patch.object(worker.device_manager, "connector") as connector_mock:
worker._get_readback(devices) worker._get_readback(devices)
pipe = producer_mock.pipeline() pipe = connector_mock.pipeline()
producer_mock.get.assert_called_once_with( connector_mock.get.assert_called_once_with(
MessageEndpoints.device_readback("samx"), pipe=pipe MessageEndpoints.device_readback("samx"), pipe=pipe
) )
producer_mock.execute_pipeline.assert_called_once() connector_mock.execute_pipeline.assert_called_once()
def test_publish_data_as_read(scan_worker_mock): def test_publish_data_as_read(scan_worker_mock):
@ -958,12 +956,12 @@ def test_publish_data_as_read(scan_worker_mock):
"RID": "requestID", "RID": "requestID",
}, },
) )
with mock.patch.object(worker.device_manager, "producer") as producer_mock: with mock.patch.object(worker.device_manager, "connector") as connector_mock:
worker.publish_data_as_read(instr) worker.publish_data_as_read(instr)
msg = messages.DeviceMessage( msg = messages.DeviceMessage(
signals=instr.content["parameter"]["data"], metadata=instr.metadata signals=instr.content["parameter"]["data"], metadata=instr.metadata
) )
producer_mock.set_and_publish.assert_called_once_with( connector_mock.set_and_publish.assert_called_once_with(
MessageEndpoints.device_read("samx"), msg MessageEndpoints.device_read("samx"), msg
) )
@ -983,13 +981,13 @@ def test_publish_data_as_read_multiple(scan_worker_mock):
"RID": "requestID", "RID": "requestID",
}, },
) )
with mock.patch.object(worker.device_manager, "producer") as producer_mock: with mock.patch.object(worker.device_manager, "connector") as connector_mock:
worker.publish_data_as_read(instr) worker.publish_data_as_read(instr)
mock_calls = [] mock_calls = []
for device, dev_data in zip(devices, data): for device, dev_data in zip(devices, data):
msg = messages.DeviceMessage(signals=dev_data, metadata=instr.metadata) msg = messages.DeviceMessage(signals=dev_data, metadata=instr.metadata)
mock_calls.append(mock.call(MessageEndpoints.device_read(device), msg)) mock_calls.append(mock.call(MessageEndpoints.device_read(device), msg))
assert producer_mock.set_and_publish.mock_calls == mock_calls assert connector_mock.set_and_publish.mock_calls == mock_calls
def test_check_for_interruption(scan_worker_mock): def test_check_for_interruption(scan_worker_mock):
@ -1048,7 +1046,7 @@ def test_open_scan(scan_worker_mock, instr, corr_num_points, scan_id):
if "pointID" in instr.metadata: if "pointID" in instr.metadata:
worker.max_point_id = instr.metadata["pointID"] worker.max_point_id = instr.metadata["pointID"]
assert worker.parent.producer.get(MessageEndpoints.scan_number()) == None assert worker.parent.connector.get(MessageEndpoints.scan_number()) == None
with mock.patch.object(worker, "current_instruction_queue_item") as queue_mock: with mock.patch.object(worker, "current_instruction_queue_item") as queue_mock:
with mock.patch.object(worker, "_initialize_scan_info") as init_mock: with mock.patch.object(worker, "_initialize_scan_info") as init_mock:
@ -1181,7 +1179,7 @@ def test_stage_device(scan_worker_mock, msg):
worker.device_manager.devices["eiger"]._config["readoutPriority"] = "async" worker.device_manager.devices["eiger"]._config["readoutPriority"] = "async"
with mock.patch.object(worker, "_wait_for_stage") as wait_mock: with mock.patch.object(worker, "_wait_for_stage") as wait_mock:
with mock.patch.object(worker.device_manager.producer, "send") as send_mock: with mock.patch.object(worker.device_manager.connector, "send") as send_mock:
worker.stage_devices(msg) worker.stage_devices(msg)
async_devices = [dev.name for dev in worker.device_manager.devices.async_devices()] async_devices = [dev.name for dev in worker.device_manager.devices.async_devices()]
devices = [ devices = [
@ -1251,7 +1249,7 @@ def test_unstage_device(scan_worker_mock, msg, devices, parameter, metadata, cle
if not devices: if not devices:
devices = [dev.name for dev in worker.device_manager.devices.enabled_devices] devices = [dev.name for dev in worker.device_manager.devices.enabled_devices]
with mock.patch.object(worker.device_manager.producer, "send") as send_mock: with mock.patch.object(worker.device_manager.connector, "send") as send_mock:
with mock.patch.object(worker, "_wait_for_stage") as wait_mock: with mock.patch.object(worker, "_wait_for_stage") as wait_mock:
worker.unstage_devices(msg, devices, cleanup) worker.unstage_devices(msg, devices, cleanup)
@ -1270,12 +1268,12 @@ def test_unstage_device(scan_worker_mock, msg, devices, parameter, metadata, cle
@pytest.mark.parametrize("status,expire", [("open", None), ("closed", 1800), ("aborted", 1800)]) @pytest.mark.parametrize("status,expire", [("open", None), ("closed", 1800), ("aborted", 1800)])
def test_send_scan_status(scan_worker_mock, status, expire): def test_send_scan_status(scan_worker_mock, status, expire):
worker = scan_worker_mock worker = scan_worker_mock
worker.device_manager.producer = ProducerMock() worker.device_manager.connector = ConnectorMock()
worker.current_scanID = str(uuid.uuid4()) worker.current_scanID = str(uuid.uuid4())
worker._send_scan_status(status) worker._send_scan_status(status)
scan_info_msgs = [ scan_info_msgs = [
msg msg
for msg in worker.device_manager.producer.message_sent for msg in worker.device_manager.connector.message_sent
if msg["queue"] == MessageEndpoints.public_scan_info(scanID=worker.current_scanID) if msg["queue"] == MessageEndpoints.public_scan_info(scanID=worker.current_scanID)
] ]
assert len(scan_info_msgs) == 1 assert len(scan_info_msgs) == 1

View File

@ -6,7 +6,7 @@ import numpy as np
import pytest import pytest
from bec_lib import messages from bec_lib import messages
from bec_lib.devicemanager import DeviceContainer from bec_lib.devicemanager import DeviceContainer
from bec_lib.tests.utils import ProducerMock from bec_lib.tests.utils import ConnectorMock
from scan_plugins.LamNIFermatScan import LamNIFermatScan from scan_plugins.LamNIFermatScan import LamNIFermatScan
from scan_plugins.otf_scan import OTFScan from scan_plugins.otf_scan import OTFScan
@ -80,7 +80,7 @@ class DeviceMock:
class DMMock: class DMMock:
devices = DeviceContainer() devices = DeviceContainer()
producer = ProducerMock() connector = ConnectorMock()
def add_device(self, name): def add_device(self, name):
self.devices[name] = DeviceMock(name) self.devices[name] = DeviceMock(name)
@ -1099,7 +1099,7 @@ def test_pre_scan_macro():
device_manager=device_manager, parameter=scan_msg.content["parameter"] device_manager=device_manager, parameter=scan_msg.content["parameter"]
) )
with mock.patch.object( with mock.patch.object(
request.device_manager.producer, request.device_manager.connector,
"lrange", "lrange",
new_callable=mock.PropertyMock, new_callable=mock.PropertyMock,
return_value=macros, return_value=macros,

View File

@ -26,9 +26,9 @@ dir_path = os.path.abspath(os.path.join(os.path.dirname(bec_lib.__file__), "./co
class ConfigHandler: class ConfigHandler:
def __init__(self, scibec_connector: SciBecConnector, connector: ConnectorBase) -> None: def __init__(self, scibec_connector: SciBecConnector, connector: ConnectorBase) -> None:
self.scibec_connector = scibec_connector self.scibec_connector = scibec_connector
self.connector = connector
self.device_manager = DeviceManager(self.scibec_connector.scihub) self.device_manager = DeviceManager(self.scibec_connector.scihub)
self.device_manager.initialize(scibec_connector.config.redis) self.device_manager.initialize(scibec_connector.config.redis)
self.producer = connector.producer()
self.validator = SciBecValidator(os.path.join(dir_path, "openapi_schema.json")) self.validator = SciBecValidator(os.path.join(dir_path, "openapi_schema.json"))
def parse_config_request(self, msg: messages.DeviceConfigMessage) -> None: def parse_config_request(self, msg: messages.DeviceConfigMessage) -> None:
@ -53,7 +53,7 @@ class ConfigHandler:
def send_config(self, msg: messages.DeviceConfigMessage) -> None: def send_config(self, msg: messages.DeviceConfigMessage) -> None:
"""broadcast a new config""" """broadcast a new config"""
self.producer.send(MessageEndpoints.device_config_update(), msg) self.connector.send(MessageEndpoints.device_config_update(), msg)
def send_config_request_reply(self, accepted, error_msg, metadata): def send_config_request_reply(self, accepted, error_msg, metadata):
"""send a config request reply""" """send a config request reply"""
@ -61,7 +61,7 @@ class ConfigHandler:
accepted=accepted, message=error_msg, metadata=metadata accepted=accepted, message=error_msg, metadata=metadata
) )
RID = metadata.get("RID") RID = metadata.get("RID")
self.producer.set(MessageEndpoints.device_config_request_response(RID), msg, expire=60) self.connector.set(MessageEndpoints.device_config_request_response(RID), msg, expire=60)
def _set_config(self, msg: messages.DeviceConfigMessage): def _set_config(self, msg: messages.DeviceConfigMessage):
config = msg.content["config"] config = msg.content["config"]
@ -127,14 +127,14 @@ class ConfigHandler:
def _update_device_server(self, RID: str, config: dict, action="update") -> None: def _update_device_server(self, RID: str, config: dict, action="update") -> None:
msg = messages.DeviceConfigMessage(action=action, config=config, metadata={"RID": RID}) msg = messages.DeviceConfigMessage(action=action, config=config, metadata={"RID": RID})
self.producer.send(MessageEndpoints.device_server_config_request(), msg) self.connector.send(MessageEndpoints.device_server_config_request(), msg)
def _wait_for_device_server_update(self, RID: str, timeout_time=10) -> bool: def _wait_for_device_server_update(self, RID: str, timeout_time=10) -> bool:
timeout = timeout_time timeout = timeout_time
time_step = 0.05 time_step = 0.05
elapsed_time = 0 elapsed_time = 0
while True: while True:
msg = self.producer.get(MessageEndpoints.device_config_request_response(RID)) msg = self.connector.get(MessageEndpoints.device_config_request_response(RID))
if msg: if msg:
return msg.content["accepted"], msg return msg.content["accepted"], msg
@ -188,11 +188,11 @@ class ConfigHandler:
self.validator.validate_device_patch(update) self.validator.validate_device_patch(update)
def update_config_in_redis(self, device): def update_config_in_redis(self, device):
config = self.device_manager.producer.get(MessageEndpoints.device_config()) config = self.device_manager.connector.get(MessageEndpoints.device_config())
config = config.content["resource"] config = config.content["resource"]
index = next( index = next(
index for index, dev_conf in enumerate(config) if dev_conf["name"] == device.name index for index, dev_conf in enumerate(config) if dev_conf["name"] == device.name
) )
config[index] = device._config config[index] = device._config
msg = messages.AvailableResourceMessage(resource=config) msg = messages.AvailableResourceMessage(resource=config)
self.device_manager.producer.set(MessageEndpoints.device_config(), msg) self.device_manager.connector.set(MessageEndpoints.device_config(), msg)

View File

@ -31,7 +31,6 @@ class SciBecConnector:
def __init__(self, scihub: SciHub, connector: ConnectorBase) -> None: def __init__(self, scihub: SciHub, connector: ConnectorBase) -> None:
self.scihub = scihub self.scihub = scihub
self.connector = connector self.connector = connector
self.producer = connector.producer()
self.scibec = None self.scibec = None
self.host = None self.host = None
self.target_bl = None self.target_bl = None
@ -132,25 +131,24 @@ class SciBecConnector:
""" """
Set the scibec account in redis Set the scibec account in redis
""" """
self.producer.set( self.connector.set(
MessageEndpoints.scibec(), MessageEndpoints.scibec(),
messages.CredentialsMessage(credentials={"url": self.host, "token": f"Bearer {token}"}), messages.CredentialsMessage(credentials={"url": self.host, "token": f"Bearer {token}"}),
) )
def set_redis_config(self, config): def set_redis_config(self, config):
msg = messages.AvailableResourceMessage(resource=config) msg = messages.AvailableResourceMessage(resource=config)
self.producer.set(MessageEndpoints.device_config(), msg) self.connector.set(MessageEndpoints.device_config(), msg)
def _start_metadata_handler(self) -> None: def _start_metadata_handler(self) -> None:
self._metadata_handler = SciBecMetadataHandler(self) self._metadata_handler = SciBecMetadataHandler(self)
def _start_config_request_handler(self) -> None: def _start_config_request_handler(self) -> None:
self._config_request_handler = self.connector.consumer( self._config_request_handler = self.connector.register(
MessageEndpoints.device_config_request(), MessageEndpoints.device_config_request(),
cb=self._device_config_request_callback, cb=self._device_config_request_callback,
parent=self, parent=self,
) )
self._config_request_handler.start()
@staticmethod @staticmethod
def _device_config_request_callback(msg, *, parent, **_kwargs) -> None: def _device_config_request_callback(msg, *, parent, **_kwargs) -> None:
@ -159,7 +157,7 @@ class SciBecConnector:
def connect_to_scibec(self): def connect_to_scibec(self):
""" """
Connect to SciBec and set the producer to the write account Connect to SciBec and set the connector to the write account
""" """
self._load_environment() self._load_environment()
if not self._env_configured: if not self._env_configured:
@ -205,7 +203,7 @@ class SciBecConnector:
write_account = self.scibec_info["activeExperiment"]["writeAccount"] write_account = self.scibec_info["activeExperiment"]["writeAccount"]
if write_account[0] == "p": if write_account[0] == "p":
write_account = write_account.replace("p", "e") write_account = write_account.replace("p", "e")
self.producer.set(MessageEndpoints.account(), write_account.encode()) self.connector.set(MessageEndpoints.account(), write_account.encode())
def shutdown(self): def shutdown(self):
""" """

View File

@ -15,22 +15,20 @@ if TYPE_CHECKING:
class SciBecMetadataHandler: class SciBecMetadataHandler:
def __init__(self, scibec_connector: SciBecConnector) -> None: def __init__(self, scibec_connector: SciBecConnector) -> None:
self.scibec_connector = scibec_connector self.scibec_connector = scibec_connector
self._scan_status_consumer = None self._scan_status_register = None
self._start_scan_subscription() self._start_scan_subscription()
self._file_subscription = None self._file_subscription = None
self._start_file_subscription() self._start_file_subscription()
def _start_scan_subscription(self): def _start_scan_subscription(self):
self._scan_status_consumer = self.scibec_connector.connector.consumer( self._scan_status_register = self.scibec_connector.connector.register(
MessageEndpoints.scan_status(), cb=self._handle_scan_status, parent=self MessageEndpoints.scan_status(), cb=self._handle_scan_status, parent=self
) )
self._scan_status_consumer.start()
def _start_file_subscription(self): def _start_file_subscription(self):
self._file_subscription = self.scibec_connector.connector.consumer( self._file_subscription = self.scibec_connector.connector.register(
MessageEndpoints.file_content(), cb=self._handle_file_content, parent=self MessageEndpoints.file_content(), cb=self._handle_file_content, parent=self
) )
self._file_subscription.start()
@staticmethod @staticmethod
def _handle_scan_status(msg, *, parent, **_kwargs) -> None: def _handle_scan_status(msg, *, parent, **_kwargs) -> None:
@ -171,7 +169,7 @@ class SciBecMetadataHandler:
""" """
Shutdown the metadata handler Shutdown the metadata handler
""" """
if self._scan_status_consumer: if self._scan_status_register:
self._scan_status_consumer.shutdown() self._scan_status_register.shutdown()
if self._file_subscription: if self._file_subscription:
self._file_subscription.shutdown() self._file_subscription.shutdown()

View File

@ -22,7 +22,6 @@ class SciLogConnector:
def __init__(self, scihub: SciHub, connector: RedisConnector) -> None: def __init__(self, scihub: SciHub, connector: RedisConnector) -> None:
self.scihub = scihub self.scihub = scihub
self.connector = connector self.connector = connector
self.producer = self.connector.producer()
self.host = None self.host = None
self.user = None self.user = None
self.user_secret = None self.user_secret = None
@ -44,7 +43,7 @@ class SciLogConnector:
def set_bec_token(self, token: str) -> None: def set_bec_token(self, token: str) -> None:
"""set the scilog token in redis""" """set the scilog token in redis"""
self.producer.set( self.connector.set(
MessageEndpoints.logbook(), MessageEndpoints.logbook(),
msgpack.dumps({"url": self.host, "user": self.user, "token": f"Bearer {token}"}), msgpack.dumps({"url": self.host, "user": self.user, "token": f"Bearer {token}"}),
) )

View File

@ -334,7 +334,7 @@ def test_config_handler_update_device_config_available_keys(config_handler, avai
def test_config_handler_wait_for_device_server_update(config_handler): def test_config_handler_wait_for_device_server_update(config_handler):
RID = "12345" RID = "12345"
with mock.patch.object(config_handler.producer, "get") as mock_get: with mock.patch.object(config_handler.connector, "get") as mock_get:
mock_get.side_effect = [ mock_get.side_effect = [
None, None,
None, None,
@ -346,7 +346,7 @@ def test_config_handler_wait_for_device_server_update(config_handler):
def test_config_handler_wait_for_device_server_update_timeout(config_handler): def test_config_handler_wait_for_device_server_update_timeout(config_handler):
RID = "12345" RID = "12345"
with mock.patch.object(config_handler.producer, "get", return_value=None) as mock_get: with mock.patch.object(config_handler.connector, "get", return_value=None) as mock_get:
with pytest.raises(TimeoutError): with pytest.raises(TimeoutError):
config_handler._wait_for_device_server_update(RID, timeout_time=0.1) config_handler._wait_for_device_server_update(RID, timeout_time=0.1)
mock_get.assert_called() mock_get.assert_called()

View File

@ -138,6 +138,6 @@ def test_scibec_update_experiment_info(SciBecMock):
def test_update_eaccount_in_redis(SciBecMock): def test_update_eaccount_in_redis(SciBecMock):
SciBecMock.scibec_info = {"activeExperiment": {"writeAccount": "p12345"}} SciBecMock.scibec_info = {"activeExperiment": {"writeAccount": "p12345"}}
with mock.patch.object(SciBecMock, "producer") as mock_producer: with mock.patch.object(SciBecMock, "connector") as mock_connector:
SciBecMock._update_eaccount_in_redis() SciBecMock._update_eaccount_in_redis()
mock_producer.set.assert_called_once_with(MessageEndpoints.account(), b"e12345") mock_connector.set.assert_called_once_with(MessageEndpoints.account(), b"e12345")