refactor!(connector): unify connector/redis_connector in one class

This commit is contained in:
guijar_m 2024-01-31 13:43:01 +01:00
parent 4edc5d02fe
commit b92a79b0c0
86 changed files with 1212 additions and 1745 deletions

View File

@ -38,16 +38,16 @@ class ReadbackDataMixin:
def get_request_done_msgs(self):
"""get all request-done messages"""
pipe = self.device_manager.producer.pipeline()
pipe = self.device_manager.connector.pipeline()
for dev in self.devices:
self.device_manager.producer.get(MessageEndpoints.device_req_status(dev), pipe)
return self.device_manager.producer.execute_pipeline(pipe)
self.device_manager.connector.get(MessageEndpoints.device_req_status(dev), pipe)
return self.device_manager.connector.execute_pipeline(pipe)
def wait_for_RID(self, request):
"""wait for the readback's metadata to match the request ID"""
while True:
msgs = [
self.device_manager.producer.get(MessageEndpoints.device_readback(dev))
self.device_manager.connector.get(MessageEndpoints.device_readback(dev))
for dev in self.devices
]
if all(msg.metadata.get("RID") == request.metadata["RID"] for msg in msgs if msg):

View File

@ -27,7 +27,7 @@ class LiveUpdatesScanProgress(LiveUpdatesTable):
Update the progressbar based on the device status message. Returns True if the scan is finished.
"""
self.check_alarms()
status = self.bec.producer.get(MessageEndpoints.device_progress(device_names[0]))
status = self.bec.connector.get(MessageEndpoints.device_progress(device_names[0]))
if not status:
logger.debug("waiting for new data point")
await asyncio.sleep(0.1)

View File

@ -13,7 +13,7 @@ from bec_client.callbacks.move_device import (
@pytest.fixture
def readback_data_mixin(bec_client):
with mock.patch.object(bec_client.device_manager, "producer"):
with mock.patch.object(bec_client.device_manager, "connector"):
yield ReadbackDataMixin(bec_client.device_manager, ["samx", "samy"])
@ -102,7 +102,7 @@ async def test_move_callback_with_report_instruction(bec_client):
def test_readback_data_mixin(readback_data_mixin):
readback_data_mixin.device_manager.producer.get.side_effect = [
readback_data_mixin.device_manager.connector.get.side_effect = [
messages.DeviceMessage(
signals={"samx": {"value": 10}, "samx_setpoint": {"value": 20}},
metadata={"device": "samx"},
@ -121,7 +121,7 @@ def test_readback_data_mixin_multiple_hints(readback_data_mixin):
"samx_setpoint",
"samx",
]
readback_data_mixin.device_manager.producer.get.side_effect = [
readback_data_mixin.device_manager.connector.get.side_effect = [
messages.DeviceMessage(
signals={"samx": {"value": 10}, "samx_setpoint": {"value": 20}},
metadata={"device": "samx"},
@ -137,7 +137,7 @@ def test_readback_data_mixin_multiple_hints(readback_data_mixin):
def test_readback_data_mixin_multiple_no_hints(readback_data_mixin):
readback_data_mixin.device_manager.devices.samx._info["hints"]["fields"] = []
readback_data_mixin.device_manager.producer.get.side_effect = [
readback_data_mixin.device_manager.connector.get.side_effect = [
messages.DeviceMessage(
signals={"samx": {"value": 10}, "samx_setpoint": {"value": 20}},
metadata={"device": "samx"},
@ -153,18 +153,18 @@ def test_readback_data_mixin_multiple_no_hints(readback_data_mixin):
def test_get_request_done_msgs(readback_data_mixin):
res = readback_data_mixin.get_request_done_msgs()
readback_data_mixin.device_manager.producer.pipeline.assert_called_once()
readback_data_mixin.device_manager.connector.pipeline.assert_called_once()
assert (
mock.call(
MessageEndpoints.device_req_status("samx"),
readback_data_mixin.device_manager.producer.pipeline.return_value,
readback_data_mixin.device_manager.connector.pipeline.return_value,
)
in readback_data_mixin.device_manager.producer.get.call_args_list
in readback_data_mixin.device_manager.connector.get.call_args_list
)
assert (
mock.call(
MessageEndpoints.device_req_status("samy"),
readback_data_mixin.device_manager.producer.pipeline.return_value,
readback_data_mixin.device_manager.connector.pipeline.return_value,
)
in readback_data_mixin.device_manager.producer.get.call_args_list
in readback_data_mixin.device_manager.connector.get.call_args_list
)

View File

@ -15,7 +15,7 @@ async def test_update_progressbar_continues_without_device_data():
live_update = LiveUpdatesScanProgress(bec=bec, report_instruction={}, request=request)
progressbar = mock.MagicMock()
bec.producer.get.return_value = None
bec.connector.get.return_value = None
res = await live_update._update_progressbar(progressbar, "async_dev1")
assert res is False
@ -29,7 +29,7 @@ async def test_update_progressbar_continues_when_scanID_doesnt_match():
live_update.scan_item = mock.MagicMock()
live_update.scan_item.scanID = "scanID2"
bec.producer.get.return_value = messages.ProgressMessage(
bec.connector.get.return_value = messages.ProgressMessage(
value=1, max_value=10, done=False, metadata={"scanID": "scanID"}
)
res = await live_update._update_progressbar(progressbar, "async_dev1")
@ -45,7 +45,7 @@ async def test_update_progressbar_continues_when_msg_specifies_no_value():
live_update.scan_item = mock.MagicMock()
live_update.scan_item.scanID = "scanID"
bec.producer.get.return_value = messages.ProgressMessage(
bec.connector.get.return_value = messages.ProgressMessage(
value=None, max_value=None, done=None, metadata={"scanID": "scanID"}
)
res = await live_update._update_progressbar(progressbar, "async_dev1")
@ -61,7 +61,7 @@ async def test_update_progressbar_updates_max_value():
live_update.scan_item = mock.MagicMock()
live_update.scan_item.scanID = "scanID"
bec.producer.get.return_value = messages.ProgressMessage(
bec.connector.get.return_value = messages.ProgressMessage(
value=10, max_value=20, done=False, metadata={"scanID": "scanID"}
)
res = await live_update._update_progressbar(progressbar, "async_dev1")
@ -79,7 +79,7 @@ async def test_update_progressbar_returns_true_when_max_value_is_reached():
live_update.scan_item = mock.MagicMock()
live_update.scan_item.scanID = "scanID"
bec.producer.get.return_value = messages.ProgressMessage(
bec.connector.get.return_value = messages.ProgressMessage(
value=10, max_value=10, done=True, metadata={"scanID": "scanID"}
)
res = await live_update._update_progressbar(progressbar, "async_dev1")

View File

@ -463,11 +463,11 @@ def test_file_writer(client):
md={"datasetID": 325},
)
assert len(scan.scan.data) == 100
msg = bec.device_manager.producer.get(MessageEndpoints.public_file(scan.scan.scanID, "master"))
msg = bec.device_manager.connector.get(MessageEndpoints.public_file(scan.scan.scanID, "master"))
while True:
if msg:
break
msg = bec.device_manager.producer.get(
msg = bec.device_manager.connector.get(
MessageEndpoints.public_file(scan.scan.scanID, "master")
)

View File

@ -3,7 +3,6 @@ from bec_lib.bec_service import BECService
from bec_lib.channel_monitor import channel_monitor_launch
from bec_lib.client import BECClient
from bec_lib.config_helper import ConfigHelper
from bec_lib.connector import ProducerConnector
from bec_lib.device import DeviceBase, DeviceStatus, Status
from bec_lib.devicemanager import DeviceConfigError, DeviceContainer, DeviceManagerBase
from bec_lib.endpoints import MessageEndpoints

View File

@ -48,23 +48,21 @@ class AlarmBase(Exception):
class AlarmHandler:
def __init__(self, connector: RedisConnector) -> None:
self.connector = connector
self.alarm_consumer = None
self.alarms_stack = deque(maxlen=100)
self._raised_alarms = deque(maxlen=100)
self._lock = threading.RLock()
def start(self):
"""start the alarm handler and its subscriptions"""
self.alarm_consumer = self.connector.consumer(
self.connector.register(
topics=MessageEndpoints.alarm(),
name="AlarmHandler",
cb=self._alarm_consumer_callback,
cb=self._alarm_register_callback,
parent=self,
)
self.alarm_consumer.start()
@staticmethod
def _alarm_consumer_callback(msg, *, parent, **_kwargs):
def _alarm_register_callback(msg, *, parent, **_kwargs):
parent.add_alarm(msg.value)
@threadlocked
@ -136,4 +134,4 @@ class AlarmHandler:
def shutdown(self):
"""shutdown the alarm handler"""
self.alarm_consumer.shutdown()
self.connector.shutdown()

View File

@ -8,12 +8,12 @@ from bec_lib.endpoints import MessageEndpoints
if TYPE_CHECKING:
from bec_lib import messages
from bec_lib.redis_connector import RedisProducer
from bec_lib.connector import ConnectorBase
class AsyncDataHandler:
def __init__(self, producer: RedisProducer):
self.producer = producer
def __init__(self, connector: ConnectorBase):
self.connector = connector
def get_async_data_for_scan(self, scan_id: str) -> dict[list]:
"""
@ -25,7 +25,9 @@ class AsyncDataHandler:
Returns:
dict[list]: the async data for the scan sorted by device name
"""
async_device_keys = self.producer.keys(MessageEndpoints.device_async_readback(scan_id, "*"))
async_device_keys = self.connector.keys(
MessageEndpoints.device_async_readback(scan_id, "*")
)
async_data = {}
for device_key in async_device_keys:
key = device_key.decode()
@ -50,7 +52,7 @@ class AsyncDataHandler:
list: the async data for the device
"""
key = MessageEndpoints.device_async_readback(scan_id, device_name)
msgs = self.producer.xrange(key, min="-", max="+")
msgs = self.connector.xrange(key, min="-", max="+")
if not msgs:
return []
return self.process_async_data(msgs)

View File

@ -54,14 +54,14 @@ class BECWidgetsConnector:
def __init__(self, gui_id: str, bec_client: BECClient = None) -> None:
self._client = bec_client
self.gui_id = gui_id
# TODO replace with a global producer
# TODO replace with a global connector
if self._client is None:
if "bec" in builtins.__dict__:
self._client = builtins.bec
else:
self._client = BECClient()
self._client.start()
self._producer = self._client.connector.producer()
self._connector = self._client.connector
def set_plot_config(self, plot_id: str, config: dict) -> None:
"""
@ -72,7 +72,7 @@ class BECWidgetsConnector:
config (dict): The config to set.
"""
msg = messages.GUIConfigMessage(config=config)
self._producer.set_and_publish(MessageEndpoints.gui_config(plot_id), msg)
self._connector.set_and_publish(MessageEndpoints.gui_config(plot_id), msg)
def close(self, plot_id: str) -> None:
"""
@ -82,7 +82,7 @@ class BECWidgetsConnector:
plot_id (str): The id of the plot.
"""
msg = messages.GUIInstructionMessage(action="close", parameter={})
self._producer.set_and_publish(MessageEndpoints.gui_instructions(plot_id), msg)
self._connector.set_and_publish(MessageEndpoints.gui_instructions(plot_id), msg)
def config_dialog(self, plot_id: str) -> None:
"""
@ -92,7 +92,7 @@ class BECWidgetsConnector:
plot_id (str): The id of the plot.
"""
msg = messages.GUIInstructionMessage(action="config_dialog", parameter={})
self._producer.set_and_publish(MessageEndpoints.gui_instructions(plot_id), msg)
self._connector.set_and_publish(MessageEndpoints.gui_instructions(plot_id), msg)
def send_data(self, plot_id: str, data: dict) -> None:
"""
@ -103,9 +103,9 @@ class BECWidgetsConnector:
data (dict): The data to send.
"""
msg = messages.GUIDataMessage(data=data)
self._producer.set_and_publish(topic=MessageEndpoints.gui_data(plot_id), msg=msg)
self._connector.set_and_publish(topic=MessageEndpoints.gui_data(plot_id), msg=msg)
# TODO bec_dispatcher can only handle set_and_publish ATM
# self._producer.xadd(topic=MessageEndpoints.gui_data(plot_id),msg= {"data": msg})
# self._connector.xadd(topic=MessageEndpoints.gui_data(plot_id),msg= {"data": msg})
def clear(self, plot_id: str) -> None:
"""
@ -115,7 +115,7 @@ class BECWidgetsConnector:
plot_id (str): The id of the plot.
"""
msg = messages.GUIInstructionMessage(action="clear", parameter={})
self._producer.set_and_publish(MessageEndpoints.gui_instructions(plot_id), msg)
self._connector.set_and_publish(MessageEndpoints.gui_instructions(plot_id), msg)
class BECPlotter:

View File

@ -41,7 +41,6 @@ class BECService:
self.connector = connector_cls(self.bootstrap_server)
self._unique_service = unique_service
self.wait_for_server = wait_for_server
self.producer = self.connector.producer()
self.__service_id = str(uuid.uuid4())
self._user = getpass.getuser()
self._hostname = socket.gethostname()
@ -110,11 +109,11 @@ class BECService:
)
def _update_existing_services(self):
service_keys = self.producer.keys(MessageEndpoints.service_status("*"))
service_keys = self.connector.keys(MessageEndpoints.service_status("*"))
if not service_keys:
return
services = [service.decode().split(":", maxsplit=1)[0] for service in service_keys]
msgs = [self.producer.get(service) for service in services]
msgs = [self.connector.get(service) for service in services]
self._services_info = {msg.content["name"]: msg for msg in msgs if msg is not None}
def _update_service_info(self):
@ -124,7 +123,7 @@ class BECService:
self._service_info_event.wait(timeout=3)
def _send_service_status(self):
self.producer.set_and_publish(
self.connector.set_and_publish(
topic=MessageEndpoints.service_status(self._service_id),
msg=messages.StatusMessage(
name=self._service_name,
@ -189,7 +188,7 @@ class BECService:
)
)
msg = messages.ServiceMetricMessage(name=self.__class__.__name__, metrics=data)
self.producer.send(MessageEndpoints.metrics(self._service_id), msg)
self.connector.send(MessageEndpoints.metrics(self._service_id), msg)
self._metrics_emitter_event.wait(timeout=1)
def set_global_var(self, name: str, val: Any) -> None:
@ -200,7 +199,7 @@ class BECService:
val (Any): Value of the variable
"""
self.producer.set(MessageEndpoints.global_vars(name), messages.VariableMessage(value=val))
self.connector.set(MessageEndpoints.global_vars(name), messages.VariableMessage(value=val))
def get_global_var(self, name: str) -> Any:
"""Get a global variable from Redis
@ -211,7 +210,7 @@ class BECService:
Returns:
Any: Value of the variable
"""
msg = self.producer.get(MessageEndpoints.global_vars(name))
msg = self.connector.get(MessageEndpoints.global_vars(name))
if msg:
return msg.content.get("value")
return None
@ -223,12 +222,12 @@ class BECService:
name (str): Name of the variable
"""
self.producer.delete(MessageEndpoints.global_vars(name))
self.connector.delete(MessageEndpoints.global_vars(name))
def global_vars(self) -> str:
"""Get all available global variables"""
# sadly, this cannot be a property as it causes side effects with IPython's tab completion
available_keys = self.producer.keys(MessageEndpoints.global_vars("*"))
available_keys = self.connector.keys(MessageEndpoints.global_vars("*"))
def get_endpoint_from_topic(topic: str) -> str:
return topic.decode().split(MessageEndpoints.global_vars(""))[-1]
@ -252,6 +251,7 @@ class BECService:
def shutdown(self):
"""shutdown the BECService"""
self.connector.shutdown()
self._service_info_event.set()
if self._service_info_thread:
self._service_info_thread.join()

View File

@ -16,11 +16,11 @@ def channel_callback(msg, **_kwargs):
print(json.dumps(out, indent=4, default=lambda o: "<not serializable object>"))
def _start_consumer(config_path, topic):
def _start_register(config_path, topic):
config = ServiceConfig(config_path)
connector = RedisConnector(config.redis)
consumer = connector.consumer(topics=topic, cb=channel_callback)
consumer.start()
register = connector.register(topics=topic, cb=channel_callback)
register.start()
event = threading.Event()
event.wait()
@ -38,4 +38,4 @@ def channel_monitor_launch():
config_path = clargs.config
topic = clargs.channel
_start_consumer(config_path, topic)
_start_register(config_path, topic)

View File

@ -91,7 +91,7 @@ class BECClient(BECService, UserScriptsMixin):
@property
def active_account(self) -> str:
"""get the currently active target (e)account"""
return self.producer.get(MessageEndpoints.account())
return self.connector.get(MessageEndpoints.account())
def start(self):
"""start the client"""
@ -133,13 +133,13 @@ class BECClient(BECService, UserScriptsMixin):
@property
def pre_scan_hooks(self):
"""currently stored pre-scan hooks"""
return self.producer.lrange(MessageEndpoints.pre_scan_macros(), 0, -1)
return self.connector.lrange(MessageEndpoints.pre_scan_macros(), 0, -1)
@pre_scan_hooks.setter
def pre_scan_hooks(self, hooks: list):
self.producer.delete(MessageEndpoints.pre_scan_macros())
self.connector.delete(MessageEndpoints.pre_scan_macros())
for hook in hooks:
self.producer.lpush(MessageEndpoints.pre_scan_macros(), hook)
self.connector.lpush(MessageEndpoints.pre_scan_macros(), hook)
def _load_scans(self):
self.scans = Scans(self)

View File

@ -26,7 +26,6 @@ logger = bec_logger.logger
class ConfigHelper:
def __init__(self, connector: RedisConnector, service_name: str = None) -> None:
self.connector = connector
self.producer = connector.producer()
self._service_name = service_name
def update_session_with_file(self, file_path: str, save_recovery: bool = True) -> None:
@ -71,7 +70,7 @@ class ConfigHelper:
print(f"Config was written to {file_path}.")
def _save_config_to_file(self, file_path: str, raise_on_error: bool = True) -> bool:
config = self.producer.get(MessageEndpoints.device_config())
config = self.connector.get(MessageEndpoints.device_config())
if not config:
if raise_on_error:
raise DeviceConfigError("No config found in the session.")
@ -99,7 +98,7 @@ class ConfigHelper:
if action in ["update", "add", "set"] and not config:
raise DeviceConfigError(f"Config cannot be empty for an {action} request.")
RID = str(uuid.uuid4())
self.producer.send(
self.connector.send(
MessageEndpoints.device_config_request(),
DeviceConfigMessage(action=action, config=config, metadata={"RID": RID}),
)
@ -145,7 +144,7 @@ class ConfigHelper:
elapsed_time = 0
max_time = timeout
while True:
service_messages = self.producer.lrange(MessageEndpoints.service_response(RID), 0, -1)
service_messages = self.connector.lrange(MessageEndpoints.service_response(RID), 0, -1)
if not service_messages:
time.sleep(0.005)
elapsed_time += 0.005
@ -185,7 +184,7 @@ class ConfigHelper:
"""
start = 0
while True:
msg = self.producer.get(MessageEndpoints.device_config_request_response(RID))
msg = self.connector.get(MessageEndpoints.device_config_request_response(RID))
if msg is None:
time.sleep(0.01)
start += 0.01

View File

@ -6,7 +6,8 @@ import threading
import traceback
from bec_lib.logger import bec_logger
from bec_lib.messages import BECMessage
from bec_lib.messages import BECMessage, LogMessage
from bec_lib.endpoints import MessageEndpoints
logger = bec_logger.logger
@ -33,154 +34,98 @@ class MessageObject:
return f"MessageObject(topic={self.topic}, value={self._value})"
class ConnectorBase(abc.ABC):
"""
ConnectorBase implements producer and consumer clients for communicating with a broker.
One ought to inherit from this base class and provide at least customized producer and consumer methods.
class StoreInterface(abc.ABC):
"""StoreBase defines the interface for storing data"""
"""
def __init__(self, bootstrap_server: list):
self.bootstrap = bootstrap_server
self._threads = []
def producer(self, **kwargs) -> ProducerConnector:
raise NotImplementedError
def consumer(self, **kwargs) -> ConsumerConnectorThreaded:
raise NotImplementedError
def shutdown(self):
for t in self._threads:
t.signal_event.set()
t.join()
def raise_warning(self, msg):
raise NotImplementedError
def send_log(self, msg):
raise NotImplementedError
def poll_messages(self):
"""Poll for new messages, receive them and execute callbacks"""
def __init__(self, store):
pass
def pipeline(self):
pass
class ProducerConnector(abc.ABC):
def execute_pipeline(self):
pass
def lpush(
self, topic: str, msg: str, pipe=None, max_size: int = None, expire: int = None
) -> None:
raise NotImplementedError
def lset(self, topic: str, index: int, msg: str, pipe=None) -> None:
raise NotImplementedError
def rpush(self, topic: str, msg: str, pipe=None) -> int:
raise NotImplementedError
def lrange(self, topic: str, start: int, end: int, pipe=None):
raise NotImplementedError
def set(self, topic: str, msg, pipe=None, expire: int = None) -> None:
raise NotImplementedError
def keys(self, pattern: str) -> list:
raise NotImplementedError
def delete(self, topic, pipe=None):
raise NotImplementedError
def get(self, topic: str, pipe=None):
raise NotImplementedError
def xadd(self, topic: str, msg: dict, max_size=None, pipe=None, expire: int = None):
raise NotImplementedError
def xread(
self,
topic: str,
id: str = None,
count: int = None,
block: int = None,
pipe=None,
from_start=False,
) -> list:
raise NotImplementedError
def xrange(self, topic: str, min: str, max: str, count: int = None, pipe=None):
raise NotImplementedError
class PubSubInterface(abc.ABC):
def raw_send(self, topic: str, msg: bytes) -> None:
raise NotImplementedError
def send(self, topic: str, msg: BECMessage) -> None:
raise NotImplementedError
class ConsumerConnector(abc.ABC):
def __init__(
self, bootstrap_server, cb, topics=None, pattern=None, group_id=None, event=None, **kwargs
):
"""
ConsumerConnector class defines the communication with the broker for consuming messages.
An implementation ought to inherit from this class and implement the initialize_connector and poll_messages methods.
Args:
bootstrap_server: list of bootstrap servers, e.g. ["localhost:9092", "localhost:9093"]
topics: the topic(s) to which the connector should attach
event: external event to trigger start and stop of the connector
cb: callback function; will be triggered from within poll_messages
kwargs: additional keyword arguments
"""
self.bootstrap = bootstrap_server
self.topics = topics
self.pattern = pattern
self.group_id = group_id
self.connector = None
self.cb = cb
self.kwargs = kwargs
if not self.topics and not self.pattern:
raise ConsumerConnectorError("Either a topic or a patter must be specified.")
def initialize_connector(self) -> None:
"""
initialize the connector instance self.connector
The connector will be initialized once the thread is started
"""
def register(self, topics=None, pattern=None, cb=None, start_thread=True, **kwargs):
raise NotImplementedError
def poll_messages(self) -> None:
"""
Poll messages from self.connector and call the callback function self.cb
def poll_messages(self, timeout=None):
"""Poll for new messages, receive them and execute callbacks"""
raise NotImplementedError
"""
raise NotImplementedError()
class ConsumerConnectorThreaded(ConsumerConnector, threading.Thread):
def __init__(
self,
bootstrap_server,
cb,
topics=None,
pattern=None,
group_id=None,
event=None,
name=None,
**kwargs,
):
"""
ConsumerConnectorThreaded class defines the threaded communication with the broker for consuming messages.
An implementation ought to inherit from this class and implement the initialize_connector and poll_messages methods.
Once started, the connector is expected to poll new messages until the signal_event is set.
Args:
bootstrap_server: list of bootstrap servers, e.g. ["localhost:9092", "localhost:9093"]
topics: the topic(s) to which the connector should attach
event: external event to trigger start and stop of the connector
cb: callback function; will be triggered from within poll_messages
kwargs: additional keyword arguments
"""
super().__init__(
bootstrap_server=bootstrap_server,
topics=topics,
pattern=pattern,
group_id=group_id,
event=event,
cb=cb,
**kwargs,
)
if name is not None:
thread_kwargs = {"name": name, "daemon": True}
else:
thread_kwargs = {"daemon": True}
super(ConsumerConnector, self).__init__(**thread_kwargs)
self.signal_event = event if event is not None else threading.Event()
def run(self):
self.initialize_connector()
while True:
try:
self.poll_messages()
except Exception as e:
logger.error(traceback.format_exc())
_thread.interrupt_main()
raise e
finally:
if self.signal_event.is_set():
self.shutdown()
break
def run_messages_loop(self):
raise NotImplementedError
def shutdown(self):
self.signal_event.set()
raise NotImplementedError
# def stop(self) -> None:
# """
# Stop consumer
# Returns:
# """
# self.signal_event.set()
# self.connector.close()
# self.join()
class ConnectorBase(PubSubInterface, StoreInterface):
def raise_warning(self, msg):
raise NotImplementedError
def log_warning(self, msg):
"""send a warning"""
self.send(MessageEndpoints.log(), LogMessage(log_type="warning", log_msg=msg))
def log_message(self, msg):
"""send a log message"""
self.send(MessageEndpoints.log(), LogMessage(log_type="log", log_msg=msg))
def log_error(self, msg):
"""send an error as log"""
self.send(MessageEndpoints.log(), LogMessage(log_type="error", log_msg=msg))
def set_and_publish(self, topic: str, msg, pipe=None, expire: int = None) -> None:
raise NotImplementedError

View File

@ -78,7 +78,7 @@ class DAPPluginObjectBase:
converted_kwargs[key] = val
kwargs = converted_kwargs
request_id = str(uuid.uuid4())
self._client.producer.set_and_publish(
self._client.connector.set_and_publish(
MessageEndpoints.dap_request(),
messages.DAPRequestMessage(
dap_cls=self._plugin_info["class"],
@ -110,7 +110,7 @@ class DAPPluginObjectBase:
while True:
if time.time() - start_time > timeout:
raise TimeoutError("Timeout waiting for DAP response.")
response = self._client.producer.get(MessageEndpoints.dap_response(request_id))
response = self._client.connector.get(MessageEndpoints.dap_response(request_id))
if not response:
time.sleep(0.005)
continue
@ -128,7 +128,7 @@ class DAPPluginObjectBase:
return
self._plugin_config["class_args"] = self._plugin_info.get("class_args")
self._plugin_config["class_kwargs"] = self._plugin_info.get("class_kwargs")
self._client.producer.set_and_publish(
self._client.connector.set_and_publish(
MessageEndpoints.dap_request(),
messages.DAPRequestMessage(
dap_cls=self._plugin_info["class"],
@ -149,7 +149,7 @@ class DAPPluginObject(DAPPluginObjectBase):
"""
Get the data from last run.
"""
msg = self._client.producer.get_last(MessageEndpoints.processed_data(self._service_name))
msg = self._client.connector.get_last(MessageEndpoints.processed_data(self._service_name))
if not msg:
return None
return self._convert_result(msg)

View File

@ -38,7 +38,7 @@ class DAPPlugins:
service for service in available_services if service.startswith("DAPServer/")
]
for service in dap_services:
available_plugins = self._parent.producer.get(
available_plugins = self._parent.connector.get(
MessageEndpoints.dap_available_plugins(service)
)
if available_plugins is None:

View File

@ -10,7 +10,7 @@ from typeguard import typechecked
from bec_lib import messages
from bec_lib.endpoints import MessageEndpoints
from bec_lib.logger import bec_logger
from bec_lib.redis_connector import RedisProducer
from bec_lib.redis_connector import RedisConnector
logger = bec_logger.logger
@ -61,15 +61,15 @@ class ReadoutPriority(str, enum.Enum):
class Status:
def __init__(self, producer: RedisProducer, RID: str) -> None:
def __init__(self, connector: RedisConnector, RID: str) -> None:
"""
Status object for RPC calls
Args:
producer (RedisProducer): Redis producer
connector (RedisConnector): Redis connector
RID (str): Request ID
"""
self._producer = producer
self._connector = connector
self._RID = RID
def __eq__(self, __value: object) -> bool:
@ -91,7 +91,7 @@ class Status:
raise TimeoutError()
while True:
request_status = self._producer.lrange(
request_status = self._connector.lrange(
MessageEndpoints.device_req_status(self._RID), 0, -1
)
if request_status:
@ -251,7 +251,7 @@ class DeviceBase:
if not isinstance(return_val, dict):
return return_val
if return_val.get("type") == "status" and return_val.get("RID"):
return Status(self.root.parent.producer, return_val.get("RID"))
return Status(self.root.parent.connector, return_val.get("RID"))
return return_val
def _get_rpc_response(self, request_id, rpc_id) -> Any:
@ -267,7 +267,7 @@ class DeviceBase:
f" {scan_queue_request.response.content['message']}"
)
while True:
msg = self.root.parent.producer.get(MessageEndpoints.device_rpc(rpc_id))
msg = self.root.parent.connector.get(MessageEndpoints.device_rpc(rpc_id))
if msg:
break
time.sleep(0.01)
@ -296,7 +296,7 @@ class DeviceBase:
msg = self._prepare_rpc_msg(rpc_id, request_id, device, func_call, *args, **kwargs)
# send RPC message
self.root.parent.producer.send(MessageEndpoints.scan_queue_request(), msg)
self.root.parent.connector.send(MessageEndpoints.scan_queue_request(), msg)
# wait for RPC response
if not wait_for_rpc_response:
@ -496,7 +496,7 @@ class DeviceBase:
# def read(self, cached, filter_readback=True):
# """get the last reading from a device"""
# val = self.parent.producer.get(MessageEndpoints.device_read(self.name))
# val = self.parent.connector.get(MessageEndpoints.device_read(self.name))
# if not val:
# return None
# if filter_readback:
@ -505,7 +505,7 @@ class DeviceBase:
#
# def readback(self, filter_readback=True):
# """get the last readback value from a device"""
# val = self.parent.producer.get(MessageEndpoints.device_readback(self.name))
# val = self.parent.connector.get(MessageEndpoints.device_readback(self.name))
# if not val:
# return None
# if filter_readback:
@ -515,7 +515,7 @@ class DeviceBase:
# @property
# def device_status(self):
# """get the current status of the device"""
# val = self.parent.producer.get(MessageEndpoints.device_status(self.name))
# val = self.parent.connector.get(MessageEndpoints.device_status(self.name))
# if val is None:
# return val
# val = DeviceStatusMessage.loads(val)
@ -524,7 +524,7 @@ class DeviceBase:
# @property
# def signals(self):
# """get the last signals from a device"""
# val = self.parent.producer.get(MessageEndpoints.device_read(self.name))
# val = self.parent.connector.get(MessageEndpoints.device_read(self.name))
# if val is None:
# return None
# self._signals = DeviceMessage.loads(val).content["signals"]
@ -593,11 +593,11 @@ class OphydInterfaceBase(DeviceBase):
if is_config_signal:
return self.read_configuration(cached=cached)
if use_readback:
val = self.root.parent.producer.get(
val = self.root.parent.connector.get(
MessageEndpoints.device_readback(self.root.name)
)
else:
val = self.root.parent.producer.get(MessageEndpoints.device_read(self.root.name))
val = self.root.parent.connector.get(MessageEndpoints.device_read(self.root.name))
if not val:
return None
@ -623,7 +623,7 @@ class OphydInterfaceBase(DeviceBase):
if is_signal and not is_config_signal:
return self.read(cached=True)
val = self.root.parent.producer.get(
val = self.root.parent.connector.get(
MessageEndpoints.device_read_configuration(self.root.name)
)
if not val:
@ -766,7 +766,7 @@ class AdjustableMixin:
"""
Returns the device limits.
"""
limit_msg = self.root.parent.producer.get(MessageEndpoints.device_limits(self.root.name))
limit_msg = self.root.parent.connector.get(MessageEndpoints.device_limits(self.root.name))
if not limit_msg:
return [0, 0]
limits = [

View File

@ -370,8 +370,7 @@ class DeviceManagerBase:
_request_config_parsed = None # parsed config request
_response = None # response message
_connector_base_consumer = {}
producer = None
_connector_base_register = {}
config_helper = None
_device_cls = DeviceBase
_status_cb = []
@ -464,7 +463,7 @@ class DeviceManagerBase:
"""
if not msg.metadata.get("RID"):
return
self.producer.lpush(
self.connector.lpush(
MessageEndpoints.service_response(msg.metadata["RID"]),
messages.ServiceResponseMessage(
# pylint: disable=no-member
@ -487,25 +486,20 @@ class DeviceManagerBase:
self._remove_device(dev)
def _start_connectors(self, bootstrap_server) -> None:
self._start_base_consumer()
self.producer = self.connector.producer()
self._start_custom_connectors(bootstrap_server)
self._start_base_register()
def _start_base_consumer(self) -> None:
def _start_base_register(self) -> None:
"""
Start consuming messages for all base topics. This method will be called upon startup.
Returns:
"""
self._connector_base_consumer["device_config_update"] = self.connector.consumer(
self.connector.register(
MessageEndpoints.device_config_update(),
cb=self._device_config_update_callback,
parent=self,
)
# self._connector_base_consumer["log"].start()
self._connector_base_consumer["device_config_update"].start()
@staticmethod
def _log_callback(msg, *, parent, **kwargs) -> None:
"""
@ -541,48 +535,11 @@ class DeviceManagerBase:
self._load_session()
def _get_redis_device_config(self) -> list:
devices = self.producer.get(MessageEndpoints.device_config())
devices = self.connector.get(MessageEndpoints.device_config())
if not devices:
return []
return devices.content["resource"]
def _stop_base_consumer(self):
"""
Stop all base consumers by setting the corresponding event
Returns:
"""
if self.connector is not None:
for _, con in self._connector_base_consumer.items():
con.signal_event.set()
con.join()
def _stop_consumer(self):
"""
Stop all consumers
Returns:
"""
self._stop_base_consumer()
self._stop_custom_consumer()
def _start_custom_connectors(self, bootstrap_server) -> None:
"""
Override this method in a derived class to start custom connectors upon initialization.
Args:
bootstrap_server: Kafka bootstrap server
Returns:
"""
def _stop_custom_consumer(self) -> None:
"""
Stop all custom consumers. Override this method in a derived class.
Returns:
"""
def _add_device(self, dev: dict, msg: messages.DeviceInfoMessage):
name = msg.content["device"]
info = msg.content["info"]
@ -621,8 +578,7 @@ class DeviceManagerBase:
logger.error(f"Failed to load device {dev}: {content}")
def _get_device_info(self, device_name) -> DeviceInfoMessage:
msg = self.producer.get(MessageEndpoints.device_info(device_name))
return msg
return self.connector.get(MessageEndpoints.device_info(device_name))
def check_request_validity(self, msg: DeviceConfigMessage) -> None:
"""
@ -663,10 +619,7 @@ class DeviceManagerBase:
"""
Shutdown all connectors.
"""
try:
self.connector.shutdown()
except RuntimeError as runtime_error:
logger.error(f"Failed to shutdown connector. {runtime_error}")
self.connector.shutdown()
def __del__(self):
self.shutdown()

View File

@ -24,7 +24,6 @@ if TYPE_CHECKING:
class LogbookConnector:
def __init__(self, connector: RedisConnector) -> None:
self.connector = connector
self.producer = connector.producer()
self.connected = False
self._scilog_module = None
self._connect()
@ -34,12 +33,12 @@ class LogbookConnector:
if "scilog" not in sys.modules:
return
msg = self.producer.get(MessageEndpoints.logbook())
msg = self.connector.get(MessageEndpoints.logbook())
if not msg:
return
msg = msgpack.loads(msg)
account = self.producer.get(MessageEndpoints.account())
account = self.connector.get(MessageEndpoints.account())
if not account:
return
account = account.decode()
@ -54,7 +53,7 @@ class LogbookConnector:
try:
logbooks = self.log.get_logbooks(readACL={"inq": [account]})
except HTTPError:
self.producer.set(MessageEndpoints.logbook(), b"")
self.connector.set(MessageEndpoints.logbook(), b"")
return
if len(logbooks) > 1:
logger.warning("Found two logbooks. Taking the first one.")

View File

@ -45,7 +45,6 @@ class BECLogger:
self.bootstrap_server = None
self.connector = None
self.service_name = None
self.producer = None
self.logger = loguru_logger
self._log_level = LogLevel.INFO
self.level = self._log_level
@ -73,7 +72,6 @@ class BECLogger:
self.bootstrap_server = bootstrap_server
self.connector = connector_cls(bootstrap_server)
self.service_name = service_name
self.producer = self.connector.producer()
self._configured = True
self._update_sinks()
@ -82,7 +80,7 @@ class BECLogger:
return
msg = json.loads(msg)
msg["service_name"] = self.service_name
self.producer.send(
self.connector.send(
topic=MessageEndpoints.log(),
msg=bec_lib.messages.LogMessage(log_type=msg["record"]["level"]["name"], log_msg=msg),
)

View File

@ -152,7 +152,7 @@ class ObserverManager:
def _get_installed_observer(self):
# get current observer list from Redis
observer_msg = self.device_manager.producer.get(MessageEndpoints.observer())
observer_msg = self.device_manager.connector.get(MessageEndpoints.observer())
if observer_msg is None:
return []
return [Observer.from_dict(obs) for obs in observer_msg.content["observer"]]

View File

@ -122,12 +122,16 @@ class QueueStorage:
if history < 0:
history *= -1
return self.scan_manager.producer.lrange(MessageEndpoints.scan_queue_history(), 0, history)
return self.scan_manager.connector.lrange(
MessageEndpoints.scan_queue_history(),
0,
history,
)
@property
def current_scan_queue(self) -> dict:
"""get the current scan queue from redis"""
msg = self.scan_manager.producer.get(MessageEndpoints.scan_queue_status())
msg = self.scan_manager.connector.get(MessageEndpoints.scan_queue_status())
if msg:
self._current_scan_queue = msg.content["queue"]
return self._current_scan_queue

View File

@ -1,19 +1,18 @@
from __future__ import annotations
import time
import collections
import queue
import sys
import threading
import warnings
from functools import wraps
from typing import TYPE_CHECKING
import louie
import redis
import redis.client
from bec_lib.connector import (
ConnectorBase,
ConsumerConnector,
ConsumerConnectorThreaded,
MessageObject,
ProducerConnector,
)
from bec_lib.connector import ConnectorBase, MessageObject
from bec_lib.endpoints import MessageEndpoints
from bec_lib.messages import AlarmMessage, BECMessage, LogMessage
from bec_lib.serialization import MsgpackSerialization
@ -31,7 +30,6 @@ def catch_connection_error(func):
return func(*args, **kwargs)
except redis.exceptions.ConnectionError:
warnings.warn("Failed to connect to redis. Is the server running?")
time.sleep(0.1)
return None
return wrapper
@ -40,149 +38,80 @@ def catch_connection_error(func):
class RedisConnector(ConnectorBase):
def __init__(self, bootstrap: list, redis_cls=None):
super().__init__(bootstrap)
self.redis_cls = redis_cls
self.host, self.port = (
bootstrap[0].split(":") if isinstance(bootstrap, list) else bootstrap.split(":")
)
self._notifications_producer = RedisProducer(
host=self.host, port=self.port, redis_cls=self.redis_cls
)
def producer(self, **kwargs):
return RedisProducer(host=self.host, port=self.port, redis_cls=self.redis_cls)
if redis_cls:
self._redis_conn = redis_cls(host=self.host, port=self.port)
else:
self._redis_conn = redis.Redis(host=self.host, port=self.port)
# pylint: disable=too-many-arguments
def consumer(
self,
topics=None,
pattern=None,
group_id=None,
event=None,
cb=None,
threaded=True,
name=None,
**kwargs,
):
if cb is None:
raise ValueError("The callback function must be specified.")
# main pubsub connection
self._pubsub_conn = self._redis_conn.pubsub()
self._pubsub_conn.ignore_subscribe_messages = True
# keep track of topics and callbacks
self._topics_cb = collections.defaultdict(list)
if threaded:
if topics is None and pattern is None:
raise ValueError("Topics must be set for threaded consumer")
listener = RedisConsumerThreaded(
self.host,
self.port,
topics,
pattern,
group_id,
event,
cb,
redis_cls=self.redis_cls,
name=name,
**kwargs,
)
self._threads.append(listener)
return listener
return RedisConsumer(
self.host,
self.port,
topics,
pattern,
group_id,
event,
cb,
redis_cls=self.redis_cls,
**kwargs,
)
self._events_listener_thread = None
self._events_dispatcher_thread = None
self._messages_queue = queue.Queue()
self._stop_events_listener_thread = threading.Event()
def stream_consumer(
self,
topics=None,
pattern=None,
group_id=None,
event=None,
cb=None,
from_start=False,
newest_only=False,
**kwargs,
):
"""
Threaded stream consumer for redis streams.
self.stream_keys = {}
Args:
topics (str, list): topics to subscribe to
pattern (str, list): pattern to subscribe to
group_id (str): group id
event (threading.Event): event to stop the consumer
cb (function): callback function
from_start (bool): read from start. Defaults to False.
newest_only (bool): read only the newest message. Defaults to False.
"""
if cb is None:
raise ValueError("The callback function must be specified.")
if pattern:
raise ValueError("Pattern is currently not supported for stream consumer.")
if topics is None and pattern is None:
raise ValueError("Topics must be set for stream consumer.")
listener = RedisStreamConsumerThreaded(
self.host,
self.port,
topics,
pattern,
group_id,
event,
cb,
redis_cls=self.redis_cls,
from_start=from_start,
newest_only=newest_only,
**kwargs,
)
self._threads.append(listener)
return listener
def shutdown(self):
if self._events_listener_thread:
self._stop_events_listener_thread.set()
self._events_listener_thread.join()
self._events_listener_thread = None
if self._events_dispatcher_thread:
self._messages_queue.put(StopIteration)
self._events_dispatcher_thread.join()
self._events_dispatcher_thread = None
# release all connections
self._pubsub_conn.close()
self._redis_conn.close()
@catch_connection_error
def log_warning(self, msg):
"""send a warning"""
self._notifications_producer.send(
MessageEndpoints.log(), LogMessage(log_type="warning", log_msg=msg)
)
self.send(MessageEndpoints.log(), LogMessage(log_type="warning", log_msg=msg))
@catch_connection_error
def log_message(self, msg):
"""send a log message"""
self._notifications_producer.send(
MessageEndpoints.log(), LogMessage(log_type="log", log_msg=msg)
)
self.send(MessageEndpoints.log(), LogMessage(log_type="log", log_msg=msg))
@catch_connection_error
def log_error(self, msg):
"""send an error as log"""
self._notifications_producer.send(
MessageEndpoints.log(), LogMessage(log_type="error", log_msg=msg)
)
self.send(MessageEndpoints.log(), LogMessage(log_type="error", log_msg=msg))
@catch_connection_error
def raise_alarm(self, severity: Alarms, alarm_type: str, source: str, msg: str, metadata: dict):
def raise_alarm(
self,
severity: Alarms,
alarm_type: str,
source: str,
msg: str,
metadata: dict,
):
"""raise an alarm"""
self._notifications_producer.set_and_publish(
MessageEndpoints.alarm(),
AlarmMessage(
severity=severity, alarm_type=alarm_type, source=source, msg=msg, metadata=metadata
),
alarm_msg = AlarmMessage(
severity=severity,
alarm_type=alarm_type,
source=source,
msg=msg,
metadata=metadata,
)
self.set_and_publish(MessageEndpoints.alarm(), alarm_msg)
def pipeline(self):
"""Create a new pipeline"""
return self._redis_conn.pipeline()
class RedisProducer(ProducerConnector):
def __init__(self, host: str, port: int, redis_cls=None) -> None:
# pylint: disable=invalid-name
if redis_cls:
self.r = redis_cls(host=host, port=port)
return
self.r = redis.Redis(host=host, port=port)
self.stream_keys = {}
@catch_connection_error
def execute_pipeline(self, pipeline):
"""Execute the pipeline and returns the results with decoded BECMessages"""
ret = []
@ -197,7 +126,7 @@ class RedisProducer(ProducerConnector):
@catch_connection_error
def raw_send(self, topic: str, msg: bytes, pipe=None):
"""send to redis without any check on message type"""
client = pipe if pipe is not None else self.r
client = pipe if pipe is not None else self._redis_conn
client.publish(topic, msg)
def send(self, topic: str, msg: BECMessage, pipe=None) -> None:
@ -206,6 +135,95 @@ class RedisProducer(ProducerConnector):
raise TypeError(f"Message {msg} is not a BECMessage")
self.raw_send(topic, MsgpackSerialization.dumps(msg), pipe)
def register(self, topics=None, patterns=None, cb=None, start_thread=True, **kwargs):
if self._events_listener_thread is None:
# create the thread that will get all messages for this connector;
# under the hood, it uses asyncio - this lets the possibility to stop
# the loop on demand
self._events_listener_thread = threading.Thread(
target=self._get_messages_loop,
args=(self._pubsub_conn,),
)
self._events_listener_thread.start()
# make a weakref from the callable, using louie;
# it can create safe refs for simple functions as well as methods
cb_ref = louie.saferef.safe_ref(cb)
if patterns is not None:
if isinstance(patterns, str):
patterns = [patterns]
self._pubsub_conn.psubscribe(patterns)
for pattern in patterns:
self._topics_cb[pattern].append((cb_ref, kwargs))
else:
if isinstance(topics, str):
topics = [topics]
self._pubsub_conn.subscribe(topics)
for topic in topics:
self._topics_cb[topic].append((cb_ref, kwargs))
if start_thread and self._events_dispatcher_thread is None:
# start dispatcher thread
self._events_dispatcher_thread = threading.Thread(target=self.dispatch_events)
self._events_dispatcher_thread.start()
def _get_messages_loop(self, pubsub) -> None:
"""
Start a listening coroutine to deal with redis events and wait for completion
"""
while not self._stop_events_listener_thread.is_set():
try:
msg = pubsub.get_message(timeout=1)
except Exception:
sys.excepthook(*sys.exc_info())
else:
if msg is not None:
self._messages_queue.put(msg)
def _handle_message(self, msg):
if msg["type"].endswith("subscribe"):
# ignore subscribe messages
return False
channel = msg["channel"].decode()
if msg["pattern"] is not None:
callbacks = self._topics_cb[msg["pattern"].decode()]
else:
callbacks = self._topics_cb[channel]
msg = MessageObject(
topic=channel,
value=MsgpackSerialization.loads(msg["data"]),
)
for cb_ref, kwargs in callbacks:
cb = cb_ref()
if cb:
try:
cb(msg, **kwargs)
except Exception:
sys.excepthook(*sys.exc_info())
return True
def poll_messages(self, timeout=None) -> None:
while True:
try:
msg = self._messages_queue.get(timeout=timeout)
except queue.Empty:
raise TimeoutError(
f"{self}: poll_messages: did not receive a message within {timeout} seconds"
)
else:
if msg is StopIteration:
return False
if self._handle_message(msg):
return True
else:
continue
def dispatch_events(self):
while self.poll_messages():
...
@catch_connection_error
def lpush(
self, topic: str, msg: str, pipe=None, max_size: int = None, expire: int = None
@ -229,7 +247,7 @@ class RedisProducer(ProducerConnector):
@catch_connection_error
def lset(self, topic: str, index: int, msg: str, pipe=None) -> None:
client = pipe if pipe is not None else self.r
client = pipe if pipe is not None else self._redis_conn
if isinstance(msg, BECMessage):
msg = MsgpackSerialization.dumps(msg)
return client.lset(topic, index, msg)
@ -241,7 +259,7 @@ class RedisProducer(ProducerConnector):
values at the tail of the list stored at key. If key does not exist,
it is created as empty list before performing the push operation. When
key holds a value that is not a list, an error is returned."""
client = pipe if pipe is not None else self.r
client = pipe if pipe is not None else self._redis_conn
if isinstance(msg, BECMessage):
msg = MsgpackSerialization.dumps(msg)
return client.rpush(topic, msg)
@ -254,7 +272,7 @@ class RedisProducer(ProducerConnector):
of the list stored at key. The offsets start and stop are zero-based indexes,
with 0 being the first element of the list (the head of the list), 1 being
the next element and so on."""
client = pipe if pipe is not None else self.r
client = pipe if pipe is not None else self._redis_conn
cmd_result = client.lrange(topic, start, end)
if pipe:
return cmd_result
@ -268,23 +286,21 @@ class RedisProducer(ProducerConnector):
ret.append(msg)
return ret
@catch_connection_error
def set_and_publish(self, topic: str, msg, pipe=None, expire: int = None) -> None:
"""piped combination of self.publish and self.set"""
client = pipe if pipe is not None else self.pipeline()
if isinstance(msg, BECMessage):
msg = MsgpackSerialization.dumps(msg)
client.publish(topic, msg)
client.set(topic, msg)
if expire:
client.expire(topic, expire)
if not isinstance(msg, BECMessage):
raise TypeError(f"Message {msg} is not a BECMessage")
msg = MsgpackSerialization.dumps(msg)
self.set(topic, msg, pipe=client, expire=expire)
self.raw_send(topic, msg, pipe=client)
if not pipe:
client.execute()
@catch_connection_error
def set(self, topic: str, msg, pipe=None, expire: int = None) -> None:
"""set redis value"""
client = pipe if pipe is not None else self.r
client = pipe if pipe is not None else self._redis_conn
if isinstance(msg, BECMessage):
msg = MsgpackSerialization.dumps(msg)
client.set(topic, msg, ex=expire)
@ -292,23 +308,18 @@ class RedisProducer(ProducerConnector):
@catch_connection_error
def keys(self, pattern: str) -> list:
"""returns all keys matching a pattern"""
return self.r.keys(pattern)
@catch_connection_error
def pipeline(self):
"""create a new pipeline"""
return self.r.pipeline()
return self._redis_conn.keys(pattern)
@catch_connection_error
def delete(self, topic, pipe=None):
"""delete topic"""
client = pipe if pipe is not None else self.r
client = pipe if pipe is not None else self._redis_conn
client.delete(topic)
@catch_connection_error
def get(self, topic: str, pipe=None):
"""retrieve entry, either via hgetall or get"""
client = pipe if pipe is not None else self.r
client = pipe if pipe is not None else self._redis_conn
data = client.get(topic)
if pipe:
return data
@ -339,7 +350,7 @@ class RedisProducer(ProducerConnector):
elif expire:
client = self.pipeline()
else:
client = self.r
client = self._redis_conn
for key, msg in msg_dict.items():
msg_dict[key] = MsgpackSerialization.dumps(msg)
@ -356,7 +367,7 @@ class RedisProducer(ProducerConnector):
@catch_connection_error
def get_last(self, topic: str, key="data"):
"""retrieve last entry from stream"""
client = self.r
client = self._redis_conn
try:
_, msg_dict = client.xrevrange(topic, "+", "-", count=1)[0]
except TypeError:
@ -370,7 +381,12 @@ class RedisProducer(ProducerConnector):
@catch_connection_error
def xread(
self, topic: str, id: str = None, count: int = None, block: int = None, from_start=False
self,
topic: str,
id: str = None,
count: int = None,
block: int = None,
from_start=False,
) -> list:
"""
read from stream
@ -395,13 +411,13 @@ class RedisProducer(ProducerConnector):
>>> key = msg[0][1][0][0]
>>> next_msg = redis.xread("test", key, count=1)
"""
client = self.r
client = self._redis_conn
if from_start:
self.stream_keys[topic] = "0-0"
if topic not in self.stream_keys:
if id is None:
try:
msg = self.r.xrevrange(topic, "+", "-", count=1)
msg = client.xrevrange(topic, "+", "-", count=1)
if msg:
self.stream_keys[topic] = msg[0][0]
out = {}
@ -438,7 +454,7 @@ class RedisProducer(ProducerConnector):
max (str): max id. Use "+" to read to end
count (int, optional): number of messages to read. Defaults to None.
"""
client = self.r
client = self._redis_conn
msgs = []
for reading in client.xrange(topic, min, max, count=count):
index, msg_dict = reading
@ -446,270 +462,3 @@ class RedisProducer(ProducerConnector):
{k.decode(): MsgpackSerialization.loads(msg) for k, msg in msg_dict.items()}
)
return msgs
class RedisConsumerMixin:
def _init_topics_and_pattern(self, topics, pattern):
if topics:
if not isinstance(topics, list):
topics = [topics]
if pattern:
if not isinstance(pattern, list):
pattern = [pattern]
return topics, pattern
def _init_redis_cls(self, redis_cls):
# pylint: disable=invalid-name
if redis_cls:
self.r = redis_cls(host=self.host, port=self.port)
else:
self.r = redis.Redis(host=self.host, port=self.port)
@catch_connection_error
def initialize_connector(self) -> None:
if self.pattern is not None:
self.pubsub.psubscribe(self.pattern)
else:
self.pubsub.subscribe(self.topics)
class RedisConsumer(RedisConsumerMixin, ConsumerConnector):
# pylint: disable=too-many-arguments
def __init__(
self,
host,
port,
topics=None,
pattern=None,
group_id=None,
event=None,
cb=None,
redis_cls=None,
**kwargs,
):
self.host = host
self.port = port
bootstrap_server = "".join([host, ":", port])
topics, pattern = self._init_topics_and_pattern(topics, pattern)
super().__init__(
bootstrap_server=bootstrap_server,
topics=topics,
pattern=pattern,
group_id=group_id,
event=event,
cb=cb,
**kwargs,
)
self.error_message_sent = False
self._init_redis_cls(redis_cls)
self.pubsub = self.r.pubsub()
self.initialize_connector()
@catch_connection_error
def poll_messages(self) -> None:
"""
Poll messages from self.connector and call the callback function self.cb
"""
message = self.pubsub.get_message(ignore_subscribe_messages=True)
if message is not None:
msg = MessageObject(
topic=message["channel"], value=MsgpackSerialization.loads(message["data"])
)
return self.cb(msg, **self.kwargs)
time.sleep(0.01)
return None
def shutdown(self):
"""shutdown the consumer"""
self.pubsub.close()
class RedisStreamConsumerThreaded(RedisConsumerMixin, ConsumerConnectorThreaded):
# pylint: disable=too-many-arguments
def __init__(
self,
host,
port,
topics=None,
pattern=None,
group_id=None,
event=None,
cb=None,
redis_cls=None,
from_start=False,
newest_only=False,
**kwargs,
):
self.host = host
self.port = port
self.from_start = from_start
self.newest_only = newest_only
bootstrap_server = "".join([host, ":", port])
topics, pattern = self._init_topics_and_pattern(topics, pattern)
super().__init__(
bootstrap_server=bootstrap_server,
topics=topics,
pattern=pattern,
group_id=group_id,
event=event,
cb=cb,
**kwargs,
)
self._init_redis_cls(redis_cls)
self.sleep_times = [0.005, 0.1]
self.last_received_msg = 0
self.idle_time = 30
self.error_message_sent = False
self.stream_keys = {}
def initialize_connector(self) -> None:
pass
def _init_topics_and_pattern(self, topics, pattern):
if topics:
if not isinstance(topics, list):
topics = [topics]
if pattern:
if not isinstance(pattern, list):
pattern = [pattern]
return topics, pattern
def get_id(self, topic: str) -> str:
"""
Get the stream key for the given topic.
Args:
topic (str): topic to get the stream key for
"""
if topic not in self.stream_keys:
return "0-0"
return self.stream_keys.get(topic)
def get_newest_message(self, container: list, append=True) -> None:
"""
Get the newest message from the stream and update the stream key. If
append is True, append the message to the container.
Args:
container (list): container to append the message to
append (bool, optional): append to container. Defaults to True.
"""
for topic in self.topics:
msg = self.r.xrevrange(topic, "+", "-", count=1)
if msg:
if append:
container.append((topic, msg[0][1]))
self.stream_keys[topic] = msg[0][0]
else:
self.stream_keys[topic] = "0-0"
@catch_connection_error
def poll_messages(self) -> None:
"""
Poll messages from self.connector and call the callback function self.cb
"""
if self.pattern is not None:
topics = [key.decode() for key in self.r.scan_iter(match=self.pattern, _type="stream")]
else:
topics = self.topics
messages = []
if self.newest_only:
self.get_newest_message(messages)
elif not self.from_start and not self.stream_keys:
self.get_newest_message(messages, append=False)
else:
streams = {topic: self.get_id(topic) for topic in topics}
read_msgs = self.r.xread(streams, count=1)
if read_msgs:
for msg in read_msgs:
topic = msg[0].decode()
messages.append((topic, msg[1][0][1]))
self.stream_keys[topic] = msg[1][-1][0]
if messages:
if MessageEndpoints.log() not in topics:
# no need to update the update frequency just for logs
self.last_received_msg = time.time()
for topic, msg in messages:
try:
msg = MsgpackSerialization.loads(msg[b"data"])
except RuntimeError:
msg = msg[b"data"]
msg_obj = MessageObject(topic=topic, value=msg)
self.cb(msg_obj, **self.kwargs)
else:
sleep_time = int(bool(time.time() - self.last_received_msg > self.idle_time))
if self.sleep_times[sleep_time]:
time.sleep(self.sleep_times[sleep_time])
class RedisConsumerThreaded(RedisConsumerMixin, ConsumerConnectorThreaded):
# pylint: disable=too-many-arguments
def __init__(
self,
host,
port,
topics=None,
pattern=None,
group_id=None,
event=None,
cb=None,
redis_cls=None,
name=None,
**kwargs,
):
self.host = host
self.port = port
bootstrap_server = "".join([host, ":", port])
topics, pattern = self._init_topics_and_pattern(topics, pattern)
super().__init__(
bootstrap_server=bootstrap_server,
topics=topics,
pattern=pattern,
group_id=group_id,
event=event,
cb=cb,
name=name,
**kwargs,
)
self._init_redis_cls(redis_cls)
self.pubsub = self.r.pubsub()
self.sleep_times = [0.005, 0.1]
self.last_received_msg = 0
self.idle_time = 30
self.error_message_sent = False
@catch_connection_error
def poll_messages(self) -> None:
"""
Poll messages from self.connector and call the callback function self.cb
Note: pubsub messages are supposed to be BECMessage objects only
"""
messages = self.pubsub.get_message(ignore_subscribe_messages=True)
if messages is not None:
if f"{MessageEndpoints.log()}".encode() not in messages["channel"]:
# no need to update the update frequency just for logs
self.last_received_msg = time.time()
msg = MessageObject(
topic=messages["channel"].decode(),
value=MsgpackSerialization.loads(messages["data"]),
)
self.cb(msg, **self.kwargs)
else:
sleep_time = int(bool(time.time() - self.last_received_msg > self.idle_time))
if self.sleep_times[sleep_time]:
time.sleep(self.sleep_times[sleep_time])
def shutdown(self):
super().shutdown()
self.pubsub.close()

View File

@ -46,7 +46,7 @@ class ScanItem:
self.data = ScanData()
self.async_data = {}
self.baseline = ScanData()
self._async_data_handler = AsyncDataHandler(scan_manager.producer)
self._async_data_handler = AsyncDataHandler(scan_manager.connector)
self.open_scan_defs = set()
self.open_queue_group = None
self.num_points = None

View File

@ -25,44 +25,31 @@ class ScanManager:
connector (BECConnector): BECConnector instance
"""
self.connector = connector
self.producer = self.connector.producer()
self.queue_storage = QueueStorage(scan_manager=self)
self.request_storage = RequestStorage(scan_manager=self)
self.scan_storage = ScanStorage(scan_manager=self)
self._scan_queue_consumer = self.connector.consumer(
self.connector.register(
topics=MessageEndpoints.scan_queue_status(),
cb=self._scan_queue_status_callback,
parent=self,
)
self._scan_queue_request_consumer = self.connector.consumer(
self.connector.register(
topics=MessageEndpoints.scan_queue_request(),
cb=self._scan_queue_request_callback,
parent=self,
)
self._scan_queue_request_response_consumer = self.connector.consumer(
self.connector.register(
topics=MessageEndpoints.scan_queue_request_response(),
cb=self._scan_queue_request_response_callback,
parent=self,
)
self._scan_status_consumer = self.connector.consumer(
topics=MessageEndpoints.scan_status(), cb=self._scan_status_callback, parent=self
self.connector.register(
topics=MessageEndpoints.scan_status(), cb=self._scan_status_callback
)
self._scan_segment_consumer = self.connector.consumer(
topics=MessageEndpoints.scan_segment(), cb=self._scan_segment_callback, parent=self
self.connector.register(
topics=MessageEndpoints.scan_segment(), cb=self._scan_segment_callback
)
self._baseline_consumer = self.connector.consumer(
topics=MessageEndpoints.scan_baseline(), cb=self._baseline_callback, parent=self
)
self._scan_queue_consumer.start()
self._scan_queue_request_consumer.start()
self._scan_queue_request_response_consumer.start()
self._scan_status_consumer.start()
self._scan_segment_consumer.start()
self._baseline_consumer.start()
self.connector.register(topics=MessageEndpoints.scan_baseline(), cb=self._baseline_callback)
def update_with_queue_status(self, queue: messages.ScanQueueStatusMessage) -> None:
"""update storage with a new queue status message"""
@ -84,7 +71,7 @@ class ScanManager:
action = "deferred_pause" if deferred_pause else "pause"
logger.info(f"Requesting {action}")
return self.producer.send(
return self.connector.send(
MessageEndpoints.scan_queue_modification_request(),
messages.ScanQueueModificationMessage(scanID=scanID, action=action, parameter={}),
)
@ -99,7 +86,7 @@ class ScanManager:
if scanID is None:
scanID = self.scan_storage.current_scanID
logger.info("Requesting scan abortion")
self.producer.send(
self.connector.send(
MessageEndpoints.scan_queue_modification_request(),
messages.ScanQueueModificationMessage(scanID=scanID, action="abort", parameter={}),
)
@ -114,7 +101,7 @@ class ScanManager:
if scanID is None:
scanID = self.scan_storage.current_scanID
logger.info("Requesting scan halt")
self.producer.send(
self.connector.send(
MessageEndpoints.scan_queue_modification_request(),
messages.ScanQueueModificationMessage(scanID=scanID, action="halt", parameter={}),
)
@ -129,7 +116,7 @@ class ScanManager:
if scanID is None:
scanID = self.scan_storage.current_scanID
logger.info("Requesting scan continuation")
self.producer.send(
self.connector.send(
MessageEndpoints.scan_queue_modification_request(),
messages.ScanQueueModificationMessage(scanID=scanID, action="continue", parameter={}),
)
@ -137,7 +124,7 @@ class ScanManager:
def request_queue_reset(self):
"""request a scan queue reset"""
logger.info("Requesting a queue reset")
self.producer.send(
self.connector.send(
MessageEndpoints.scan_queue_modification_request(),
messages.ScanQueueModificationMessage(scanID=None, action="clear", parameter={}),
)
@ -151,7 +138,7 @@ class ScanManager:
logger.info("Requesting to abort and repeat a scan")
position = "replace" if replace else "append"
self.producer.send(
self.connector.send(
MessageEndpoints.scan_queue_modification_request(),
messages.ScanQueueModificationMessage(
scanID=scanID, action="restart", parameter={"position": position, "RID": requestID}
@ -162,7 +149,7 @@ class ScanManager:
@property
def next_scan_number(self):
"""get the next scan number from redis"""
num = self.producer.get(MessageEndpoints.scan_number())
num = self.connector.get(MessageEndpoints.scan_number())
if num is None:
logger.warning("Failed to retrieve scan number from redis.")
return -1
@ -172,63 +159,51 @@ class ScanManager:
@typechecked
def next_scan_number(self, val: int):
"""set the next scan number in redis"""
return self.producer.set(MessageEndpoints.scan_number(), val)
return self.connector.set(MessageEndpoints.scan_number(), val)
@property
def next_dataset_number(self):
"""get the next dataset number from redis"""
return int(self.producer.get(MessageEndpoints.dataset_number()))
return int(self.connector.get(MessageEndpoints.dataset_number()))
@next_dataset_number.setter
@typechecked
def next_dataset_number(self, val: int):
"""set the next dataset number in redis"""
return self.producer.set(MessageEndpoints.dataset_number(), val)
return self.connector.set(MessageEndpoints.dataset_number(), val)
@staticmethod
def _scan_queue_status_callback(msg, *, parent: ScanManager, **_kwargs) -> None:
def _scan_queue_status_callback(self, msg, **_kwargs) -> None:
queue_status = msg.value
if not queue_status:
return
parent.update_with_queue_status(queue_status)
self.update_with_queue_status(queue_status)
@staticmethod
def _scan_queue_request_callback(msg, *, parent: ScanManager, **_kwargs) -> None:
def _scan_queue_request_callback(self, msg, **_kwargs) -> None:
request = msg.value
parent.request_storage.update_with_request(request)
self.request_storage.update_with_request(request)
@staticmethod
def _scan_queue_request_response_callback(msg, *, parent: ScanManager, **_kwargs) -> None:
def _scan_queue_request_response_callback(self, msg, **_kwargs) -> None:
response = msg.value
logger.debug(response)
parent.request_storage.update_with_response(response)
self.request_storage.update_with_response(response)
@staticmethod
def _scan_status_callback(msg, *, parent: ScanManager, **_kwargs) -> None:
def _scan_status_callback(self, msg, **_kwargs) -> None:
scan = msg.value
parent.scan_storage.update_with_scan_status(scan)
self.scan_storage.update_with_scan_status(scan)
@staticmethod
def _scan_segment_callback(msg, *, parent: ScanManager, **_kwargs) -> None:
def _scan_segment_callback(self, msg, **_kwargs) -> None:
scan_msgs = msg.value
if not isinstance(scan_msgs, list):
scan_msgs = [scan_msgs]
for scan_msg in scan_msgs:
parent.scan_storage.add_scan_segment(scan_msg)
self.scan_storage.add_scan_segment(scan_msg)
@staticmethod
def _baseline_callback(msg, *, parent: ScanManager, **_kwargs) -> None:
def _baseline_callback(self, msg, **_kwargs) -> None:
msg = msg.value
parent.scan_storage.add_scan_baseline(msg)
self.scan_storage.add_scan_baseline(msg)
def __str__(self) -> str:
return "\n".join(self.queue_storage.describe_queue())
def shutdown(self):
"""stop the scan manager's threads"""
self._scan_queue_consumer.shutdown()
self._scan_queue_request_consumer.shutdown()
self._scan_queue_request_response_consumer.shutdown()
self._scan_status_consumer.shutdown()
self._scan_segment_consumer.shutdown()
self._baseline_consumer.shutdown()
pass

View File

@ -89,7 +89,7 @@ class ScanReport:
def _get_mv_status(self) -> bool:
"""get the status of a move request"""
motors = list(self.request.request.content["parameter"]["args"].keys())
request_status = self._client.device_manager.producer.lrange(
request_status = self._client.device_manager.connector.lrange(
MessageEndpoints.device_req_status(self.request.requestID), 0, -1
)
if len(request_status) == len(motors):

View File

@ -100,9 +100,9 @@ class ScanObject:
return None
return self.scan_info.get("scan_report_hint")
def _start_consumer(self, request: messages.ScanQueueMessage) -> ConsumerConnector:
"""Start a consumer for the given request"""
consumer = self.client.device_manager.connector.consumer(
def _start_register(self, request: messages.ScanQueueMessage) -> ConsumerConnector:
"""Start a register for the given request"""
register = self.client.device_manager.connector.register(
[
MessageEndpoints.device_readback(dev)
for dev in request.content["parameter"]["args"].keys()
@ -110,11 +110,11 @@ class ScanObject:
threaded=False,
cb=(lambda msg: msg),
)
return consumer
return register
def _send_scan_request(self, request: messages.ScanQueueMessage) -> None:
"""Send a scan request to the scan server"""
self.client.device_manager.producer.send(MessageEndpoints.scan_queue_request(), request)
self.client.device_manager.connector.send(MessageEndpoints.scan_queue_request(), request)
class Scans:
@ -136,7 +136,7 @@ class Scans:
def _import_scans(self):
"""Import scans from the scan server"""
available_scans = self.parent.producer.get(MessageEndpoints.available_scans())
available_scans = self.parent.connector.get(MessageEndpoints.available_scans())
if available_scans is None:
logger.warning("No scans available. Are redis and the BEC server running?")
return

View File

@ -14,7 +14,9 @@ logger = bec_logger.logger
DEFAULT_SERVICE_CONFIG = {
"redis": {"host": "localhost", "port": 6379},
"service_config": {"file_writer": {"plugin": "default_NeXus_format", "base_path": "./"}},
"service_config": {
"file_writer": {"plugin": "default_NeXus_format", "base_path": os.path.dirname(__file__)}
},
}
@ -32,7 +34,7 @@ class ServiceConfig:
self._update_config(service_config=config, redis=redis)
self.service_config = self.config.get(
"service_config", {"file_writer": {"plugin": "default_NeXus_format", "base_path": "./"}}
"service_config", DEFAULT_SERVICE_CONFIG["service_config"]
)
def _update_config(self, **kwargs):

View File

@ -56,7 +56,7 @@ def queue_is_empty(queue) -> bool: # pragma: no cover
def get_queue(bec): # pragma: no cover
return bec.queue.producer.get(MessageEndpoints.scan_queue_status())
return bec.queue.connector.get(MessageEndpoints.scan_queue_status())
def wait_for_empty_queue(bec): # pragma: no cover
@ -484,7 +484,6 @@ def bec_client():
with open(f"{dir_path}/tests/test_config.yaml", "r", encoding="utf-8") as f:
builtins.__dict__["test_session"] = create_session_from_config(yaml.safe_load(f))
device_manager._session = builtins.__dict__["test_session"]
device_manager.producer = device_manager.connector.producer()
client.wait_for_service = lambda service_name: None
device_manager._load_session()
for name, dev in device_manager.devices.items():
@ -497,37 +496,23 @@ def bec_client():
class PipelineMock: # pragma: no cover
_pipe_buffer = []
_producer = None
_connector = None
def __init__(self, producer) -> None:
self._producer = producer
def __init__(self, connector) -> None:
self._connector = connector
def execute(self):
if not self._producer.store_data:
if not self._connector.store_data:
self._pipe_buffer = []
return []
res = [
getattr(self._producer, method)(*args, **kwargs)
getattr(self._connector, method)(*args, **kwargs)
for method, args, kwargs in self._pipe_buffer
]
self._pipe_buffer = []
return res
class ConsumerMock: # pragma: no cover
def __init__(self) -> None:
self.signal_event = SignalMock()
def start(self):
pass
def join(self):
pass
def shutdown(self):
pass
class SignalMock: # pragma: no cover
def __init__(self) -> None:
self.is_set = False
@ -536,12 +521,36 @@ class SignalMock: # pragma: no cover
self.is_set = True
class ProducerMock: # pragma: no cover
def __init__(self, store_data=True) -> None:
class ConnectorMock(ConnectorBase): # pragma: no cover
def __init__(self, bootstrap_server="localhost:0000", store_data=True):
super().__init__(bootstrap_server)
self.message_sent = []
self._get_buffer = {}
self.store_data = store_data
def raise_alarm(
self, severity: Alarms, alarm_type: str, source: str, msg: dict, metadata: dict
):
pass
def log_error(self, *args, **kwargs):
pass
def shutdown(self):
pass
def register(self, *args, **kwargs):
pass
def set(self, *args, **kwargs):
pass
def set_and_publish(self, *args, **kwargs):
pass
def keys(self, *args, **kwargs):
return []
def set(self, topic, msg, pipe=None, expire: int = None):
if pipe:
pipe._pipe_buffer.append(("set", (topic, msg), {"expire": expire}))
@ -592,9 +601,6 @@ class ProducerMock: # pragma: no cover
self._get_buffer.pop(topic, None)
return val
def keys(self, pattern: str) -> list:
return []
def pipeline(self):
return PipelineMock(self)
@ -609,29 +615,6 @@ class ProducerMock: # pragma: no cover
return
class ConnectorMock(ConnectorBase): # pragma: no cover
def __init__(self, bootstrap_server: list, store_data=True):
super().__init__(bootstrap_server)
self.store_data = store_data
def consumer(self, *args, **kwargs) -> ConsumerMock:
return ConsumerMock()
def producer(self, *args, **kwargs):
return ProducerMock(self.store_data)
def raise_alarm(
self, severity: Alarms, alarm_type: str, source: str, msg: dict, metadata: dict
):
pass
def log_error(self, *args, **kwargs):
pass
def shutdown(self):
pass
def create_session_from_config(config: dict) -> dict:
device_configs = []
session_id = str(uuid.uuid4())

View File

@ -5,6 +5,8 @@ __version__ = "1.12.1"
if __name__ == "__main__":
setup(
install_requires=[
"hiredis",
"louie",
"numpy",
"scipy",
"msgpack",
@ -22,7 +24,16 @@ if __name__ == "__main__":
"lmfit",
],
extras_require={
"dev": ["pytest", "pytest-random-order", "coverage", "pandas", "black", "pylint"]
"dev": [
"pytest",
"pytest-random-order",
"pytest-redis",
"pytest-timeout",
"coverage",
"pandas",
"black",
"pylint",
]
},
entry_points={"console_scripts": ["bec-channel-monitor = bec_lib:channel_monitor_launch"]},
package_data={"bec_lib.tests": ["*.yaml"], "bec_lib.configs": ["*.yaml", "*.json"]},

View File

@ -17,7 +17,7 @@ def test_bec_widgets_connector_set_plot_config(bec_client):
config = {"x": "test", "y": "test", "color": "test", "size": "test", "shape": "test"}
connector.set_plot_config(plot_id="plot_id", config=config)
msg = messages.GUIConfigMessage(config=config)
bec_client.connector.producer().set_and_publish.assert_called_once_with(
bec_client.connector.set_and_publish.assert_called_once_with(
MessageEndpoints.gui_config("plot_id"), msg
) is None
@ -26,7 +26,7 @@ def test_bec_widgets_connector_close(bec_client):
connector = BECWidgetsConnector(gui_id="gui_id", bec_client=bec_client)
connector.close("plot_id")
msg = messages.GUIInstructionMessage(action="close", parameter={})
bec_client.connector.producer().set_and_publish.assert_called_once_with(
bec_client.connector.set_and_publish.assert_called_once_with(
MessageEndpoints.gui_instructions("plot_id"), msg
)
@ -36,7 +36,7 @@ def test_bec_widgets_connector_send_data(bec_client):
data = {"x": [1, 2, 3], "y": [1, 2, 3]}
connector.send_data("plot_id", data)
msg = messages.GUIDataMessage(data=data)
bec_client.connector.producer().set_and_publish.assert_called_once_with(
bec_client.connector.set_and_publish.assert_called_once_with(
topic=MessageEndpoints.gui_data("plot_id"), msg=msg
)
@ -45,7 +45,7 @@ def test_bec_widgets_connector_clear(bec_client):
connector = BECWidgetsConnector(gui_id="gui_id", bec_client=bec_client)
connector.clear("plot_id")
msg = messages.GUIInstructionMessage(action="clear", parameter={})
bec_client.connector.producer().set_and_publish.assert_called_once_with(
bec_client.connector.set_and_publish.assert_called_once_with(
MessageEndpoints.gui_instructions("plot_id"), msg
)

View File

@ -124,8 +124,8 @@ def test_bec_service_update_existing_services():
messages.StatusMessage(name="service2", status=BECStatus.IDLE, info={}, metadata={}),
]
connector_cls = mock.MagicMock()
connector_cls().producer().keys.return_value = service_keys
connector_cls().producer().get.side_effect = [msg for msg in service_msgs]
connector_cls().keys.return_value = service_keys
connector_cls().get.side_effect = [msg for msg in service_msgs]
service = BECService(
config=f"{os.path.dirname(bec_lib.__file__)}/tests/test_service_config.yaml",
connector_cls=connector_cls,
@ -144,8 +144,8 @@ def test_bec_service_update_existing_services_ignores_wrong_msgs():
None,
]
connector_cls = mock.MagicMock()
connector_cls().producer().keys.return_value = service_keys
connector_cls().producer().get.side_effect = [service_msgs[0], None]
connector_cls().keys.return_value = service_keys
connector_cls().get.side_effect = [service_msgs[0], None]
service = BECService(
config=f"{os.path.dirname(bec_lib.__file__)}/tests/test_service_config.yaml",
connector_cls=connector_cls,

View File

@ -13,7 +13,7 @@ def test_channel_monitor_callback():
mock_print.assert_called_once()
def test_channel_monitor_start_consumer():
def test_channel_monitor_start_register():
with mock.patch("bec_lib.channel_monitor.argparse") as mock_argparse:
with mock.patch("bec_lib.channel_monitor.ServiceConfig") as mock_config:
with mock.patch("bec_lib.channel_monitor.RedisConnector") as mock_connector:
@ -26,6 +26,6 @@ def test_channel_monitor_start_consumer():
mock_config.return_value = mock.MagicMock()
mock_connector.return_value = mock.MagicMock()
channel_monitor_launch()
mock_connector().consumer.assert_called_once()
mock_connector().consumer.return_value.start.assert_called_once()
mock_connector().register.assert_called_once()
mock_connector().register.return_value.start.assert_called_once()
mock_threading.Event().wait.assert_called_once()

View File

@ -49,7 +49,7 @@ def test_config_helper_save_current_session():
connector = mock.MagicMock()
config_helper = ConfigHelper(connector)
connector.producer().get.return_value = messages.AvailableResourceMessage(
connector.get.return_value = messages.AvailableResourceMessage(
resource=[
{
"id": "648c817f67d3c7cd6a354e8e",
@ -158,9 +158,7 @@ def test_send_config_request_raises_for_rejected_update(config_helper):
def test_wait_for_config_reply():
connector = mock.MagicMock()
config_helper = ConfigHelper(connector)
connector.producer().get.return_value = messages.RequestResponseMessage(
accepted=True, message="test"
)
connector.get.return_value = messages.RequestResponseMessage(accepted=True, message="test")
res = config_helper.wait_for_config_reply("test")
assert res == messages.RequestResponseMessage(accepted=True, message="test")
@ -169,7 +167,7 @@ def test_wait_for_config_reply():
def test_wait_for_config_raises_timeout():
connector = mock.MagicMock()
config_helper = ConfigHelper(connector)
connector.producer().get.return_value = None
connector.get.return_value = None
with pytest.raises(DeviceConfigError):
config_helper.wait_for_config_reply("test", timeout=0.3)
@ -178,7 +176,7 @@ def test_wait_for_config_raises_timeout():
def test_wait_for_service_response():
connector = mock.MagicMock()
config_helper = ConfigHelper(connector)
connector.producer().lrange.side_effect = [
connector.lrange.side_effect = [
[],
[
messages.ServiceResponseMessage(
@ -196,7 +194,7 @@ def test_wait_for_service_response():
def test_wait_for_service_response_raises_timeout():
connector = mock.MagicMock()
config_helper = ConfigHelper(connector)
connector.producer().lrange.return_value = []
connector.lrange.return_value = []
with pytest.raises(DeviceConfigError):
config_helper.wait_for_service_response("test", timeout=0.3)

View File

@ -349,7 +349,7 @@ def dap(dap_plugin_message):
}
client = mock.MagicMock()
client.service_status = dap_services
client.producer.get.return_value = dap_plugin_message
client.connector.get.return_value = dap_plugin_message
dap_plugins = DAPPlugins(client)
yield dap_plugins
@ -367,7 +367,7 @@ def test_dap_plugins_construction(dap):
def test_dap_plugin_fit(dap):
with mock.patch.object(dap.GaussianModel, "_wait_for_dap_response") as mock_wait:
dap.GaussianModel.fit()
dap._parent.producer.set_and_publish.assert_called_once()
dap._parent.connector.set_and_publish.assert_called_once()
mock_wait.assert_called_once()
@ -380,7 +380,7 @@ def test_dap_auto_run(dap):
def test_dap_wait_for_dap_response_waits_for_RID(dap):
dap._parent.producer.get.return_value = messages.DAPResponseMessage(
dap._parent.connector.get.return_value = messages.DAPResponseMessage(
success=True, data={}, metadata={"RID": "wrong_ID"}
)
with pytest.raises(TimeoutError):
@ -388,7 +388,7 @@ def test_dap_wait_for_dap_response_waits_for_RID(dap):
def test_dap_wait_for_dap_respnse_returns(dap):
dap._parent.producer.get.return_value = messages.DAPResponseMessage(
dap._parent.connector.get.return_value = messages.DAPResponseMessage(
success=True, data={}, metadata={"RID": "1234"}
)
val = dap.GaussianModel._wait_for_dap_response(request_id="1234", timeout=0.1)
@ -429,11 +429,11 @@ def test_dap_select_raises_on_wrong_device(dap):
def test_dap_get_data(dap):
dap._parent.producer.get_last.return_value = messages.ProcessedDataMessage(
dap._parent.connector.get_last.return_value = messages.ProcessedDataMessage(
data=[{"x": [1, 2, 3], "y": [4, 5, 6]}, {"fit_parameters": {"amplitude": 1}}]
)
data = dap.GaussianModel.get_data()
dap._parent.producer.get_last.assert_called_once_with(
dap._parent.connector.get_last.assert_called_once_with(
MessageEndpoints.processed_data("GaussianModel")
)
@ -443,13 +443,13 @@ def test_dap_get_data(dap):
def test_dap_update_dap_config_not_called_without_device(dap):
dap.GaussianModel._update_dap_config(request_id="1234")
dap._parent.producer.set_and_publish.assert_not_called()
dap._parent.connector.set_and_publish.assert_not_called()
def test_dap_update_dap_config(dap):
dap.GaussianModel._plugin_config["selected_device"] = ["samx", "samx"]
dap.GaussianModel._update_dap_config(request_id="1234")
dap._parent.producer.set_and_publish.assert_called_with(
dap._parent.connector.set_and_publish.assert_called_with(
MessageEndpoints.dap_request(),
messages.DAPRequestMessage(
dap_cls="LmfitService1D",

View File

@ -91,15 +91,14 @@ def test_get_config_calls_load(dm):
dm, "_get_redis_device_config", return_value={"devices": [{}]}
) as get_redis_config:
with mock.patch.object(dm, "_load_session") as load_session:
with mock.patch.object(dm, "producer") as producer:
dm._get_config()
get_redis_config.assert_called_once()
load_session.assert_called_once()
dm._get_config()
get_redis_config.assert_called_once()
load_session.assert_called_once()
def test_get_redis_device_config(dm):
with mock.patch.object(dm, "producer") as producer:
producer.get.return_value = messages.AvailableResourceMessage(resource={"devices": [{}]})
with mock.patch.object(dm, "connector") as connector:
connector.get.return_value = messages.AvailableResourceMessage(resource={"devices": [{}]})
assert dm._get_redis_device_config() == {"devices": [{}]}

View File

@ -23,7 +23,7 @@ def test_nested_device_root(dev):
def test_read(dev):
with mock.patch.object(dev.samx.root.parent.producer, "get") as mock_get:
with mock.patch.object(dev.samx.root.parent.connector, "get") as mock_get:
mock_get.return_value = messages.DeviceMessage(
signals={
"samx": {"value": 0, "timestamp": 1701105880.1711318},
@ -42,7 +42,7 @@ def test_read(dev):
def test_read_filtered_hints(dev):
with mock.patch.object(dev.samx.root.parent.producer, "get") as mock_get:
with mock.patch.object(dev.samx.root.parent.connector, "get") as mock_get:
mock_get.return_value = messages.DeviceMessage(
signals={
"samx": {"value": 0, "timestamp": 1701105880.1711318},
@ -57,7 +57,7 @@ def test_read_filtered_hints(dev):
def test_read_use_read(dev):
with mock.patch.object(dev.samx.root.parent.producer, "get") as mock_get:
with mock.patch.object(dev.samx.root.parent.connector, "get") as mock_get:
data = {
"samx": {"value": 0, "timestamp": 1701105880.1711318},
"samx_setpoint": {"value": 0, "timestamp": 1701105880.1693492},
@ -72,7 +72,7 @@ def test_read_use_read(dev):
def test_read_nested_device(dev):
with mock.patch.object(dev.dyn_signals.root.parent.producer, "get") as mock_get:
with mock.patch.object(dev.dyn_signals.root.parent.connector, "get") as mock_get:
data = {
"dyn_signals_messages_message1": {"value": 0, "timestamp": 1701105880.0716832},
"dyn_signals_messages_message2": {"value": 0, "timestamp": 1701105880.071722},
@ -93,7 +93,7 @@ def test_read_nested_device(dev):
)
def test_read_kind_hinted(dev, kind, cached):
with mock.patch.object(dev.samx.readback, "_run") as mock_run:
with mock.patch.object(dev.samx.root.parent.producer, "get") as mock_get:
with mock.patch.object(dev.samx.root.parent.connector, "get") as mock_get:
data = {
"samx": {"value": 0, "timestamp": 1701105880.1711318},
"samx_setpoint": {"value": 0, "timestamp": 1701105880.1693492},
@ -138,7 +138,7 @@ def test_read_configuration_cached(dev, is_signal, is_config_signal, method):
with mock.patch.object(
dev.samx.readback, "_get_rpc_signal_info", return_value=(is_signal, is_config_signal, True)
):
with mock.patch.object(dev.samx.root.parent.producer, "get") as mock_get:
with mock.patch.object(dev.samx.root.parent.connector, "get") as mock_get:
mock_get.return_value = messages.DeviceMessage(
signals={
"samx": {"value": 0, "timestamp": 1701105880.1711318},
@ -182,7 +182,7 @@ def test_get_rpc_func_name_read(dev):
)
def test_get_rpc_func_name_readback_get(dev, kind, cached):
with mock.patch.object(dev.samx.readback, "_run") as mock_rpc:
with mock.patch.object(dev.samx.root.parent.producer, "get") as mock_get:
with mock.patch.object(dev.samx.root.parent.connector, "get") as mock_get:
mock_get.return_value = messages.DeviceMessage(
signals={
"samx": {"value": 0, "timestamp": 1701105880.1711318},
@ -219,9 +219,7 @@ def test_handle_rpc_response_returns_status(dev, bec_client):
msg = messages.DeviceRPCMessage(
device="samx", return_val={"type": "status", "RID": "request_id"}, out="done", success=True
)
assert dev.samx._handle_rpc_response(msg) == Status(
bec_client.device_manager.producer, "request_id"
)
assert dev.samx._handle_rpc_response(msg) == Status(bec_client.device_manager, "request_id")
def test_handle_rpc_response_raises(dev):
@ -348,7 +346,7 @@ def test_device_update_user_parameter(device_obj, user_param, val, out, raised_e
def test_status_wait():
producer = mock.MagicMock()
connector = mock.MagicMock()
def lrange_mock(*args, **kwargs):
yield False
@ -358,8 +356,8 @@ def test_status_wait():
return next(lmock)
lmock = lrange_mock()
producer.lrange = get_lrange
status = Status(producer, "test")
connector.lrange = get_lrange
status = Status(connector, "test")
status.wait()
@ -561,7 +559,7 @@ def test_show_all():
def test_adjustable_mixin_limits():
adj = AdjustableMixin()
adj.root = mock.MagicMock()
adj.root.parent.producer.get.return_value = messages.DeviceMessage(
adj.root.parent.connector.get.return_value = messages.DeviceMessage(
signals={"low": -12, "high": 12}, metadata={}
)
assert adj.limits == [-12, 12]
@ -570,7 +568,7 @@ def test_adjustable_mixin_limits():
def test_adjustable_mixin_limits_missing():
adj = AdjustableMixin()
adj.root = mock.MagicMock()
adj.root.parent.producer.get.return_value = None
adj.root.parent.connector.get.return_value = None
assert adj.limits == [0, 0]
@ -585,7 +583,7 @@ def test_adjustable_mixin_set_low_limit():
adj = AdjustableMixin()
adj.update_config = mock.MagicMock()
adj.root = mock.MagicMock()
adj.root.parent.producer.get.return_value = messages.DeviceMessage(
adj.root.parent.connector.get.return_value = messages.DeviceMessage(
signals={"low": -12, "high": 12}, metadata={}
)
adj.low_limit = -20
@ -596,7 +594,7 @@ def test_adjustable_mixin_set_high_limit():
adj = AdjustableMixin()
adj.update_config = mock.MagicMock()
adj.root = mock.MagicMock()
adj.root.parent.producer.get.return_value = messages.DeviceMessage(
adj.root.parent.connector.get.return_value = messages.DeviceMessage(
signals={"low": -12, "high": 12}, metadata={}
)
adj.high_limit = 20

View File

@ -97,9 +97,9 @@ def device_manager(dm_with_devices):
def test_observer_manager_None(device_manager):
with mock.patch.object(device_manager.producer, "get", return_value=None) as producer_get:
with mock.patch.object(device_manager.connector, "get", return_value=None) as connector_get:
observer_manager = ObserverManager(device_manager=device_manager)
producer_get.assert_called_once_with(MessageEndpoints.observer())
connector_get.assert_called_once_with(MessageEndpoints.observer())
assert len(observer_manager._observer) == 0
@ -115,9 +115,9 @@ def test_observer_manager_msg(device_manager):
}
]
)
with mock.patch.object(device_manager.producer, "get", return_value=msg) as producer_get:
with mock.patch.object(device_manager.connector, "get", return_value=msg) as connector_get:
observer_manager = ObserverManager(device_manager=device_manager)
producer_get.assert_called_once_with(MessageEndpoints.observer())
connector_get.assert_called_once_with(MessageEndpoints.observer())
assert len(observer_manager._observer) == 1
@ -139,7 +139,7 @@ def test_observer_manager_msg(device_manager):
],
)
def test_add_observer(device_manager, observer, raises_error):
with mock.patch.object(device_manager.producer, "get", return_value=None) as producer_get:
with mock.patch.object(device_manager.connector, "get", return_value=None) as connector_get:
observer_manager = ObserverManager(device_manager=device_manager)
observer_manager.add_observer(observer)
with pytest.raises(AttributeError):
@ -185,7 +185,7 @@ def test_add_observer_existing_device(device_manager, observer, raises_error):
"limits": [380, None],
}
)
with mock.patch.object(device_manager.producer, "get", return_value=None) as producer_get:
with mock.patch.object(device_manager.connector, "get", return_value=None) as connector_get:
observer_manager = ObserverManager(device_manager=device_manager)
observer_manager.add_observer(default_observer)
if raises_error:

View File

@ -12,113 +12,71 @@ from bec_lib.messages import AlarmMessage, BECMessage, LogMessage
from bec_lib.redis_connector import (
MessageObject,
RedisConnector,
RedisConsumer,
RedisConsumerMixin,
RedisConsumerThreaded,
RedisProducer,
RedisStreamConsumerThreaded,
)
from bec_lib.serialization import MsgpackSerialization
@pytest.fixture
def producer():
with mock.patch("bec_lib.redis_connector.redis.Redis"):
prod = RedisProducer("localhost", 1)
yield prod
@pytest.fixture
def connector():
with mock.patch("bec_lib.redis_connector.redis.Redis"):
connector = RedisConnector("localhost:1")
try:
yield connector
finally:
connector.shutdown()
@pytest.fixture
def connected_connector(redis_proc):
connector = RedisConnector(f"localhost:{redis_proc.port}")
try:
yield connector
@pytest.fixture
def consumer():
with mock.patch("bec_lib.redis_connector.redis.Redis"):
consumer = RedisConsumer("localhost", "1", topics="topics")
yield consumer
@pytest.fixture
def consumer_threaded():
with mock.patch("bec_lib.redis_connector.redis.Redis"):
consumer_threaded = RedisConsumerThreaded("localhost", "1", topics="topics")
yield consumer_threaded
@pytest.fixture
def mixin():
with mock.patch("bec_lib.redis_connector.redis.Redis"):
mixin = RedisConsumerMixin
yield mixin
def test_redis_connector_producer(connector):
ret = connector.producer()
assert isinstance(ret, RedisProducer)
finally:
connector.shutdown()
@pytest.mark.parametrize(
"topics, threaded", [["topics", True], ["topics", False], [None, True], [None, False]]
"topics, threaded",
[["topics", True], ["topics", False], [None, True], [None, False]],
)
def test_redis_connector_consumer(connector, threaded, topics):
pattern = None
len_of_threads = len(connector._threads)
if threaded:
if topics is None and pattern is None:
with pytest.raises(ValueError) as exc_info:
ret = connector.consumer(
topics=topics, threaded=threaded, cb=lambda *args, **kwargs: ...
)
assert exc_info.value.args[0] == "Topics must be set for threaded consumer"
else:
ret = connector.consumer(
topics=topics, threaded=threaded, cb=lambda *args, **kwargs: ...
def test_redis_connector_register(connected_connector, threaded, topics):
breakpoint()
connector = connected_connector
if topics is None:
with pytest.raises(TypeError):
ret = connector.register(
topics=topics, cb=lambda *args, **kwargs: ..., start_thread=threaded
)
assert len(connector._threads) == len_of_threads + 1
assert isinstance(ret, RedisConsumerThreaded)
else:
if not topics:
with pytest.raises(ConsumerConnectorError):
ret = connector.consumer(
topics=topics, threaded=threaded, cb=lambda *args, **kwargs: ...
)
return
ret = connector.consumer(topics=topics, threaded=threaded, cb=lambda *args, **kwargs: ...)
assert isinstance(ret, RedisConsumer)
ret = connector.register(
topics=topics, cb=lambda *args, **kwargs: ..., start_thread=threaded
)
if threaded:
assert connector._events_listener_thread is not None
def test_redis_connector_log_warning(connector):
connector._notifications_producer.send = mock.MagicMock()
connector.log_warning("msg")
connector._notifications_producer.send.assert_called_once_with(
MessageEndpoints.log(), LogMessage(log_type="warning", log_msg="msg")
)
with mock.patch.object(connector, "send", return_value=None):
connector.log_warning("msg")
connector.send.assert_called_once_with(
MessageEndpoints.log(), LogMessage(log_type="warning", log_msg="msg")
)
def test_redis_connector_log_message(connector):
connector._notifications_producer.send = mock.MagicMock()
connector.log_message("msg")
connector._notifications_producer.send.assert_called_once_with(
MessageEndpoints.log(), LogMessage(log_type="log", log_msg="msg")
)
with mock.patch.object(connector, "send", return_value=None):
connector.log_message("msg")
connector.send.assert_called_once_with(
MessageEndpoints.log(), LogMessage(log_type="log", log_msg="msg")
)
def test_redis_connector_log_error(connector):
connector._notifications_producer.send = mock.MagicMock()
connector.log_error("msg")
connector._notifications_producer.send.assert_called_once_with(
MessageEndpoints.log(), LogMessage(log_type="error", log_msg="msg")
)
with mock.patch.object(connector, "send", return_value=None):
connector.log_error("msg")
connector.send.assert_called_once_with(
MessageEndpoints.log(), LogMessage(log_type="error", log_msg="msg")
)
@pytest.mark.parametrize(
@ -130,20 +88,24 @@ def test_redis_connector_log_error(connector):
],
)
def test_redis_connector_raise_alarm(connector, severity, alarm_type, source, msg, metadata):
connector._notifications_producer.set_and_publish = mock.MagicMock()
with mock.patch.object(connector, "set_and_publish", return_value=None):
connector.raise_alarm(severity, alarm_type, source, msg, metadata)
connector.raise_alarm(severity, alarm_type, source, msg, metadata)
connector._notifications_producer.set_and_publish.assert_called_once_with(
MessageEndpoints.alarm(),
AlarmMessage(
severity=severity, alarm_type=alarm_type, source=source, msg=msg, metadata=metadata
),
)
connector.set_and_publish.assert_called_once_with(
MessageEndpoints.alarm(),
AlarmMessage(
severity=severity,
alarm_type=alarm_type,
source=source,
msg=msg,
metadata=metadata,
),
)
@dataclass(eq=False)
class TestMessage(BECMessage):
__test__ = False # just for pytest to ignore this class
msg_type = "test_message"
msg: str
# have to add this field here,
@ -160,30 +122,36 @@ bec_messages.TestMessage = TestMessage
@pytest.mark.parametrize(
"topic , msg", [["topic1", TestMessage("msg1")], ["topic2", TestMessage("msg2")]]
)
def test_redis_producer_send(producer, topic, msg):
producer.send(topic, msg)
producer.r.publish.assert_called_once_with(topic, MsgpackSerialization.dumps(msg))
def test_redis_connector_send(connector, topic, msg):
connector.send(topic, msg)
connector._redis_conn.publish.assert_called_once_with(topic, MsgpackSerialization.dumps(msg))
producer.send(topic, msg, pipe=producer.pipeline())
producer.r.pipeline().publish.assert_called_once_with(topic, MsgpackSerialization.dumps(msg))
connector.send(topic, msg, pipe=connector.pipeline())
connector._redis_conn.pipeline().publish.assert_called_once_with(
topic, MsgpackSerialization.dumps(msg)
)
@pytest.mark.parametrize(
"topic, msgs, max_size, expire",
[["topic1", "msgs", None, None], ["topic1", "msgs", 10, None], ["topic1", "msgs", None, 100]],
[
["topic1", "msgs", None, None],
["topic1", "msgs", 10, None],
["topic1", "msgs", None, 100],
],
)
def test_redis_producer_lpush(producer, topic, msgs, max_size, expire):
def test_redis_connector_lpush(connector, topic, msgs, max_size, expire):
pipe = None
producer.lpush(topic, msgs, pipe, max_size, expire)
connector.lpush(topic, msgs, pipe, max_size, expire)
producer.r.pipeline().lpush.assert_called_once_with(topic, msgs)
connector._redis_conn.pipeline().lpush.assert_called_once_with(topic, msgs)
if max_size:
producer.r.pipeline().ltrim.assert_called_once_with(topic, 0, max_size)
connector._redis_conn.pipeline().ltrim.assert_called_once_with(topic, 0, max_size)
if expire:
producer.r.pipeline().expire.assert_called_once_with(topic, expire)
connector._redis_conn.pipeline().expire.assert_called_once_with(topic, expire)
if not pipe:
producer.r.pipeline().execute.assert_called_once()
connector._redis_conn.pipeline().execute.assert_called_once()
@pytest.mark.parametrize(
@ -194,68 +162,76 @@ def test_redis_producer_lpush(producer, topic, msgs, max_size, expire):
["topic1", TestMessage("msgs"), None, 100],
],
)
def test_redis_producer_lpush_BECMessage(producer, topic, msgs, max_size, expire):
def test_redis_connector_lpush_BECMessage(connector, topic, msgs, max_size, expire):
pipe = None
producer.lpush(topic, msgs, pipe, max_size, expire)
connector.lpush(topic, msgs, pipe, max_size, expire)
producer.r.pipeline().lpush.assert_called_once_with(topic, MsgpackSerialization.dumps(msgs))
connector._redis_conn.pipeline().lpush.assert_called_once_with(
topic, MsgpackSerialization.dumps(msgs)
)
if max_size:
producer.r.pipeline().ltrim.assert_called_once_with(topic, 0, max_size)
connector._redis_conn.pipeline().ltrim.assert_called_once_with(topic, 0, max_size)
if expire:
producer.r.pipeline().expire.assert_called_once_with(topic, expire)
connector._redis_conn.pipeline().expire.assert_called_once_with(topic, expire)
if not pipe:
producer.r.pipeline().execute.assert_called_once()
connector._redis_conn.pipeline().execute.assert_called_once()
@pytest.mark.parametrize(
"topic , index , msgs, use_pipe", [["topic1", 1, "msg1", True], ["topic2", 4, "msg2", False]]
"topic , index , msgs, use_pipe",
[["topic1", 1, "msg1", True], ["topic2", 4, "msg2", False]],
)
def test_redis_producer_lset(producer, topic, index, msgs, use_pipe):
pipe = use_pipe_fcn(producer, use_pipe)
def test_redis_connector_lset(connector, topic, index, msgs, use_pipe):
pipe = use_pipe_fcn(connector, use_pipe)
ret = producer.lset(topic, index, msgs, pipe)
ret = connector.lset(topic, index, msgs, pipe)
if pipe:
producer.r.pipeline().lset.assert_called_once_with(topic, index, msgs)
connector._redis_conn.pipeline().lset.assert_called_once_with(topic, index, msgs)
assert ret == redis.Redis().pipeline().lset()
else:
producer.r.lset.assert_called_once_with(topic, index, msgs)
connector._redis_conn.lset.assert_called_once_with(topic, index, msgs)
assert ret == redis.Redis().lset()
@pytest.mark.parametrize(
"topic , index , msgs, use_pipe",
[["topic1", 1, TestMessage("msg1"), True], ["topic2", 4, TestMessage("msg2"), False]],
[
["topic1", 1, TestMessage("msg1"), True],
["topic2", 4, TestMessage("msg2"), False],
],
)
def test_redis_producer_lset_BECMessage(producer, topic, index, msgs, use_pipe):
pipe = use_pipe_fcn(producer, use_pipe)
def test_redis_connector_lset_BECMessage(connector, topic, index, msgs, use_pipe):
pipe = use_pipe_fcn(connector, use_pipe)
ret = producer.lset(topic, index, msgs, pipe)
ret = connector.lset(topic, index, msgs, pipe)
if pipe:
producer.r.pipeline().lset.assert_called_once_with(
connector._redis_conn.pipeline().lset.assert_called_once_with(
topic, index, MsgpackSerialization.dumps(msgs)
)
assert ret == redis.Redis().pipeline().lset()
else:
producer.r.lset.assert_called_once_with(topic, index, MsgpackSerialization.dumps(msgs))
connector._redis_conn.lset.assert_called_once_with(
topic, index, MsgpackSerialization.dumps(msgs)
)
assert ret == redis.Redis().lset()
@pytest.mark.parametrize(
"topic, msgs, use_pipe", [["topic1", "msg1", True], ["topic2", "msg2", False]]
)
def test_redis_producer_rpush(producer, topic, msgs, use_pipe):
pipe = use_pipe_fcn(producer, use_pipe)
def test_redis_connector_rpush(connector, topic, msgs, use_pipe):
pipe = use_pipe_fcn(connector, use_pipe)
ret = producer.rpush(topic, msgs, pipe)
ret = connector.rpush(topic, msgs, pipe)
if pipe:
producer.r.pipeline().rpush.assert_called_once_with(topic, msgs)
connector._redis_conn.pipeline().rpush.assert_called_once_with(topic, msgs)
assert ret == redis.Redis().pipeline().rpush()
else:
producer.r.rpush.assert_called_once_with(topic, msgs)
connector._redis_conn.rpush.assert_called_once_with(topic, msgs)
assert ret == redis.Redis().rpush()
@ -263,421 +239,349 @@ def test_redis_producer_rpush(producer, topic, msgs, use_pipe):
"topic, msgs, use_pipe",
[["topic1", TestMessage("msg1"), True], ["topic2", TestMessage("msg2"), False]],
)
def test_redis_producer_rpush_BECMessage(producer, topic, msgs, use_pipe):
pipe = use_pipe_fcn(producer, use_pipe)
def test_redis_connector_rpush_BECMessage(connector, topic, msgs, use_pipe):
pipe = use_pipe_fcn(connector, use_pipe)
ret = producer.rpush(topic, msgs, pipe)
ret = connector.rpush(topic, msgs, pipe)
if pipe:
producer.r.pipeline().rpush.assert_called_once_with(topic, MsgpackSerialization.dumps(msgs))
connector._redis_conn.pipeline().rpush.assert_called_once_with(
topic, MsgpackSerialization.dumps(msgs)
)
assert ret == redis.Redis().pipeline().rpush()
else:
producer.r.rpush.assert_called_once_with(topic, MsgpackSerialization.dumps(msgs))
connector._redis_conn.rpush.assert_called_once_with(topic, MsgpackSerialization.dumps(msgs))
assert ret == redis.Redis().rpush()
@pytest.mark.parametrize(
"topic, start, end, use_pipe", [["topic1", 0, 4, True], ["topic2", 3, 7, False]]
)
def test_redis_producer_lrange(producer, topic, start, end, use_pipe):
pipe = use_pipe_fcn(producer, use_pipe)
def test_redis_connector_lrange(connector, topic, start, end, use_pipe):
pipe = use_pipe_fcn(connector, use_pipe)
ret = producer.lrange(topic, start, end, pipe)
ret = connector.lrange(topic, start, end, pipe)
if pipe:
producer.r.pipeline().lrange.assert_called_once_with(topic, start, end)
connector._redis_conn.pipeline().lrange.assert_called_once_with(topic, start, end)
assert ret == redis.Redis().pipeline().lrange()
else:
producer.r.lrange.assert_called_once_with(topic, start, end)
connector._redis_conn.lrange.assert_called_once_with(topic, start, end)
assert ret == []
@pytest.mark.parametrize(
"topic, msg, pipe, expire", [["topic1", "msg1", None, 400], ["topic2", "msg2", None, None]]
"topic, msg, pipe, expire",
[
["topic1", TestMessage("msg1"), None, 400],
["topic2", TestMessage("msg2"), None, None],
["topic3", "msg3", None, None],
],
)
def test_redis_producer_set_and_publish(producer, topic, msg, pipe, expire):
producer.set_and_publish(topic, msg, pipe, expire)
def test_redis_connector_set_and_publish(connector, topic, msg, pipe, expire):
if not isinstance(msg, BECMessage):
with pytest.raises(TypeError):
connector.set_and_publish(topic, msg, pipe, expire)
else:
connector.set_and_publish(topic, msg, pipe, expire)
producer.r.pipeline().publish.assert_called_once_with(topic, msg)
producer.r.pipeline().set.assert_called_once_with(topic, msg)
if expire:
producer.r.pipeline().expire.assert_called_once_with(topic, expire)
if not pipe:
producer.r.pipeline().execute.assert_called_once()
connector._redis_conn.pipeline().publish.assert_called_once_with(
topic, MsgpackSerialization.dumps(msg)
)
connector._redis_conn.pipeline().set.assert_called_once_with(
topic, MsgpackSerialization.dumps(msg), ex=expire
)
if not pipe:
connector._redis_conn.pipeline().execute.assert_called_once()
@pytest.mark.parametrize("topic, msg, expire", [["topic1", "msg1", None], ["topic2", "msg2", 400]])
def test_redis_producer_set(producer, topic, msg, expire):
def test_redis_connector_set(connector, topic, msg, expire):
pipe = None
producer.set(topic, msg, pipe, expire)
connector.set(topic, msg, pipe, expire)
if pipe:
producer.r.pipeline().set.assert_called_once_with(topic, msg, ex=expire)
connector._redis_conn.pipeline().set.assert_called_once_with(topic, msg, ex=expire)
else:
producer.r.set.assert_called_once_with(topic, msg, ex=expire)
connector._redis_conn.set.assert_called_once_with(topic, msg, ex=expire)
@pytest.mark.parametrize("pattern", ["samx", "samy"])
def test_redis_producer_keys(producer, pattern):
ret = producer.keys(pattern)
producer.r.keys.assert_called_once_with(pattern)
def test_redis_connector_keys(connector, pattern):
ret = connector.keys(pattern)
connector._redis_conn.keys.assert_called_once_with(pattern)
assert ret == redis.Redis().keys()
def test_redis_producer_pipeline(producer):
ret = producer.pipeline()
producer.r.pipeline.assert_called_once()
def test_redis_connector_pipeline(connector):
ret = connector.pipeline()
connector._redis_conn.pipeline.assert_called_once()
assert ret == redis.Redis().pipeline()
@pytest.mark.parametrize("topic,use_pipe", [["topic1", True], ["topic2", False]])
def test_redis_producer_delete(producer, topic, use_pipe):
pipe = use_pipe_fcn(producer, use_pipe)
producer.delete(topic, pipe)
if pipe:
producer.pipeline().delete.assert_called_once_with(topic)
else:
producer.r.delete.assert_called_once_with(topic)
@pytest.mark.parametrize("topic, use_pipe", [["topic1", True], ["topic2", False]])
def test_redis_producer_get(producer, topic, use_pipe):
pipe = use_pipe_fcn(producer, use_pipe)
ret = producer.get(topic, pipe)
if pipe:
producer.pipeline().get.assert_called_once_with(topic)
assert ret == redis.Redis().pipeline().get()
else:
producer.r.get.assert_called_once_with(topic)
assert ret == redis.Redis().get()
def use_pipe_fcn(producer, use_pipe):
def use_pipe_fcn(connector, use_pipe):
if use_pipe:
return producer.pipeline()
return connector.pipeline()
return None
@pytest.mark.parametrize("topic,use_pipe", [["topic1", True], ["topic2", False]])
def test_redis_connector_delete(connector, topic, use_pipe):
pipe = use_pipe_fcn(connector, use_pipe)
connector.delete(topic, pipe)
if pipe:
connector.pipeline().delete.assert_called_once_with(topic)
else:
connector._redis_conn.delete.assert_called_once_with(topic)
@pytest.mark.parametrize("topic, use_pipe", [["topic1", True], ["topic2", False]])
def test_redis_connector_get(connector, topic, use_pipe):
pipe = use_pipe_fcn(connector, use_pipe)
ret = connector.get(topic, pipe)
if pipe:
connector.pipeline().get.assert_called_once_with(topic)
assert ret == redis.Redis().pipeline().get()
else:
connector._redis_conn.get.assert_called_once_with(topic)
assert ret == redis.Redis().get()
@pytest.mark.parametrize(
"topics, pattern",
"subscribed_topics, subscribed_patterns, msgs",
[
["topics1", None],
[["topics1", "topics2"], None],
[None, "pattern1"],
[None, ["pattern1", "pattern2"]],
["topics1", None, ["topics1"]],
[["topics1", "topics2"], None, ["topics1", "topics2"]],
[None, "pattern1", ["pattern1"]],
[None, ["patt*", "top*"], ["pattern1", "topics1"]],
],
)
def test_redis_consumer_init(consumer, topics, pattern):
with mock.patch("bec_lib.redis_connector.redis.Redis"):
consumer = RedisConsumer(
"localhost", "1", topics, pattern, redis_cls=redis.Redis, cb=lambda *args, **kwargs: ...
def test_redis_connector_register(
redisdb, connected_connector, subscribed_topics, subscribed_patterns, msgs
):
connector = connected_connector
test_msg = TestMessage("test")
cb_mock = mock.Mock(spec=[]) # spec is here to remove all attributes
if subscribed_topics:
connector.register(
subscribed_topics, subscribed_patterns, cb=cb_mock, start_thread=False, a=1
)
if topics:
if isinstance(topics, list):
assert consumer.topics == topics
else:
assert consumer.topics == [topics]
if pattern:
if isinstance(pattern, list):
assert consumer.pattern == pattern
else:
assert consumer.pattern == [pattern]
assert consumer.r == redis.Redis()
assert consumer.pubsub == consumer.r.pubsub()
assert consumer.host == "localhost"
assert consumer.port == "1"
for msg in msgs:
connector.send(msg, TestMessage(msg))
connector.poll_messages()
msg_object = MessageObject(msg, TestMessage(msg))
cb_mock.assert_called_with(msg_object, a=1)
if subscribed_patterns:
connector.register(
subscribed_topics, subscribed_patterns, cb=cb_mock, start_thread=False, a=1
)
for msg in msgs:
connector.send(msg, TestMessage(msg))
connector.poll_messages()
msg_object = MessageObject(msg, TestMessage(msg))
cb_mock.assert_called_with(msg_object, a=1)
@pytest.mark.parametrize("pattern, topics", [["pattern", "topics1"], [None, "topics2"]])
def test_redis_consumer_initialize_connector(consumer, pattern, topics):
consumer.pattern = pattern
consumer.topics = topics
consumer.initialize_connector()
if consumer.pattern is not None:
consumer.pubsub.psubscribe.assert_called_once_with(consumer.pattern)
else:
consumer.pubsub.subscribe.assert_called_with(consumer.topics)
def test_redis_consumer_poll_messages(consumer):
def test_redis_register_poll_messages(redisdb, connected_connector):
connector = connected_connector
cb_fcn_has_been_called = False
def cb_fcn(msg, **kwargs):
nonlocal cb_fcn_has_been_called
cb_fcn_has_been_called = True
print(msg)
consumer.cb = cb_fcn
assert kwargs["a"] == 1
test_msg = TestMessage("test")
consumer.pubsub.get_message.return_value = {
"channel": "",
"data": MsgpackSerialization.dumps(test_msg),
}
ret = consumer.poll_messages()
consumer.pubsub.get_message.assert_called_once_with(ignore_subscribe_messages=True)
connector.register("test", cb=cb_fcn, a=1, start_thread=False)
redisdb.publish("test", MsgpackSerialization.dumps(test_msg))
connector.poll_messages(timeout=1)
assert cb_fcn_has_been_called
def test_redis_consumer_shutdown(consumer):
consumer.shutdown()
consumer.pubsub.close.assert_called_once()
with pytest.raises(TimeoutError):
connector.poll_messages(timeout=0.1)
def test_redis_consumer_additional_kwargs(connector):
cons = connector.consumer(topics="topic1", parent="here", cb=lambda *args, **kwargs: ...)
assert "parent" in cons.kwargs
def test_redis_connector_xadd(connector):
connector.xadd("topic1", {"key": "value"})
connector._redis_conn.xadd.assert_called_once_with("topic1", {"key": "value"})
@pytest.mark.parametrize(
"topics, pattern",
[
["topics1", None],
[["topics1", "topics2"], None],
[None, "pattern1"],
[None, ["pattern1", "pattern2"]],
],
)
def test_mixin_init_topics_and_pattern(mixin, topics, pattern):
ret_topics, ret_pattern = mixin._init_topics_and_pattern(mixin, topics, pattern)
if topics:
if isinstance(topics, list):
assert ret_topics == topics
else:
assert ret_topics == [topics]
if pattern:
if isinstance(pattern, list):
assert ret_pattern == pattern
else:
assert ret_pattern == [pattern]
def test_redis_connector_xadd_with_maxlen(connector):
connector.xadd("topic1", {"key": "value"}, max_size=100)
connector._redis_conn.xadd.assert_called_once_with("topic1", {"key": "value"}, maxlen=100)
def test_mixin_init_redis_cls(mixin, consumer):
mixin._init_redis_cls(consumer, None)
assert consumer.r == redis.Redis(host="localhost", port=1)
def test_redis_connector_xadd_with_expire(connector):
connector.xadd("topic1", {"key": "value"}, expire=100)
connector._redis_conn.pipeline().xadd.assert_called_once_with("topic1", {"key": "value"})
connector._redis_conn.pipeline().expire.assert_called_once_with("topic1", 100)
connector._redis_conn.pipeline().execute.assert_called_once()
@pytest.mark.parametrize(
"topics, pattern",
[
["topics1", None],
[["topics1", "topics2"], None],
[None, "pattern1"],
[None, ["pattern1", "pattern2"]],
],
)
def test_redis_consumer_threaded_init(consumer_threaded, topics, pattern):
with mock.patch("bec_lib.redis_connector.redis.Redis"):
consumer_threaded = RedisConsumerThreaded(
"localhost", "1", topics, pattern, redis_cls=redis.Redis, cb=lambda *args, **kwargs: ...
)
if topics:
if isinstance(topics, list):
assert consumer_threaded.topics == topics
else:
assert consumer_threaded.topics == [topics]
if pattern:
if isinstance(pattern, list):
assert consumer_threaded.pattern == pattern
else:
assert consumer_threaded.pattern == [pattern]
assert consumer_threaded.r == redis.Redis()
assert consumer_threaded.pubsub == consumer_threaded.r.pubsub()
assert consumer_threaded.host == "localhost"
assert consumer_threaded.port == "1"
assert consumer_threaded.sleep_times == [0.005, 0.1]
assert consumer_threaded.last_received_msg == 0
assert consumer_threaded.idle_time == 30
def test_redis_connector_xread(connector):
connector.xread("topic1", "id")
connector._redis_conn.xread.assert_called_once_with({"topic1": "id"}, count=None, block=None)
def test_redis_connector_xadd(producer):
producer.xadd("topic1", {"key": "value"})
producer.r.xadd.assert_called_once_with("topic1", {"key": MsgpackSerialization.dumps("value")})
def test_redis_connector_xadd(connector):
connector.xadd("topic1", {"key": "value"})
connector._redis_conn.xadd.assert_called_once_with(
"topic1", {"key": MsgpackSerialization.dumps("value")}
)
test_msg = TestMessage("test")
producer.xadd("topic1", {"data": test_msg})
producer.r.xadd.assert_called_with("topic1", {"data": MsgpackSerialization.dumps(test_msg)})
producer.r.xrevrange.return_value = [
connector.xadd("topic1", {"data": test_msg})
connector._redis_conn.xadd.assert_called_with(
"topic1", {"data": MsgpackSerialization.dumps(test_msg)}
)
connector._redis_conn.xrevrange.return_value = [
(b"1707391599960-0", {b"data": MsgpackSerialization.dumps(test_msg)})
]
msg = producer.get_last("topic1")
msg = connector.get_last("topic1")
assert msg == test_msg
def test_redis_connector_xadd_with_maxlen(producer):
producer.xadd("topic1", {"key": "value"}, max_size=100)
producer.r.xadd.assert_called_once_with(
def test_redis_connector_xadd_with_maxlen(connector):
connector.xadd("topic1", {"key": "value"}, max_size=100)
connector._redis_conn.xadd.assert_called_once_with(
"topic1", {"key": MsgpackSerialization.dumps("value")}, maxlen=100
)
def test_redis_connector_xadd_with_expire(producer):
producer.xadd("topic1", {"key": "value"}, expire=100)
producer.r.pipeline().xadd.assert_called_once_with(
def test_redis_connector_xadd_with_expire(connector):
connector.xadd("topic1", {"key": "value"}, expire=100)
connector._redis_conn.pipeline().xadd.assert_called_once_with(
"topic1", {"key": MsgpackSerialization.dumps("value")}
)
producer.r.pipeline().expire.assert_called_once_with("topic1", 100)
producer.r.pipeline().execute.assert_called_once()
connector._redis_conn.pipeline().expire.assert_called_once_with("topic1", 100)
connector._redis_conn.pipeline().execute.assert_called_once()
def test_redis_connector_xread(producer):
producer.xread("topic1", "id")
producer.r.xread.assert_called_once_with({"topic1": "id"}, count=None, block=None)
def test_redis_connector_xread(connector):
connector.xread("topic1", "id")
connector._redis_conn.xread.assert_called_once_with({"topic1": "id"}, count=None, block=None)
def test_redis_connector_xread_without_id(producer):
producer.xread("topic1", from_start=True)
producer.r.xread.assert_called_once_with({"topic1": "0-0"}, count=None, block=None)
producer.r.xread.reset_mock()
def test_redis_connector_xread_without_id(connector):
connector.xread("topic1", from_start=True)
connector._redis_conn.xread.assert_called_once_with({"topic1": "0-0"}, count=None, block=None)
connector._redis_conn.xread.reset_mock()
producer.stream_keys["topic1"] = "id"
producer.xread("topic1")
producer.r.xread.assert_called_once_with({"topic1": "id"}, count=None, block=None)
connector.stream_keys["topic1"] = "id"
connector.xread("topic1")
connector._redis_conn.xread.assert_called_once_with({"topic1": "id"}, count=None, block=None)
def test_redis_connector_xread_from_end(producer):
producer.xread("topic1", from_start=False)
producer.r.xrevrange.assert_called_once_with("topic1", "+", "-", count=1)
def test_redis_connector_xread_from_end(connector):
connector.xread("topic1", from_start=False)
connector._redis_conn.xrevrange.assert_called_once_with("topic1", "+", "-", count=1)
def test_redis_connector_get_last(producer):
producer.r.xrevrange.return_value = [
def test_redis_connector_get_last(connector):
connector._redis_conn.xrevrange.return_value = [
(b"1707391599960-0", {b"key": MsgpackSerialization.dumps("value")})
]
msg = producer.get_last("topic1")
producer.r.xrevrange.assert_called_once_with("topic1", "+", "-", count=1)
msg = connector.get_last("topic1")
connector._redis_conn.xrevrange.assert_called_once_with("topic1", "+", "-", count=1)
assert msg is None # no key given, default is b'data'
assert producer.get_last("topic1", "key") == "value"
assert producer.get_last("topic1", None) == {"key": "value"}
assert connector.get_last("topic1", "key") == "value"
assert connector.get_last("topic1", None) == {"key": "value"}
def test_redis_xrange(producer):
producer.xrange("topic1", "start", "end")
producer.r.xrange.assert_called_once_with("topic1", "start", "end", count=None)
def test_redis_connector_xread_without_id(connector):
connector.xread("topic1", from_start=True)
connector._redis_conn.xread.assert_called_once_with({"topic1": "0-0"}, count=None, block=None)
connector._redis_conn.xread.reset_mock()
connector.stream_keys["topic1"] = "id"
connector.xread("topic1")
connector._redis_conn.xread.assert_called_once_with({"topic1": "id"}, count=None, block=None)
def test_redis_xrange_topic_with_suffix(producer):
producer.xrange("topic1", "start", "end")
producer.r.xrange.assert_called_once_with("topic1", "start", "end", count=None)
def test_redis_xrange(connector):
connector.xrange("topic1", "start", "end")
connector._redis_conn.xrange.assert_called_once_with("topic1", "start", "end", count=None)
def test_redis_consumer_threaded_no_cb_without_messages(consumer_threaded):
with mock.patch.object(consumer_threaded.pubsub, "get_message", return_value=None):
consumer_threaded.cb = mock.MagicMock()
consumer_threaded.poll_messages()
consumer_threaded.cb.assert_not_called()
def test_redis_xrange_topic_with_suffix(connector):
connector.xrange("topic1", "start", "end")
connector._redis_conn.xrange.assert_called_once_with("topic1", "start", "end", count=None)
def test_redis_consumer_threaded_cb_called_with_messages(consumer_threaded):
message = {"channel": b"topic1", "data": MsgpackSerialization.dumps(TestMessage("test"))}
with mock.patch.object(consumer_threaded.pubsub, "get_message", return_value=message):
consumer_threaded.cb = mock.MagicMock()
consumer_threaded.poll_messages()
msg_object = MessageObject("topic1", TestMessage("test"))
consumer_threaded.cb.assert_called_once_with(msg_object)
# def test_redis_stream_register_threaded_get_id():
# register = RedisStreamConsumerThreaded(
# "localhost", "1", topics="topic1", cb=mock.MagicMock(), redis_cls=mock.MagicMock()
# )
# register.stream_keys["topic1"] = b"1691610882756-0"
# assert register.get_id("topic1") == b"1691610882756-0"
# assert register.get_id("doesnt_exist") == "0-0"
def test_redis_consumer_threaded_shutdown(consumer_threaded):
consumer_threaded.shutdown()
consumer_threaded.pubsub.close.assert_called_once()
# def test_redis_stream_register_threaded_poll_messages():
# register = RedisStreamConsumerThreaded(
# "localhost", "1", topics="topic1", cb=mock.MagicMock(), redis_cls=mock.MagicMock()
# )
# with mock.patch.object(
# register, "get_newest_message", return_value=None
# ) as mock_get_newest_message:
# register.poll_messages()
# mock_get_newest_message.assert_called_once()
# register._redis_conn.xread.assert_not_called()
def test_redis_stream_consumer_threaded_get_newest_message():
consumer = RedisStreamConsumerThreaded(
"localhost", "1", topics="topic1", cb=mock.MagicMock(), redis_cls=mock.MagicMock()
)
consumer.r.xrevrange.return_value = [(b"1691610882756-0", {b"data": b"msg"})]
msgs = []
consumer.get_newest_message(msgs)
assert "topic1" in consumer.stream_keys
assert consumer.stream_keys["topic1"] == b"1691610882756-0"
# def test_redis_stream_register_threaded_poll_messages_newest_only():
# register = RedisStreamConsumerThreaded(
# "localhost",
# "1",
# topics="topic1",
# cb=mock.MagicMock(),
# redis_cls=mock.MagicMock(),
# newest_only=True,
# )
#
# register._redis_conn.xrevrange.return_value = [(b"1691610882756-0", {b"data": b"msg"})]
# register.poll_messages()
# register._redis_conn.xread.assert_not_called()
# register.cb.assert_called_once_with(MessageObject(topic="topic1", value=b"msg"))
def test_redis_stream_consumer_threaded_get_newest_message_no_msg():
consumer = RedisStreamConsumerThreaded(
"localhost", "1", topics="topic1", cb=mock.MagicMock(), redis_cls=mock.MagicMock()
)
consumer.r.xrevrange.return_value = []
msgs = []
consumer.get_newest_message(msgs)
assert "topic1" in consumer.stream_keys
assert consumer.stream_keys["topic1"] == "0-0"
# def test_redis_stream_register_threaded_poll_messages_read():
# register = RedisStreamConsumerThreaded(
# "localhost",
# "1",
# topics="topic1",
# cb=mock.MagicMock(),
# redis_cls=mock.MagicMock(),
# )
# register.stream_keys["topic1"] = "0-0"
#
# msg = [[b"topic1", [(b"1691610714612-0", {b"data": b"msg"})]]]
#
# register._redis_conn.xread.return_value = msg
# register.poll_messages()
# register._redis_conn.xread.assert_called_once_with({"topic1": "0-0"}, count=1)
# register.cb.assert_called_once_with(MessageObject(topic="topic1", value=b"msg"))
def test_redis_stream_consumer_threaded_get_id():
consumer = RedisStreamConsumerThreaded(
"localhost", "1", topics="topic1", cb=mock.MagicMock(), redis_cls=mock.MagicMock()
)
consumer.stream_keys["topic1"] = b"1691610882756-0"
assert consumer.get_id("topic1") == b"1691610882756-0"
assert consumer.get_id("doesnt_exist") == "0-0"
def test_redis_stream_consumer_threaded_poll_messages():
consumer = RedisStreamConsumerThreaded(
"localhost", "1", topics="topic1", cb=mock.MagicMock(), redis_cls=mock.MagicMock()
)
with mock.patch.object(
consumer, "get_newest_message", return_value=None
) as mock_get_newest_message:
consumer.poll_messages()
mock_get_newest_message.assert_called_once()
consumer.r.xread.assert_not_called()
def test_redis_stream_consumer_threaded_poll_messages_newest_only():
consumer = RedisStreamConsumerThreaded(
"localhost",
"1",
topics="topic1",
cb=mock.MagicMock(),
redis_cls=mock.MagicMock(),
newest_only=True,
)
consumer.r.xrevrange.return_value = [(b"1691610882756-0", {b"data": b"msg"})]
consumer.poll_messages()
consumer.r.xread.assert_not_called()
consumer.cb.assert_called_once_with(MessageObject(topic="topic1", value=b"msg"))
def test_redis_stream_consumer_threaded_poll_messages_read():
consumer = RedisStreamConsumerThreaded(
"localhost", "1", topics="topic1", cb=mock.MagicMock(), redis_cls=mock.MagicMock()
)
consumer.stream_keys["topic1"] = "0-0"
msg = [[b"topic1", [(b"1691610714612-0", {b"data": b"msg"})]]]
consumer.r.xread.return_value = msg
consumer.poll_messages()
consumer.r.xread.assert_called_once_with({"topic1": "0-0"}, count=1)
consumer.cb.assert_called_once_with(MessageObject(topic="topic1", value=b"msg"))
@pytest.mark.parametrize(
"topics,expected",
[
("topic1", ["topic1"]),
(["topic1"], ["topic1"]),
(["topic1", "topic2"], ["topic1", "topic2"]),
],
)
def test_redis_stream_consumer_threaded_init_topics(topics, expected):
consumer = RedisStreamConsumerThreaded(
"localhost", "1", topics=topics, cb=mock.MagicMock(), redis_cls=mock.MagicMock()
)
assert consumer.topics == expected
# @pytest.mark.parametrize(
# "topics,expected",
# [
# ("topic1", ["topic1"]),
# (["topic1"], ["topic1"]),
# (["topic1", "topic2"], ["topic1", "topic2"]),
# ],
# )
# def test_redis_stream_register_threaded_init_topics(topics, expected):
# register = RedisStreamConsumerThreaded(
# "localhost",
# "1",
# topics=topics,
# cb=mock.MagicMock(),
# redis_cls=mock.MagicMock(),
# )
# assert register.topics == expected

View File

@ -63,7 +63,7 @@ from bec_lib.tests.utils import ConnectorMock
)
def test_update_with_queue_status(queue_msg):
scan_manager = ScanManager(ConnectorMock(""))
scan_manager.producer._get_buffer[MessageEndpoints.scan_queue_status()] = queue_msg
scan_manager.connector._get_buffer[MessageEndpoints.scan_queue_status()] = queue_msg
scan_manager.update_with_queue_status(queue_msg)
assert (
scan_manager.scan_storage.find_scan_by_ID("bfa582aa-f9cd-4258-ab5d-3e5d54d3dde5")

View File

@ -105,6 +105,6 @@ def test_scan_report_get_mv_status(scan_report, lrange_return, expected):
scan_report.request.request = messages.ScanQueueMessage(
scan_type="mv", parameter={"args": {"samx": [5], "samy": [5]}}
)
with mock.patch.object(scan_report._client.device_manager.producer, "lrange") as mock_lrange:
with mock.patch.object(scan_report._client.device_manager.connector, "lrange") as mock_lrange:
mock_lrange.return_value = lrange_return
assert scan_report._get_mv_status() == expected

View File

@ -14,7 +14,6 @@ parser.add_argument("--redis", default="localhost:6379", help="redis host and po
clargs = parser.parse_args()
connector = RedisConnector(clargs.redis)
producer = connector.producer()
with open(clargs.config, "r", encoding="utf-8") as stream:
data = yaml.safe_load(stream)
@ -22,4 +21,4 @@ for name, device in data.items():
device["name"] = name
config_data = list(data.values())
msg = messages.AvailableResourceMessage(resource=config_data)
producer.set(MessageEndpoints.device_config(), msg)
connector.set(MessageEndpoints.device_config(), msg)

View File

@ -11,7 +11,6 @@ class DAPServiceManager:
def __init__(self, services: list) -> None:
self.connector = None
self.producer = None
self._started = False
self.client = None
self._dap_request_thread = None
@ -24,13 +23,11 @@ class DAPServiceManager:
"""
Start the dap request consumer.
"""
self._dap_request_thread = self.connector.consumer(
topics=MessageEndpoints.dap_request(), cb=self._dap_request_callback, parent=self
self.connector.register(
topics=MessageEndpoints.dap_request(), cb=self._dap_request_callback
)
self._dap_request_thread.start()
@staticmethod
def _dap_request_callback(msg: MessageObject, *, parent: DAPServiceManager) -> None:
def _dap_request_callback(self, msg: MessageObject) -> None:
"""
Callback function for dap request consumer.
@ -41,7 +38,7 @@ class DAPServiceManager:
dap_request_msg = messages.DAPRequestMessage.loads(msg.value)
if not dap_request_msg:
return
parent.process_dap_request(dap_request_msg)
self.process_dap_request(dap_request_msg)
def process_dap_request(self, dap_request_msg: messages.DAPRequestMessage) -> None:
"""
@ -153,7 +150,7 @@ class DAPServiceManager:
dap_response_msg = messages.DAPResponseMessage(
success=success, data=data, error=error, dap_request=dap_request_msg, metadata=metadata
)
self.producer.set_and_publish(
self.connector.set_and_publish(
MessageEndpoints.dap_response(metadata.get("RID")), dap_response_msg, expire=60
)
@ -168,7 +165,6 @@ class DAPServiceManager:
return
self.client = client
self.connector = client.connector
self.producer = self.connector.producer()
self._start_dap_request_consumer()
self.update_available_dap_services()
self.publish_available_services()
@ -264,12 +260,12 @@ class DAPServiceManager:
"""send all available dap services to the broker"""
msg = messages.AvailableResourceMessage(resource=self.available_dap_services)
# pylint: disable=protected-access
self.producer.set(
self.connector.set(
MessageEndpoints.dap_available_plugins(f"DAPServer/{self.client._service_id}"), msg
)
def shutdown(self) -> None:
if not self._started:
return
self._dap_request_thread.stop()
self.connector.shutdown()
self._started = False

View File

@ -148,7 +148,7 @@ class LmfitService1D(DAPServiceBase):
out = self.process()
if out:
stream_output, metadata = out
self.client.producer.xadd(
self.client.connector.xadd(
MessageEndpoints.processed_data(self.model.__class__.__name__),
msg={
"data": MsgpackSerialization.dumps(

View File

@ -72,7 +72,7 @@ def test_DAPServiceManager_init(service_manager):
def test_DAPServiceManager_request_callback(service_manager, msg, process_called):
msg_obj = MessageObject(value=msg, topic="topic")
with mock.patch.object(service_manager, "process_dap_request") as mock_process_dap_request:
service_manager._dap_request_callback(msg_obj, parent=service_manager)
service_manager._dap_request_callback(msg_obj)
if process_called:
mock_process_dap_request.assert_called_once_with(msg)

View File

@ -134,7 +134,7 @@ def test_LmfitService1D_process_until_finished(lmfit_service):
lmfit_service.process_until_finished(event)
assert get_data.call_count == 2
assert process.call_count == 2
assert lmfit_service.client.producer.xadd.call_count == 2
assert lmfit_service.client.connector.xadd.call_count == 2
def test_LmfitService1D_configure(lmfit_service):

View File

@ -18,7 +18,7 @@ from device_server.rpc_mixin import RPCMixin
logger = bec_logger.logger
consumer_stop = threading.Event()
register_stop = threading.Event()
class DisabledDeviceError(Exception):
@ -38,14 +38,10 @@ class DeviceServer(RPCMixin, BECService):
super().__init__(config, connector_cls, unique_service=True)
self._tasks = []
self.device_manager = None
self.threads = []
self.sig_thread = None
self.sig_thread = self.connector.consumer(
self.connector.register(
MessageEndpoints.scan_queue_modification(),
cb=self.consumer_interception_callback,
parent=self,
cb=self.register_interception_callback,
)
self.sig_thread.start()
self.executor = ThreadPoolExecutor(max_workers=4)
self._start_device_manager()
@ -55,19 +51,16 @@ class DeviceServer(RPCMixin, BECService):
def start(self) -> None:
"""start the device server"""
if consumer_stop.is_set():
consumer_stop.clear()
if register_stop.is_set():
register_stop.clear()
self.connector.register(
MessageEndpoints.device_instructions(),
event=register_stop,
cb=self.instructions_callback,
parent=self,
)
self.threads = [
self.connector.consumer(
MessageEndpoints.device_instructions(),
event=consumer_stop,
cb=self.instructions_callback,
parent=self,
)
]
for thread in self.threads:
thread.start()
self.status = BECStatus.RUNNING
def update_status(self, status: BECStatus):
@ -76,17 +69,13 @@ class DeviceServer(RPCMixin, BECService):
def stop(self) -> None:
"""stop the device server"""
consumer_stop.set()
for thread in self.threads:
thread.join()
register_stop.set()
self.status = BECStatus.IDLE
def shutdown(self) -> None:
"""shutdown the device server"""
super().shutdown()
self.stop()
self.sig_thread.signal_event.set()
self.sig_thread.join()
self.device_manager.shutdown()
def _update_device_metadata(self, instr) -> None:
@ -97,8 +86,7 @@ class DeviceServer(RPCMixin, BECService):
device_root = dev.split(".")[0]
self.device_manager.devices.get(device_root).metadata = instr.metadata
@staticmethod
def consumer_interception_callback(msg, *, parent, **_kwargs) -> None:
def register_interception_callback(self, msg, **_kwargs) -> None:
"""callback for receiving scan modifications / interceptions"""
mvalue = msg.value
if mvalue is None:
@ -106,7 +94,7 @@ class DeviceServer(RPCMixin, BECService):
return
logger.info(f"Receiving: {mvalue.content}")
if mvalue.content.get("action") in ["pause", "abort", "halt"]:
parent.stop_devices()
self.stop_devices()
def stop_devices(self) -> None:
"""stop all enabled devices"""
@ -279,7 +267,7 @@ class DeviceServer(RPCMixin, BECService):
devices = instr.content["device"]
if not isinstance(devices, list):
devices = [devices]
pipe = self.producer.pipeline()
pipe = self.connector.pipeline()
for dev in devices:
obj = self.device_manager.devices.get(dev)
obj.metadata = instr.metadata
@ -288,11 +276,11 @@ class DeviceServer(RPCMixin, BECService):
dev_msg = messages.DeviceReqStatusMessage(
device=dev, success=True, metadata=instr.metadata
)
self.producer.set_and_publish(MessageEndpoints.device_req_status(dev), dev_msg, pipe)
self.connector.set_and_publish(MessageEndpoints.device_req_status(dev), dev_msg, pipe)
pipe.execute()
def _status_callback(self, status):
pipe = self.producer.pipeline()
pipe = self.connector.pipeline()
if hasattr(status, "device"):
obj = status.device
else:
@ -302,12 +290,12 @@ class DeviceServer(RPCMixin, BECService):
device=device_name, success=status.success, metadata=status.instruction.metadata
)
logger.debug(f"req status for device {device_name}: {status.success}")
self.producer.set_and_publish(
self.connector.set_and_publish(
MessageEndpoints.device_req_status(device_name), dev_msg, pipe
)
response = status.instruction.metadata.get("response")
if response:
self.producer.lpush(
self.connector.lpush(
MessageEndpoints.device_req_status(status.instruction.metadata["RID"]),
dev_msg,
pipe,
@ -328,7 +316,7 @@ class DeviceServer(RPCMixin, BECService):
dev_config_msg = messages.DeviceMessage(
signals=obj.root.read_configuration(), metadata=metadata
)
self.producer.set_and_publish(
self.connector.set_and_publish(
MessageEndpoints.device_read_configuration(obj.root.name), dev_config_msg, pipe
)
@ -342,7 +330,7 @@ class DeviceServer(RPCMixin, BECService):
def _read_and_update_devices(self, devices: list[str], metadata: dict) -> list:
start = time.time()
pipe = self.producer.pipeline()
pipe = self.connector.pipeline()
signal_container = []
for dev in devices:
device_root = dev.split(".")[0]
@ -354,17 +342,17 @@ class DeviceServer(RPCMixin, BECService):
except Exception as exc:
signals = self._retry_obj_method(dev, obj, "read", exc)
self.producer.set_and_publish(
self.connector.set_and_publish(
MessageEndpoints.device_read(device_root),
messages.DeviceMessage(signals=signals, metadata=metadata),
pipe,
)
self.producer.set_and_publish(
self.connector.set_and_publish(
MessageEndpoints.device_readback(device_root),
messages.DeviceMessage(signals=signals, metadata=metadata),
pipe,
)
self.producer.set(
self.connector.set(
MessageEndpoints.device_status(device_root),
messages.DeviceStatusMessage(device=device_root, status=0, metadata=metadata),
pipe,
@ -377,7 +365,7 @@ class DeviceServer(RPCMixin, BECService):
def _read_config_and_update_devices(self, devices: list[str], metadata: dict) -> list:
start = time.time()
pipe = self.producer.pipeline()
pipe = self.connector.pipeline()
signal_container = []
for dev in devices:
self.device_manager.devices.get(dev).metadata = metadata
@ -387,7 +375,7 @@ class DeviceServer(RPCMixin, BECService):
signal_container.append(signals)
except Exception as exc:
signals = self._retry_obj_method(dev, obj, "read_configuration", exc)
self.producer.set_and_publish(
self.connector.set_and_publish(
MessageEndpoints.device_read_configuration(dev),
messages.DeviceMessage(signals=signals, metadata=metadata),
pipe,
@ -420,9 +408,11 @@ class DeviceServer(RPCMixin, BECService):
f"Failed to run {method} on device {device_root}. Trying to load an old value."
)
if method == "read":
old_msg = self.producer.get(MessageEndpoints.device_read(device_root))
old_msg = self.connector.get(MessageEndpoints.device_read(device_root))
elif method == "read_configuration":
old_msg = self.producer.get(MessageEndpoints.device_read_configuration(device_root))
old_msg = self.connector.get(
MessageEndpoints.device_read_configuration(device_root)
)
else:
raise ValueError(f"Unknown method {method}.")
if not old_msg:
@ -435,7 +425,7 @@ class DeviceServer(RPCMixin, BECService):
if not isinstance(devices, list):
devices = [devices]
pipe = self.producer.pipeline()
pipe = self.connector.pipeline()
for dev in devices:
obj = self.device_manager.devices[dev].obj
if hasattr(obj, "_staged"):
@ -444,7 +434,7 @@ class DeviceServer(RPCMixin, BECService):
logger.info(f"Device {obj.name} was already staged and will be first unstaged.")
self.device_manager.devices[dev].obj.unstage()
self.device_manager.devices[dev].obj.stage()
self.producer.set(
self.connector.set(
MessageEndpoints.device_staged(dev),
messages.DeviceStatusMessage(device=dev, status=1, metadata=instr.metadata),
pipe,
@ -456,7 +446,7 @@ class DeviceServer(RPCMixin, BECService):
if not isinstance(devices, list):
devices = [devices]
pipe = self.producer.pipeline()
pipe = self.connector.pipeline()
for dev in devices:
obj = self.device_manager.devices[dev].obj
if hasattr(obj, "_staged"):
@ -465,7 +455,7 @@ class DeviceServer(RPCMixin, BECService):
self.device_manager.devices[dev].obj.unstage()
else:
logger.debug(f"Device {obj.name} was already unstaged.")
self.producer.set(
self.connector.set(
MessageEndpoints.device_staged(dev),
messages.DeviceStatusMessage(device=dev, status=0, metadata=instr.metadata),
pipe,

View File

@ -16,17 +16,11 @@ class ConfigUpdateHandler:
def __init__(self, device_manager: DeviceManagerDS) -> None:
self.device_manager = device_manager
self.connector = self.device_manager.connector
self._config_request_handler = None
self._start_config_handler()
def _start_config_handler(self) -> None:
self._config_request_handler = self.connector.consumer(
self.connector.register(
MessageEndpoints.device_server_config_request(),
cb=self._device_config_callback,
parent=self,
)
self._config_request_handler.start()
@staticmethod
def _device_config_callback(msg, *, parent, **_kwargs) -> None:
@ -74,7 +68,7 @@ class ConfigUpdateHandler:
accepted=accepted, message=error_msg, metadata=metadata
)
RID = metadata.get("RID")
self.device_manager.producer.set(
self.device_manager.connector.set(
MessageEndpoints.device_config_request_response(RID), msg, expire=60
)
@ -97,7 +91,7 @@ class ConfigUpdateHandler:
"low": device.obj.low_limit_travel.get(),
"high": device.obj.high_limit_travel.get(),
}
self.device_manager.producer.set_and_publish(
self.device_manager.connector.set_and_publish(
MessageEndpoints.device_limits(device.name),
messages.DeviceMessage(signals=limits),
)

View File

@ -51,7 +51,7 @@ class DSDevice(DeviceBase):
self.metadata = {}
self.initialized = False
def initialize_device_buffer(self, producer):
def initialize_device_buffer(self, connector):
"""initialize the device read and readback buffer on redis with a new reading"""
dev_msg = messages.DeviceMessage(signals=self.obj.read(), metadata={})
dev_config_msg = messages.DeviceMessage(signals=self.obj.read_configuration(), metadata={})
@ -62,14 +62,14 @@ class DSDevice(DeviceBase):
}
else:
limits = None
pipe = producer.pipeline()
producer.set_and_publish(MessageEndpoints.device_readback(self.name), dev_msg, pipe=pipe)
producer.set(topic=MessageEndpoints.device_read(self.name), msg=dev_msg, pipe=pipe)
producer.set_and_publish(
pipe = connector.pipeline()
connector.set_and_publish(MessageEndpoints.device_readback(self.name), dev_msg, pipe=pipe)
connector.set(topic=MessageEndpoints.device_read(self.name), msg=dev_msg, pipe=pipe)
connector.set_and_publish(
MessageEndpoints.device_read_configuration(self.name), dev_config_msg, pipe=pipe
)
if limits is not None:
producer.set_and_publish(
connector.set_and_publish(
MessageEndpoints.device_limits(self.name),
messages.DeviceMessage(signals=limits),
pipe=pipe,
@ -318,7 +318,7 @@ class DeviceManagerDS(DeviceManagerBase):
self.update_config(obj, config)
# refresh the device info
pipe = self.producer.pipeline()
pipe = self.connector.pipeline()
self.reset_device_data(obj, pipe)
self.publish_device_info(obj, pipe)
pipe.execute()
@ -369,7 +369,7 @@ class DeviceManagerDS(DeviceManagerBase):
def initialize_enabled_device(self, opaas_obj):
"""connect to an enabled device and initialize the device buffer"""
self.connect_device(opaas_obj.obj)
opaas_obj.initialize_device_buffer(self.producer)
opaas_obj.initialize_device_buffer(self.connector)
@staticmethod
def disconnect_device(obj):
@ -420,7 +420,7 @@ class DeviceManagerDS(DeviceManagerBase):
"""
interface = get_device_info(obj, {})
self.producer.set(
self.connector.set(
MessageEndpoints.device_info(obj.name),
messages.DeviceInfoMessage(device=obj.name, info=interface),
pipe,
@ -428,9 +428,9 @@ class DeviceManagerDS(DeviceManagerBase):
def reset_device_data(self, obj: OphydObject, pipe=None) -> None:
"""delete all device data and device info"""
self.producer.delete(MessageEndpoints.device_status(obj.name), pipe)
self.producer.delete(MessageEndpoints.device_read(obj.name), pipe)
self.producer.delete(MessageEndpoints.device_info(obj.name), pipe)
self.connector.delete(MessageEndpoints.device_status(obj.name), pipe)
self.connector.delete(MessageEndpoints.device_read(obj.name), pipe)
self.connector.delete(MessageEndpoints.device_info(obj.name), pipe)
def _obj_callback_readback(self, *_args, obj: OphydObject, **kwargs):
if obj.connected:
@ -438,8 +438,8 @@ class DeviceManagerDS(DeviceManagerBase):
signals = obj.read()
metadata = self.devices.get(obj.root.name).metadata
dev_msg = messages.DeviceMessage(signals=signals, metadata=metadata)
pipe = self.producer.pipeline()
self.producer.set_and_publish(MessageEndpoints.device_readback(name), dev_msg, pipe)
pipe = self.connector.pipeline()
self.connector.set_and_publish(MessageEndpoints.device_readback(name), dev_msg, pipe)
pipe.execute()
@typechecked
@ -466,7 +466,7 @@ class DeviceManagerDS(DeviceManagerBase):
metadata = self.devices[name].metadata
msg = messages.DeviceMonitorMessage(device=name, data=value, metadata=metadata)
stream_msg = {"data": msg}
self.producer.xadd(
self.connector.xadd(
MessageEndpoints.device_monitor(name),
stream_msg,
max_size=min(100, int(max_size // dsize)),
@ -476,7 +476,7 @@ class DeviceManagerDS(DeviceManagerBase):
device = kwargs["obj"].root.name
status = 0
metadata = self.devices[device].metadata
self.producer.send(
self.connector.send(
MessageEndpoints.device_status(device),
messages.DeviceStatusMessage(device=device, status=status, metadata=metadata),
)
@ -489,8 +489,8 @@ class DeviceManagerDS(DeviceManagerBase):
device = kwargs["obj"].root.name
status = int(kwargs.get("value"))
metadata = self.devices[device].metadata
self.producer.set(
MessageEndpoints.device_status(kwargs["obj"].root.name),
self.connector.set(
MessageEndpoints.device_status(device),
messages.DeviceStatusMessage(device=device, status=status, metadata=metadata),
)
@ -521,12 +521,12 @@ class DeviceManagerDS(DeviceManagerBase):
)
)
ds_obj.emitted_points[metadata["scanID"]] = max_points
pipe = self.producer.pipeline()
self.producer.send(MessageEndpoints.device_read(obj.root.name), bundle, pipe=pipe)
pipe = self.connector.pipeline()
self.connector.send(MessageEndpoints.device_read(obj.root.name), bundle, pipe=pipe)
msg = messages.DeviceStatusMessage(
device=obj.root.name, status=max_points, metadata=metadata
)
self.producer.set_and_publish(
self.connector.set_and_publish(
MessageEndpoints.device_progress(obj.root.name), msg, pipe=pipe
)
pipe.execute()
@ -536,4 +536,4 @@ class DeviceManagerDS(DeviceManagerBase):
msg = messages.ProgressMessage(
value=value, max_value=max_value, done=done, metadata=metadata
)
self.producer.set_and_publish(MessageEndpoints.device_progress(obj.root.name), msg)
self.connector.set_and_publish(MessageEndpoints.device_progress(obj.root.name), msg)

View File

@ -70,7 +70,7 @@ class RPCMixin:
def _send_rpc_result_to_client(
self, device: str, instr_params: dict, res: Any, result: StringIO
):
self.producer.set(
self.connector.set(
MessageEndpoints.device_rpc(instr_params.get("rpc_id")),
messages.DeviceRPCMessage(
device=device, return_val=res, out=result.getvalue(), success=True
@ -175,7 +175,7 @@ class RPCMixin:
}
logger.info(f"Received exception: {exc_formatted}, {exc}")
instr_params = instr.content.get("parameter")
self.producer.set(
self.connector.set(
MessageEndpoints.device_rpc(instr_params.get("rpc_id")),
messages.DeviceRPCMessage(
device=instr.content["device"], return_val=None, out=exc_formatted, success=False

View File

@ -6,7 +6,7 @@ import numpy as np
import pytest
import yaml
from bec_lib import MessageEndpoints, messages
from bec_lib.tests.utils import ConnectorMock, ProducerMock, create_session_from_config
from bec_lib.tests.utils import ConnectorMock, create_session_from_config
from device_server.devices.devicemanager import DeviceManagerDS
@ -52,7 +52,7 @@ def load_device_manager():
service_mock = mock.MagicMock()
service_mock.connector = ConnectorMock("", store_data=False)
device_manager = DeviceManagerDS(service_mock, "")
device_manager.producer = service_mock.connector.producer()
device_manager.connector = service_mock.connector
device_manager.config_update_handler = mock.MagicMock()
with open(f"{dir_path}/tests/test_config.yaml", "r") as session_file:
device_manager._session = create_session_from_config(yaml.safe_load(session_file))
@ -133,10 +133,10 @@ def test_flyer_event_callback():
device_manager._obj_flyer_callback(
obj=samx.obj, value={"data": {"idata": np.random.rand(20), "edata": np.random.rand(20)}}
)
pipe = device_manager.producer.pipeline()
pipe = device_manager.connector.pipeline()
bundle, progress = pipe._pipe_buffer[-2:]
# check producer method
# check connector method
assert bundle[0] == "send"
assert progress[0] == "set_and_publish"
@ -157,9 +157,9 @@ def test_obj_progress_callback():
samx = device_manager.devices.samx
samx.metadata = {"scanID": "12345"}
with mock.patch.object(device_manager, "producer") as mock_producer:
with mock.patch.object(device_manager, "connector") as mock_connector:
device_manager._obj_progress_callback(obj=samx.obj, value=1, max_value=2, done=False)
mock_producer.set_and_publish.assert_called_once_with(
mock_connector.set_and_publish.assert_called_once_with(
MessageEndpoints.device_progress("samx"),
messages.ProgressMessage(
value=1, max_value=2, done=False, metadata={"scanID": "12345"}
@ -176,9 +176,9 @@ def test_obj_monitor_callback(value):
eiger.metadata = {"scanID": "12345"}
value_size = len(value.tobytes()) / 1e6 # MB
max_size = 100
with mock.patch.object(device_manager, "producer") as mock_producer:
with mock.patch.object(device_manager, "connector") as mock_connector:
device_manager._obj_callback_monitor(obj=eiger.obj, value=value)
mock_producer.xadd.assert_called_once_with(
mock_connector.xadd.assert_called_once_with(
MessageEndpoints.device_monitor(eiger.name),
{
"data": messages.DeviceMonitorMessage(

View File

@ -7,7 +7,7 @@ from bec_lib import Alarms, MessageEndpoints, ServiceConfig, messages
from bec_lib.device import OnFailure
from bec_lib.messages import BECStatus
from bec_lib.redis_connector import MessageObject
from bec_lib.tests.utils import ConnectorMock, ConsumerMock
from bec_lib.tests.utils import ConnectorMock
from ophyd import Staged
from ophyd.utils import errors as ophyd_errors
from test_device_manager_ds import device_manager, load_device_manager
@ -54,8 +54,6 @@ def test_start(device_server_mock):
device_server.start()
assert device_server.threads
assert isinstance(device_server.threads[0], ConsumerMock)
assert device_server.status == BECStatus.RUNNING
@ -187,10 +185,10 @@ def test_stop_devices(device_server_mock):
),
],
)
def test_consumer_interception_callback(device_server_mock, msg, stop_called):
def test_register_interception_callback(device_server_mock, msg, stop_called):
device_server = device_server_mock
with mock.patch.object(device_server, "stop_devices") as stop:
device_server.consumer_interception_callback(msg, parent=device_server)
device_server.register_interception_callback(msg, parent=device_server)
if stop_called:
stop.assert_called_once()
else:
@ -640,7 +638,7 @@ def test_set_device(device_server_mock, instr):
while True:
res = [
msg
for msg in device_server.producer.message_sent
for msg in device_server.connector.message_sent
if msg["queue"] == MessageEndpoints.device_req_status("samx")
]
if res:
@ -676,7 +674,7 @@ def test_read_device(device_server_mock, instr):
for device in devices:
res = [
msg
for msg in device_server.producer.message_sent
for msg in device_server.connector.message_sent
if msg["queue"] == MessageEndpoints.device_read(device)
]
assert res[-1]["msg"].metadata["RID"] == instr.metadata["RID"]
@ -690,7 +688,7 @@ def test_read_config_and_update_devices(device_server_mock, devices):
for device in devices:
res = [
msg
for msg in device_server.producer.message_sent
for msg in device_server.connector.message_sent
if msg["queue"] == MessageEndpoints.device_read_configuration(device)
]
config = device_server.device_manager.devices[device].obj.read_configuration()
@ -755,8 +753,8 @@ def test_retry_obj_method_buffer(device_server_mock, instr):
return
signals_before = getattr(samx.obj, instr)()
device_server.producer = mock.MagicMock()
device_server.producer.get.return_value = messages.DeviceMessage(
device_server.connector = mock.MagicMock()
device_server.connector.get.return_value = messages.DeviceMessage(
signals=signals_before, metadata={"RID": "test", "stream": "primary"}
)

View File

@ -13,7 +13,7 @@ from device_server.rpc_mixin import RPCMixin
def rpc_cls():
rpc_mixin = RPCMixin()
rpc_mixin.connector = mock.MagicMock()
rpc_mixin.producer = mock.MagicMock()
rpc_mixin.connector = mock.MagicMock()
rpc_mixin.device_manager = mock.MagicMock()
yield rpc_mixin
@ -93,7 +93,7 @@ def test_get_result_from_rpc_list_from_stage(rpc_cls):
def test_send_rpc_exception(rpc_cls, instr):
rpc_cls._send_rpc_exception(Exception(), instr)
rpc_cls.producer.set.assert_called_once_with(
rpc_cls.connector.set.assert_called_once_with(
MessageEndpoints.device_rpc("rpc_id"),
messages.DeviceRPCMessage(
device="device",
@ -108,7 +108,7 @@ def test_send_rpc_result_to_client(rpc_cls):
result = mock.MagicMock()
result.getvalue.return_value = "result"
rpc_cls._send_rpc_result_to_client("device", {"rpc_id": "rpc_id"}, 1, result)
rpc_cls.producer.set.assert_called_once_with(
rpc_cls.connector.set.assert_called_once_with(
MessageEndpoints.device_rpc("rpc_id"),
messages.DeviceRPCMessage(device="device", return_val=1, out="result", success=True),
expire=1800,

View File

@ -325,7 +325,7 @@ class NexusFileWriter(FileWriter):
file_data[key] = val if not isinstance(val, list) else merge_dicts(val)
msg_data = {"file_path": file_path, "data": file_data}
msg = messages.FileContentMessage(**msg_data)
self.file_writer_manager.producer.set_and_publish(MessageEndpoints.file_content(), msg)
self.file_writer_manager.connector.set_and_publish(MessageEndpoints.file_content(), msg)
with h5py.File(file_path, "w") as file:
HDF5StorageWriter.write(writer_storage._storage, device_storage, file)

View File

@ -80,10 +80,13 @@ class FileWriterManager(BECService):
self._lock = threading.RLock()
self.file_writer_config = self._service_config.service_config.get("file_writer")
self.writer_mixin = FileWriterMixin(self.file_writer_config)
self.producer = self.connector.producer()
self._start_device_manager()
self._start_scan_segment_consumer()
self._start_scan_status_consumer()
self.connector.register(
patterns=MessageEndpoints.scan_segment(), cb=self._scan_segment_callback, parent=self
)
self.connector.register(
MessageEndpoints.scan_status(), cb=self._scan_status_callback, parent=self
)
self.scan_storage = {}
self.file_writer = NexusFileWriter(self)
@ -92,20 +95,7 @@ class FileWriterManager(BECService):
self.device_manager = DeviceManagerBase(self)
self.device_manager.initialize([self.bootstrap_server])
def _start_scan_segment_consumer(self):
self._scan_segment_consumer = self.connector.consumer(
pattern=MessageEndpoints.scan_segment(), cb=self._scan_segment_callback, parent=self
)
self._scan_segment_consumer.start()
def _start_scan_status_consumer(self):
self._scan_status_consumer = self.connector.consumer(
MessageEndpoints.scan_status(), cb=self._scan_status_callback, parent=self
)
self._scan_status_consumer.start()
@staticmethod
def _scan_segment_callback(msg: MessageObject, *, parent: FileWriterManager):
def _scan_segment_callback(self, msg: MessageObject, *, parent: FileWriterManager):
msgs = msg.value
for scan_msg in msgs:
parent.insert_to_scan_storage(scan_msg)
@ -188,7 +178,7 @@ class FileWriterManager(BECService):
return
if self.scan_storage[scanID].baseline:
return
baseline = self.producer.get(MessageEndpoints.public_scan_baseline(scanID))
baseline = self.connector.get(MessageEndpoints.public_scan_baseline(scanID))
if not baseline:
return
self.scan_storage[scanID].baseline = baseline.content["data"]
@ -205,13 +195,13 @@ class FileWriterManager(BECService):
"""
if not self.scan_storage.get(scanID):
return
msgs = self.producer.keys(MessageEndpoints.public_file(scanID, "*"))
msgs = self.connector.keys(MessageEndpoints.public_file(scanID, "*"))
if not msgs:
return
# extract name from 'public/<scanID>/file/<name>'
names = [msg.decode().split("/")[-1] for msg in msgs]
file_msgs = [self.producer.get(msg.decode()) for msg in msgs]
file_msgs = [self.connector.get(msg.decode()) for msg in msgs]
if not file_msgs:
return
for name, file_msg in zip(names, file_msgs):
@ -236,7 +226,7 @@ class FileWriterManager(BECService):
if not self.scan_storage.get(scanID):
return
# get all async devices
async_device_keys = self.producer.keys(MessageEndpoints.device_async_readback(scanID, "*"))
async_device_keys = self.connector.keys(MessageEndpoints.device_async_readback(scanID, "*"))
if not async_device_keys:
return
for device_key in async_device_keys:
@ -244,7 +234,7 @@ class FileWriterManager(BECService):
device_name = key.split(MessageEndpoints.device_async_readback(scanID, ""))[-1].split(
":"
)[0]
msgs = self.producer.xrange(key, min="-", max="+")
msgs = self.connector.xrange(key, min="-", max="+")
if not msgs:
continue
self._process_async_data(msgs, scanID, device_name)
@ -298,7 +288,7 @@ class FileWriterManager(BECService):
try:
file_path = self.writer_mixin.compile_full_filename(scan, file_suffix)
self.producer.set_and_publish(
self.connector.set_and_publish(
MessageEndpoints.public_file(scanID, "master"),
messages.FileMessage(file_path=file_path, done=False),
)
@ -319,7 +309,7 @@ class FileWriterManager(BECService):
)
successful = False
self.scan_storage.pop(scanID)
self.producer.set_and_publish(
self.connector.set_and_publish(
MessageEndpoints.public_file(scanID, "master"),
messages.FileMessage(file_path=file_path, successful=successful),
)

View File

@ -25,7 +25,7 @@ def load_FileWriter():
service_mock = mock.MagicMock()
service_mock.connector = ConnectorMock("")
device_manager = DeviceManagerBase(service_mock, "")
device_manager.producer = service_mock.connector.producer()
device_manager.connector = service_mock.connector
with open(f"{dir_path}/tests/test_config.yaml", "r") as session_file:
device_manager._session = create_session_from_config(yaml.safe_load(session_file))
device_manager._load_session()
@ -152,13 +152,13 @@ def test_write_file_raises_alarm_on_error():
def test_update_baseline_reading():
file_manager = load_FileWriter()
file_manager.scan_storage["scanID"] = ScanStorage(10, "scanID")
with mock.patch.object(file_manager, "producer") as mock_producer:
mock_producer.get.return_value = messages.ScanBaselineMessage(
with mock.patch.object(file_manager, "connector") as mock_connector:
mock_connector.get.return_value = messages.ScanBaselineMessage(
scanID="scanID", data={"data": "data"}
)
file_manager.update_baseline_reading("scanID")
assert file_manager.scan_storage["scanID"].baseline == {"data": "data"}
mock_producer.get.assert_called_once_with(MessageEndpoints.public_scan_baseline("scanID"))
mock_connector.get.assert_called_once_with(MessageEndpoints.public_scan_baseline("scanID"))
def test_scan_storage_append():
@ -178,30 +178,30 @@ def test_scan_storage_ready_to_write():
def test_update_file_references():
file_manager = load_FileWriter()
with mock.patch.object(file_manager, "producer") as mock_producer:
with mock.patch.object(file_manager, "connector") as mock_connector:
file_manager.update_file_references("scanID")
mock_producer.keys.assert_not_called()
mock_connector.keys.assert_not_called()
def test_update_file_references_gets_keys():
file_manager = load_FileWriter()
file_manager.scan_storage["scanID"] = ScanStorage(10, "scanID")
with mock.patch.object(file_manager, "producer") as mock_producer:
with mock.patch.object(file_manager, "connector") as mock_connector:
file_manager.update_file_references("scanID")
mock_producer.keys.assert_called_once_with(MessageEndpoints.public_file("scanID", "*"))
mock_connector.keys.assert_called_once_with(MessageEndpoints.public_file("scanID", "*"))
def test_update_async_data():
file_manager = load_FileWriter()
file_manager.scan_storage["scanID"] = ScanStorage(10, "scanID")
with mock.patch.object(file_manager, "producer") as mock_producer:
with mock.patch.object(file_manager, "connector") as mock_connector:
with mock.patch.object(file_manager, "_process_async_data") as mock_process:
key = MessageEndpoints.device_async_readback("scanID", "dev1")
mock_producer.keys.return_value = [key.encode()]
mock_connector.keys.return_value = [key.encode()]
data = [(b"0-0", b'{"data": "data"}')]
mock_producer.xrange.return_value = data
mock_connector.xrange.return_value = data
file_manager.update_async_data("scanID")
mock_producer.xrange.assert_called_once_with(key, min="-", max="+")
mock_connector.xrange.assert_called_once_with(key, min="-", max="+")
mock_process.assert_called_once_with(data, "scanID", "dev1")

View File

@ -14,7 +14,7 @@ if TYPE_CHECKING:
class BECEmitter(EmitterBase):
def __init__(self, scan_bundler: ScanBundler) -> None:
super().__init__(scan_bundler.producer)
super().__init__(scan_bundler.connector)
self.scan_bundler = scan_bundler
def on_scan_point_emit(self, scanID: str, pointID: int):
@ -46,9 +46,16 @@ class BECEmitter(EmitterBase):
data=sb.sync_storage[scanID]["baseline"],
metadata=sb.sync_storage[scanID]["info"],
)
pipe = sb.producer.pipeline()
sb.producer.set(
MessageEndpoints.public_scan_baseline(scanID=scanID), msg, expire=1800, pipe=pipe
pipe = sb.connector.pipeline()
sb.connector.set(
MessageEndpoints.public_scan_baseline(scanID=scanID),
msg,
expire=1800,
pipe=pipe,
)
sb.connector.set_and_publish(
MessageEndpoints.scan_baseline(),
msg,
pipe=pipe,
)
sb.producer.set_and_publish(MessageEndpoints.scan_baseline(), msg, pipe=pipe)
pipe.execute()

View File

@ -17,7 +17,7 @@ if TYPE_CHECKING:
class BlueskyEmitter(EmitterBase):
def __init__(self, scan_bundler: ScanBundler) -> None:
super().__init__(scan_bundler.producer)
super().__init__(scan_bundler.connector)
self.scan_bundler = scan_bundler
self.bluesky_metadata = {}
@ -27,7 +27,7 @@ class BlueskyEmitter(EmitterBase):
self.bluesky_metadata[scanID] = {}
doc = self._get_run_start_document(scanID)
self.bluesky_metadata[scanID]["start"] = doc
self.producer.raw_send(MessageEndpoints.bluesky_events(), msgpack.dumps(("start", doc)))
self.connector.raw_send(MessageEndpoints.bluesky_events(), msgpack.dumps(("start", doc)))
self.send_descriptor_document(scanID)
def _get_run_start_document(self, scanID) -> dict:
@ -71,7 +71,7 @@ class BlueskyEmitter(EmitterBase):
"""Bluesky only: send descriptor document"""
doc = self._get_descriptor_document(scanID)
self.bluesky_metadata[scanID]["descriptor"] = doc
self.producer.raw_send(
self.connector.raw_send(
MessageEndpoints.bluesky_events(), msgpack.dumps(("descriptor", doc))
)
@ -85,7 +85,7 @@ class BlueskyEmitter(EmitterBase):
logger.warning(f"Failed to remove {scanID} from {storage}.")
def send_bluesky_scan_point(self, scanID, pointID) -> None:
self.producer.raw_send(
self.connector.raw_send(
MessageEndpoints.bluesky_events(),
msgpack.dumps(("event", self._prepare_bluesky_event_data(scanID, pointID))),
)

View File

@ -6,16 +6,16 @@ from bec_lib import messages
class EmitterBase:
def __init__(self, producer) -> None:
def __init__(self, connector) -> None:
self._send_buffer = Queue()
self.producer = producer
self._start_buffered_producer()
self.connector = connector
self._start_buffered_connector()
def _start_buffered_producer(self):
self._buffered_producer_thread = threading.Thread(
def _start_buffered_connector(self):
self._buffered_connector_thread = threading.Thread(
target=self._buffered_publish, daemon=True, name="buffered_publisher"
)
self._buffered_producer_thread.start()
self._buffered_connector_thread.start()
def add_message(self, msg: messages.BECMessage, endpoint: str, public: str = None):
self._send_buffer.put((msg, endpoint, public))
@ -37,20 +37,20 @@ class EmitterBase:
time.sleep(0.1)
return
pipe = self.producer.pipeline()
pipe = self.connector.pipeline()
msgs = messages.BundleMessage()
_, endpoint, _ = msgs_to_send[0]
for msg, endpoint, public in msgs_to_send:
msg_dump = msg
msgs.append(msg_dump)
if public:
self.producer.set(
self.connector.set(
public,
msg_dump,
pipe=pipe,
expire=1800,
)
self.producer.send(endpoint, msgs, pipe=pipe)
self.connector.send(endpoint, msgs, pipe=pipe)
pipe.execute()
def on_init(self, scanID: str):

View File

@ -20,9 +20,23 @@ class ScanBundler(BECService):
self.device_manager = None
self._start_device_manager()
self._start_device_read_consumer()
self._start_scan_queue_consumer()
self._start_scan_status_consumer()
self.connector.register(
patterns=MessageEndpoints.device_read("*"),
cb=self._device_read_callback,
name="device_read_register",
)
self.connector.register(
MessageEndpoints.scan_queue_status(),
cb=self._scan_queue_callback,
group_id="scan_bundler",
name="scan_queue_register",
)
self.connector.register(
MessageEndpoints.scan_status(),
cb=self._scan_status_callback,
group_id="scan_bundler",
name="scan_status_register",
)
self.sync_storage = {}
self.monitored_devices = {}
@ -56,56 +70,24 @@ class ScanBundler(BECService):
self.device_manager = DeviceManagerBase(self)
self.device_manager.initialize(self.bootstrap_server)
def _start_device_read_consumer(self):
self._device_read_consumer = self.connector.consumer(
pattern=MessageEndpoints.device_read("*"),
cb=self._device_read_callback,
parent=self,
name="device_read_consumer",
)
self._device_read_consumer.start()
def _start_scan_queue_consumer(self):
self._scan_queue_consumer = self.connector.consumer(
MessageEndpoints.scan_queue_status(),
cb=self._scan_queue_callback,
group_id="scan_bundler",
parent=self,
name="scan_queue_consumer",
)
self._scan_queue_consumer.start()
def _start_scan_status_consumer(self):
self._scan_status_consumer = self.connector.consumer(
MessageEndpoints.scan_status(),
cb=self._scan_status_callback,
group_id="scan_bundler",
parent=self,
name="scan_status_consumer",
)
self._scan_status_consumer.start()
@staticmethod
def _device_read_callback(msg, parent, **_kwargs):
def _device_read_callback(self, msg, **_kwargs):
# pylint: disable=protected-access
dev = msg.topic.split(MessageEndpoints._device_read + "/")[-1]
msgs = msg.value
logger.debug(f"Received reading from device {dev}")
if not isinstance(msgs, list):
msgs = [msgs]
task = parent.executor.submit(parent._add_device_to_storage, msgs, dev)
parent.executor_tasks.append(task)
task = self.executor.submit(self._add_device_to_storage, msgs, dev)
self.executor_tasks.append(task)
@staticmethod
def _scan_queue_callback(msg, parent, **_kwargs):
def _scan_queue_callback(self, msg, **_kwargs):
msg = msg.value
logger.trace(msg)
parent.current_queue = msg.content["queue"]["primary"].get("info")
self.current_queue = msg.content["queue"]["primary"].get("info")
@staticmethod
def _scan_status_callback(msg, parent, **_kwargs):
def _scan_status_callback(self, msg, **_kwargs):
msg = msg.value
parent.handle_scan_status_message(msg)
self.handle_scan_status_message(msg)
def handle_scan_status_message(self, msg: messages.ScanStatusMessage) -> None:
"""handle scan status messages"""
@ -270,7 +252,7 @@ class ScanBundler(BECService):
}
def _get_scan_status_history(self, length):
return self.producer.lrange(MessageEndpoints.scan_status() + "_list", length * -1, -1)
return self.connector.lrange(MessageEndpoints.scan_status() + "_list", length * -1, -1)
def _wait_for_scanID(self, scanID, timeout_time=10):
elapsed_time = 0
@ -344,10 +326,10 @@ class ScanBundler(BECService):
self.sync_storage[scanID][pointID][dev.name] = read
def _get_last_device_readback(self, devices: list) -> list:
pipe = self.producer.pipeline()
pipe = self.connector.pipeline()
for dev in devices:
self.producer.get(MessageEndpoints.device_readback(dev.name), pipe)
return [msg.content["signals"] for msg in self.producer.execute_pipeline(pipe)]
self.connector.get(MessageEndpoints.device_readback(dev.name), pipe)
return [msg.content["signals"] for msg in self.connector.execute_pipeline(pipe)]
def cleanup_storage(self):
"""remove old scanIDs to free memory"""

View File

@ -57,16 +57,16 @@ def test_send_baseline_BEC():
sb.sync_storage[scanID] = {"info": {}, "status": "open", "sent": set()}
sb.sync_storage[scanID]["baseline"] = {}
msg = messages.ScanBaselineMessage(scanID=scanID, data=sb.sync_storage[scanID]["baseline"])
with mock.patch.object(sb, "producer") as producer:
with mock.patch.object(sb, "connector") as connector:
bec_emitter._send_baseline(scanID)
pipe = producer.pipeline()
producer.set.assert_called_once_with(
pipe = connector.pipeline()
connector.set.assert_called_once_with(
MessageEndpoints.public_scan_baseline(scanID),
msg,
expire=1800,
pipe=pipe,
)
producer.set_and_publish.assert_called_once_with(
connector.set_and_publish.assert_called_once_with(
MessageEndpoints.scan_baseline(),
msg,
pipe=pipe,

View File

@ -13,7 +13,7 @@ from scan_bundler.bluesky_emitter import BlueskyEmitter
def test_run_start_document(scanID):
sb = load_ScanBundlerMock()
bls_emitter = BlueskyEmitter(sb)
with mock.patch.object(bls_emitter.producer, "raw_send") as send:
with mock.patch.object(bls_emitter.connector, "raw_send") as send:
with mock.patch.object(bls_emitter, "send_descriptor_document") as send_descr:
with mock.patch.object(
bls_emitter, "_get_run_start_document", return_value={}
@ -45,7 +45,7 @@ def test_send_descriptor_document():
bls_emitter = BlueskyEmitter(sb)
scanID = "lkajsdl"
bls_emitter.bluesky_metadata[scanID] = {}
with mock.patch.object(bls_emitter.producer, "raw_send") as send:
with mock.patch.object(bls_emitter.connector, "raw_send") as send:
with mock.patch.object(
bls_emitter, "_get_descriptor_document", return_value={}
) as get_descr:

View File

@ -51,30 +51,30 @@ from scan_bundler.emitter import EmitterBase
],
)
def test_publish_data(msgs):
producer = mock.MagicMock()
with mock.patch.object(EmitterBase, "_start_buffered_producer") as start:
emitter = EmitterBase(producer)
connector = mock.MagicMock()
with mock.patch.object(EmitterBase, "_start_buffered_connector") as start:
emitter = EmitterBase(connector)
start.assert_called_once()
with mock.patch.object(emitter, "_get_messages_from_buffer", return_value=msgs) as get_msgs:
emitter._publish_data()
get_msgs.assert_called_once()
if not msgs:
producer.send.assert_not_called()
connector.send.assert_not_called()
return
pipe = producer.pipeline()
pipe = connector.pipeline()
msgs_bundle = messages.BundleMessage()
_, endpoint, _ = msgs[0]
for msg, endpoint, public in msgs:
msg_dump = msg
msgs_bundle.append(msg_dump)
if public:
producer.set.assert_has_calls(
producer.set(public, msg_dump, pipe=pipe, expire=1800)
connector.set.assert_has_calls(
connector.set(public, msg_dump, pipe=pipe, expire=1800)
)
producer.send.assert_called_with(endpoint, msgs_bundle, pipe=pipe)
connector.send.assert_called_with(endpoint, msgs_bundle, pipe=pipe)
@pytest.mark.parametrize(
@ -93,8 +93,8 @@ def test_publish_data(msgs):
],
)
def test_add_message(msg, endpoint, public):
producer = mock.MagicMock()
emitter = EmitterBase(producer)
connector = mock.MagicMock()
emitter = EmitterBase(connector)
emitter.add_message(msg, endpoint, public)
msgs = emitter._get_messages_from_buffer()
out_msg, out_endpoint, out_public = msgs[0]

View File

@ -36,7 +36,7 @@ def load_ScanBundlerMock():
service_mock = mock.MagicMock()
service_mock.connector = ConnectorMock("")
device_manager = ScanBundlerDeviceManagerMock(service_mock, "")
device_manager.producer = service_mock.connector.producer()
device_manager.connector = service_mock.connector
with open(f"{dir_path}/tests/test_config.yaml", "r") as session_file:
device_manager._session = create_session_from_config(yaml.safe_load(session_file))
device_manager._load_session()
@ -74,7 +74,7 @@ def test_device_read_callback():
msg.topic = MessageEndpoints.device_read("samx")
with mock.patch.object(scan_bundler, "_add_device_to_storage") as add_dev:
scan_bundler._device_read_callback(msg, scan_bundler)
scan_bundler._device_read_callback(msg)
add_dev.assert_called_once_with([dev_msg], "samx")
@ -157,7 +157,7 @@ def test_wait_for_scanID(scanID, storageID, scan_msg):
)
def test_get_scan_status_history(msgs):
sb = load_ScanBundlerMock()
with mock.patch.object(sb.producer, "lrange", return_value=[msg for msg in msgs]) as lrange:
with mock.patch.object(sb.connector, "lrange", return_value=[msg for msg in msgs]) as lrange:
res = sb._get_scan_status_history(5)
lrange.assert_called_once_with(MessageEndpoints.scan_status() + "_list", -5, -1)
assert res == msgs
@ -371,7 +371,7 @@ def test_scan_queue_callback(queue_msg):
sb = load_ScanBundlerMock()
msg = MessageMock()
msg.value = queue_msg
sb._scan_queue_callback(msg, sb)
sb._scan_queue_callback(msg)
assert sb.current_queue == queue_msg.content["queue"]["primary"].get("info")
@ -399,7 +399,7 @@ def test_scan_status_callback(scan_msg):
msg.value = scan_msg
with mock.patch.object(sb, "handle_scan_status_message") as handle_scan_status_message_mock:
sb._scan_status_callback(msg, sb)
sb._scan_status_callback(msg)
handle_scan_status_message_mock.assert_called_once_with(scan_msg)
@ -744,10 +744,10 @@ def test_get_last_device_readback():
signals={"samx": {"samx": 0.51, "setpoint": 0.5, "motor_is_moving": 0}},
metadata={"scanID": "laksjd", "readout_priority": "monitored"},
)
with mock.patch.object(sb, "producer") as producer_mock:
producer_mock.execute_pipeline.return_value = [dev_msg]
with mock.patch.object(sb, "connector") as connector_mock:
connector_mock.execute_pipeline.return_value = [dev_msg]
ret = sb._get_last_device_readback([sb.device_manager.devices.samx])
assert producer_mock.get.mock_calls == [
mock.call(MessageEndpoints.device_readback("samx"), producer_mock.pipeline())
assert connector_mock.get.mock_calls == [
mock.call(MessageEndpoints.device_readback("samx"), connector_mock.pipeline())
]
assert ret == [dev_msg.content["signals"]]

View File

@ -458,7 +458,7 @@ class LamNIFermatScan(ScanBase, LamNIMixin):
yield from self.stubs.kickoff(device="rtx")
while True:
yield from self.stubs.read_and_wait(group="primary", wait_group="readout_primary")
msg = self.device_manager.producer.get(MessageEndpoints.device_status("rt_scan"))
msg = self.device_manager.connector.get(MessageEndpoints.device_status("rt_scan"))
if msg:
status = msg
status_id = status.content.get("status", 1)

View File

@ -163,7 +163,7 @@ class OwisGrid(AsyncFlyScanBase):
def scan_progress(self) -> int:
"""Timeout of the progress bar. This gets updated in the frequency of scan segments"""
msg = self.device_manager.producer.get(MessageEndpoints.device_progress("mcs"))
msg = self.device_manager.connector.get(MessageEndpoints.device_progress("mcs"))
if not msg:
self.timeout_progress += 1
return self.timeout_progress

View File

@ -106,7 +106,7 @@ class SgalilGrid(AsyncFlyScanBase):
def scan_progress(self) -> int:
"""Timeout of the progress bar. This gets updated in the frequency of scan segments"""
msg = self.device_manager.producer.get(MessageEndpoints.device_progress("mcs"))
msg = self.device_manager.connector.get(MessageEndpoints.device_progress("mcs"))
if not msg:
self.timeout_progress += 1
return self.timeout_progress

View File

@ -10,8 +10,8 @@ class DeviceValidation:
Mixin class for validation methods
"""
def __init__(self, producer, worker):
self.producer = producer
def __init__(self, connector, worker):
self.connector = connector
self.worker = worker
def get_device_status(self, endpoint: MessageEndpoints, devices: list) -> list:
@ -25,10 +25,10 @@ class DeviceValidation:
Returns:
list: List of BECMessage objects
"""
pipe = self.producer.pipeline()
pipe = self.connector.pipeline()
for dev in devices:
self.producer.get(endpoint(dev), pipe)
return self.producer.execute_pipeline(pipe)
self.connector.get(endpoint(dev), pipe)
return self.connector.execute_pipeline(pipe)
def devices_are_ready(
self,

View File

@ -27,25 +27,19 @@ class ScanGuard:
self.parent = parent
self.device_manager = self.parent.device_manager
self.connector = self.parent.connector
self.producer = self.connector.producer()
self._start_scan_queue_request_consumer()
def _start_scan_queue_request_consumer(self):
self._scan_queue_request_consumer = self.connector.consumer(
self.connector.register(
MessageEndpoints.scan_queue_request(),
cb=self._scan_queue_request_callback,
parent=self,
)
self._scan_queue_modification_request_consumer = self.connector.consumer(
self.connector.register(
MessageEndpoints.scan_queue_modification_request(),
cb=self._scan_queue_modification_request_callback,
parent=self,
)
self._scan_queue_request_consumer.start()
self._scan_queue_modification_request_consumer.start()
def _is_valid_scan_request(self, request) -> ScanStatus:
try:
self._check_valid_request(request)
@ -63,7 +57,7 @@ class ScanGuard:
raise ScanRejection("Invalid request.")
def _check_valid_scan(self, request) -> None:
avail_scans = self.producer.get(MessageEndpoints.available_scans())
avail_scans = self.connector.get(MessageEndpoints.available_scans())
scan_type = request.content.get("scan_type")
if scan_type not in avail_scans.resource:
raise ScanRejection(f"Unknown scan type {scan_type}.")
@ -140,7 +134,7 @@ class ScanGuard:
message=scan_status.message,
metadata=metadata,
)
self.device_manager.producer.send(sqrr, rrm)
self.device_manager.connector.send(sqrr, rrm)
def _handle_scan_request(self, msg):
"""
@ -181,10 +175,10 @@ class ScanGuard:
self._send_scan_request_response(ScanStatus(), mod_msg.metadata)
sqm = MessageEndpoints.scan_queue_modification()
self.device_manager.producer.send(sqm, mod_msg)
self.device_manager.connector.send(sqm, mod_msg)
def _append_to_scan_queue(self, msg):
logger.info("Appending new scan to queue")
msg = msg
sqi = MessageEndpoints.scan_queue_insert()
self.device_manager.producer.send(sqi, msg)
self.device_manager.connector.send(sqi, msg)

View File

@ -101,7 +101,7 @@ class ScanManager:
def publish_available_scans(self):
"""send all available scans to the broker"""
self.parent.producer.set(
self.parent.connector.set(
MessageEndpoints.available_scans(),
AvailableResourceMessage(resource=self.available_scans),
)

View File

@ -51,11 +51,10 @@ class QueueManager:
def __init__(self, parent) -> None:
self.parent = parent
self.connector = parent.connector
self.producer = parent.producer
self.num_queues = 1
self.key = ""
self.queues = {}
self._start_scan_queue_consumer()
self._start_scan_queue_register()
self._lock = threading.RLock()
def add_to_queue(self, scan_queue: str, msg: messages.ScanQueueMessage, position=-1) -> None:
@ -91,17 +90,15 @@ class QueueManager:
self.queues[queue_name] = ScanQueue(self, queue_name=queue_name)
self.queues[queue_name].start_worker()
def _start_scan_queue_consumer(self) -> None:
self._scan_queue_consumer = self.connector.consumer(
def _start_scan_queue_register(self) -> None:
self.connector.register(
MessageEndpoints.scan_queue_insert(), cb=self._scan_queue_callback, parent=self
)
self._scan_queue_modification_consumer = self.connector.consumer(
self.connector.register(
MessageEndpoints.scan_queue_modification(),
cb=self._scan_queue_modification_callback,
parent=self,
)
self._scan_queue_consumer.start()
self._scan_queue_modification_consumer.start()
@staticmethod
def _scan_queue_callback(msg, parent, **_kwargs) -> None:
@ -233,7 +230,7 @@ class QueueManager:
logger.info("New scan queue:")
for queue in self.describe_queue():
logger.info(f"\n {queue}")
self.producer.set_and_publish(
self.connector.set_and_publish(
MessageEndpoints.scan_queue_status(),
messages.ScanQueueStatusMessage(queue=queue_export),
)
@ -685,7 +682,7 @@ class InstructionQueueItem:
self.instructions = []
self.parent = parent
self.queue = RequestBlockQueue(instruction_queue=self, assembler=assembler)
self.producer = self.parent.queue_manager.producer
self.connector = self.parent.queue_manager.connector
self._is_scan = False
self.is_active = False # set to true while a worker is processing the instructions
self.completed = False
@ -790,7 +787,7 @@ class InstructionQueueItem:
msg = messages.ScanQueueHistoryMessage(
status=self.status.name, queueID=self.queue_id, info=self.describe()
)
self.parent.queue_manager.producer.lpush(
self.parent.queue_manager.connector.lpush(
MessageEndpoints.scan_queue_history(), msg, max_size=100
)

View File

@ -23,7 +23,6 @@ class ScanServer(BECService):
def __init__(self, config: ServiceConfig, connector_cls: ConnectorBase):
super().__init__(config, connector_cls, unique_service=True)
self.producer = self.connector.producer()
self._start_scan_manager()
self._start_queue_manager()
self._start_device_manager()
@ -52,15 +51,12 @@ class ScanServer(BECService):
self.scan_guard = ScanGuard(parent=self)
def _start_alarm_handler(self):
self._alarm_consumer = self.connector.consumer(
MessageEndpoints.alarm(), cb=self._alarm_callback, parent=self
)
self._alarm_consumer.start()
self.connector.register(MessageEndpoints.alarm(), cb=self._alarm_callback, parent=self)
def _reset_scan_number(self):
if self.producer.get(MessageEndpoints.scan_number()) is None:
if self.connector.get(MessageEndpoints.scan_number()) is None:
self.scan_number = 1
if self.producer.get(MessageEndpoints.dataset_number()) is None:
if self.connector.get(MessageEndpoints.dataset_number()) is None:
self.dataset_number = 1
@staticmethod
@ -74,25 +70,24 @@ class ScanServer(BECService):
@property
def scan_number(self) -> int:
"""get the current scan number"""
return int(self.producer.get(MessageEndpoints.scan_number()))
return int(self.connector.get(MessageEndpoints.scan_number()))
@scan_number.setter
def scan_number(self, val: int):
"""set the current scan number"""
self.producer.set(MessageEndpoints.scan_number(), val)
self.connector.set(MessageEndpoints.scan_number(), val)
@property
def dataset_number(self) -> int:
"""get the current dataset number"""
return int(self.producer.get(MessageEndpoints.dataset_number()))
return int(self.connector.get(MessageEndpoints.dataset_number()))
@dataset_number.setter
def dataset_number(self, val: int):
"""set the current dataset number"""
self.producer.set(MessageEndpoints.dataset_number(), val)
self.connector.set(MessageEndpoints.dataset_number(), val)
def shutdown(self) -> None:
"""shutdown the scan server"""
self.device_manager.shutdown()
self.queue_manager.shutdown()

View File

@ -5,7 +5,9 @@ import uuid
from collections.abc import Callable
import numpy as np
from bec_lib import MessageEndpoints, ProducerConnector, Status, bec_logger, messages
from bec_lib import MessageEndpoints, Status, bec_logger, messages
from bec_lib.connector import ConnectorBase
from .errors import DeviceMessageError, ScanAbortion
@ -13,8 +15,8 @@ logger = bec_logger.logger
class ScanStubs:
def __init__(self, producer: ProducerConnector, device_msg_callback: Callable = None) -> None:
self.producer = producer
def __init__(self, connector: ConnectorBase, device_msg_callback: Callable = None) -> None:
self.connector = connector
self.device_msg_metadata = (
device_msg_callback if device_msg_callback is not None else lambda: {}
)
@ -62,7 +64,7 @@ class ScanStubs:
def _get_from_rpc(self, rpc_id):
while True:
msg = self.producer.get(MessageEndpoints.device_rpc(rpc_id))
msg = self.connector.get(MessageEndpoints.device_rpc(rpc_id))
if msg:
break
time.sleep(0.001)
@ -81,7 +83,7 @@ class ScanStubs:
if not isinstance(return_val, dict):
return return_val
if return_val.get("type") == "status" and return_val.get("RID"):
return Status(self.producer, return_val.get("RID"))
return Status(self.connector, return_val.get("RID"))
return return_val
def set_and_wait(self, *, device: list[str], positions: list | np.ndarray):
@ -182,7 +184,7 @@ class ScanStubs:
DIID (int): device instruction ID
"""
msg = self.producer.get(MessageEndpoints.device_req_status(device))
msg = self.connector.get(MessageEndpoints.device_req_status(device))
if not msg:
return 0
matching_RID = msg.metadata.get("RID") == RID
@ -199,7 +201,7 @@ class ScanStubs:
RID (str): request ID
"""
msg = self.producer.get(MessageEndpoints.device_progress(device))
msg = self.connector.get(MessageEndpoints.device_progress(device))
if not msg:
return None
matching_RID = msg.metadata.get("RID") == RID

View File

@ -41,7 +41,7 @@ class ScanWorker(threading.Thread):
self._groups = {}
self.interception_msg = None
self.reset()
self.validate = DeviceValidation(self.device_manager.producer, self)
self.validate = DeviceValidation(self.device_manager.connector, self)
def open_scan(self, instr: messages.DeviceInstructionMessage) -> None:
"""
@ -138,7 +138,7 @@ class ScanWorker(threading.Thread):
"""
devices = [dev.name for dev in self.device_manager.devices.get_software_triggered_devices()]
self._last_trigger = instr
self.device_manager.producer.send(
self.device_manager.connector.send(
MessageEndpoints.device_instructions(),
messages.DeviceInstructionMessage(
device=devices,
@ -157,7 +157,7 @@ class ScanWorker(threading.Thread):
"""
# send instruction
self.device_manager.producer.send(MessageEndpoints.device_instructions(), instr)
self.device_manager.connector.send(MessageEndpoints.device_instructions(), instr)
def read_devices(self, instr: messages.DeviceInstructionMessage) -> None:
"""
@ -171,7 +171,7 @@ class ScanWorker(threading.Thread):
self._publish_readback(instr)
return
producer = self.device_manager.producer
connector = self.device_manager.connector
devices = instr.content.get("device")
if devices is None:
@ -181,7 +181,7 @@ class ScanWorker(threading.Thread):
readout_priority=self.readout_priority
)
]
producer.send(
connector.send(
MessageEndpoints.device_instructions(),
messages.DeviceInstructionMessage(
device=devices,
@ -201,7 +201,7 @@ class ScanWorker(threading.Thread):
"""
# logger.info("kickoff")
self.device_manager.producer.send(
self.device_manager.connector.send(
MessageEndpoints.device_instructions(),
messages.DeviceInstructionMessage(
device=instr.content.get("device"),
@ -225,7 +225,7 @@ class ScanWorker(threading.Thread):
devices = instr.content.get("device")
if not isinstance(devices, list):
devices = [devices]
self.device_manager.producer.send(
self.device_manager.connector.send(
MessageEndpoints.device_instructions(),
messages.DeviceInstructionMessage(
device=devices,
@ -251,7 +251,7 @@ class ScanWorker(threading.Thread):
)
]
params = instr.content["parameter"]
self.device_manager.producer.send(
self.device_manager.connector.send(
MessageEndpoints.device_instructions(),
messages.DeviceInstructionMessage(
device=baseline_devices, action="read", parameter=params, metadata=instr.metadata
@ -266,7 +266,7 @@ class ScanWorker(threading.Thread):
instr (DeviceInstructionMessage): Device instruction received from the scan assembler
"""
devices = [dev.name for dev in self.device_manager.devices.enabled_devices]
self.device_manager.producer.send(
self.device_manager.connector.send(
MessageEndpoints.device_instructions(),
messages.DeviceInstructionMessage(
device=devices,
@ -285,7 +285,7 @@ class ScanWorker(threading.Thread):
Args:
instr (DeviceInstructionMessage): Device instruction received from the scan assembler
"""
producer = self.device_manager.producer
connector = self.device_manager.connector
data = instr.content["parameter"]["data"]
devices = instr.content["device"]
if not isinstance(devices, list):
@ -294,7 +294,7 @@ class ScanWorker(threading.Thread):
data = [data]
for device, dev_data in zip(devices, data):
msg = messages.DeviceMessage(signals=dev_data, metadata=instr.metadata)
producer.set_and_publish(MessageEndpoints.device_read(device), msg)
connector.set_and_publish(MessageEndpoints.device_read(device), msg)
def send_rpc(self, instr: messages.DeviceInstructionMessage) -> None:
"""
@ -304,7 +304,7 @@ class ScanWorker(threading.Thread):
instr (DeviceInstructionMessage): Device instruction received from the scan assembler
"""
self.device_manager.producer.send(MessageEndpoints.device_instructions(), instr)
self.device_manager.connector.send(MessageEndpoints.device_instructions(), instr)
def process_scan_report_instruction(self, instr):
"""
@ -333,7 +333,7 @@ class ScanWorker(threading.Thread):
if dev.name not in async_devices
]
for det in async_devices:
self.device_manager.producer.send(
self.device_manager.connector.send(
MessageEndpoints.device_instructions(),
messages.DeviceInstructionMessage(
device=det,
@ -344,7 +344,7 @@ class ScanWorker(threading.Thread):
)
self._staged_devices.update(async_devices)
self.device_manager.producer.send(
self.device_manager.connector.send(
MessageEndpoints.device_instructions(),
messages.DeviceInstructionMessage(
device=devices,
@ -375,7 +375,7 @@ class ScanWorker(threading.Thread):
parameter = {} if not instr else instr.content["parameter"]
metadata = {} if not instr else instr.metadata
self._staged_devices.difference_update(devices)
self.device_manager.producer.send(
self.device_manager.connector.send(
MessageEndpoints.device_instructions(),
messages.DeviceInstructionMessage(
device=devices, action="unstage", parameter=parameter, metadata=metadata
@ -459,7 +459,7 @@ class ScanWorker(threading.Thread):
matching_DIID = device_status[ind].metadata.get("DIID") >= devices[ind][1]
matching_RID = device_status[ind].metadata.get("RID") == instr.metadata["RID"]
if matching_DIID and matching_RID:
last_pos_msg = self.device_manager.producer.get(
last_pos_msg = self.device_manager.connector.get(
MessageEndpoints.device_readback(failed_device[0])
)
last_pos = last_pos_msg.content["signals"][failed_device[0]]["value"]
@ -603,25 +603,25 @@ class ScanWorker(threading.Thread):
def _publish_readback(
self, instr: messages.DeviceInstructionMessage, devices: list = None
) -> None:
producer = self.device_manager.producer
connector = self.device_manager.connector
if not devices:
devices = instr.content.get("device")
# cached readout
readouts = self._get_readback(devices)
pipe = producer.pipeline()
pipe = connector.pipeline()
for readout, device in zip(readouts, devices):
msg = messages.DeviceMessage(signals=readout, metadata=instr.metadata)
producer.set_and_publish(MessageEndpoints.device_read(device), msg, pipe)
connector.set_and_publish(MessageEndpoints.device_read(device), msg, pipe)
return pipe.execute()
def _get_readback(self, devices: list) -> list:
producer = self.device_manager.producer
connector = self.device_manager.connector
# cached readout
pipe = producer.pipeline()
pipe = connector.pipeline()
for dev in devices:
producer.get(MessageEndpoints.device_readback(dev), pipe=pipe)
return producer.execute_pipeline(pipe)
connector.get(MessageEndpoints.device_readback(dev), pipe=pipe)
return connector.execute_pipeline(pipe)
def _check_for_interruption(self) -> None:
if self.status == InstructionQueueStatus.PAUSED:
@ -700,11 +700,13 @@ class ScanWorker(threading.Thread):
scanID=self.current_scanID, status=status, info=self.current_scan_info
)
expire = None if status in ["open", "paused"] else 1800
pipe = self.device_manager.producer.pipeline()
self.device_manager.producer.set(
pipe = self.device_manager.connector.pipeline()
self.device_manager.connector.set(
MessageEndpoints.public_scan_info(self.current_scanID), msg, pipe=pipe, expire=expire
)
self.device_manager.producer.set_and_publish(MessageEndpoints.scan_status(), msg, pipe=pipe)
self.device_manager.connector.set_and_publish(
MessageEndpoints.scan_status(), msg, pipe=pipe
)
pipe.execute()
def _process_instructions(self, queue: InstructionQueueItem) -> None:

View File

@ -213,7 +213,7 @@ class RequestBase(ABC):
if metadata is None:
self.metadata = {}
self.stubs = ScanStubs(
producer=self.device_manager.producer, device_msg_callback=self.device_msg_metadata
connector=self.device_manager.connector, device_msg_callback=self.device_msg_metadata
)
@property
@ -239,7 +239,7 @@ class RequestBase(ABC):
def run_pre_scan_macros(self):
"""run pre scan macros if any"""
macros = self.device_manager.producer.lrange(MessageEndpoints.pre_scan_macros(), 0, -1)
macros = self.device_manager.connector.lrange(MessageEndpoints.pre_scan_macros(), 0, -1)
for macro in macros:
macro = macro.decode().strip()
func_name = self._get_func_name_from_macro(macro)
@ -558,12 +558,12 @@ class SyncFlyScanBase(ScanBase, ABC):
def _get_flyer_status(self) -> list:
flyer = self.scan_motors[0]
producer = self.device_manager.producer
connector = self.device_manager.connector
pipe = producer.pipeline()
producer.lrange(MessageEndpoints.device_req_status(self.metadata["RID"]), 0, -1, pipe)
producer.get(MessageEndpoints.device_readback(flyer), pipe)
return producer.execute_pipeline(pipe)
pipe = connector.pipeline()
connector.lrange(MessageEndpoints.device_req_status(self.metadata["RID"]), 0, -1, pipe)
connector.get(MessageEndpoints.device_readback(flyer), pipe)
return connector.execute_pipeline(pipe)
@abstractmethod
def scan_core(self):
@ -1098,7 +1098,7 @@ class RoundScanFlySim(SyncFlyScanBase):
while True:
yield from self.stubs.read_and_wait(group="primary", wait_group="readout_primary")
status = self.device_manager.producer.get(MessageEndpoints.device_status(self.flyer))
status = self.device_manager.connector.get(MessageEndpoints.device_status(self.flyer))
if status:
device_is_idle = status.content.get("status", 1) == 0
matching_RID = self.metadata.get("RID") == status.metadata.get("RID")
@ -1318,12 +1318,12 @@ class MonitorScan(ScanBase):
self._check_limits()
def _get_flyer_status(self) -> list:
producer = self.device_manager.producer
connector = self.device_manager.connector
pipe = producer.pipeline()
producer.lrange(MessageEndpoints.device_req_status(self.metadata["RID"]), 0, -1, pipe)
producer.get(MessageEndpoints.device_readback(self.flyer), pipe)
return producer.execute_pipeline(pipe)
pipe = connector.pipeline()
connector.lrange(MessageEndpoints.device_req_status(self.metadata["RID"]), 0, -1, pipe)
connector.get(MessageEndpoints.device_readback(self.flyer), pipe)
return connector.execute_pipeline(pipe)
def scan_core(self):
yield from self.stubs.set(

View File

@ -12,7 +12,7 @@ from scan_server.scan_guard import ScanGuard, ScanRejection, ScanStatus
@pytest.fixture
def scan_guard_mock(scan_server_mock):
sg = ScanGuard(parent=scan_server_mock)
sg.device_manager.producer = mock.MagicMock()
sg.device_manager.connector = mock.MagicMock()
yield sg
@ -113,8 +113,8 @@ def test_valid_request(scan_server_mock, scan_queue_msg, valid):
def test_check_valid_scan_raises_for_unknown_scan(scan_guard_mock):
sg = scan_guard_mock
sg.producer = mock.MagicMock()
sg.producer.get.return_value = messages.AvailableResourceMessage(
sg.connector = mock.MagicMock()
sg.connector.get.return_value = messages.AvailableResourceMessage(
resource={"fermat_scan": "fermat_scan"}
)
@ -130,8 +130,8 @@ def test_check_valid_scan_raises_for_unknown_scan(scan_guard_mock):
def test_check_valid_scan_accepts_known_scan(scan_guard_mock):
sg = scan_guard_mock
sg.producer = mock.MagicMock()
sg.producer.get.return_value = messages.AvailableResourceMessage(
sg.connector = mock.MagicMock()
sg.connector.get.return_value = messages.AvailableResourceMessage(
resource={"fermat_scan": "fermat_scan"}
)
@ -146,8 +146,8 @@ def test_check_valid_scan_accepts_known_scan(scan_guard_mock):
def test_check_valid_scan_device_rpc(scan_guard_mock):
sg = scan_guard_mock
sg.producer = mock.MagicMock()
sg.producer.get.return_value = messages.AvailableResourceMessage(
sg.connector = mock.MagicMock()
sg.connector.get.return_value = messages.AvailableResourceMessage(
resource={"device_rpc": "device_rpc"}
)
request = messages.ScanQueueMessage(
@ -162,8 +162,8 @@ def test_check_valid_scan_device_rpc(scan_guard_mock):
def test_check_valid_scan_device_rpc_raises(scan_guard_mock):
sg = scan_guard_mock
sg.producer = mock.MagicMock()
sg.producer.get.return_value = messages.AvailableResourceMessage(
sg.connector = mock.MagicMock()
sg.connector.get.return_value = messages.AvailableResourceMessage(
resource={"device_rpc": "device_rpc"}
)
request = messages.ScanQueueMessage(
@ -184,7 +184,7 @@ def test_handle_scan_modification_request(scan_guard_mock):
msg = messages.ScanQueueModificationMessage(
scanID="scanID", action="abort", parameter={}, metadata={"RID": "RID"}
)
with mock.patch.object(sg.device_manager.producer, "send") as send:
with mock.patch.object(sg.device_manager.connector, "send") as send:
sg._handle_scan_modification_request(msg)
send.assert_called_once_with(MessageEndpoints.scan_queue_modification(), msg)
@ -207,7 +207,7 @@ def test_append_to_scan_queue(scan_guard_mock):
parameter={"args": {"samx": (-5, 5), "samy": (-5, 5)}, "kwargs": {"step": 3}},
queue="primary",
)
with mock.patch.object(sg.device_manager.producer, "send") as send:
with mock.patch.object(sg.device_manager.connector, "send") as send:
sg._append_to_scan_queue(msg)
send.assert_called_once_with(MessageEndpoints.scan_queue_insert(), msg)
@ -251,7 +251,7 @@ def test_scan_queue_modification_request_callback(scan_guard_mock):
def test_send_scan_request_response(scan_guard_mock):
sg = scan_guard_mock
with mock.patch.object(sg.device_manager.producer, "send") as send:
with mock.patch.object(sg.device_manager.connector, "send") as send:
sg._send_scan_request_response(ScanStatus(), {"RID": "RID"})
send.assert_called_once_with(
MessageEndpoints.scan_queue_request_response(),

View File

@ -166,40 +166,40 @@ def test_set_halt_disables_return_to_start(queuemanager_mock):
def test_set_pause(queuemanager_mock):
queue_manager = queuemanager_mock()
queue_manager.producer.message_sent = []
queue_manager.connector.message_sent = []
queue_manager.set_pause(queue="primary")
assert queue_manager.queues["primary"].status == ScanQueueStatus.PAUSED
assert len(queue_manager.producer.message_sent) == 1
assert len(queue_manager.connector.message_sent) == 1
assert (
queue_manager.producer.message_sent[0].get("queue") == MessageEndpoints.scan_queue_status()
queue_manager.connector.message_sent[0].get("queue") == MessageEndpoints.scan_queue_status()
)
def test_set_deferred_pause(queuemanager_mock):
queue_manager = queuemanager_mock()
queue_manager.producer.message_sent = []
queue_manager.connector.message_sent = []
queue_manager.set_deferred_pause(queue="primary")
assert queue_manager.queues["primary"].status == ScanQueueStatus.PAUSED
assert len(queue_manager.producer.message_sent) == 1
assert len(queue_manager.connector.message_sent) == 1
assert (
queue_manager.producer.message_sent[0].get("queue") == MessageEndpoints.scan_queue_status()
queue_manager.connector.message_sent[0].get("queue") == MessageEndpoints.scan_queue_status()
)
def test_set_continue(queuemanager_mock):
queue_manager = queuemanager_mock()
queue_manager.producer.message_sent = []
queue_manager.connector.message_sent = []
queue_manager.set_continue(queue="primary")
assert queue_manager.queues["primary"].status == ScanQueueStatus.RUNNING
assert len(queue_manager.producer.message_sent) == 1
assert len(queue_manager.connector.message_sent) == 1
assert (
queue_manager.producer.message_sent[0].get("queue") == MessageEndpoints.scan_queue_status()
queue_manager.connector.message_sent[0].get("queue") == MessageEndpoints.scan_queue_status()
)
def test_set_abort(queuemanager_mock):
queue_manager = queuemanager_mock()
queue_manager.producer.message_sent = []
queue_manager.connector.message_sent = []
msg = messages.ScanQueueMessage(
scan_type="mv",
parameter={"args": {"samx": (1,)}, "kwargs": {}},
@ -210,23 +210,23 @@ def test_set_abort(queuemanager_mock):
queue_manager.add_to_queue(scan_queue="primary", msg=msg)
queue_manager.set_abort(queue="primary")
assert queue_manager.queues["primary"].status == ScanQueueStatus.PAUSED
assert len(queue_manager.producer.message_sent) == 2
assert len(queue_manager.connector.message_sent) == 2
assert (
queue_manager.producer.message_sent[0].get("queue") == MessageEndpoints.scan_queue_status()
queue_manager.connector.message_sent[0].get("queue") == MessageEndpoints.scan_queue_status()
)
def test_set_abort_with_empty_queue(queuemanager_mock):
queue_manager = queuemanager_mock()
queue_manager.producer.message_sent = []
queue_manager.connector.message_sent = []
queue_manager.set_abort(queue="primary")
assert queue_manager.queues["primary"].status == ScanQueueStatus.RUNNING
assert len(queue_manager.producer.message_sent) == 0
assert len(queue_manager.connector.message_sent) == 0
def test_set_clear_sends_message(queuemanager_mock):
queue_manager = queuemanager_mock()
queue_manager.producer.message_sent = []
queue_manager.connector.message_sent = []
setter_mock = mock.Mock(wraps=ScanQueue.worker_status.fset)
# pylint: disable=assignment-from-no-return
# pylint: disable=too-many-function-args
@ -238,9 +238,9 @@ def test_set_clear_sends_message(queuemanager_mock):
mock_property.fset.assert_called_once_with(
queue_manager.queues["primary"], InstructionQueueStatus.STOPPED
)
assert len(queue_manager.producer.message_sent) == 1
assert len(queue_manager.connector.message_sent) == 1
assert (
queue_manager.producer.message_sent[0].get("queue")
queue_manager.connector.message_sent[0].get("queue")
== MessageEndpoints.scan_queue_status()
)

View File

@ -11,7 +11,7 @@ from scan_server.scan_stubs import ScanAbortion, ScanStubs
@pytest.fixture
def stubs():
connector = ConnectorMock("")
yield ScanStubs(connector.producer())
yield ScanStubs(connector)
@pytest.mark.parametrize(
@ -36,7 +36,11 @@ def stubs():
device="rtx",
action="kickoff",
parameter={
"configure": {"num_pos": 5, "positions": [1, 2, 3, 4, 5], "exp_time": 2},
"configure": {
"num_pos": 5,
"positions": [1, 2, 3, 4, 5],
"exp_time": 2,
},
"wait_group": "kickoff",
},
metadata={},
@ -45,6 +49,8 @@ def stubs():
],
)
def test_kickoff(stubs, device, parameter, metadata, reference_msg):
connector = ConnectorMock("")
stubs = ScanStubs(connector)
msg = list(stubs.kickoff(device=device, parameter=parameter, metadata=metadata))
assert msg[0] == reference_msg
@ -52,12 +58,19 @@ def test_kickoff(stubs, device, parameter, metadata, reference_msg):
@pytest.mark.parametrize(
"msg,raised_error",
[
(messages.DeviceRPCMessage(device="samx", return_val="", out="", success=True), None),
(
messages.DeviceRPCMessage(device="samx", return_val="", out="", success=True),
None,
),
(
messages.DeviceRPCMessage(
device="samx",
return_val="",
out={"error": "TypeError", "msg": "some weird error", "traceback": "traceback"},
out={
"error": "TypeError",
"msg": "some weird error",
"traceback": "traceback",
},
success=False,
),
ScanAbortion,
@ -69,8 +82,7 @@ def test_kickoff(stubs, device, parameter, metadata, reference_msg):
],
)
def test_rpc_raises_scan_abortion(stubs, msg, raised_error):
msg = msg
with mock.patch.object(stubs.producer, "get", return_value=msg) as prod_get:
with mock.patch.object(stubs.connector, "get", return_value=msg) as prod_get:
if raised_error is None:
stubs._get_from_rpc("rpc-id")
else:
@ -106,8 +118,8 @@ def test_rpc_raises_scan_abortion(stubs, msg, raised_error):
def test_device_progress(stubs, msg, ret_value, raised_error):
if raised_error:
with pytest.raises(DeviceMessageError):
with mock.patch.object(stubs.producer, "get", return_value=msg):
with mock.patch.object(stubs.connector, "get", return_value=msg):
assert stubs.get_device_progress(device="samx", RID="rid") == ret_value
return
with mock.patch.object(stubs.producer, "get", return_value=msg):
with mock.patch.object(stubs.connector, "get", return_value=msg):
assert stubs.get_device_progress(device="samx", RID="rid") == ret_value

View File

@ -4,7 +4,7 @@ from unittest import mock
import pytest
from bec_lib import MessageEndpoints, messages
from bec_lib.tests.utils import ProducerMock, dm, dm_with_devices
from bec_lib.tests.utils import ConnectorMock, dm, dm_with_devices
from utils import scan_server_mock
from scan_server.errors import DeviceMessageError, ScanAbortion
@ -22,7 +22,7 @@ from scan_server.scan_worker import ScanWorker
@pytest.fixture
def scan_worker_mock(scan_server_mock) -> ScanWorker:
scan_server_mock.device_manager.producer = mock.MagicMock()
scan_server_mock.device_manager.connector = mock.MagicMock()
scan_worker = ScanWorker(parent=scan_server_mock)
yield scan_worker
@ -295,7 +295,7 @@ def test_wait_for_devices(scan_worker_mock, instructions, wait_type):
def test_complete_devices(scan_worker_mock, instructions):
worker = scan_worker_mock
with mock.patch.object(worker, "_wait_for_status") as wait_for_status_mock:
with mock.patch.object(worker.device_manager.producer, "send") as send_mock:
with mock.patch.object(worker.device_manager.connector, "send") as send_mock:
worker.complete_devices(instructions)
if instructions.content["device"]:
devices = instructions.content["device"]
@ -328,7 +328,7 @@ def test_complete_devices(scan_worker_mock, instructions):
)
def test_pre_scan(scan_worker_mock, instructions):
worker = scan_worker_mock
with mock.patch.object(worker.device_manager.producer, "send") as send_mock:
with mock.patch.object(worker.device_manager.connector, "send") as send_mock:
with mock.patch.object(worker, "_wait_for_status") as wait_for_status_mock:
worker.pre_scan(instructions)
devices = [dev.name for dev in worker.device_manager.devices.enabled_devices]
@ -457,12 +457,12 @@ def test_pre_scan(scan_worker_mock, instructions):
)
def test_check_for_failed_movements(scan_worker_mock, device_status, devices, instr, abort):
worker = scan_worker_mock
worker.device_manager.producer = ProducerMock()
worker.device_manager.connector = ConnectorMock()
if abort:
with pytest.raises(ScanAbortion):
worker.device_manager.producer._get_buffer[MessageEndpoints.device_readback("samx")] = (
messages.DeviceMessage(signals={"samx": {"value": 4}}, metadata={})
)
worker.device_manager.connector._get_buffer[
MessageEndpoints.device_readback("samx")
] = messages.DeviceMessage(signals={"samx": {"value": 4}}, metadata={})
worker._check_for_failed_movements(device_status, devices, instr)
else:
worker._check_for_failed_movements(device_status, devices, instr)
@ -577,12 +577,12 @@ def test_check_for_failed_movements(scan_worker_mock, device_status, devices, in
)
def test_wait_for_idle(scan_worker_mock, msg1, msg2, req_msg: messages.DeviceReqStatusMessage):
worker = scan_worker_mock
worker.device_manager.producer = ProducerMock()
worker.device_manager.connector = ConnectorMock()
with mock.patch.object(
worker.validate, "get_device_status", return_value=[req_msg]
) as device_status:
worker.device_manager.producer._get_buffer[MessageEndpoints.device_readback("samx")] = (
worker.device_manager.connector._get_buffer[MessageEndpoints.device_readback("samx")] = (
messages.DeviceMessage(signals={"samx": {"value": 4}}, metadata={})
)
@ -635,7 +635,7 @@ def test_wait_for_idle(scan_worker_mock, msg1, msg2, req_msg: messages.DeviceReq
)
def test_wait_for_read(scan_worker_mock, msg1, msg2, req_msg: messages.DeviceReqStatusMessage):
worker = scan_worker_mock
worker.device_manager.producer = ProducerMock()
worker.device_manager.connector = ConnectorMock()
with mock.patch.object(
worker.validate, "get_device_status", return_value=[req_msg]
@ -643,9 +643,9 @@ def test_wait_for_read(scan_worker_mock, msg1, msg2, req_msg: messages.DeviceReq
with mock.patch.object(worker, "_check_for_interruption") as interruption_mock:
assert worker._groups == {}
worker._groups["scan_motor"] = {"samx": 3, "samy": 4}
worker.device_manager.producer._get_buffer[MessageEndpoints.device_readback("samx")] = (
messages.DeviceMessage(signals={"samx": {"value": 4}}, metadata={})
)
worker.device_manager.connector._get_buffer[
MessageEndpoints.device_readback("samx")
] = messages.DeviceMessage(signals={"samx": {"value": 4}}, metadata={})
worker._add_wait_group(msg1)
worker._wait_for_read(msg2)
assert worker._groups == {"scan_motor": {"samy": 4}}
@ -730,7 +730,7 @@ def test_wait_for_device_server(scan_worker_mock):
)
def test_set_devices(scan_worker_mock, instr):
worker = scan_worker_mock
with mock.patch.object(worker.device_manager.producer, "send") as send_mock:
with mock.patch.object(worker.device_manager.connector, "send") as send_mock:
worker.set_devices(instr)
send_mock.assert_called_once_with(MessageEndpoints.device_instructions(), instr)
@ -755,7 +755,7 @@ def test_set_devices(scan_worker_mock, instr):
)
def test_trigger_devices(scan_worker_mock, instr):
worker = scan_worker_mock
with mock.patch.object(worker.device_manager.producer, "send") as send_mock:
with mock.patch.object(worker.device_manager.connector, "send") as send_mock:
worker.trigger_devices(instr)
devices = [
dev.name for dev in worker.device_manager.devices.get_software_triggered_devices()
@ -797,7 +797,7 @@ def test_trigger_devices(scan_worker_mock, instr):
)
def test_send_rpc(scan_worker_mock, instr):
worker = scan_worker_mock
with mock.patch.object(worker.device_manager.producer, "send") as send_mock:
with mock.patch.object(worker.device_manager.connector, "send") as send_mock:
worker.send_rpc(instr)
send_mock.assert_called_once_with(MessageEndpoints.device_instructions(), instr)
@ -840,7 +840,7 @@ def test_read_devices(scan_worker_mock, instr):
instr_devices = []
worker.readout_priority.update({"monitored": instr_devices})
devices = [dev.name for dev in worker._get_devices_from_instruction(instr)]
with mock.patch.object(worker.device_manager.producer, "send") as send_mock:
with mock.patch.object(worker.device_manager.connector, "send") as send_mock:
worker.read_devices(instr)
if instr.content.get("device"):
@ -888,7 +888,7 @@ def test_read_devices(scan_worker_mock, instr):
)
def test_kickoff_devices(scan_worker_mock, instr, devices, parameter, metadata):
worker = scan_worker_mock
with mock.patch.object(worker.device_manager.producer, "send") as send_mock:
with mock.patch.object(worker.device_manager.connector, "send") as send_mock:
worker.kickoff_devices(instr)
send_mock.assert_called_once_with(
MessageEndpoints.device_instructions(),
@ -920,29 +920,27 @@ def test_kickoff_devices(scan_worker_mock, instr, devices, parameter, metadata):
def test_publish_readback(scan_worker_mock, instr, devices):
worker = scan_worker_mock
with mock.patch.object(worker, "_get_readback", return_value=[{}]) as get_readback:
with mock.patch.object(worker.device_manager, "producer") as producer_mock:
with mock.patch.object(worker.device_manager, "connector") as connector_mock:
worker._publish_readback(instr)
get_readback.assert_called_once_with(["samx"])
pipe = producer_mock.pipeline()
pipe = connector_mock.pipeline()
msg = messages.DeviceMessage(signals={}, metadata=instr.metadata)
producer_mock.set_and_publish.assert_called_once_with(
connector_mock.set_and_publish.assert_called_once_with(
MessageEndpoints.device_read("samx"), msg, pipe
)
pipe.execute.assert_called_once()
def test_get_readback(scan_worker_mock):
worker = scan_worker_mock
devices = ["samx"]
with mock.patch.object(worker.device_manager, "producer") as producer_mock:
with mock.patch.object(worker.device_manager, "connector") as connector_mock:
worker._get_readback(devices)
pipe = producer_mock.pipeline()
producer_mock.get.assert_called_once_with(
pipe = connector_mock.pipeline()
connector_mock.get.assert_called_once_with(
MessageEndpoints.device_readback("samx"), pipe=pipe
)
producer_mock.execute_pipeline.assert_called_once()
connector_mock.execute_pipeline.assert_called_once()
def test_publish_data_as_read(scan_worker_mock):
@ -958,12 +956,12 @@ def test_publish_data_as_read(scan_worker_mock):
"RID": "requestID",
},
)
with mock.patch.object(worker.device_manager, "producer") as producer_mock:
with mock.patch.object(worker.device_manager, "connector") as connector_mock:
worker.publish_data_as_read(instr)
msg = messages.DeviceMessage(
signals=instr.content["parameter"]["data"], metadata=instr.metadata
)
producer_mock.set_and_publish.assert_called_once_with(
connector_mock.set_and_publish.assert_called_once_with(
MessageEndpoints.device_read("samx"), msg
)
@ -983,13 +981,13 @@ def test_publish_data_as_read_multiple(scan_worker_mock):
"RID": "requestID",
},
)
with mock.patch.object(worker.device_manager, "producer") as producer_mock:
with mock.patch.object(worker.device_manager, "connector") as connector_mock:
worker.publish_data_as_read(instr)
mock_calls = []
for device, dev_data in zip(devices, data):
msg = messages.DeviceMessage(signals=dev_data, metadata=instr.metadata)
mock_calls.append(mock.call(MessageEndpoints.device_read(device), msg))
assert producer_mock.set_and_publish.mock_calls == mock_calls
assert connector_mock.set_and_publish.mock_calls == mock_calls
def test_check_for_interruption(scan_worker_mock):
@ -1048,7 +1046,7 @@ def test_open_scan(scan_worker_mock, instr, corr_num_points, scan_id):
if "pointID" in instr.metadata:
worker.max_point_id = instr.metadata["pointID"]
assert worker.parent.producer.get(MessageEndpoints.scan_number()) == None
assert worker.parent.connector.get(MessageEndpoints.scan_number()) == None
with mock.patch.object(worker, "current_instruction_queue_item") as queue_mock:
with mock.patch.object(worker, "_initialize_scan_info") as init_mock:
@ -1181,7 +1179,7 @@ def test_stage_device(scan_worker_mock, msg):
worker.device_manager.devices["eiger"]._config["readoutPriority"] = "async"
with mock.patch.object(worker, "_wait_for_stage") as wait_mock:
with mock.patch.object(worker.device_manager.producer, "send") as send_mock:
with mock.patch.object(worker.device_manager.connector, "send") as send_mock:
worker.stage_devices(msg)
async_devices = [dev.name for dev in worker.device_manager.devices.async_devices()]
devices = [
@ -1251,7 +1249,7 @@ def test_unstage_device(scan_worker_mock, msg, devices, parameter, metadata, cle
if not devices:
devices = [dev.name for dev in worker.device_manager.devices.enabled_devices]
with mock.patch.object(worker.device_manager.producer, "send") as send_mock:
with mock.patch.object(worker.device_manager.connector, "send") as send_mock:
with mock.patch.object(worker, "_wait_for_stage") as wait_mock:
worker.unstage_devices(msg, devices, cleanup)
@ -1270,12 +1268,12 @@ def test_unstage_device(scan_worker_mock, msg, devices, parameter, metadata, cle
@pytest.mark.parametrize("status,expire", [("open", None), ("closed", 1800), ("aborted", 1800)])
def test_send_scan_status(scan_worker_mock, status, expire):
worker = scan_worker_mock
worker.device_manager.producer = ProducerMock()
worker.device_manager.connector = ConnectorMock()
worker.current_scanID = str(uuid.uuid4())
worker._send_scan_status(status)
scan_info_msgs = [
msg
for msg in worker.device_manager.producer.message_sent
for msg in worker.device_manager.connector.message_sent
if msg["queue"] == MessageEndpoints.public_scan_info(scanID=worker.current_scanID)
]
assert len(scan_info_msgs) == 1

View File

@ -6,7 +6,7 @@ import numpy as np
import pytest
from bec_lib import messages
from bec_lib.devicemanager import DeviceContainer
from bec_lib.tests.utils import ProducerMock
from bec_lib.tests.utils import ConnectorMock
from scan_plugins.LamNIFermatScan import LamNIFermatScan
from scan_plugins.otf_scan import OTFScan
@ -80,7 +80,7 @@ class DeviceMock:
class DMMock:
devices = DeviceContainer()
producer = ProducerMock()
connector = ConnectorMock()
def add_device(self, name):
self.devices[name] = DeviceMock(name)
@ -1099,7 +1099,7 @@ def test_pre_scan_macro():
device_manager=device_manager, parameter=scan_msg.content["parameter"]
)
with mock.patch.object(
request.device_manager.producer,
request.device_manager.connector,
"lrange",
new_callable=mock.PropertyMock,
return_value=macros,

View File

@ -26,9 +26,9 @@ dir_path = os.path.abspath(os.path.join(os.path.dirname(bec_lib.__file__), "./co
class ConfigHandler:
def __init__(self, scibec_connector: SciBecConnector, connector: ConnectorBase) -> None:
self.scibec_connector = scibec_connector
self.connector = connector
self.device_manager = DeviceManager(self.scibec_connector.scihub)
self.device_manager.initialize(scibec_connector.config.redis)
self.producer = connector.producer()
self.validator = SciBecValidator(os.path.join(dir_path, "openapi_schema.json"))
def parse_config_request(self, msg: messages.DeviceConfigMessage) -> None:
@ -53,7 +53,7 @@ class ConfigHandler:
def send_config(self, msg: messages.DeviceConfigMessage) -> None:
"""broadcast a new config"""
self.producer.send(MessageEndpoints.device_config_update(), msg)
self.connector.send(MessageEndpoints.device_config_update(), msg)
def send_config_request_reply(self, accepted, error_msg, metadata):
"""send a config request reply"""
@ -61,7 +61,7 @@ class ConfigHandler:
accepted=accepted, message=error_msg, metadata=metadata
)
RID = metadata.get("RID")
self.producer.set(MessageEndpoints.device_config_request_response(RID), msg, expire=60)
self.connector.set(MessageEndpoints.device_config_request_response(RID), msg, expire=60)
def _set_config(self, msg: messages.DeviceConfigMessage):
config = msg.content["config"]
@ -127,14 +127,14 @@ class ConfigHandler:
def _update_device_server(self, RID: str, config: dict, action="update") -> None:
msg = messages.DeviceConfigMessage(action=action, config=config, metadata={"RID": RID})
self.producer.send(MessageEndpoints.device_server_config_request(), msg)
self.connector.send(MessageEndpoints.device_server_config_request(), msg)
def _wait_for_device_server_update(self, RID: str, timeout_time=10) -> bool:
timeout = timeout_time
time_step = 0.05
elapsed_time = 0
while True:
msg = self.producer.get(MessageEndpoints.device_config_request_response(RID))
msg = self.connector.get(MessageEndpoints.device_config_request_response(RID))
if msg:
return msg.content["accepted"], msg
@ -188,11 +188,11 @@ class ConfigHandler:
self.validator.validate_device_patch(update)
def update_config_in_redis(self, device):
config = self.device_manager.producer.get(MessageEndpoints.device_config())
config = self.device_manager.connector.get(MessageEndpoints.device_config())
config = config.content["resource"]
index = next(
index for index, dev_conf in enumerate(config) if dev_conf["name"] == device.name
)
config[index] = device._config
msg = messages.AvailableResourceMessage(resource=config)
self.device_manager.producer.set(MessageEndpoints.device_config(), msg)
self.device_manager.connector.set(MessageEndpoints.device_config(), msg)

View File

@ -31,7 +31,6 @@ class SciBecConnector:
def __init__(self, scihub: SciHub, connector: ConnectorBase) -> None:
self.scihub = scihub
self.connector = connector
self.producer = connector.producer()
self.scibec = None
self.host = None
self.target_bl = None
@ -132,25 +131,24 @@ class SciBecConnector:
"""
Set the scibec account in redis
"""
self.producer.set(
self.connector.set(
MessageEndpoints.scibec(),
messages.CredentialsMessage(credentials={"url": self.host, "token": f"Bearer {token}"}),
)
def set_redis_config(self, config):
msg = messages.AvailableResourceMessage(resource=config)
self.producer.set(MessageEndpoints.device_config(), msg)
self.connector.set(MessageEndpoints.device_config(), msg)
def _start_metadata_handler(self) -> None:
self._metadata_handler = SciBecMetadataHandler(self)
def _start_config_request_handler(self) -> None:
self._config_request_handler = self.connector.consumer(
self._config_request_handler = self.connector.register(
MessageEndpoints.device_config_request(),
cb=self._device_config_request_callback,
parent=self,
)
self._config_request_handler.start()
@staticmethod
def _device_config_request_callback(msg, *, parent, **_kwargs) -> None:
@ -159,7 +157,7 @@ class SciBecConnector:
def connect_to_scibec(self):
"""
Connect to SciBec and set the producer to the write account
Connect to SciBec and set the connector to the write account
"""
self._load_environment()
if not self._env_configured:
@ -205,7 +203,7 @@ class SciBecConnector:
write_account = self.scibec_info["activeExperiment"]["writeAccount"]
if write_account[0] == "p":
write_account = write_account.replace("p", "e")
self.producer.set(MessageEndpoints.account(), write_account.encode())
self.connector.set(MessageEndpoints.account(), write_account.encode())
def shutdown(self):
"""

View File

@ -15,22 +15,20 @@ if TYPE_CHECKING:
class SciBecMetadataHandler:
def __init__(self, scibec_connector: SciBecConnector) -> None:
self.scibec_connector = scibec_connector
self._scan_status_consumer = None
self._scan_status_register = None
self._start_scan_subscription()
self._file_subscription = None
self._start_file_subscription()
def _start_scan_subscription(self):
self._scan_status_consumer = self.scibec_connector.connector.consumer(
self._scan_status_register = self.scibec_connector.connector.register(
MessageEndpoints.scan_status(), cb=self._handle_scan_status, parent=self
)
self._scan_status_consumer.start()
def _start_file_subscription(self):
self._file_subscription = self.scibec_connector.connector.consumer(
self._file_subscription = self.scibec_connector.connector.register(
MessageEndpoints.file_content(), cb=self._handle_file_content, parent=self
)
self._file_subscription.start()
@staticmethod
def _handle_scan_status(msg, *, parent, **_kwargs) -> None:
@ -171,7 +169,7 @@ class SciBecMetadataHandler:
"""
Shutdown the metadata handler
"""
if self._scan_status_consumer:
self._scan_status_consumer.shutdown()
if self._scan_status_register:
self._scan_status_register.shutdown()
if self._file_subscription:
self._file_subscription.shutdown()

View File

@ -22,7 +22,6 @@ class SciLogConnector:
def __init__(self, scihub: SciHub, connector: RedisConnector) -> None:
self.scihub = scihub
self.connector = connector
self.producer = self.connector.producer()
self.host = None
self.user = None
self.user_secret = None
@ -44,7 +43,7 @@ class SciLogConnector:
def set_bec_token(self, token: str) -> None:
"""set the scilog token in redis"""
self.producer.set(
self.connector.set(
MessageEndpoints.logbook(),
msgpack.dumps({"url": self.host, "user": self.user, "token": f"Bearer {token}"}),
)

View File

@ -334,7 +334,7 @@ def test_config_handler_update_device_config_available_keys(config_handler, avai
def test_config_handler_wait_for_device_server_update(config_handler):
RID = "12345"
with mock.patch.object(config_handler.producer, "get") as mock_get:
with mock.patch.object(config_handler.connector, "get") as mock_get:
mock_get.side_effect = [
None,
None,
@ -346,7 +346,7 @@ def test_config_handler_wait_for_device_server_update(config_handler):
def test_config_handler_wait_for_device_server_update_timeout(config_handler):
RID = "12345"
with mock.patch.object(config_handler.producer, "get", return_value=None) as mock_get:
with mock.patch.object(config_handler.connector, "get", return_value=None) as mock_get:
with pytest.raises(TimeoutError):
config_handler._wait_for_device_server_update(RID, timeout_time=0.1)
mock_get.assert_called()

View File

@ -138,6 +138,6 @@ def test_scibec_update_experiment_info(SciBecMock):
def test_update_eaccount_in_redis(SciBecMock):
SciBecMock.scibec_info = {"activeExperiment": {"writeAccount": "p12345"}}
with mock.patch.object(SciBecMock, "producer") as mock_producer:
with mock.patch.object(SciBecMock, "connector") as mock_connector:
SciBecMock._update_eaccount_in_redis()
mock_producer.set.assert_called_once_with(MessageEndpoints.account(), b"e12345")
mock_connector.set.assert_called_once_with(MessageEndpoints.account(), b"e12345")