mirror of
https://github.com/ivan-usov-org/bec.git
synced 2025-04-20 01:40:02 +02:00
refactor!(connector): unify connector/redis_connector in one class
This commit is contained in:
parent
4edc5d02fe
commit
b92a79b0c0
@ -38,16 +38,16 @@ class ReadbackDataMixin:
|
||||
|
||||
def get_request_done_msgs(self):
|
||||
"""get all request-done messages"""
|
||||
pipe = self.device_manager.producer.pipeline()
|
||||
pipe = self.device_manager.connector.pipeline()
|
||||
for dev in self.devices:
|
||||
self.device_manager.producer.get(MessageEndpoints.device_req_status(dev), pipe)
|
||||
return self.device_manager.producer.execute_pipeline(pipe)
|
||||
self.device_manager.connector.get(MessageEndpoints.device_req_status(dev), pipe)
|
||||
return self.device_manager.connector.execute_pipeline(pipe)
|
||||
|
||||
def wait_for_RID(self, request):
|
||||
"""wait for the readback's metadata to match the request ID"""
|
||||
while True:
|
||||
msgs = [
|
||||
self.device_manager.producer.get(MessageEndpoints.device_readback(dev))
|
||||
self.device_manager.connector.get(MessageEndpoints.device_readback(dev))
|
||||
for dev in self.devices
|
||||
]
|
||||
if all(msg.metadata.get("RID") == request.metadata["RID"] for msg in msgs if msg):
|
||||
|
@ -27,7 +27,7 @@ class LiveUpdatesScanProgress(LiveUpdatesTable):
|
||||
Update the progressbar based on the device status message. Returns True if the scan is finished.
|
||||
"""
|
||||
self.check_alarms()
|
||||
status = self.bec.producer.get(MessageEndpoints.device_progress(device_names[0]))
|
||||
status = self.bec.connector.get(MessageEndpoints.device_progress(device_names[0]))
|
||||
if not status:
|
||||
logger.debug("waiting for new data point")
|
||||
await asyncio.sleep(0.1)
|
||||
|
@ -13,7 +13,7 @@ from bec_client.callbacks.move_device import (
|
||||
|
||||
@pytest.fixture
|
||||
def readback_data_mixin(bec_client):
|
||||
with mock.patch.object(bec_client.device_manager, "producer"):
|
||||
with mock.patch.object(bec_client.device_manager, "connector"):
|
||||
yield ReadbackDataMixin(bec_client.device_manager, ["samx", "samy"])
|
||||
|
||||
|
||||
@ -102,7 +102,7 @@ async def test_move_callback_with_report_instruction(bec_client):
|
||||
|
||||
|
||||
def test_readback_data_mixin(readback_data_mixin):
|
||||
readback_data_mixin.device_manager.producer.get.side_effect = [
|
||||
readback_data_mixin.device_manager.connector.get.side_effect = [
|
||||
messages.DeviceMessage(
|
||||
signals={"samx": {"value": 10}, "samx_setpoint": {"value": 20}},
|
||||
metadata={"device": "samx"},
|
||||
@ -121,7 +121,7 @@ def test_readback_data_mixin_multiple_hints(readback_data_mixin):
|
||||
"samx_setpoint",
|
||||
"samx",
|
||||
]
|
||||
readback_data_mixin.device_manager.producer.get.side_effect = [
|
||||
readback_data_mixin.device_manager.connector.get.side_effect = [
|
||||
messages.DeviceMessage(
|
||||
signals={"samx": {"value": 10}, "samx_setpoint": {"value": 20}},
|
||||
metadata={"device": "samx"},
|
||||
@ -137,7 +137,7 @@ def test_readback_data_mixin_multiple_hints(readback_data_mixin):
|
||||
|
||||
def test_readback_data_mixin_multiple_no_hints(readback_data_mixin):
|
||||
readback_data_mixin.device_manager.devices.samx._info["hints"]["fields"] = []
|
||||
readback_data_mixin.device_manager.producer.get.side_effect = [
|
||||
readback_data_mixin.device_manager.connector.get.side_effect = [
|
||||
messages.DeviceMessage(
|
||||
signals={"samx": {"value": 10}, "samx_setpoint": {"value": 20}},
|
||||
metadata={"device": "samx"},
|
||||
@ -153,18 +153,18 @@ def test_readback_data_mixin_multiple_no_hints(readback_data_mixin):
|
||||
|
||||
def test_get_request_done_msgs(readback_data_mixin):
|
||||
res = readback_data_mixin.get_request_done_msgs()
|
||||
readback_data_mixin.device_manager.producer.pipeline.assert_called_once()
|
||||
readback_data_mixin.device_manager.connector.pipeline.assert_called_once()
|
||||
assert (
|
||||
mock.call(
|
||||
MessageEndpoints.device_req_status("samx"),
|
||||
readback_data_mixin.device_manager.producer.pipeline.return_value,
|
||||
readback_data_mixin.device_manager.connector.pipeline.return_value,
|
||||
)
|
||||
in readback_data_mixin.device_manager.producer.get.call_args_list
|
||||
in readback_data_mixin.device_manager.connector.get.call_args_list
|
||||
)
|
||||
assert (
|
||||
mock.call(
|
||||
MessageEndpoints.device_req_status("samy"),
|
||||
readback_data_mixin.device_manager.producer.pipeline.return_value,
|
||||
readback_data_mixin.device_manager.connector.pipeline.return_value,
|
||||
)
|
||||
in readback_data_mixin.device_manager.producer.get.call_args_list
|
||||
in readback_data_mixin.device_manager.connector.get.call_args_list
|
||||
)
|
||||
|
@ -15,7 +15,7 @@ async def test_update_progressbar_continues_without_device_data():
|
||||
live_update = LiveUpdatesScanProgress(bec=bec, report_instruction={}, request=request)
|
||||
progressbar = mock.MagicMock()
|
||||
|
||||
bec.producer.get.return_value = None
|
||||
bec.connector.get.return_value = None
|
||||
res = await live_update._update_progressbar(progressbar, "async_dev1")
|
||||
assert res is False
|
||||
|
||||
@ -29,7 +29,7 @@ async def test_update_progressbar_continues_when_scanID_doesnt_match():
|
||||
live_update.scan_item = mock.MagicMock()
|
||||
live_update.scan_item.scanID = "scanID2"
|
||||
|
||||
bec.producer.get.return_value = messages.ProgressMessage(
|
||||
bec.connector.get.return_value = messages.ProgressMessage(
|
||||
value=1, max_value=10, done=False, metadata={"scanID": "scanID"}
|
||||
)
|
||||
res = await live_update._update_progressbar(progressbar, "async_dev1")
|
||||
@ -45,7 +45,7 @@ async def test_update_progressbar_continues_when_msg_specifies_no_value():
|
||||
live_update.scan_item = mock.MagicMock()
|
||||
live_update.scan_item.scanID = "scanID"
|
||||
|
||||
bec.producer.get.return_value = messages.ProgressMessage(
|
||||
bec.connector.get.return_value = messages.ProgressMessage(
|
||||
value=None, max_value=None, done=None, metadata={"scanID": "scanID"}
|
||||
)
|
||||
res = await live_update._update_progressbar(progressbar, "async_dev1")
|
||||
@ -61,7 +61,7 @@ async def test_update_progressbar_updates_max_value():
|
||||
live_update.scan_item = mock.MagicMock()
|
||||
live_update.scan_item.scanID = "scanID"
|
||||
|
||||
bec.producer.get.return_value = messages.ProgressMessage(
|
||||
bec.connector.get.return_value = messages.ProgressMessage(
|
||||
value=10, max_value=20, done=False, metadata={"scanID": "scanID"}
|
||||
)
|
||||
res = await live_update._update_progressbar(progressbar, "async_dev1")
|
||||
@ -79,7 +79,7 @@ async def test_update_progressbar_returns_true_when_max_value_is_reached():
|
||||
live_update.scan_item = mock.MagicMock()
|
||||
live_update.scan_item.scanID = "scanID"
|
||||
|
||||
bec.producer.get.return_value = messages.ProgressMessage(
|
||||
bec.connector.get.return_value = messages.ProgressMessage(
|
||||
value=10, max_value=10, done=True, metadata={"scanID": "scanID"}
|
||||
)
|
||||
res = await live_update._update_progressbar(progressbar, "async_dev1")
|
||||
|
@ -463,11 +463,11 @@ def test_file_writer(client):
|
||||
md={"datasetID": 325},
|
||||
)
|
||||
assert len(scan.scan.data) == 100
|
||||
msg = bec.device_manager.producer.get(MessageEndpoints.public_file(scan.scan.scanID, "master"))
|
||||
msg = bec.device_manager.connector.get(MessageEndpoints.public_file(scan.scan.scanID, "master"))
|
||||
while True:
|
||||
if msg:
|
||||
break
|
||||
msg = bec.device_manager.producer.get(
|
||||
msg = bec.device_manager.connector.get(
|
||||
MessageEndpoints.public_file(scan.scan.scanID, "master")
|
||||
)
|
||||
|
||||
|
@ -3,7 +3,6 @@ from bec_lib.bec_service import BECService
|
||||
from bec_lib.channel_monitor import channel_monitor_launch
|
||||
from bec_lib.client import BECClient
|
||||
from bec_lib.config_helper import ConfigHelper
|
||||
from bec_lib.connector import ProducerConnector
|
||||
from bec_lib.device import DeviceBase, DeviceStatus, Status
|
||||
from bec_lib.devicemanager import DeviceConfigError, DeviceContainer, DeviceManagerBase
|
||||
from bec_lib.endpoints import MessageEndpoints
|
||||
|
@ -48,23 +48,21 @@ class AlarmBase(Exception):
|
||||
class AlarmHandler:
|
||||
def __init__(self, connector: RedisConnector) -> None:
|
||||
self.connector = connector
|
||||
self.alarm_consumer = None
|
||||
self.alarms_stack = deque(maxlen=100)
|
||||
self._raised_alarms = deque(maxlen=100)
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def start(self):
|
||||
"""start the alarm handler and its subscriptions"""
|
||||
self.alarm_consumer = self.connector.consumer(
|
||||
self.connector.register(
|
||||
topics=MessageEndpoints.alarm(),
|
||||
name="AlarmHandler",
|
||||
cb=self._alarm_consumer_callback,
|
||||
cb=self._alarm_register_callback,
|
||||
parent=self,
|
||||
)
|
||||
self.alarm_consumer.start()
|
||||
|
||||
@staticmethod
|
||||
def _alarm_consumer_callback(msg, *, parent, **_kwargs):
|
||||
def _alarm_register_callback(msg, *, parent, **_kwargs):
|
||||
parent.add_alarm(msg.value)
|
||||
|
||||
@threadlocked
|
||||
@ -136,4 +134,4 @@ class AlarmHandler:
|
||||
|
||||
def shutdown(self):
|
||||
"""shutdown the alarm handler"""
|
||||
self.alarm_consumer.shutdown()
|
||||
self.connector.shutdown()
|
||||
|
@ -8,12 +8,12 @@ from bec_lib.endpoints import MessageEndpoints
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bec_lib import messages
|
||||
from bec_lib.redis_connector import RedisProducer
|
||||
from bec_lib.connector import ConnectorBase
|
||||
|
||||
|
||||
class AsyncDataHandler:
|
||||
def __init__(self, producer: RedisProducer):
|
||||
self.producer = producer
|
||||
def __init__(self, connector: ConnectorBase):
|
||||
self.connector = connector
|
||||
|
||||
def get_async_data_for_scan(self, scan_id: str) -> dict[list]:
|
||||
"""
|
||||
@ -25,7 +25,9 @@ class AsyncDataHandler:
|
||||
Returns:
|
||||
dict[list]: the async data for the scan sorted by device name
|
||||
"""
|
||||
async_device_keys = self.producer.keys(MessageEndpoints.device_async_readback(scan_id, "*"))
|
||||
async_device_keys = self.connector.keys(
|
||||
MessageEndpoints.device_async_readback(scan_id, "*")
|
||||
)
|
||||
async_data = {}
|
||||
for device_key in async_device_keys:
|
||||
key = device_key.decode()
|
||||
@ -50,7 +52,7 @@ class AsyncDataHandler:
|
||||
list: the async data for the device
|
||||
"""
|
||||
key = MessageEndpoints.device_async_readback(scan_id, device_name)
|
||||
msgs = self.producer.xrange(key, min="-", max="+")
|
||||
msgs = self.connector.xrange(key, min="-", max="+")
|
||||
if not msgs:
|
||||
return []
|
||||
return self.process_async_data(msgs)
|
||||
|
@ -54,14 +54,14 @@ class BECWidgetsConnector:
|
||||
def __init__(self, gui_id: str, bec_client: BECClient = None) -> None:
|
||||
self._client = bec_client
|
||||
self.gui_id = gui_id
|
||||
# TODO replace with a global producer
|
||||
# TODO replace with a global connector
|
||||
if self._client is None:
|
||||
if "bec" in builtins.__dict__:
|
||||
self._client = builtins.bec
|
||||
else:
|
||||
self._client = BECClient()
|
||||
self._client.start()
|
||||
self._producer = self._client.connector.producer()
|
||||
self._connector = self._client.connector
|
||||
|
||||
def set_plot_config(self, plot_id: str, config: dict) -> None:
|
||||
"""
|
||||
@ -72,7 +72,7 @@ class BECWidgetsConnector:
|
||||
config (dict): The config to set.
|
||||
"""
|
||||
msg = messages.GUIConfigMessage(config=config)
|
||||
self._producer.set_and_publish(MessageEndpoints.gui_config(plot_id), msg)
|
||||
self._connector.set_and_publish(MessageEndpoints.gui_config(plot_id), msg)
|
||||
|
||||
def close(self, plot_id: str) -> None:
|
||||
"""
|
||||
@ -82,7 +82,7 @@ class BECWidgetsConnector:
|
||||
plot_id (str): The id of the plot.
|
||||
"""
|
||||
msg = messages.GUIInstructionMessage(action="close", parameter={})
|
||||
self._producer.set_and_publish(MessageEndpoints.gui_instructions(plot_id), msg)
|
||||
self._connector.set_and_publish(MessageEndpoints.gui_instructions(plot_id), msg)
|
||||
|
||||
def config_dialog(self, plot_id: str) -> None:
|
||||
"""
|
||||
@ -92,7 +92,7 @@ class BECWidgetsConnector:
|
||||
plot_id (str): The id of the plot.
|
||||
"""
|
||||
msg = messages.GUIInstructionMessage(action="config_dialog", parameter={})
|
||||
self._producer.set_and_publish(MessageEndpoints.gui_instructions(plot_id), msg)
|
||||
self._connector.set_and_publish(MessageEndpoints.gui_instructions(plot_id), msg)
|
||||
|
||||
def send_data(self, plot_id: str, data: dict) -> None:
|
||||
"""
|
||||
@ -103,9 +103,9 @@ class BECWidgetsConnector:
|
||||
data (dict): The data to send.
|
||||
"""
|
||||
msg = messages.GUIDataMessage(data=data)
|
||||
self._producer.set_and_publish(topic=MessageEndpoints.gui_data(plot_id), msg=msg)
|
||||
self._connector.set_and_publish(topic=MessageEndpoints.gui_data(plot_id), msg=msg)
|
||||
# TODO bec_dispatcher can only handle set_and_publish ATM
|
||||
# self._producer.xadd(topic=MessageEndpoints.gui_data(plot_id),msg= {"data": msg})
|
||||
# self._connector.xadd(topic=MessageEndpoints.gui_data(plot_id),msg= {"data": msg})
|
||||
|
||||
def clear(self, plot_id: str) -> None:
|
||||
"""
|
||||
@ -115,7 +115,7 @@ class BECWidgetsConnector:
|
||||
plot_id (str): The id of the plot.
|
||||
"""
|
||||
msg = messages.GUIInstructionMessage(action="clear", parameter={})
|
||||
self._producer.set_and_publish(MessageEndpoints.gui_instructions(plot_id), msg)
|
||||
self._connector.set_and_publish(MessageEndpoints.gui_instructions(plot_id), msg)
|
||||
|
||||
|
||||
class BECPlotter:
|
||||
|
@ -41,7 +41,6 @@ class BECService:
|
||||
self.connector = connector_cls(self.bootstrap_server)
|
||||
self._unique_service = unique_service
|
||||
self.wait_for_server = wait_for_server
|
||||
self.producer = self.connector.producer()
|
||||
self.__service_id = str(uuid.uuid4())
|
||||
self._user = getpass.getuser()
|
||||
self._hostname = socket.gethostname()
|
||||
@ -110,11 +109,11 @@ class BECService:
|
||||
)
|
||||
|
||||
def _update_existing_services(self):
|
||||
service_keys = self.producer.keys(MessageEndpoints.service_status("*"))
|
||||
service_keys = self.connector.keys(MessageEndpoints.service_status("*"))
|
||||
if not service_keys:
|
||||
return
|
||||
services = [service.decode().split(":", maxsplit=1)[0] for service in service_keys]
|
||||
msgs = [self.producer.get(service) for service in services]
|
||||
msgs = [self.connector.get(service) for service in services]
|
||||
self._services_info = {msg.content["name"]: msg for msg in msgs if msg is not None}
|
||||
|
||||
def _update_service_info(self):
|
||||
@ -124,7 +123,7 @@ class BECService:
|
||||
self._service_info_event.wait(timeout=3)
|
||||
|
||||
def _send_service_status(self):
|
||||
self.producer.set_and_publish(
|
||||
self.connector.set_and_publish(
|
||||
topic=MessageEndpoints.service_status(self._service_id),
|
||||
msg=messages.StatusMessage(
|
||||
name=self._service_name,
|
||||
@ -189,7 +188,7 @@ class BECService:
|
||||
)
|
||||
)
|
||||
msg = messages.ServiceMetricMessage(name=self.__class__.__name__, metrics=data)
|
||||
self.producer.send(MessageEndpoints.metrics(self._service_id), msg)
|
||||
self.connector.send(MessageEndpoints.metrics(self._service_id), msg)
|
||||
self._metrics_emitter_event.wait(timeout=1)
|
||||
|
||||
def set_global_var(self, name: str, val: Any) -> None:
|
||||
@ -200,7 +199,7 @@ class BECService:
|
||||
val (Any): Value of the variable
|
||||
|
||||
"""
|
||||
self.producer.set(MessageEndpoints.global_vars(name), messages.VariableMessage(value=val))
|
||||
self.connector.set(MessageEndpoints.global_vars(name), messages.VariableMessage(value=val))
|
||||
|
||||
def get_global_var(self, name: str) -> Any:
|
||||
"""Get a global variable from Redis
|
||||
@ -211,7 +210,7 @@ class BECService:
|
||||
Returns:
|
||||
Any: Value of the variable
|
||||
"""
|
||||
msg = self.producer.get(MessageEndpoints.global_vars(name))
|
||||
msg = self.connector.get(MessageEndpoints.global_vars(name))
|
||||
if msg:
|
||||
return msg.content.get("value")
|
||||
return None
|
||||
@ -223,12 +222,12 @@ class BECService:
|
||||
name (str): Name of the variable
|
||||
|
||||
"""
|
||||
self.producer.delete(MessageEndpoints.global_vars(name))
|
||||
self.connector.delete(MessageEndpoints.global_vars(name))
|
||||
|
||||
def global_vars(self) -> str:
|
||||
"""Get all available global variables"""
|
||||
# sadly, this cannot be a property as it causes side effects with IPython's tab completion
|
||||
available_keys = self.producer.keys(MessageEndpoints.global_vars("*"))
|
||||
available_keys = self.connector.keys(MessageEndpoints.global_vars("*"))
|
||||
|
||||
def get_endpoint_from_topic(topic: str) -> str:
|
||||
return topic.decode().split(MessageEndpoints.global_vars(""))[-1]
|
||||
@ -252,6 +251,7 @@ class BECService:
|
||||
|
||||
def shutdown(self):
|
||||
"""shutdown the BECService"""
|
||||
self.connector.shutdown()
|
||||
self._service_info_event.set()
|
||||
if self._service_info_thread:
|
||||
self._service_info_thread.join()
|
||||
|
@ -16,11 +16,11 @@ def channel_callback(msg, **_kwargs):
|
||||
print(json.dumps(out, indent=4, default=lambda o: "<not serializable object>"))
|
||||
|
||||
|
||||
def _start_consumer(config_path, topic):
|
||||
def _start_register(config_path, topic):
|
||||
config = ServiceConfig(config_path)
|
||||
connector = RedisConnector(config.redis)
|
||||
consumer = connector.consumer(topics=topic, cb=channel_callback)
|
||||
consumer.start()
|
||||
register = connector.register(topics=topic, cb=channel_callback)
|
||||
register.start()
|
||||
event = threading.Event()
|
||||
event.wait()
|
||||
|
||||
@ -38,4 +38,4 @@ def channel_monitor_launch():
|
||||
config_path = clargs.config
|
||||
topic = clargs.channel
|
||||
|
||||
_start_consumer(config_path, topic)
|
||||
_start_register(config_path, topic)
|
||||
|
@ -91,7 +91,7 @@ class BECClient(BECService, UserScriptsMixin):
|
||||
@property
|
||||
def active_account(self) -> str:
|
||||
"""get the currently active target (e)account"""
|
||||
return self.producer.get(MessageEndpoints.account())
|
||||
return self.connector.get(MessageEndpoints.account())
|
||||
|
||||
def start(self):
|
||||
"""start the client"""
|
||||
@ -133,13 +133,13 @@ class BECClient(BECService, UserScriptsMixin):
|
||||
@property
|
||||
def pre_scan_hooks(self):
|
||||
"""currently stored pre-scan hooks"""
|
||||
return self.producer.lrange(MessageEndpoints.pre_scan_macros(), 0, -1)
|
||||
return self.connector.lrange(MessageEndpoints.pre_scan_macros(), 0, -1)
|
||||
|
||||
@pre_scan_hooks.setter
|
||||
def pre_scan_hooks(self, hooks: list):
|
||||
self.producer.delete(MessageEndpoints.pre_scan_macros())
|
||||
self.connector.delete(MessageEndpoints.pre_scan_macros())
|
||||
for hook in hooks:
|
||||
self.producer.lpush(MessageEndpoints.pre_scan_macros(), hook)
|
||||
self.connector.lpush(MessageEndpoints.pre_scan_macros(), hook)
|
||||
|
||||
def _load_scans(self):
|
||||
self.scans = Scans(self)
|
||||
|
@ -26,7 +26,6 @@ logger = bec_logger.logger
|
||||
class ConfigHelper:
|
||||
def __init__(self, connector: RedisConnector, service_name: str = None) -> None:
|
||||
self.connector = connector
|
||||
self.producer = connector.producer()
|
||||
self._service_name = service_name
|
||||
|
||||
def update_session_with_file(self, file_path: str, save_recovery: bool = True) -> None:
|
||||
@ -71,7 +70,7 @@ class ConfigHelper:
|
||||
print(f"Config was written to {file_path}.")
|
||||
|
||||
def _save_config_to_file(self, file_path: str, raise_on_error: bool = True) -> bool:
|
||||
config = self.producer.get(MessageEndpoints.device_config())
|
||||
config = self.connector.get(MessageEndpoints.device_config())
|
||||
if not config:
|
||||
if raise_on_error:
|
||||
raise DeviceConfigError("No config found in the session.")
|
||||
@ -99,7 +98,7 @@ class ConfigHelper:
|
||||
if action in ["update", "add", "set"] and not config:
|
||||
raise DeviceConfigError(f"Config cannot be empty for an {action} request.")
|
||||
RID = str(uuid.uuid4())
|
||||
self.producer.send(
|
||||
self.connector.send(
|
||||
MessageEndpoints.device_config_request(),
|
||||
DeviceConfigMessage(action=action, config=config, metadata={"RID": RID}),
|
||||
)
|
||||
@ -145,7 +144,7 @@ class ConfigHelper:
|
||||
elapsed_time = 0
|
||||
max_time = timeout
|
||||
while True:
|
||||
service_messages = self.producer.lrange(MessageEndpoints.service_response(RID), 0, -1)
|
||||
service_messages = self.connector.lrange(MessageEndpoints.service_response(RID), 0, -1)
|
||||
if not service_messages:
|
||||
time.sleep(0.005)
|
||||
elapsed_time += 0.005
|
||||
@ -185,7 +184,7 @@ class ConfigHelper:
|
||||
"""
|
||||
start = 0
|
||||
while True:
|
||||
msg = self.producer.get(MessageEndpoints.device_config_request_response(RID))
|
||||
msg = self.connector.get(MessageEndpoints.device_config_request_response(RID))
|
||||
if msg is None:
|
||||
time.sleep(0.01)
|
||||
start += 0.01
|
||||
|
@ -6,7 +6,8 @@ import threading
|
||||
import traceback
|
||||
|
||||
from bec_lib.logger import bec_logger
|
||||
from bec_lib.messages import BECMessage
|
||||
from bec_lib.messages import BECMessage, LogMessage
|
||||
from bec_lib.endpoints import MessageEndpoints
|
||||
|
||||
logger = bec_logger.logger
|
||||
|
||||
@ -33,154 +34,98 @@ class MessageObject:
|
||||
return f"MessageObject(topic={self.topic}, value={self._value})"
|
||||
|
||||
|
||||
class ConnectorBase(abc.ABC):
|
||||
"""
|
||||
ConnectorBase implements producer and consumer clients for communicating with a broker.
|
||||
One ought to inherit from this base class and provide at least customized producer and consumer methods.
|
||||
class StoreInterface(abc.ABC):
|
||||
"""StoreBase defines the interface for storing data"""
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, bootstrap_server: list):
|
||||
self.bootstrap = bootstrap_server
|
||||
self._threads = []
|
||||
|
||||
def producer(self, **kwargs) -> ProducerConnector:
|
||||
raise NotImplementedError
|
||||
|
||||
def consumer(self, **kwargs) -> ConsumerConnectorThreaded:
|
||||
raise NotImplementedError
|
||||
|
||||
def shutdown(self):
|
||||
for t in self._threads:
|
||||
t.signal_event.set()
|
||||
t.join()
|
||||
|
||||
def raise_warning(self, msg):
|
||||
raise NotImplementedError
|
||||
|
||||
def send_log(self, msg):
|
||||
raise NotImplementedError
|
||||
|
||||
def poll_messages(self):
|
||||
"""Poll for new messages, receive them and execute callbacks"""
|
||||
def __init__(self, store):
|
||||
pass
|
||||
|
||||
def pipeline(self):
|
||||
pass
|
||||
|
||||
class ProducerConnector(abc.ABC):
|
||||
def execute_pipeline(self):
|
||||
pass
|
||||
|
||||
def lpush(
|
||||
self, topic: str, msg: str, pipe=None, max_size: int = None, expire: int = None
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def lset(self, topic: str, index: int, msg: str, pipe=None) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def rpush(self, topic: str, msg: str, pipe=None) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
def lrange(self, topic: str, start: int, end: int, pipe=None):
|
||||
raise NotImplementedError
|
||||
|
||||
def set(self, topic: str, msg, pipe=None, expire: int = None) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def keys(self, pattern: str) -> list:
|
||||
raise NotImplementedError
|
||||
|
||||
def delete(self, topic, pipe=None):
|
||||
raise NotImplementedError
|
||||
|
||||
def get(self, topic: str, pipe=None):
|
||||
raise NotImplementedError
|
||||
|
||||
def xadd(self, topic: str, msg: dict, max_size=None, pipe=None, expire: int = None):
|
||||
raise NotImplementedError
|
||||
|
||||
def xread(
|
||||
self,
|
||||
topic: str,
|
||||
id: str = None,
|
||||
count: int = None,
|
||||
block: int = None,
|
||||
pipe=None,
|
||||
from_start=False,
|
||||
) -> list:
|
||||
raise NotImplementedError
|
||||
|
||||
def xrange(self, topic: str, min: str, max: str, count: int = None, pipe=None):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class PubSubInterface(abc.ABC):
|
||||
def raw_send(self, topic: str, msg: bytes) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def send(self, topic: str, msg: BECMessage) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ConsumerConnector(abc.ABC):
|
||||
def __init__(
|
||||
self, bootstrap_server, cb, topics=None, pattern=None, group_id=None, event=None, **kwargs
|
||||
):
|
||||
"""
|
||||
ConsumerConnector class defines the communication with the broker for consuming messages.
|
||||
An implementation ought to inherit from this class and implement the initialize_connector and poll_messages methods.
|
||||
|
||||
Args:
|
||||
bootstrap_server: list of bootstrap servers, e.g. ["localhost:9092", "localhost:9093"]
|
||||
topics: the topic(s) to which the connector should attach
|
||||
event: external event to trigger start and stop of the connector
|
||||
cb: callback function; will be triggered from within poll_messages
|
||||
kwargs: additional keyword arguments
|
||||
|
||||
"""
|
||||
self.bootstrap = bootstrap_server
|
||||
self.topics = topics
|
||||
self.pattern = pattern
|
||||
self.group_id = group_id
|
||||
self.connector = None
|
||||
self.cb = cb
|
||||
self.kwargs = kwargs
|
||||
|
||||
if not self.topics and not self.pattern:
|
||||
raise ConsumerConnectorError("Either a topic or a patter must be specified.")
|
||||
|
||||
def initialize_connector(self) -> None:
|
||||
"""
|
||||
initialize the connector instance self.connector
|
||||
The connector will be initialized once the thread is started
|
||||
"""
|
||||
def register(self, topics=None, pattern=None, cb=None, start_thread=True, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def poll_messages(self) -> None:
|
||||
"""
|
||||
Poll messages from self.connector and call the callback function self.cb
|
||||
def poll_messages(self, timeout=None):
|
||||
"""Poll for new messages, receive them and execute callbacks"""
|
||||
raise NotImplementedError
|
||||
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class ConsumerConnectorThreaded(ConsumerConnector, threading.Thread):
|
||||
def __init__(
|
||||
self,
|
||||
bootstrap_server,
|
||||
cb,
|
||||
topics=None,
|
||||
pattern=None,
|
||||
group_id=None,
|
||||
event=None,
|
||||
name=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
ConsumerConnectorThreaded class defines the threaded communication with the broker for consuming messages.
|
||||
An implementation ought to inherit from this class and implement the initialize_connector and poll_messages methods.
|
||||
Once started, the connector is expected to poll new messages until the signal_event is set.
|
||||
|
||||
Args:
|
||||
bootstrap_server: list of bootstrap servers, e.g. ["localhost:9092", "localhost:9093"]
|
||||
topics: the topic(s) to which the connector should attach
|
||||
event: external event to trigger start and stop of the connector
|
||||
cb: callback function; will be triggered from within poll_messages
|
||||
kwargs: additional keyword arguments
|
||||
|
||||
"""
|
||||
super().__init__(
|
||||
bootstrap_server=bootstrap_server,
|
||||
topics=topics,
|
||||
pattern=pattern,
|
||||
group_id=group_id,
|
||||
event=event,
|
||||
cb=cb,
|
||||
**kwargs,
|
||||
)
|
||||
if name is not None:
|
||||
thread_kwargs = {"name": name, "daemon": True}
|
||||
else:
|
||||
thread_kwargs = {"daemon": True}
|
||||
super(ConsumerConnector, self).__init__(**thread_kwargs)
|
||||
self.signal_event = event if event is not None else threading.Event()
|
||||
|
||||
def run(self):
|
||||
self.initialize_connector()
|
||||
|
||||
while True:
|
||||
try:
|
||||
self.poll_messages()
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
_thread.interrupt_main()
|
||||
raise e
|
||||
finally:
|
||||
if self.signal_event.is_set():
|
||||
self.shutdown()
|
||||
break
|
||||
def run_messages_loop(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def shutdown(self):
|
||||
self.signal_event.set()
|
||||
raise NotImplementedError
|
||||
|
||||
# def stop(self) -> None:
|
||||
# """
|
||||
# Stop consumer
|
||||
# Returns:
|
||||
|
||||
# """
|
||||
# self.signal_event.set()
|
||||
# self.connector.close()
|
||||
# self.join()
|
||||
class ConnectorBase(PubSubInterface, StoreInterface):
|
||||
def raise_warning(self, msg):
|
||||
raise NotImplementedError
|
||||
|
||||
def log_warning(self, msg):
|
||||
"""send a warning"""
|
||||
self.send(MessageEndpoints.log(), LogMessage(log_type="warning", log_msg=msg))
|
||||
|
||||
def log_message(self, msg):
|
||||
"""send a log message"""
|
||||
self.send(MessageEndpoints.log(), LogMessage(log_type="log", log_msg=msg))
|
||||
|
||||
def log_error(self, msg):
|
||||
"""send an error as log"""
|
||||
self.send(MessageEndpoints.log(), LogMessage(log_type="error", log_msg=msg))
|
||||
|
||||
def set_and_publish(self, topic: str, msg, pipe=None, expire: int = None) -> None:
|
||||
raise NotImplementedError
|
||||
|
@ -78,7 +78,7 @@ class DAPPluginObjectBase:
|
||||
converted_kwargs[key] = val
|
||||
kwargs = converted_kwargs
|
||||
request_id = str(uuid.uuid4())
|
||||
self._client.producer.set_and_publish(
|
||||
self._client.connector.set_and_publish(
|
||||
MessageEndpoints.dap_request(),
|
||||
messages.DAPRequestMessage(
|
||||
dap_cls=self._plugin_info["class"],
|
||||
@ -110,7 +110,7 @@ class DAPPluginObjectBase:
|
||||
while True:
|
||||
if time.time() - start_time > timeout:
|
||||
raise TimeoutError("Timeout waiting for DAP response.")
|
||||
response = self._client.producer.get(MessageEndpoints.dap_response(request_id))
|
||||
response = self._client.connector.get(MessageEndpoints.dap_response(request_id))
|
||||
if not response:
|
||||
time.sleep(0.005)
|
||||
continue
|
||||
@ -128,7 +128,7 @@ class DAPPluginObjectBase:
|
||||
return
|
||||
self._plugin_config["class_args"] = self._plugin_info.get("class_args")
|
||||
self._plugin_config["class_kwargs"] = self._plugin_info.get("class_kwargs")
|
||||
self._client.producer.set_and_publish(
|
||||
self._client.connector.set_and_publish(
|
||||
MessageEndpoints.dap_request(),
|
||||
messages.DAPRequestMessage(
|
||||
dap_cls=self._plugin_info["class"],
|
||||
@ -149,7 +149,7 @@ class DAPPluginObject(DAPPluginObjectBase):
|
||||
"""
|
||||
Get the data from last run.
|
||||
"""
|
||||
msg = self._client.producer.get_last(MessageEndpoints.processed_data(self._service_name))
|
||||
msg = self._client.connector.get_last(MessageEndpoints.processed_data(self._service_name))
|
||||
if not msg:
|
||||
return None
|
||||
return self._convert_result(msg)
|
||||
|
@ -38,7 +38,7 @@ class DAPPlugins:
|
||||
service for service in available_services if service.startswith("DAPServer/")
|
||||
]
|
||||
for service in dap_services:
|
||||
available_plugins = self._parent.producer.get(
|
||||
available_plugins = self._parent.connector.get(
|
||||
MessageEndpoints.dap_available_plugins(service)
|
||||
)
|
||||
if available_plugins is None:
|
||||
|
@ -10,7 +10,7 @@ from typeguard import typechecked
|
||||
from bec_lib import messages
|
||||
from bec_lib.endpoints import MessageEndpoints
|
||||
from bec_lib.logger import bec_logger
|
||||
from bec_lib.redis_connector import RedisProducer
|
||||
from bec_lib.redis_connector import RedisConnector
|
||||
|
||||
logger = bec_logger.logger
|
||||
|
||||
@ -61,15 +61,15 @@ class ReadoutPriority(str, enum.Enum):
|
||||
|
||||
|
||||
class Status:
|
||||
def __init__(self, producer: RedisProducer, RID: str) -> None:
|
||||
def __init__(self, connector: RedisConnector, RID: str) -> None:
|
||||
"""
|
||||
Status object for RPC calls
|
||||
|
||||
Args:
|
||||
producer (RedisProducer): Redis producer
|
||||
connector (RedisConnector): Redis connector
|
||||
RID (str): Request ID
|
||||
"""
|
||||
self._producer = producer
|
||||
self._connector = connector
|
||||
self._RID = RID
|
||||
|
||||
def __eq__(self, __value: object) -> bool:
|
||||
@ -91,7 +91,7 @@ class Status:
|
||||
raise TimeoutError()
|
||||
|
||||
while True:
|
||||
request_status = self._producer.lrange(
|
||||
request_status = self._connector.lrange(
|
||||
MessageEndpoints.device_req_status(self._RID), 0, -1
|
||||
)
|
||||
if request_status:
|
||||
@ -251,7 +251,7 @@ class DeviceBase:
|
||||
if not isinstance(return_val, dict):
|
||||
return return_val
|
||||
if return_val.get("type") == "status" and return_val.get("RID"):
|
||||
return Status(self.root.parent.producer, return_val.get("RID"))
|
||||
return Status(self.root.parent.connector, return_val.get("RID"))
|
||||
return return_val
|
||||
|
||||
def _get_rpc_response(self, request_id, rpc_id) -> Any:
|
||||
@ -267,7 +267,7 @@ class DeviceBase:
|
||||
f" {scan_queue_request.response.content['message']}"
|
||||
)
|
||||
while True:
|
||||
msg = self.root.parent.producer.get(MessageEndpoints.device_rpc(rpc_id))
|
||||
msg = self.root.parent.connector.get(MessageEndpoints.device_rpc(rpc_id))
|
||||
if msg:
|
||||
break
|
||||
time.sleep(0.01)
|
||||
@ -296,7 +296,7 @@ class DeviceBase:
|
||||
msg = self._prepare_rpc_msg(rpc_id, request_id, device, func_call, *args, **kwargs)
|
||||
|
||||
# send RPC message
|
||||
self.root.parent.producer.send(MessageEndpoints.scan_queue_request(), msg)
|
||||
self.root.parent.connector.send(MessageEndpoints.scan_queue_request(), msg)
|
||||
|
||||
# wait for RPC response
|
||||
if not wait_for_rpc_response:
|
||||
@ -496,7 +496,7 @@ class DeviceBase:
|
||||
|
||||
# def read(self, cached, filter_readback=True):
|
||||
# """get the last reading from a device"""
|
||||
# val = self.parent.producer.get(MessageEndpoints.device_read(self.name))
|
||||
# val = self.parent.connector.get(MessageEndpoints.device_read(self.name))
|
||||
# if not val:
|
||||
# return None
|
||||
# if filter_readback:
|
||||
@ -505,7 +505,7 @@ class DeviceBase:
|
||||
#
|
||||
# def readback(self, filter_readback=True):
|
||||
# """get the last readback value from a device"""
|
||||
# val = self.parent.producer.get(MessageEndpoints.device_readback(self.name))
|
||||
# val = self.parent.connector.get(MessageEndpoints.device_readback(self.name))
|
||||
# if not val:
|
||||
# return None
|
||||
# if filter_readback:
|
||||
@ -515,7 +515,7 @@ class DeviceBase:
|
||||
# @property
|
||||
# def device_status(self):
|
||||
# """get the current status of the device"""
|
||||
# val = self.parent.producer.get(MessageEndpoints.device_status(self.name))
|
||||
# val = self.parent.connector.get(MessageEndpoints.device_status(self.name))
|
||||
# if val is None:
|
||||
# return val
|
||||
# val = DeviceStatusMessage.loads(val)
|
||||
@ -524,7 +524,7 @@ class DeviceBase:
|
||||
# @property
|
||||
# def signals(self):
|
||||
# """get the last signals from a device"""
|
||||
# val = self.parent.producer.get(MessageEndpoints.device_read(self.name))
|
||||
# val = self.parent.connector.get(MessageEndpoints.device_read(self.name))
|
||||
# if val is None:
|
||||
# return None
|
||||
# self._signals = DeviceMessage.loads(val).content["signals"]
|
||||
@ -593,11 +593,11 @@ class OphydInterfaceBase(DeviceBase):
|
||||
if is_config_signal:
|
||||
return self.read_configuration(cached=cached)
|
||||
if use_readback:
|
||||
val = self.root.parent.producer.get(
|
||||
val = self.root.parent.connector.get(
|
||||
MessageEndpoints.device_readback(self.root.name)
|
||||
)
|
||||
else:
|
||||
val = self.root.parent.producer.get(MessageEndpoints.device_read(self.root.name))
|
||||
val = self.root.parent.connector.get(MessageEndpoints.device_read(self.root.name))
|
||||
|
||||
if not val:
|
||||
return None
|
||||
@ -623,7 +623,7 @@ class OphydInterfaceBase(DeviceBase):
|
||||
if is_signal and not is_config_signal:
|
||||
return self.read(cached=True)
|
||||
|
||||
val = self.root.parent.producer.get(
|
||||
val = self.root.parent.connector.get(
|
||||
MessageEndpoints.device_read_configuration(self.root.name)
|
||||
)
|
||||
if not val:
|
||||
@ -766,7 +766,7 @@ class AdjustableMixin:
|
||||
"""
|
||||
Returns the device limits.
|
||||
"""
|
||||
limit_msg = self.root.parent.producer.get(MessageEndpoints.device_limits(self.root.name))
|
||||
limit_msg = self.root.parent.connector.get(MessageEndpoints.device_limits(self.root.name))
|
||||
if not limit_msg:
|
||||
return [0, 0]
|
||||
limits = [
|
||||
|
@ -370,8 +370,7 @@ class DeviceManagerBase:
|
||||
_request_config_parsed = None # parsed config request
|
||||
_response = None # response message
|
||||
|
||||
_connector_base_consumer = {}
|
||||
producer = None
|
||||
_connector_base_register = {}
|
||||
config_helper = None
|
||||
_device_cls = DeviceBase
|
||||
_status_cb = []
|
||||
@ -464,7 +463,7 @@ class DeviceManagerBase:
|
||||
"""
|
||||
if not msg.metadata.get("RID"):
|
||||
return
|
||||
self.producer.lpush(
|
||||
self.connector.lpush(
|
||||
MessageEndpoints.service_response(msg.metadata["RID"]),
|
||||
messages.ServiceResponseMessage(
|
||||
# pylint: disable=no-member
|
||||
@ -487,25 +486,20 @@ class DeviceManagerBase:
|
||||
self._remove_device(dev)
|
||||
|
||||
def _start_connectors(self, bootstrap_server) -> None:
|
||||
self._start_base_consumer()
|
||||
self.producer = self.connector.producer()
|
||||
self._start_custom_connectors(bootstrap_server)
|
||||
self._start_base_register()
|
||||
|
||||
def _start_base_consumer(self) -> None:
|
||||
def _start_base_register(self) -> None:
|
||||
"""
|
||||
Start consuming messages for all base topics. This method will be called upon startup.
|
||||
Returns:
|
||||
|
||||
"""
|
||||
self._connector_base_consumer["device_config_update"] = self.connector.consumer(
|
||||
self.connector.register(
|
||||
MessageEndpoints.device_config_update(),
|
||||
cb=self._device_config_update_callback,
|
||||
parent=self,
|
||||
)
|
||||
|
||||
# self._connector_base_consumer["log"].start()
|
||||
self._connector_base_consumer["device_config_update"].start()
|
||||
|
||||
@staticmethod
|
||||
def _log_callback(msg, *, parent, **kwargs) -> None:
|
||||
"""
|
||||
@ -541,48 +535,11 @@ class DeviceManagerBase:
|
||||
self._load_session()
|
||||
|
||||
def _get_redis_device_config(self) -> list:
|
||||
devices = self.producer.get(MessageEndpoints.device_config())
|
||||
devices = self.connector.get(MessageEndpoints.device_config())
|
||||
if not devices:
|
||||
return []
|
||||
return devices.content["resource"]
|
||||
|
||||
def _stop_base_consumer(self):
|
||||
"""
|
||||
Stop all base consumers by setting the corresponding event
|
||||
Returns:
|
||||
|
||||
"""
|
||||
if self.connector is not None:
|
||||
for _, con in self._connector_base_consumer.items():
|
||||
con.signal_event.set()
|
||||
con.join()
|
||||
|
||||
def _stop_consumer(self):
|
||||
"""
|
||||
Stop all consumers
|
||||
Returns:
|
||||
|
||||
"""
|
||||
self._stop_base_consumer()
|
||||
self._stop_custom_consumer()
|
||||
|
||||
def _start_custom_connectors(self, bootstrap_server) -> None:
|
||||
"""
|
||||
Override this method in a derived class to start custom connectors upon initialization.
|
||||
Args:
|
||||
bootstrap_server: Kafka bootstrap server
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
|
||||
def _stop_custom_consumer(self) -> None:
|
||||
"""
|
||||
Stop all custom consumers. Override this method in a derived class.
|
||||
Returns:
|
||||
|
||||
"""
|
||||
|
||||
def _add_device(self, dev: dict, msg: messages.DeviceInfoMessage):
|
||||
name = msg.content["device"]
|
||||
info = msg.content["info"]
|
||||
@ -621,8 +578,7 @@ class DeviceManagerBase:
|
||||
logger.error(f"Failed to load device {dev}: {content}")
|
||||
|
||||
def _get_device_info(self, device_name) -> DeviceInfoMessage:
|
||||
msg = self.producer.get(MessageEndpoints.device_info(device_name))
|
||||
return msg
|
||||
return self.connector.get(MessageEndpoints.device_info(device_name))
|
||||
|
||||
def check_request_validity(self, msg: DeviceConfigMessage) -> None:
|
||||
"""
|
||||
@ -663,10 +619,7 @@ class DeviceManagerBase:
|
||||
"""
|
||||
Shutdown all connectors.
|
||||
"""
|
||||
try:
|
||||
self.connector.shutdown()
|
||||
except RuntimeError as runtime_error:
|
||||
logger.error(f"Failed to shutdown connector. {runtime_error}")
|
||||
self.connector.shutdown()
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
|
@ -24,7 +24,6 @@ if TYPE_CHECKING:
|
||||
class LogbookConnector:
|
||||
def __init__(self, connector: RedisConnector) -> None:
|
||||
self.connector = connector
|
||||
self.producer = connector.producer()
|
||||
self.connected = False
|
||||
self._scilog_module = None
|
||||
self._connect()
|
||||
@ -34,12 +33,12 @@ class LogbookConnector:
|
||||
if "scilog" not in sys.modules:
|
||||
return
|
||||
|
||||
msg = self.producer.get(MessageEndpoints.logbook())
|
||||
msg = self.connector.get(MessageEndpoints.logbook())
|
||||
if not msg:
|
||||
return
|
||||
msg = msgpack.loads(msg)
|
||||
|
||||
account = self.producer.get(MessageEndpoints.account())
|
||||
account = self.connector.get(MessageEndpoints.account())
|
||||
if not account:
|
||||
return
|
||||
account = account.decode()
|
||||
@ -54,7 +53,7 @@ class LogbookConnector:
|
||||
try:
|
||||
logbooks = self.log.get_logbooks(readACL={"inq": [account]})
|
||||
except HTTPError:
|
||||
self.producer.set(MessageEndpoints.logbook(), b"")
|
||||
self.connector.set(MessageEndpoints.logbook(), b"")
|
||||
return
|
||||
if len(logbooks) > 1:
|
||||
logger.warning("Found two logbooks. Taking the first one.")
|
||||
|
@ -45,7 +45,6 @@ class BECLogger:
|
||||
self.bootstrap_server = None
|
||||
self.connector = None
|
||||
self.service_name = None
|
||||
self.producer = None
|
||||
self.logger = loguru_logger
|
||||
self._log_level = LogLevel.INFO
|
||||
self.level = self._log_level
|
||||
@ -73,7 +72,6 @@ class BECLogger:
|
||||
self.bootstrap_server = bootstrap_server
|
||||
self.connector = connector_cls(bootstrap_server)
|
||||
self.service_name = service_name
|
||||
self.producer = self.connector.producer()
|
||||
self._configured = True
|
||||
self._update_sinks()
|
||||
|
||||
@ -82,7 +80,7 @@ class BECLogger:
|
||||
return
|
||||
msg = json.loads(msg)
|
||||
msg["service_name"] = self.service_name
|
||||
self.producer.send(
|
||||
self.connector.send(
|
||||
topic=MessageEndpoints.log(),
|
||||
msg=bec_lib.messages.LogMessage(log_type=msg["record"]["level"]["name"], log_msg=msg),
|
||||
)
|
||||
|
@ -152,7 +152,7 @@ class ObserverManager:
|
||||
|
||||
def _get_installed_observer(self):
|
||||
# get current observer list from Redis
|
||||
observer_msg = self.device_manager.producer.get(MessageEndpoints.observer())
|
||||
observer_msg = self.device_manager.connector.get(MessageEndpoints.observer())
|
||||
if observer_msg is None:
|
||||
return []
|
||||
return [Observer.from_dict(obs) for obs in observer_msg.content["observer"]]
|
||||
|
@ -122,12 +122,16 @@ class QueueStorage:
|
||||
if history < 0:
|
||||
history *= -1
|
||||
|
||||
return self.scan_manager.producer.lrange(MessageEndpoints.scan_queue_history(), 0, history)
|
||||
return self.scan_manager.connector.lrange(
|
||||
MessageEndpoints.scan_queue_history(),
|
||||
0,
|
||||
history,
|
||||
)
|
||||
|
||||
@property
|
||||
def current_scan_queue(self) -> dict:
|
||||
"""get the current scan queue from redis"""
|
||||
msg = self.scan_manager.producer.get(MessageEndpoints.scan_queue_status())
|
||||
msg = self.scan_manager.connector.get(MessageEndpoints.scan_queue_status())
|
||||
if msg:
|
||||
self._current_scan_queue = msg.content["queue"]
|
||||
return self._current_scan_queue
|
||||
|
@ -1,19 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import collections
|
||||
import queue
|
||||
import sys
|
||||
import threading
|
||||
import warnings
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import louie
|
||||
import redis
|
||||
import redis.client
|
||||
|
||||
from bec_lib.connector import (
|
||||
ConnectorBase,
|
||||
ConsumerConnector,
|
||||
ConsumerConnectorThreaded,
|
||||
MessageObject,
|
||||
ProducerConnector,
|
||||
)
|
||||
from bec_lib.connector import ConnectorBase, MessageObject
|
||||
from bec_lib.endpoints import MessageEndpoints
|
||||
from bec_lib.messages import AlarmMessage, BECMessage, LogMessage
|
||||
from bec_lib.serialization import MsgpackSerialization
|
||||
@ -31,7 +30,6 @@ def catch_connection_error(func):
|
||||
return func(*args, **kwargs)
|
||||
except redis.exceptions.ConnectionError:
|
||||
warnings.warn("Failed to connect to redis. Is the server running?")
|
||||
time.sleep(0.1)
|
||||
return None
|
||||
|
||||
return wrapper
|
||||
@ -40,149 +38,80 @@ def catch_connection_error(func):
|
||||
class RedisConnector(ConnectorBase):
|
||||
def __init__(self, bootstrap: list, redis_cls=None):
|
||||
super().__init__(bootstrap)
|
||||
self.redis_cls = redis_cls
|
||||
self.host, self.port = (
|
||||
bootstrap[0].split(":") if isinstance(bootstrap, list) else bootstrap.split(":")
|
||||
)
|
||||
self._notifications_producer = RedisProducer(
|
||||
host=self.host, port=self.port, redis_cls=self.redis_cls
|
||||
)
|
||||
|
||||
def producer(self, **kwargs):
|
||||
return RedisProducer(host=self.host, port=self.port, redis_cls=self.redis_cls)
|
||||
if redis_cls:
|
||||
self._redis_conn = redis_cls(host=self.host, port=self.port)
|
||||
else:
|
||||
self._redis_conn = redis.Redis(host=self.host, port=self.port)
|
||||
|
||||
# pylint: disable=too-many-arguments
|
||||
def consumer(
|
||||
self,
|
||||
topics=None,
|
||||
pattern=None,
|
||||
group_id=None,
|
||||
event=None,
|
||||
cb=None,
|
||||
threaded=True,
|
||||
name=None,
|
||||
**kwargs,
|
||||
):
|
||||
if cb is None:
|
||||
raise ValueError("The callback function must be specified.")
|
||||
# main pubsub connection
|
||||
self._pubsub_conn = self._redis_conn.pubsub()
|
||||
self._pubsub_conn.ignore_subscribe_messages = True
|
||||
# keep track of topics and callbacks
|
||||
self._topics_cb = collections.defaultdict(list)
|
||||
|
||||
if threaded:
|
||||
if topics is None and pattern is None:
|
||||
raise ValueError("Topics must be set for threaded consumer")
|
||||
listener = RedisConsumerThreaded(
|
||||
self.host,
|
||||
self.port,
|
||||
topics,
|
||||
pattern,
|
||||
group_id,
|
||||
event,
|
||||
cb,
|
||||
redis_cls=self.redis_cls,
|
||||
name=name,
|
||||
**kwargs,
|
||||
)
|
||||
self._threads.append(listener)
|
||||
return listener
|
||||
return RedisConsumer(
|
||||
self.host,
|
||||
self.port,
|
||||
topics,
|
||||
pattern,
|
||||
group_id,
|
||||
event,
|
||||
cb,
|
||||
redis_cls=self.redis_cls,
|
||||
**kwargs,
|
||||
)
|
||||
self._events_listener_thread = None
|
||||
self._events_dispatcher_thread = None
|
||||
self._messages_queue = queue.Queue()
|
||||
self._stop_events_listener_thread = threading.Event()
|
||||
|
||||
def stream_consumer(
|
||||
self,
|
||||
topics=None,
|
||||
pattern=None,
|
||||
group_id=None,
|
||||
event=None,
|
||||
cb=None,
|
||||
from_start=False,
|
||||
newest_only=False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Threaded stream consumer for redis streams.
|
||||
self.stream_keys = {}
|
||||
|
||||
Args:
|
||||
topics (str, list): topics to subscribe to
|
||||
pattern (str, list): pattern to subscribe to
|
||||
group_id (str): group id
|
||||
event (threading.Event): event to stop the consumer
|
||||
cb (function): callback function
|
||||
from_start (bool): read from start. Defaults to False.
|
||||
newest_only (bool): read only the newest message. Defaults to False.
|
||||
"""
|
||||
if cb is None:
|
||||
raise ValueError("The callback function must be specified.")
|
||||
|
||||
if pattern:
|
||||
raise ValueError("Pattern is currently not supported for stream consumer.")
|
||||
|
||||
if topics is None and pattern is None:
|
||||
raise ValueError("Topics must be set for stream consumer.")
|
||||
listener = RedisStreamConsumerThreaded(
|
||||
self.host,
|
||||
self.port,
|
||||
topics,
|
||||
pattern,
|
||||
group_id,
|
||||
event,
|
||||
cb,
|
||||
redis_cls=self.redis_cls,
|
||||
from_start=from_start,
|
||||
newest_only=newest_only,
|
||||
**kwargs,
|
||||
)
|
||||
self._threads.append(listener)
|
||||
return listener
|
||||
def shutdown(self):
|
||||
if self._events_listener_thread:
|
||||
self._stop_events_listener_thread.set()
|
||||
self._events_listener_thread.join()
|
||||
self._events_listener_thread = None
|
||||
if self._events_dispatcher_thread:
|
||||
self._messages_queue.put(StopIteration)
|
||||
self._events_dispatcher_thread.join()
|
||||
self._events_dispatcher_thread = None
|
||||
# release all connections
|
||||
self._pubsub_conn.close()
|
||||
self._redis_conn.close()
|
||||
|
||||
@catch_connection_error
|
||||
def log_warning(self, msg):
|
||||
"""send a warning"""
|
||||
self._notifications_producer.send(
|
||||
MessageEndpoints.log(), LogMessage(log_type="warning", log_msg=msg)
|
||||
)
|
||||
self.send(MessageEndpoints.log(), LogMessage(log_type="warning", log_msg=msg))
|
||||
|
||||
@catch_connection_error
|
||||
def log_message(self, msg):
|
||||
"""send a log message"""
|
||||
self._notifications_producer.send(
|
||||
MessageEndpoints.log(), LogMessage(log_type="log", log_msg=msg)
|
||||
)
|
||||
self.send(MessageEndpoints.log(), LogMessage(log_type="log", log_msg=msg))
|
||||
|
||||
@catch_connection_error
|
||||
def log_error(self, msg):
|
||||
"""send an error as log"""
|
||||
self._notifications_producer.send(
|
||||
MessageEndpoints.log(), LogMessage(log_type="error", log_msg=msg)
|
||||
)
|
||||
self.send(MessageEndpoints.log(), LogMessage(log_type="error", log_msg=msg))
|
||||
|
||||
@catch_connection_error
|
||||
def raise_alarm(self, severity: Alarms, alarm_type: str, source: str, msg: str, metadata: dict):
|
||||
def raise_alarm(
|
||||
self,
|
||||
severity: Alarms,
|
||||
alarm_type: str,
|
||||
source: str,
|
||||
msg: str,
|
||||
metadata: dict,
|
||||
):
|
||||
"""raise an alarm"""
|
||||
self._notifications_producer.set_and_publish(
|
||||
MessageEndpoints.alarm(),
|
||||
AlarmMessage(
|
||||
severity=severity, alarm_type=alarm_type, source=source, msg=msg, metadata=metadata
|
||||
),
|
||||
alarm_msg = AlarmMessage(
|
||||
severity=severity,
|
||||
alarm_type=alarm_type,
|
||||
source=source,
|
||||
msg=msg,
|
||||
metadata=metadata,
|
||||
)
|
||||
self.set_and_publish(MessageEndpoints.alarm(), alarm_msg)
|
||||
|
||||
def pipeline(self):
|
||||
"""Create a new pipeline"""
|
||||
return self._redis_conn.pipeline()
|
||||
|
||||
class RedisProducer(ProducerConnector):
|
||||
def __init__(self, host: str, port: int, redis_cls=None) -> None:
|
||||
# pylint: disable=invalid-name
|
||||
if redis_cls:
|
||||
self.r = redis_cls(host=host, port=port)
|
||||
return
|
||||
self.r = redis.Redis(host=host, port=port)
|
||||
self.stream_keys = {}
|
||||
|
||||
@catch_connection_error
|
||||
def execute_pipeline(self, pipeline):
|
||||
"""Execute the pipeline and returns the results with decoded BECMessages"""
|
||||
ret = []
|
||||
@ -197,7 +126,7 @@ class RedisProducer(ProducerConnector):
|
||||
@catch_connection_error
|
||||
def raw_send(self, topic: str, msg: bytes, pipe=None):
|
||||
"""send to redis without any check on message type"""
|
||||
client = pipe if pipe is not None else self.r
|
||||
client = pipe if pipe is not None else self._redis_conn
|
||||
client.publish(topic, msg)
|
||||
|
||||
def send(self, topic: str, msg: BECMessage, pipe=None) -> None:
|
||||
@ -206,6 +135,95 @@ class RedisProducer(ProducerConnector):
|
||||
raise TypeError(f"Message {msg} is not a BECMessage")
|
||||
self.raw_send(topic, MsgpackSerialization.dumps(msg), pipe)
|
||||
|
||||
def register(self, topics=None, patterns=None, cb=None, start_thread=True, **kwargs):
|
||||
if self._events_listener_thread is None:
|
||||
# create the thread that will get all messages for this connector;
|
||||
# under the hood, it uses asyncio - this lets the possibility to stop
|
||||
# the loop on demand
|
||||
self._events_listener_thread = threading.Thread(
|
||||
target=self._get_messages_loop,
|
||||
args=(self._pubsub_conn,),
|
||||
)
|
||||
self._events_listener_thread.start()
|
||||
# make a weakref from the callable, using louie;
|
||||
# it can create safe refs for simple functions as well as methods
|
||||
cb_ref = louie.saferef.safe_ref(cb)
|
||||
|
||||
if patterns is not None:
|
||||
if isinstance(patterns, str):
|
||||
patterns = [patterns]
|
||||
|
||||
self._pubsub_conn.psubscribe(patterns)
|
||||
for pattern in patterns:
|
||||
self._topics_cb[pattern].append((cb_ref, kwargs))
|
||||
else:
|
||||
if isinstance(topics, str):
|
||||
topics = [topics]
|
||||
|
||||
self._pubsub_conn.subscribe(topics)
|
||||
for topic in topics:
|
||||
self._topics_cb[topic].append((cb_ref, kwargs))
|
||||
|
||||
if start_thread and self._events_dispatcher_thread is None:
|
||||
# start dispatcher thread
|
||||
self._events_dispatcher_thread = threading.Thread(target=self.dispatch_events)
|
||||
self._events_dispatcher_thread.start()
|
||||
|
||||
def _get_messages_loop(self, pubsub) -> None:
|
||||
"""
|
||||
Start a listening coroutine to deal with redis events and wait for completion
|
||||
"""
|
||||
while not self._stop_events_listener_thread.is_set():
|
||||
try:
|
||||
msg = pubsub.get_message(timeout=1)
|
||||
except Exception:
|
||||
sys.excepthook(*sys.exc_info())
|
||||
else:
|
||||
if msg is not None:
|
||||
self._messages_queue.put(msg)
|
||||
|
||||
def _handle_message(self, msg):
|
||||
if msg["type"].endswith("subscribe"):
|
||||
# ignore subscribe messages
|
||||
return False
|
||||
channel = msg["channel"].decode()
|
||||
if msg["pattern"] is not None:
|
||||
callbacks = self._topics_cb[msg["pattern"].decode()]
|
||||
else:
|
||||
callbacks = self._topics_cb[channel]
|
||||
msg = MessageObject(
|
||||
topic=channel,
|
||||
value=MsgpackSerialization.loads(msg["data"]),
|
||||
)
|
||||
for cb_ref, kwargs in callbacks:
|
||||
cb = cb_ref()
|
||||
if cb:
|
||||
try:
|
||||
cb(msg, **kwargs)
|
||||
except Exception:
|
||||
sys.excepthook(*sys.exc_info())
|
||||
return True
|
||||
|
||||
def poll_messages(self, timeout=None) -> None:
|
||||
while True:
|
||||
try:
|
||||
msg = self._messages_queue.get(timeout=timeout)
|
||||
except queue.Empty:
|
||||
raise TimeoutError(
|
||||
f"{self}: poll_messages: did not receive a message within {timeout} seconds"
|
||||
)
|
||||
else:
|
||||
if msg is StopIteration:
|
||||
return False
|
||||
if self._handle_message(msg):
|
||||
return True
|
||||
else:
|
||||
continue
|
||||
|
||||
def dispatch_events(self):
|
||||
while self.poll_messages():
|
||||
...
|
||||
|
||||
@catch_connection_error
|
||||
def lpush(
|
||||
self, topic: str, msg: str, pipe=None, max_size: int = None, expire: int = None
|
||||
@ -229,7 +247,7 @@ class RedisProducer(ProducerConnector):
|
||||
|
||||
@catch_connection_error
|
||||
def lset(self, topic: str, index: int, msg: str, pipe=None) -> None:
|
||||
client = pipe if pipe is not None else self.r
|
||||
client = pipe if pipe is not None else self._redis_conn
|
||||
if isinstance(msg, BECMessage):
|
||||
msg = MsgpackSerialization.dumps(msg)
|
||||
return client.lset(topic, index, msg)
|
||||
@ -241,7 +259,7 @@ class RedisProducer(ProducerConnector):
|
||||
values at the tail of the list stored at key. If key does not exist,
|
||||
it is created as empty list before performing the push operation. When
|
||||
key holds a value that is not a list, an error is returned."""
|
||||
client = pipe if pipe is not None else self.r
|
||||
client = pipe if pipe is not None else self._redis_conn
|
||||
if isinstance(msg, BECMessage):
|
||||
msg = MsgpackSerialization.dumps(msg)
|
||||
return client.rpush(topic, msg)
|
||||
@ -254,7 +272,7 @@ class RedisProducer(ProducerConnector):
|
||||
of the list stored at key. The offsets start and stop are zero-based indexes,
|
||||
with 0 being the first element of the list (the head of the list), 1 being
|
||||
the next element and so on."""
|
||||
client = pipe if pipe is not None else self.r
|
||||
client = pipe if pipe is not None else self._redis_conn
|
||||
cmd_result = client.lrange(topic, start, end)
|
||||
if pipe:
|
||||
return cmd_result
|
||||
@ -268,23 +286,21 @@ class RedisProducer(ProducerConnector):
|
||||
ret.append(msg)
|
||||
return ret
|
||||
|
||||
@catch_connection_error
|
||||
def set_and_publish(self, topic: str, msg, pipe=None, expire: int = None) -> None:
|
||||
"""piped combination of self.publish and self.set"""
|
||||
client = pipe if pipe is not None else self.pipeline()
|
||||
if isinstance(msg, BECMessage):
|
||||
msg = MsgpackSerialization.dumps(msg)
|
||||
client.publish(topic, msg)
|
||||
client.set(topic, msg)
|
||||
if expire:
|
||||
client.expire(topic, expire)
|
||||
if not isinstance(msg, BECMessage):
|
||||
raise TypeError(f"Message {msg} is not a BECMessage")
|
||||
msg = MsgpackSerialization.dumps(msg)
|
||||
self.set(topic, msg, pipe=client, expire=expire)
|
||||
self.raw_send(topic, msg, pipe=client)
|
||||
if not pipe:
|
||||
client.execute()
|
||||
|
||||
@catch_connection_error
|
||||
def set(self, topic: str, msg, pipe=None, expire: int = None) -> None:
|
||||
"""set redis value"""
|
||||
client = pipe if pipe is not None else self.r
|
||||
client = pipe if pipe is not None else self._redis_conn
|
||||
if isinstance(msg, BECMessage):
|
||||
msg = MsgpackSerialization.dumps(msg)
|
||||
client.set(topic, msg, ex=expire)
|
||||
@ -292,23 +308,18 @@ class RedisProducer(ProducerConnector):
|
||||
@catch_connection_error
|
||||
def keys(self, pattern: str) -> list:
|
||||
"""returns all keys matching a pattern"""
|
||||
return self.r.keys(pattern)
|
||||
|
||||
@catch_connection_error
|
||||
def pipeline(self):
|
||||
"""create a new pipeline"""
|
||||
return self.r.pipeline()
|
||||
return self._redis_conn.keys(pattern)
|
||||
|
||||
@catch_connection_error
|
||||
def delete(self, topic, pipe=None):
|
||||
"""delete topic"""
|
||||
client = pipe if pipe is not None else self.r
|
||||
client = pipe if pipe is not None else self._redis_conn
|
||||
client.delete(topic)
|
||||
|
||||
@catch_connection_error
|
||||
def get(self, topic: str, pipe=None):
|
||||
"""retrieve entry, either via hgetall or get"""
|
||||
client = pipe if pipe is not None else self.r
|
||||
client = pipe if pipe is not None else self._redis_conn
|
||||
data = client.get(topic)
|
||||
if pipe:
|
||||
return data
|
||||
@ -339,7 +350,7 @@ class RedisProducer(ProducerConnector):
|
||||
elif expire:
|
||||
client = self.pipeline()
|
||||
else:
|
||||
client = self.r
|
||||
client = self._redis_conn
|
||||
|
||||
for key, msg in msg_dict.items():
|
||||
msg_dict[key] = MsgpackSerialization.dumps(msg)
|
||||
@ -356,7 +367,7 @@ class RedisProducer(ProducerConnector):
|
||||
@catch_connection_error
|
||||
def get_last(self, topic: str, key="data"):
|
||||
"""retrieve last entry from stream"""
|
||||
client = self.r
|
||||
client = self._redis_conn
|
||||
try:
|
||||
_, msg_dict = client.xrevrange(topic, "+", "-", count=1)[0]
|
||||
except TypeError:
|
||||
@ -370,7 +381,12 @@ class RedisProducer(ProducerConnector):
|
||||
|
||||
@catch_connection_error
|
||||
def xread(
|
||||
self, topic: str, id: str = None, count: int = None, block: int = None, from_start=False
|
||||
self,
|
||||
topic: str,
|
||||
id: str = None,
|
||||
count: int = None,
|
||||
block: int = None,
|
||||
from_start=False,
|
||||
) -> list:
|
||||
"""
|
||||
read from stream
|
||||
@ -395,13 +411,13 @@ class RedisProducer(ProducerConnector):
|
||||
>>> key = msg[0][1][0][0]
|
||||
>>> next_msg = redis.xread("test", key, count=1)
|
||||
"""
|
||||
client = self.r
|
||||
client = self._redis_conn
|
||||
if from_start:
|
||||
self.stream_keys[topic] = "0-0"
|
||||
if topic not in self.stream_keys:
|
||||
if id is None:
|
||||
try:
|
||||
msg = self.r.xrevrange(topic, "+", "-", count=1)
|
||||
msg = client.xrevrange(topic, "+", "-", count=1)
|
||||
if msg:
|
||||
self.stream_keys[topic] = msg[0][0]
|
||||
out = {}
|
||||
@ -438,7 +454,7 @@ class RedisProducer(ProducerConnector):
|
||||
max (str): max id. Use "+" to read to end
|
||||
count (int, optional): number of messages to read. Defaults to None.
|
||||
"""
|
||||
client = self.r
|
||||
client = self._redis_conn
|
||||
msgs = []
|
||||
for reading in client.xrange(topic, min, max, count=count):
|
||||
index, msg_dict = reading
|
||||
@ -446,270 +462,3 @@ class RedisProducer(ProducerConnector):
|
||||
{k.decode(): MsgpackSerialization.loads(msg) for k, msg in msg_dict.items()}
|
||||
)
|
||||
return msgs
|
||||
|
||||
|
||||
class RedisConsumerMixin:
|
||||
def _init_topics_and_pattern(self, topics, pattern):
|
||||
if topics:
|
||||
if not isinstance(topics, list):
|
||||
topics = [topics]
|
||||
if pattern:
|
||||
if not isinstance(pattern, list):
|
||||
pattern = [pattern]
|
||||
return topics, pattern
|
||||
|
||||
def _init_redis_cls(self, redis_cls):
|
||||
# pylint: disable=invalid-name
|
||||
if redis_cls:
|
||||
self.r = redis_cls(host=self.host, port=self.port)
|
||||
else:
|
||||
self.r = redis.Redis(host=self.host, port=self.port)
|
||||
|
||||
@catch_connection_error
|
||||
def initialize_connector(self) -> None:
|
||||
if self.pattern is not None:
|
||||
self.pubsub.psubscribe(self.pattern)
|
||||
else:
|
||||
self.pubsub.subscribe(self.topics)
|
||||
|
||||
|
||||
class RedisConsumer(RedisConsumerMixin, ConsumerConnector):
|
||||
# pylint: disable=too-many-arguments
|
||||
def __init__(
|
||||
self,
|
||||
host,
|
||||
port,
|
||||
topics=None,
|
||||
pattern=None,
|
||||
group_id=None,
|
||||
event=None,
|
||||
cb=None,
|
||||
redis_cls=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.host = host
|
||||
self.port = port
|
||||
|
||||
bootstrap_server = "".join([host, ":", port])
|
||||
topics, pattern = self._init_topics_and_pattern(topics, pattern)
|
||||
super().__init__(
|
||||
bootstrap_server=bootstrap_server,
|
||||
topics=topics,
|
||||
pattern=pattern,
|
||||
group_id=group_id,
|
||||
event=event,
|
||||
cb=cb,
|
||||
**kwargs,
|
||||
)
|
||||
self.error_message_sent = False
|
||||
self._init_redis_cls(redis_cls)
|
||||
self.pubsub = self.r.pubsub()
|
||||
self.initialize_connector()
|
||||
|
||||
@catch_connection_error
|
||||
def poll_messages(self) -> None:
|
||||
"""
|
||||
Poll messages from self.connector and call the callback function self.cb
|
||||
"""
|
||||
message = self.pubsub.get_message(ignore_subscribe_messages=True)
|
||||
if message is not None:
|
||||
msg = MessageObject(
|
||||
topic=message["channel"], value=MsgpackSerialization.loads(message["data"])
|
||||
)
|
||||
return self.cb(msg, **self.kwargs)
|
||||
|
||||
time.sleep(0.01)
|
||||
return None
|
||||
|
||||
def shutdown(self):
|
||||
"""shutdown the consumer"""
|
||||
self.pubsub.close()
|
||||
|
||||
|
||||
class RedisStreamConsumerThreaded(RedisConsumerMixin, ConsumerConnectorThreaded):
|
||||
# pylint: disable=too-many-arguments
|
||||
def __init__(
|
||||
self,
|
||||
host,
|
||||
port,
|
||||
topics=None,
|
||||
pattern=None,
|
||||
group_id=None,
|
||||
event=None,
|
||||
cb=None,
|
||||
redis_cls=None,
|
||||
from_start=False,
|
||||
newest_only=False,
|
||||
**kwargs,
|
||||
):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.from_start = from_start
|
||||
self.newest_only = newest_only
|
||||
|
||||
bootstrap_server = "".join([host, ":", port])
|
||||
topics, pattern = self._init_topics_and_pattern(topics, pattern)
|
||||
super().__init__(
|
||||
bootstrap_server=bootstrap_server,
|
||||
topics=topics,
|
||||
pattern=pattern,
|
||||
group_id=group_id,
|
||||
event=event,
|
||||
cb=cb,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self._init_redis_cls(redis_cls)
|
||||
|
||||
self.sleep_times = [0.005, 0.1]
|
||||
self.last_received_msg = 0
|
||||
self.idle_time = 30
|
||||
self.error_message_sent = False
|
||||
self.stream_keys = {}
|
||||
|
||||
def initialize_connector(self) -> None:
|
||||
pass
|
||||
|
||||
def _init_topics_and_pattern(self, topics, pattern):
|
||||
if topics:
|
||||
if not isinstance(topics, list):
|
||||
topics = [topics]
|
||||
if pattern:
|
||||
if not isinstance(pattern, list):
|
||||
pattern = [pattern]
|
||||
return topics, pattern
|
||||
|
||||
def get_id(self, topic: str) -> str:
|
||||
"""
|
||||
Get the stream key for the given topic.
|
||||
|
||||
Args:
|
||||
topic (str): topic to get the stream key for
|
||||
"""
|
||||
if topic not in self.stream_keys:
|
||||
return "0-0"
|
||||
return self.stream_keys.get(topic)
|
||||
|
||||
def get_newest_message(self, container: list, append=True) -> None:
|
||||
"""
|
||||
Get the newest message from the stream and update the stream key. If
|
||||
append is True, append the message to the container.
|
||||
|
||||
Args:
|
||||
container (list): container to append the message to
|
||||
append (bool, optional): append to container. Defaults to True.
|
||||
"""
|
||||
for topic in self.topics:
|
||||
msg = self.r.xrevrange(topic, "+", "-", count=1)
|
||||
if msg:
|
||||
if append:
|
||||
container.append((topic, msg[0][1]))
|
||||
self.stream_keys[topic] = msg[0][0]
|
||||
else:
|
||||
self.stream_keys[topic] = "0-0"
|
||||
|
||||
@catch_connection_error
|
||||
def poll_messages(self) -> None:
|
||||
"""
|
||||
Poll messages from self.connector and call the callback function self.cb
|
||||
|
||||
"""
|
||||
if self.pattern is not None:
|
||||
topics = [key.decode() for key in self.r.scan_iter(match=self.pattern, _type="stream")]
|
||||
else:
|
||||
topics = self.topics
|
||||
messages = []
|
||||
if self.newest_only:
|
||||
self.get_newest_message(messages)
|
||||
elif not self.from_start and not self.stream_keys:
|
||||
self.get_newest_message(messages, append=False)
|
||||
else:
|
||||
streams = {topic: self.get_id(topic) for topic in topics}
|
||||
read_msgs = self.r.xread(streams, count=1)
|
||||
if read_msgs:
|
||||
for msg in read_msgs:
|
||||
topic = msg[0].decode()
|
||||
messages.append((topic, msg[1][0][1]))
|
||||
self.stream_keys[topic] = msg[1][-1][0]
|
||||
|
||||
if messages:
|
||||
if MessageEndpoints.log() not in topics:
|
||||
# no need to update the update frequency just for logs
|
||||
self.last_received_msg = time.time()
|
||||
for topic, msg in messages:
|
||||
try:
|
||||
msg = MsgpackSerialization.loads(msg[b"data"])
|
||||
except RuntimeError:
|
||||
msg = msg[b"data"]
|
||||
msg_obj = MessageObject(topic=topic, value=msg)
|
||||
self.cb(msg_obj, **self.kwargs)
|
||||
else:
|
||||
sleep_time = int(bool(time.time() - self.last_received_msg > self.idle_time))
|
||||
if self.sleep_times[sleep_time]:
|
||||
time.sleep(self.sleep_times[sleep_time])
|
||||
|
||||
|
||||
class RedisConsumerThreaded(RedisConsumerMixin, ConsumerConnectorThreaded):
|
||||
# pylint: disable=too-many-arguments
|
||||
def __init__(
|
||||
self,
|
||||
host,
|
||||
port,
|
||||
topics=None,
|
||||
pattern=None,
|
||||
group_id=None,
|
||||
event=None,
|
||||
cb=None,
|
||||
redis_cls=None,
|
||||
name=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.host = host
|
||||
self.port = port
|
||||
|
||||
bootstrap_server = "".join([host, ":", port])
|
||||
topics, pattern = self._init_topics_and_pattern(topics, pattern)
|
||||
super().__init__(
|
||||
bootstrap_server=bootstrap_server,
|
||||
topics=topics,
|
||||
pattern=pattern,
|
||||
group_id=group_id,
|
||||
event=event,
|
||||
cb=cb,
|
||||
name=name,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self._init_redis_cls(redis_cls)
|
||||
self.pubsub = self.r.pubsub()
|
||||
|
||||
self.sleep_times = [0.005, 0.1]
|
||||
self.last_received_msg = 0
|
||||
self.idle_time = 30
|
||||
self.error_message_sent = False
|
||||
|
||||
@catch_connection_error
|
||||
def poll_messages(self) -> None:
|
||||
"""
|
||||
Poll messages from self.connector and call the callback function self.cb
|
||||
|
||||
Note: pubsub messages are supposed to be BECMessage objects only
|
||||
"""
|
||||
messages = self.pubsub.get_message(ignore_subscribe_messages=True)
|
||||
if messages is not None:
|
||||
if f"{MessageEndpoints.log()}".encode() not in messages["channel"]:
|
||||
# no need to update the update frequency just for logs
|
||||
self.last_received_msg = time.time()
|
||||
msg = MessageObject(
|
||||
topic=messages["channel"].decode(),
|
||||
value=MsgpackSerialization.loads(messages["data"]),
|
||||
)
|
||||
self.cb(msg, **self.kwargs)
|
||||
else:
|
||||
sleep_time = int(bool(time.time() - self.last_received_msg > self.idle_time))
|
||||
if self.sleep_times[sleep_time]:
|
||||
time.sleep(self.sleep_times[sleep_time])
|
||||
|
||||
def shutdown(self):
|
||||
super().shutdown()
|
||||
self.pubsub.close()
|
||||
|
@ -46,7 +46,7 @@ class ScanItem:
|
||||
self.data = ScanData()
|
||||
self.async_data = {}
|
||||
self.baseline = ScanData()
|
||||
self._async_data_handler = AsyncDataHandler(scan_manager.producer)
|
||||
self._async_data_handler = AsyncDataHandler(scan_manager.connector)
|
||||
self.open_scan_defs = set()
|
||||
self.open_queue_group = None
|
||||
self.num_points = None
|
||||
|
@ -25,44 +25,31 @@ class ScanManager:
|
||||
connector (BECConnector): BECConnector instance
|
||||
"""
|
||||
self.connector = connector
|
||||
self.producer = self.connector.producer()
|
||||
self.queue_storage = QueueStorage(scan_manager=self)
|
||||
self.request_storage = RequestStorage(scan_manager=self)
|
||||
self.scan_storage = ScanStorage(scan_manager=self)
|
||||
|
||||
self._scan_queue_consumer = self.connector.consumer(
|
||||
self.connector.register(
|
||||
topics=MessageEndpoints.scan_queue_status(),
|
||||
cb=self._scan_queue_status_callback,
|
||||
parent=self,
|
||||
)
|
||||
self._scan_queue_request_consumer = self.connector.consumer(
|
||||
self.connector.register(
|
||||
topics=MessageEndpoints.scan_queue_request(),
|
||||
cb=self._scan_queue_request_callback,
|
||||
parent=self,
|
||||
)
|
||||
self._scan_queue_request_response_consumer = self.connector.consumer(
|
||||
self.connector.register(
|
||||
topics=MessageEndpoints.scan_queue_request_response(),
|
||||
cb=self._scan_queue_request_response_callback,
|
||||
parent=self,
|
||||
)
|
||||
self._scan_status_consumer = self.connector.consumer(
|
||||
topics=MessageEndpoints.scan_status(), cb=self._scan_status_callback, parent=self
|
||||
self.connector.register(
|
||||
topics=MessageEndpoints.scan_status(), cb=self._scan_status_callback
|
||||
)
|
||||
|
||||
self._scan_segment_consumer = self.connector.consumer(
|
||||
topics=MessageEndpoints.scan_segment(), cb=self._scan_segment_callback, parent=self
|
||||
self.connector.register(
|
||||
topics=MessageEndpoints.scan_segment(), cb=self._scan_segment_callback
|
||||
)
|
||||
|
||||
self._baseline_consumer = self.connector.consumer(
|
||||
topics=MessageEndpoints.scan_baseline(), cb=self._baseline_callback, parent=self
|
||||
)
|
||||
|
||||
self._scan_queue_consumer.start()
|
||||
self._scan_queue_request_consumer.start()
|
||||
self._scan_queue_request_response_consumer.start()
|
||||
self._scan_status_consumer.start()
|
||||
self._scan_segment_consumer.start()
|
||||
self._baseline_consumer.start()
|
||||
self.connector.register(topics=MessageEndpoints.scan_baseline(), cb=self._baseline_callback)
|
||||
|
||||
def update_with_queue_status(self, queue: messages.ScanQueueStatusMessage) -> None:
|
||||
"""update storage with a new queue status message"""
|
||||
@ -84,7 +71,7 @@ class ScanManager:
|
||||
|
||||
action = "deferred_pause" if deferred_pause else "pause"
|
||||
logger.info(f"Requesting {action}")
|
||||
return self.producer.send(
|
||||
return self.connector.send(
|
||||
MessageEndpoints.scan_queue_modification_request(),
|
||||
messages.ScanQueueModificationMessage(scanID=scanID, action=action, parameter={}),
|
||||
)
|
||||
@ -99,7 +86,7 @@ class ScanManager:
|
||||
if scanID is None:
|
||||
scanID = self.scan_storage.current_scanID
|
||||
logger.info("Requesting scan abortion")
|
||||
self.producer.send(
|
||||
self.connector.send(
|
||||
MessageEndpoints.scan_queue_modification_request(),
|
||||
messages.ScanQueueModificationMessage(scanID=scanID, action="abort", parameter={}),
|
||||
)
|
||||
@ -114,7 +101,7 @@ class ScanManager:
|
||||
if scanID is None:
|
||||
scanID = self.scan_storage.current_scanID
|
||||
logger.info("Requesting scan halt")
|
||||
self.producer.send(
|
||||
self.connector.send(
|
||||
MessageEndpoints.scan_queue_modification_request(),
|
||||
messages.ScanQueueModificationMessage(scanID=scanID, action="halt", parameter={}),
|
||||
)
|
||||
@ -129,7 +116,7 @@ class ScanManager:
|
||||
if scanID is None:
|
||||
scanID = self.scan_storage.current_scanID
|
||||
logger.info("Requesting scan continuation")
|
||||
self.producer.send(
|
||||
self.connector.send(
|
||||
MessageEndpoints.scan_queue_modification_request(),
|
||||
messages.ScanQueueModificationMessage(scanID=scanID, action="continue", parameter={}),
|
||||
)
|
||||
@ -137,7 +124,7 @@ class ScanManager:
|
||||
def request_queue_reset(self):
|
||||
"""request a scan queue reset"""
|
||||
logger.info("Requesting a queue reset")
|
||||
self.producer.send(
|
||||
self.connector.send(
|
||||
MessageEndpoints.scan_queue_modification_request(),
|
||||
messages.ScanQueueModificationMessage(scanID=None, action="clear", parameter={}),
|
||||
)
|
||||
@ -151,7 +138,7 @@ class ScanManager:
|
||||
logger.info("Requesting to abort and repeat a scan")
|
||||
position = "replace" if replace else "append"
|
||||
|
||||
self.producer.send(
|
||||
self.connector.send(
|
||||
MessageEndpoints.scan_queue_modification_request(),
|
||||
messages.ScanQueueModificationMessage(
|
||||
scanID=scanID, action="restart", parameter={"position": position, "RID": requestID}
|
||||
@ -162,7 +149,7 @@ class ScanManager:
|
||||
@property
|
||||
def next_scan_number(self):
|
||||
"""get the next scan number from redis"""
|
||||
num = self.producer.get(MessageEndpoints.scan_number())
|
||||
num = self.connector.get(MessageEndpoints.scan_number())
|
||||
if num is None:
|
||||
logger.warning("Failed to retrieve scan number from redis.")
|
||||
return -1
|
||||
@ -172,63 +159,51 @@ class ScanManager:
|
||||
@typechecked
|
||||
def next_scan_number(self, val: int):
|
||||
"""set the next scan number in redis"""
|
||||
return self.producer.set(MessageEndpoints.scan_number(), val)
|
||||
return self.connector.set(MessageEndpoints.scan_number(), val)
|
||||
|
||||
@property
|
||||
def next_dataset_number(self):
|
||||
"""get the next dataset number from redis"""
|
||||
return int(self.producer.get(MessageEndpoints.dataset_number()))
|
||||
return int(self.connector.get(MessageEndpoints.dataset_number()))
|
||||
|
||||
@next_dataset_number.setter
|
||||
@typechecked
|
||||
def next_dataset_number(self, val: int):
|
||||
"""set the next dataset number in redis"""
|
||||
return self.producer.set(MessageEndpoints.dataset_number(), val)
|
||||
return self.connector.set(MessageEndpoints.dataset_number(), val)
|
||||
|
||||
@staticmethod
|
||||
def _scan_queue_status_callback(msg, *, parent: ScanManager, **_kwargs) -> None:
|
||||
def _scan_queue_status_callback(self, msg, **_kwargs) -> None:
|
||||
queue_status = msg.value
|
||||
if not queue_status:
|
||||
return
|
||||
parent.update_with_queue_status(queue_status)
|
||||
self.update_with_queue_status(queue_status)
|
||||
|
||||
@staticmethod
|
||||
def _scan_queue_request_callback(msg, *, parent: ScanManager, **_kwargs) -> None:
|
||||
def _scan_queue_request_callback(self, msg, **_kwargs) -> None:
|
||||
request = msg.value
|
||||
parent.request_storage.update_with_request(request)
|
||||
self.request_storage.update_with_request(request)
|
||||
|
||||
@staticmethod
|
||||
def _scan_queue_request_response_callback(msg, *, parent: ScanManager, **_kwargs) -> None:
|
||||
def _scan_queue_request_response_callback(self, msg, **_kwargs) -> None:
|
||||
response = msg.value
|
||||
logger.debug(response)
|
||||
parent.request_storage.update_with_response(response)
|
||||
self.request_storage.update_with_response(response)
|
||||
|
||||
@staticmethod
|
||||
def _scan_status_callback(msg, *, parent: ScanManager, **_kwargs) -> None:
|
||||
def _scan_status_callback(self, msg, **_kwargs) -> None:
|
||||
scan = msg.value
|
||||
parent.scan_storage.update_with_scan_status(scan)
|
||||
self.scan_storage.update_with_scan_status(scan)
|
||||
|
||||
@staticmethod
|
||||
def _scan_segment_callback(msg, *, parent: ScanManager, **_kwargs) -> None:
|
||||
def _scan_segment_callback(self, msg, **_kwargs) -> None:
|
||||
scan_msgs = msg.value
|
||||
if not isinstance(scan_msgs, list):
|
||||
scan_msgs = [scan_msgs]
|
||||
for scan_msg in scan_msgs:
|
||||
parent.scan_storage.add_scan_segment(scan_msg)
|
||||
self.scan_storage.add_scan_segment(scan_msg)
|
||||
|
||||
@staticmethod
|
||||
def _baseline_callback(msg, *, parent: ScanManager, **_kwargs) -> None:
|
||||
def _baseline_callback(self, msg, **_kwargs) -> None:
|
||||
msg = msg.value
|
||||
parent.scan_storage.add_scan_baseline(msg)
|
||||
self.scan_storage.add_scan_baseline(msg)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "\n".join(self.queue_storage.describe_queue())
|
||||
|
||||
def shutdown(self):
|
||||
"""stop the scan manager's threads"""
|
||||
self._scan_queue_consumer.shutdown()
|
||||
self._scan_queue_request_consumer.shutdown()
|
||||
self._scan_queue_request_response_consumer.shutdown()
|
||||
self._scan_status_consumer.shutdown()
|
||||
self._scan_segment_consumer.shutdown()
|
||||
self._baseline_consumer.shutdown()
|
||||
pass
|
||||
|
@ -89,7 +89,7 @@ class ScanReport:
|
||||
def _get_mv_status(self) -> bool:
|
||||
"""get the status of a move request"""
|
||||
motors = list(self.request.request.content["parameter"]["args"].keys())
|
||||
request_status = self._client.device_manager.producer.lrange(
|
||||
request_status = self._client.device_manager.connector.lrange(
|
||||
MessageEndpoints.device_req_status(self.request.requestID), 0, -1
|
||||
)
|
||||
if len(request_status) == len(motors):
|
||||
|
@ -100,9 +100,9 @@ class ScanObject:
|
||||
return None
|
||||
return self.scan_info.get("scan_report_hint")
|
||||
|
||||
def _start_consumer(self, request: messages.ScanQueueMessage) -> ConsumerConnector:
|
||||
"""Start a consumer for the given request"""
|
||||
consumer = self.client.device_manager.connector.consumer(
|
||||
def _start_register(self, request: messages.ScanQueueMessage) -> ConsumerConnector:
|
||||
"""Start a register for the given request"""
|
||||
register = self.client.device_manager.connector.register(
|
||||
[
|
||||
MessageEndpoints.device_readback(dev)
|
||||
for dev in request.content["parameter"]["args"].keys()
|
||||
@ -110,11 +110,11 @@ class ScanObject:
|
||||
threaded=False,
|
||||
cb=(lambda msg: msg),
|
||||
)
|
||||
return consumer
|
||||
return register
|
||||
|
||||
def _send_scan_request(self, request: messages.ScanQueueMessage) -> None:
|
||||
"""Send a scan request to the scan server"""
|
||||
self.client.device_manager.producer.send(MessageEndpoints.scan_queue_request(), request)
|
||||
self.client.device_manager.connector.send(MessageEndpoints.scan_queue_request(), request)
|
||||
|
||||
|
||||
class Scans:
|
||||
@ -136,7 +136,7 @@ class Scans:
|
||||
|
||||
def _import_scans(self):
|
||||
"""Import scans from the scan server"""
|
||||
available_scans = self.parent.producer.get(MessageEndpoints.available_scans())
|
||||
available_scans = self.parent.connector.get(MessageEndpoints.available_scans())
|
||||
if available_scans is None:
|
||||
logger.warning("No scans available. Are redis and the BEC server running?")
|
||||
return
|
||||
|
@ -14,7 +14,9 @@ logger = bec_logger.logger
|
||||
|
||||
DEFAULT_SERVICE_CONFIG = {
|
||||
"redis": {"host": "localhost", "port": 6379},
|
||||
"service_config": {"file_writer": {"plugin": "default_NeXus_format", "base_path": "./"}},
|
||||
"service_config": {
|
||||
"file_writer": {"plugin": "default_NeXus_format", "base_path": os.path.dirname(__file__)}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@ -32,7 +34,7 @@ class ServiceConfig:
|
||||
self._update_config(service_config=config, redis=redis)
|
||||
|
||||
self.service_config = self.config.get(
|
||||
"service_config", {"file_writer": {"plugin": "default_NeXus_format", "base_path": "./"}}
|
||||
"service_config", DEFAULT_SERVICE_CONFIG["service_config"]
|
||||
)
|
||||
|
||||
def _update_config(self, **kwargs):
|
||||
|
@ -56,7 +56,7 @@ def queue_is_empty(queue) -> bool: # pragma: no cover
|
||||
|
||||
|
||||
def get_queue(bec): # pragma: no cover
|
||||
return bec.queue.producer.get(MessageEndpoints.scan_queue_status())
|
||||
return bec.queue.connector.get(MessageEndpoints.scan_queue_status())
|
||||
|
||||
|
||||
def wait_for_empty_queue(bec): # pragma: no cover
|
||||
@ -484,7 +484,6 @@ def bec_client():
|
||||
with open(f"{dir_path}/tests/test_config.yaml", "r", encoding="utf-8") as f:
|
||||
builtins.__dict__["test_session"] = create_session_from_config(yaml.safe_load(f))
|
||||
device_manager._session = builtins.__dict__["test_session"]
|
||||
device_manager.producer = device_manager.connector.producer()
|
||||
client.wait_for_service = lambda service_name: None
|
||||
device_manager._load_session()
|
||||
for name, dev in device_manager.devices.items():
|
||||
@ -497,37 +496,23 @@ def bec_client():
|
||||
|
||||
class PipelineMock: # pragma: no cover
|
||||
_pipe_buffer = []
|
||||
_producer = None
|
||||
_connector = None
|
||||
|
||||
def __init__(self, producer) -> None:
|
||||
self._producer = producer
|
||||
def __init__(self, connector) -> None:
|
||||
self._connector = connector
|
||||
|
||||
def execute(self):
|
||||
if not self._producer.store_data:
|
||||
if not self._connector.store_data:
|
||||
self._pipe_buffer = []
|
||||
return []
|
||||
res = [
|
||||
getattr(self._producer, method)(*args, **kwargs)
|
||||
getattr(self._connector, method)(*args, **kwargs)
|
||||
for method, args, kwargs in self._pipe_buffer
|
||||
]
|
||||
self._pipe_buffer = []
|
||||
return res
|
||||
|
||||
|
||||
class ConsumerMock: # pragma: no cover
|
||||
def __init__(self) -> None:
|
||||
self.signal_event = SignalMock()
|
||||
|
||||
def start(self):
|
||||
pass
|
||||
|
||||
def join(self):
|
||||
pass
|
||||
|
||||
def shutdown(self):
|
||||
pass
|
||||
|
||||
|
||||
class SignalMock: # pragma: no cover
|
||||
def __init__(self) -> None:
|
||||
self.is_set = False
|
||||
@ -536,12 +521,36 @@ class SignalMock: # pragma: no cover
|
||||
self.is_set = True
|
||||
|
||||
|
||||
class ProducerMock: # pragma: no cover
|
||||
def __init__(self, store_data=True) -> None:
|
||||
class ConnectorMock(ConnectorBase): # pragma: no cover
|
||||
def __init__(self, bootstrap_server="localhost:0000", store_data=True):
|
||||
super().__init__(bootstrap_server)
|
||||
self.message_sent = []
|
||||
self._get_buffer = {}
|
||||
self.store_data = store_data
|
||||
|
||||
def raise_alarm(
|
||||
self, severity: Alarms, alarm_type: str, source: str, msg: dict, metadata: dict
|
||||
):
|
||||
pass
|
||||
|
||||
def log_error(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def shutdown(self):
|
||||
pass
|
||||
|
||||
def register(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def set(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def set_and_publish(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def keys(self, *args, **kwargs):
|
||||
return []
|
||||
|
||||
def set(self, topic, msg, pipe=None, expire: int = None):
|
||||
if pipe:
|
||||
pipe._pipe_buffer.append(("set", (topic, msg), {"expire": expire}))
|
||||
@ -592,9 +601,6 @@ class ProducerMock: # pragma: no cover
|
||||
self._get_buffer.pop(topic, None)
|
||||
return val
|
||||
|
||||
def keys(self, pattern: str) -> list:
|
||||
return []
|
||||
|
||||
def pipeline(self):
|
||||
return PipelineMock(self)
|
||||
|
||||
@ -609,29 +615,6 @@ class ProducerMock: # pragma: no cover
|
||||
return
|
||||
|
||||
|
||||
class ConnectorMock(ConnectorBase): # pragma: no cover
|
||||
def __init__(self, bootstrap_server: list, store_data=True):
|
||||
super().__init__(bootstrap_server)
|
||||
self.store_data = store_data
|
||||
|
||||
def consumer(self, *args, **kwargs) -> ConsumerMock:
|
||||
return ConsumerMock()
|
||||
|
||||
def producer(self, *args, **kwargs):
|
||||
return ProducerMock(self.store_data)
|
||||
|
||||
def raise_alarm(
|
||||
self, severity: Alarms, alarm_type: str, source: str, msg: dict, metadata: dict
|
||||
):
|
||||
pass
|
||||
|
||||
def log_error(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def shutdown(self):
|
||||
pass
|
||||
|
||||
|
||||
def create_session_from_config(config: dict) -> dict:
|
||||
device_configs = []
|
||||
session_id = str(uuid.uuid4())
|
||||
|
@ -5,6 +5,8 @@ __version__ = "1.12.1"
|
||||
if __name__ == "__main__":
|
||||
setup(
|
||||
install_requires=[
|
||||
"hiredis",
|
||||
"louie",
|
||||
"numpy",
|
||||
"scipy",
|
||||
"msgpack",
|
||||
@ -22,7 +24,16 @@ if __name__ == "__main__":
|
||||
"lmfit",
|
||||
],
|
||||
extras_require={
|
||||
"dev": ["pytest", "pytest-random-order", "coverage", "pandas", "black", "pylint"]
|
||||
"dev": [
|
||||
"pytest",
|
||||
"pytest-random-order",
|
||||
"pytest-redis",
|
||||
"pytest-timeout",
|
||||
"coverage",
|
||||
"pandas",
|
||||
"black",
|
||||
"pylint",
|
||||
]
|
||||
},
|
||||
entry_points={"console_scripts": ["bec-channel-monitor = bec_lib:channel_monitor_launch"]},
|
||||
package_data={"bec_lib.tests": ["*.yaml"], "bec_lib.configs": ["*.yaml", "*.json"]},
|
||||
|
@ -17,7 +17,7 @@ def test_bec_widgets_connector_set_plot_config(bec_client):
|
||||
config = {"x": "test", "y": "test", "color": "test", "size": "test", "shape": "test"}
|
||||
connector.set_plot_config(plot_id="plot_id", config=config)
|
||||
msg = messages.GUIConfigMessage(config=config)
|
||||
bec_client.connector.producer().set_and_publish.assert_called_once_with(
|
||||
bec_client.connector.set_and_publish.assert_called_once_with(
|
||||
MessageEndpoints.gui_config("plot_id"), msg
|
||||
) is None
|
||||
|
||||
@ -26,7 +26,7 @@ def test_bec_widgets_connector_close(bec_client):
|
||||
connector = BECWidgetsConnector(gui_id="gui_id", bec_client=bec_client)
|
||||
connector.close("plot_id")
|
||||
msg = messages.GUIInstructionMessage(action="close", parameter={})
|
||||
bec_client.connector.producer().set_and_publish.assert_called_once_with(
|
||||
bec_client.connector.set_and_publish.assert_called_once_with(
|
||||
MessageEndpoints.gui_instructions("plot_id"), msg
|
||||
)
|
||||
|
||||
@ -36,7 +36,7 @@ def test_bec_widgets_connector_send_data(bec_client):
|
||||
data = {"x": [1, 2, 3], "y": [1, 2, 3]}
|
||||
connector.send_data("plot_id", data)
|
||||
msg = messages.GUIDataMessage(data=data)
|
||||
bec_client.connector.producer().set_and_publish.assert_called_once_with(
|
||||
bec_client.connector.set_and_publish.assert_called_once_with(
|
||||
topic=MessageEndpoints.gui_data("plot_id"), msg=msg
|
||||
)
|
||||
|
||||
@ -45,7 +45,7 @@ def test_bec_widgets_connector_clear(bec_client):
|
||||
connector = BECWidgetsConnector(gui_id="gui_id", bec_client=bec_client)
|
||||
connector.clear("plot_id")
|
||||
msg = messages.GUIInstructionMessage(action="clear", parameter={})
|
||||
bec_client.connector.producer().set_and_publish.assert_called_once_with(
|
||||
bec_client.connector.set_and_publish.assert_called_once_with(
|
||||
MessageEndpoints.gui_instructions("plot_id"), msg
|
||||
)
|
||||
|
||||
|
@ -124,8 +124,8 @@ def test_bec_service_update_existing_services():
|
||||
messages.StatusMessage(name="service2", status=BECStatus.IDLE, info={}, metadata={}),
|
||||
]
|
||||
connector_cls = mock.MagicMock()
|
||||
connector_cls().producer().keys.return_value = service_keys
|
||||
connector_cls().producer().get.side_effect = [msg for msg in service_msgs]
|
||||
connector_cls().keys.return_value = service_keys
|
||||
connector_cls().get.side_effect = [msg for msg in service_msgs]
|
||||
service = BECService(
|
||||
config=f"{os.path.dirname(bec_lib.__file__)}/tests/test_service_config.yaml",
|
||||
connector_cls=connector_cls,
|
||||
@ -144,8 +144,8 @@ def test_bec_service_update_existing_services_ignores_wrong_msgs():
|
||||
None,
|
||||
]
|
||||
connector_cls = mock.MagicMock()
|
||||
connector_cls().producer().keys.return_value = service_keys
|
||||
connector_cls().producer().get.side_effect = [service_msgs[0], None]
|
||||
connector_cls().keys.return_value = service_keys
|
||||
connector_cls().get.side_effect = [service_msgs[0], None]
|
||||
service = BECService(
|
||||
config=f"{os.path.dirname(bec_lib.__file__)}/tests/test_service_config.yaml",
|
||||
connector_cls=connector_cls,
|
||||
|
@ -13,7 +13,7 @@ def test_channel_monitor_callback():
|
||||
mock_print.assert_called_once()
|
||||
|
||||
|
||||
def test_channel_monitor_start_consumer():
|
||||
def test_channel_monitor_start_register():
|
||||
with mock.patch("bec_lib.channel_monitor.argparse") as mock_argparse:
|
||||
with mock.patch("bec_lib.channel_monitor.ServiceConfig") as mock_config:
|
||||
with mock.patch("bec_lib.channel_monitor.RedisConnector") as mock_connector:
|
||||
@ -26,6 +26,6 @@ def test_channel_monitor_start_consumer():
|
||||
mock_config.return_value = mock.MagicMock()
|
||||
mock_connector.return_value = mock.MagicMock()
|
||||
channel_monitor_launch()
|
||||
mock_connector().consumer.assert_called_once()
|
||||
mock_connector().consumer.return_value.start.assert_called_once()
|
||||
mock_connector().register.assert_called_once()
|
||||
mock_connector().register.return_value.start.assert_called_once()
|
||||
mock_threading.Event().wait.assert_called_once()
|
||||
|
@ -49,7 +49,7 @@ def test_config_helper_save_current_session():
|
||||
connector = mock.MagicMock()
|
||||
|
||||
config_helper = ConfigHelper(connector)
|
||||
connector.producer().get.return_value = messages.AvailableResourceMessage(
|
||||
connector.get.return_value = messages.AvailableResourceMessage(
|
||||
resource=[
|
||||
{
|
||||
"id": "648c817f67d3c7cd6a354e8e",
|
||||
@ -158,9 +158,7 @@ def test_send_config_request_raises_for_rejected_update(config_helper):
|
||||
def test_wait_for_config_reply():
|
||||
connector = mock.MagicMock()
|
||||
config_helper = ConfigHelper(connector)
|
||||
connector.producer().get.return_value = messages.RequestResponseMessage(
|
||||
accepted=True, message="test"
|
||||
)
|
||||
connector.get.return_value = messages.RequestResponseMessage(accepted=True, message="test")
|
||||
|
||||
res = config_helper.wait_for_config_reply("test")
|
||||
assert res == messages.RequestResponseMessage(accepted=True, message="test")
|
||||
@ -169,7 +167,7 @@ def test_wait_for_config_reply():
|
||||
def test_wait_for_config_raises_timeout():
|
||||
connector = mock.MagicMock()
|
||||
config_helper = ConfigHelper(connector)
|
||||
connector.producer().get.return_value = None
|
||||
connector.get.return_value = None
|
||||
|
||||
with pytest.raises(DeviceConfigError):
|
||||
config_helper.wait_for_config_reply("test", timeout=0.3)
|
||||
@ -178,7 +176,7 @@ def test_wait_for_config_raises_timeout():
|
||||
def test_wait_for_service_response():
|
||||
connector = mock.MagicMock()
|
||||
config_helper = ConfigHelper(connector)
|
||||
connector.producer().lrange.side_effect = [
|
||||
connector.lrange.side_effect = [
|
||||
[],
|
||||
[
|
||||
messages.ServiceResponseMessage(
|
||||
@ -196,7 +194,7 @@ def test_wait_for_service_response():
|
||||
def test_wait_for_service_response_raises_timeout():
|
||||
connector = mock.MagicMock()
|
||||
config_helper = ConfigHelper(connector)
|
||||
connector.producer().lrange.return_value = []
|
||||
connector.lrange.return_value = []
|
||||
|
||||
with pytest.raises(DeviceConfigError):
|
||||
config_helper.wait_for_service_response("test", timeout=0.3)
|
||||
|
@ -349,7 +349,7 @@ def dap(dap_plugin_message):
|
||||
}
|
||||
client = mock.MagicMock()
|
||||
client.service_status = dap_services
|
||||
client.producer.get.return_value = dap_plugin_message
|
||||
client.connector.get.return_value = dap_plugin_message
|
||||
dap_plugins = DAPPlugins(client)
|
||||
yield dap_plugins
|
||||
|
||||
@ -367,7 +367,7 @@ def test_dap_plugins_construction(dap):
|
||||
def test_dap_plugin_fit(dap):
|
||||
with mock.patch.object(dap.GaussianModel, "_wait_for_dap_response") as mock_wait:
|
||||
dap.GaussianModel.fit()
|
||||
dap._parent.producer.set_and_publish.assert_called_once()
|
||||
dap._parent.connector.set_and_publish.assert_called_once()
|
||||
mock_wait.assert_called_once()
|
||||
|
||||
|
||||
@ -380,7 +380,7 @@ def test_dap_auto_run(dap):
|
||||
|
||||
|
||||
def test_dap_wait_for_dap_response_waits_for_RID(dap):
|
||||
dap._parent.producer.get.return_value = messages.DAPResponseMessage(
|
||||
dap._parent.connector.get.return_value = messages.DAPResponseMessage(
|
||||
success=True, data={}, metadata={"RID": "wrong_ID"}
|
||||
)
|
||||
with pytest.raises(TimeoutError):
|
||||
@ -388,7 +388,7 @@ def test_dap_wait_for_dap_response_waits_for_RID(dap):
|
||||
|
||||
|
||||
def test_dap_wait_for_dap_respnse_returns(dap):
|
||||
dap._parent.producer.get.return_value = messages.DAPResponseMessage(
|
||||
dap._parent.connector.get.return_value = messages.DAPResponseMessage(
|
||||
success=True, data={}, metadata={"RID": "1234"}
|
||||
)
|
||||
val = dap.GaussianModel._wait_for_dap_response(request_id="1234", timeout=0.1)
|
||||
@ -429,11 +429,11 @@ def test_dap_select_raises_on_wrong_device(dap):
|
||||
|
||||
|
||||
def test_dap_get_data(dap):
|
||||
dap._parent.producer.get_last.return_value = messages.ProcessedDataMessage(
|
||||
dap._parent.connector.get_last.return_value = messages.ProcessedDataMessage(
|
||||
data=[{"x": [1, 2, 3], "y": [4, 5, 6]}, {"fit_parameters": {"amplitude": 1}}]
|
||||
)
|
||||
data = dap.GaussianModel.get_data()
|
||||
dap._parent.producer.get_last.assert_called_once_with(
|
||||
dap._parent.connector.get_last.assert_called_once_with(
|
||||
MessageEndpoints.processed_data("GaussianModel")
|
||||
)
|
||||
|
||||
@ -443,13 +443,13 @@ def test_dap_get_data(dap):
|
||||
|
||||
def test_dap_update_dap_config_not_called_without_device(dap):
|
||||
dap.GaussianModel._update_dap_config(request_id="1234")
|
||||
dap._parent.producer.set_and_publish.assert_not_called()
|
||||
dap._parent.connector.set_and_publish.assert_not_called()
|
||||
|
||||
|
||||
def test_dap_update_dap_config(dap):
|
||||
dap.GaussianModel._plugin_config["selected_device"] = ["samx", "samx"]
|
||||
dap.GaussianModel._update_dap_config(request_id="1234")
|
||||
dap._parent.producer.set_and_publish.assert_called_with(
|
||||
dap._parent.connector.set_and_publish.assert_called_with(
|
||||
MessageEndpoints.dap_request(),
|
||||
messages.DAPRequestMessage(
|
||||
dap_cls="LmfitService1D",
|
||||
|
@ -91,15 +91,14 @@ def test_get_config_calls_load(dm):
|
||||
dm, "_get_redis_device_config", return_value={"devices": [{}]}
|
||||
) as get_redis_config:
|
||||
with mock.patch.object(dm, "_load_session") as load_session:
|
||||
with mock.patch.object(dm, "producer") as producer:
|
||||
dm._get_config()
|
||||
get_redis_config.assert_called_once()
|
||||
load_session.assert_called_once()
|
||||
dm._get_config()
|
||||
get_redis_config.assert_called_once()
|
||||
load_session.assert_called_once()
|
||||
|
||||
|
||||
def test_get_redis_device_config(dm):
|
||||
with mock.patch.object(dm, "producer") as producer:
|
||||
producer.get.return_value = messages.AvailableResourceMessage(resource={"devices": [{}]})
|
||||
with mock.patch.object(dm, "connector") as connector:
|
||||
connector.get.return_value = messages.AvailableResourceMessage(resource={"devices": [{}]})
|
||||
assert dm._get_redis_device_config() == {"devices": [{}]}
|
||||
|
||||
|
||||
|
@ -23,7 +23,7 @@ def test_nested_device_root(dev):
|
||||
|
||||
|
||||
def test_read(dev):
|
||||
with mock.patch.object(dev.samx.root.parent.producer, "get") as mock_get:
|
||||
with mock.patch.object(dev.samx.root.parent.connector, "get") as mock_get:
|
||||
mock_get.return_value = messages.DeviceMessage(
|
||||
signals={
|
||||
"samx": {"value": 0, "timestamp": 1701105880.1711318},
|
||||
@ -42,7 +42,7 @@ def test_read(dev):
|
||||
|
||||
|
||||
def test_read_filtered_hints(dev):
|
||||
with mock.patch.object(dev.samx.root.parent.producer, "get") as mock_get:
|
||||
with mock.patch.object(dev.samx.root.parent.connector, "get") as mock_get:
|
||||
mock_get.return_value = messages.DeviceMessage(
|
||||
signals={
|
||||
"samx": {"value": 0, "timestamp": 1701105880.1711318},
|
||||
@ -57,7 +57,7 @@ def test_read_filtered_hints(dev):
|
||||
|
||||
|
||||
def test_read_use_read(dev):
|
||||
with mock.patch.object(dev.samx.root.parent.producer, "get") as mock_get:
|
||||
with mock.patch.object(dev.samx.root.parent.connector, "get") as mock_get:
|
||||
data = {
|
||||
"samx": {"value": 0, "timestamp": 1701105880.1711318},
|
||||
"samx_setpoint": {"value": 0, "timestamp": 1701105880.1693492},
|
||||
@ -72,7 +72,7 @@ def test_read_use_read(dev):
|
||||
|
||||
|
||||
def test_read_nested_device(dev):
|
||||
with mock.patch.object(dev.dyn_signals.root.parent.producer, "get") as mock_get:
|
||||
with mock.patch.object(dev.dyn_signals.root.parent.connector, "get") as mock_get:
|
||||
data = {
|
||||
"dyn_signals_messages_message1": {"value": 0, "timestamp": 1701105880.0716832},
|
||||
"dyn_signals_messages_message2": {"value": 0, "timestamp": 1701105880.071722},
|
||||
@ -93,7 +93,7 @@ def test_read_nested_device(dev):
|
||||
)
|
||||
def test_read_kind_hinted(dev, kind, cached):
|
||||
with mock.patch.object(dev.samx.readback, "_run") as mock_run:
|
||||
with mock.patch.object(dev.samx.root.parent.producer, "get") as mock_get:
|
||||
with mock.patch.object(dev.samx.root.parent.connector, "get") as mock_get:
|
||||
data = {
|
||||
"samx": {"value": 0, "timestamp": 1701105880.1711318},
|
||||
"samx_setpoint": {"value": 0, "timestamp": 1701105880.1693492},
|
||||
@ -138,7 +138,7 @@ def test_read_configuration_cached(dev, is_signal, is_config_signal, method):
|
||||
with mock.patch.object(
|
||||
dev.samx.readback, "_get_rpc_signal_info", return_value=(is_signal, is_config_signal, True)
|
||||
):
|
||||
with mock.patch.object(dev.samx.root.parent.producer, "get") as mock_get:
|
||||
with mock.patch.object(dev.samx.root.parent.connector, "get") as mock_get:
|
||||
mock_get.return_value = messages.DeviceMessage(
|
||||
signals={
|
||||
"samx": {"value": 0, "timestamp": 1701105880.1711318},
|
||||
@ -182,7 +182,7 @@ def test_get_rpc_func_name_read(dev):
|
||||
)
|
||||
def test_get_rpc_func_name_readback_get(dev, kind, cached):
|
||||
with mock.patch.object(dev.samx.readback, "_run") as mock_rpc:
|
||||
with mock.patch.object(dev.samx.root.parent.producer, "get") as mock_get:
|
||||
with mock.patch.object(dev.samx.root.parent.connector, "get") as mock_get:
|
||||
mock_get.return_value = messages.DeviceMessage(
|
||||
signals={
|
||||
"samx": {"value": 0, "timestamp": 1701105880.1711318},
|
||||
@ -219,9 +219,7 @@ def test_handle_rpc_response_returns_status(dev, bec_client):
|
||||
msg = messages.DeviceRPCMessage(
|
||||
device="samx", return_val={"type": "status", "RID": "request_id"}, out="done", success=True
|
||||
)
|
||||
assert dev.samx._handle_rpc_response(msg) == Status(
|
||||
bec_client.device_manager.producer, "request_id"
|
||||
)
|
||||
assert dev.samx._handle_rpc_response(msg) == Status(bec_client.device_manager, "request_id")
|
||||
|
||||
|
||||
def test_handle_rpc_response_raises(dev):
|
||||
@ -348,7 +346,7 @@ def test_device_update_user_parameter(device_obj, user_param, val, out, raised_e
|
||||
|
||||
|
||||
def test_status_wait():
|
||||
producer = mock.MagicMock()
|
||||
connector = mock.MagicMock()
|
||||
|
||||
def lrange_mock(*args, **kwargs):
|
||||
yield False
|
||||
@ -358,8 +356,8 @@ def test_status_wait():
|
||||
return next(lmock)
|
||||
|
||||
lmock = lrange_mock()
|
||||
producer.lrange = get_lrange
|
||||
status = Status(producer, "test")
|
||||
connector.lrange = get_lrange
|
||||
status = Status(connector, "test")
|
||||
status.wait()
|
||||
|
||||
|
||||
@ -561,7 +559,7 @@ def test_show_all():
|
||||
def test_adjustable_mixin_limits():
|
||||
adj = AdjustableMixin()
|
||||
adj.root = mock.MagicMock()
|
||||
adj.root.parent.producer.get.return_value = messages.DeviceMessage(
|
||||
adj.root.parent.connector.get.return_value = messages.DeviceMessage(
|
||||
signals={"low": -12, "high": 12}, metadata={}
|
||||
)
|
||||
assert adj.limits == [-12, 12]
|
||||
@ -570,7 +568,7 @@ def test_adjustable_mixin_limits():
|
||||
def test_adjustable_mixin_limits_missing():
|
||||
adj = AdjustableMixin()
|
||||
adj.root = mock.MagicMock()
|
||||
adj.root.parent.producer.get.return_value = None
|
||||
adj.root.parent.connector.get.return_value = None
|
||||
assert adj.limits == [0, 0]
|
||||
|
||||
|
||||
@ -585,7 +583,7 @@ def test_adjustable_mixin_set_low_limit():
|
||||
adj = AdjustableMixin()
|
||||
adj.update_config = mock.MagicMock()
|
||||
adj.root = mock.MagicMock()
|
||||
adj.root.parent.producer.get.return_value = messages.DeviceMessage(
|
||||
adj.root.parent.connector.get.return_value = messages.DeviceMessage(
|
||||
signals={"low": -12, "high": 12}, metadata={}
|
||||
)
|
||||
adj.low_limit = -20
|
||||
@ -596,7 +594,7 @@ def test_adjustable_mixin_set_high_limit():
|
||||
adj = AdjustableMixin()
|
||||
adj.update_config = mock.MagicMock()
|
||||
adj.root = mock.MagicMock()
|
||||
adj.root.parent.producer.get.return_value = messages.DeviceMessage(
|
||||
adj.root.parent.connector.get.return_value = messages.DeviceMessage(
|
||||
signals={"low": -12, "high": 12}, metadata={}
|
||||
)
|
||||
adj.high_limit = 20
|
||||
|
@ -97,9 +97,9 @@ def device_manager(dm_with_devices):
|
||||
|
||||
|
||||
def test_observer_manager_None(device_manager):
|
||||
with mock.patch.object(device_manager.producer, "get", return_value=None) as producer_get:
|
||||
with mock.patch.object(device_manager.connector, "get", return_value=None) as connector_get:
|
||||
observer_manager = ObserverManager(device_manager=device_manager)
|
||||
producer_get.assert_called_once_with(MessageEndpoints.observer())
|
||||
connector_get.assert_called_once_with(MessageEndpoints.observer())
|
||||
assert len(observer_manager._observer) == 0
|
||||
|
||||
|
||||
@ -115,9 +115,9 @@ def test_observer_manager_msg(device_manager):
|
||||
}
|
||||
]
|
||||
)
|
||||
with mock.patch.object(device_manager.producer, "get", return_value=msg) as producer_get:
|
||||
with mock.patch.object(device_manager.connector, "get", return_value=msg) as connector_get:
|
||||
observer_manager = ObserverManager(device_manager=device_manager)
|
||||
producer_get.assert_called_once_with(MessageEndpoints.observer())
|
||||
connector_get.assert_called_once_with(MessageEndpoints.observer())
|
||||
assert len(observer_manager._observer) == 1
|
||||
|
||||
|
||||
@ -139,7 +139,7 @@ def test_observer_manager_msg(device_manager):
|
||||
],
|
||||
)
|
||||
def test_add_observer(device_manager, observer, raises_error):
|
||||
with mock.patch.object(device_manager.producer, "get", return_value=None) as producer_get:
|
||||
with mock.patch.object(device_manager.connector, "get", return_value=None) as connector_get:
|
||||
observer_manager = ObserverManager(device_manager=device_manager)
|
||||
observer_manager.add_observer(observer)
|
||||
with pytest.raises(AttributeError):
|
||||
@ -185,7 +185,7 @@ def test_add_observer_existing_device(device_manager, observer, raises_error):
|
||||
"limits": [380, None],
|
||||
}
|
||||
)
|
||||
with mock.patch.object(device_manager.producer, "get", return_value=None) as producer_get:
|
||||
with mock.patch.object(device_manager.connector, "get", return_value=None) as connector_get:
|
||||
observer_manager = ObserverManager(device_manager=device_manager)
|
||||
observer_manager.add_observer(default_observer)
|
||||
if raises_error:
|
||||
|
@ -12,113 +12,71 @@ from bec_lib.messages import AlarmMessage, BECMessage, LogMessage
|
||||
from bec_lib.redis_connector import (
|
||||
MessageObject,
|
||||
RedisConnector,
|
||||
RedisConsumer,
|
||||
RedisConsumerMixin,
|
||||
RedisConsumerThreaded,
|
||||
RedisProducer,
|
||||
RedisStreamConsumerThreaded,
|
||||
)
|
||||
from bec_lib.serialization import MsgpackSerialization
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def producer():
|
||||
with mock.patch("bec_lib.redis_connector.redis.Redis"):
|
||||
prod = RedisProducer("localhost", 1)
|
||||
yield prod
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def connector():
|
||||
with mock.patch("bec_lib.redis_connector.redis.Redis"):
|
||||
connector = RedisConnector("localhost:1")
|
||||
try:
|
||||
yield connector
|
||||
finally:
|
||||
connector.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def connected_connector(redis_proc):
|
||||
connector = RedisConnector(f"localhost:{redis_proc.port}")
|
||||
try:
|
||||
yield connector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def consumer():
|
||||
with mock.patch("bec_lib.redis_connector.redis.Redis"):
|
||||
consumer = RedisConsumer("localhost", "1", topics="topics")
|
||||
yield consumer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def consumer_threaded():
|
||||
with mock.patch("bec_lib.redis_connector.redis.Redis"):
|
||||
consumer_threaded = RedisConsumerThreaded("localhost", "1", topics="topics")
|
||||
yield consumer_threaded
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mixin():
|
||||
with mock.patch("bec_lib.redis_connector.redis.Redis"):
|
||||
mixin = RedisConsumerMixin
|
||||
yield mixin
|
||||
|
||||
|
||||
def test_redis_connector_producer(connector):
|
||||
ret = connector.producer()
|
||||
assert isinstance(ret, RedisProducer)
|
||||
finally:
|
||||
connector.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"topics, threaded", [["topics", True], ["topics", False], [None, True], [None, False]]
|
||||
"topics, threaded",
|
||||
[["topics", True], ["topics", False], [None, True], [None, False]],
|
||||
)
|
||||
def test_redis_connector_consumer(connector, threaded, topics):
|
||||
pattern = None
|
||||
len_of_threads = len(connector._threads)
|
||||
|
||||
if threaded:
|
||||
if topics is None and pattern is None:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ret = connector.consumer(
|
||||
topics=topics, threaded=threaded, cb=lambda *args, **kwargs: ...
|
||||
)
|
||||
|
||||
assert exc_info.value.args[0] == "Topics must be set for threaded consumer"
|
||||
else:
|
||||
ret = connector.consumer(
|
||||
topics=topics, threaded=threaded, cb=lambda *args, **kwargs: ...
|
||||
def test_redis_connector_register(connected_connector, threaded, topics):
|
||||
breakpoint()
|
||||
connector = connected_connector
|
||||
if topics is None:
|
||||
with pytest.raises(TypeError):
|
||||
ret = connector.register(
|
||||
topics=topics, cb=lambda *args, **kwargs: ..., start_thread=threaded
|
||||
)
|
||||
assert len(connector._threads) == len_of_threads + 1
|
||||
assert isinstance(ret, RedisConsumerThreaded)
|
||||
|
||||
else:
|
||||
if not topics:
|
||||
with pytest.raises(ConsumerConnectorError):
|
||||
ret = connector.consumer(
|
||||
topics=topics, threaded=threaded, cb=lambda *args, **kwargs: ...
|
||||
)
|
||||
return
|
||||
ret = connector.consumer(topics=topics, threaded=threaded, cb=lambda *args, **kwargs: ...)
|
||||
assert isinstance(ret, RedisConsumer)
|
||||
ret = connector.register(
|
||||
topics=topics, cb=lambda *args, **kwargs: ..., start_thread=threaded
|
||||
)
|
||||
if threaded:
|
||||
assert connector._events_listener_thread is not None
|
||||
|
||||
|
||||
def test_redis_connector_log_warning(connector):
|
||||
connector._notifications_producer.send = mock.MagicMock()
|
||||
|
||||
connector.log_warning("msg")
|
||||
connector._notifications_producer.send.assert_called_once_with(
|
||||
MessageEndpoints.log(), LogMessage(log_type="warning", log_msg="msg")
|
||||
)
|
||||
with mock.patch.object(connector, "send", return_value=None):
|
||||
connector.log_warning("msg")
|
||||
connector.send.assert_called_once_with(
|
||||
MessageEndpoints.log(), LogMessage(log_type="warning", log_msg="msg")
|
||||
)
|
||||
|
||||
|
||||
def test_redis_connector_log_message(connector):
|
||||
connector._notifications_producer.send = mock.MagicMock()
|
||||
|
||||
connector.log_message("msg")
|
||||
connector._notifications_producer.send.assert_called_once_with(
|
||||
MessageEndpoints.log(), LogMessage(log_type="log", log_msg="msg")
|
||||
)
|
||||
with mock.patch.object(connector, "send", return_value=None):
|
||||
connector.log_message("msg")
|
||||
connector.send.assert_called_once_with(
|
||||
MessageEndpoints.log(), LogMessage(log_type="log", log_msg="msg")
|
||||
)
|
||||
|
||||
|
||||
def test_redis_connector_log_error(connector):
|
||||
connector._notifications_producer.send = mock.MagicMock()
|
||||
|
||||
connector.log_error("msg")
|
||||
connector._notifications_producer.send.assert_called_once_with(
|
||||
MessageEndpoints.log(), LogMessage(log_type="error", log_msg="msg")
|
||||
)
|
||||
with mock.patch.object(connector, "send", return_value=None):
|
||||
connector.log_error("msg")
|
||||
connector.send.assert_called_once_with(
|
||||
MessageEndpoints.log(), LogMessage(log_type="error", log_msg="msg")
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -130,20 +88,24 @@ def test_redis_connector_log_error(connector):
|
||||
],
|
||||
)
|
||||
def test_redis_connector_raise_alarm(connector, severity, alarm_type, source, msg, metadata):
|
||||
connector._notifications_producer.set_and_publish = mock.MagicMock()
|
||||
with mock.patch.object(connector, "set_and_publish", return_value=None):
|
||||
connector.raise_alarm(severity, alarm_type, source, msg, metadata)
|
||||
|
||||
connector.raise_alarm(severity, alarm_type, source, msg, metadata)
|
||||
|
||||
connector._notifications_producer.set_and_publish.assert_called_once_with(
|
||||
MessageEndpoints.alarm(),
|
||||
AlarmMessage(
|
||||
severity=severity, alarm_type=alarm_type, source=source, msg=msg, metadata=metadata
|
||||
),
|
||||
)
|
||||
connector.set_and_publish.assert_called_once_with(
|
||||
MessageEndpoints.alarm(),
|
||||
AlarmMessage(
|
||||
severity=severity,
|
||||
alarm_type=alarm_type,
|
||||
source=source,
|
||||
msg=msg,
|
||||
metadata=metadata,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class TestMessage(BECMessage):
|
||||
__test__ = False # just for pytest to ignore this class
|
||||
msg_type = "test_message"
|
||||
msg: str
|
||||
# have to add this field here,
|
||||
@ -160,30 +122,36 @@ bec_messages.TestMessage = TestMessage
|
||||
@pytest.mark.parametrize(
|
||||
"topic , msg", [["topic1", TestMessage("msg1")], ["topic2", TestMessage("msg2")]]
|
||||
)
|
||||
def test_redis_producer_send(producer, topic, msg):
|
||||
producer.send(topic, msg)
|
||||
producer.r.publish.assert_called_once_with(topic, MsgpackSerialization.dumps(msg))
|
||||
def test_redis_connector_send(connector, topic, msg):
|
||||
connector.send(topic, msg)
|
||||
connector._redis_conn.publish.assert_called_once_with(topic, MsgpackSerialization.dumps(msg))
|
||||
|
||||
producer.send(topic, msg, pipe=producer.pipeline())
|
||||
producer.r.pipeline().publish.assert_called_once_with(topic, MsgpackSerialization.dumps(msg))
|
||||
connector.send(topic, msg, pipe=connector.pipeline())
|
||||
connector._redis_conn.pipeline().publish.assert_called_once_with(
|
||||
topic, MsgpackSerialization.dumps(msg)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"topic, msgs, max_size, expire",
|
||||
[["topic1", "msgs", None, None], ["topic1", "msgs", 10, None], ["topic1", "msgs", None, 100]],
|
||||
[
|
||||
["topic1", "msgs", None, None],
|
||||
["topic1", "msgs", 10, None],
|
||||
["topic1", "msgs", None, 100],
|
||||
],
|
||||
)
|
||||
def test_redis_producer_lpush(producer, topic, msgs, max_size, expire):
|
||||
def test_redis_connector_lpush(connector, topic, msgs, max_size, expire):
|
||||
pipe = None
|
||||
producer.lpush(topic, msgs, pipe, max_size, expire)
|
||||
connector.lpush(topic, msgs, pipe, max_size, expire)
|
||||
|
||||
producer.r.pipeline().lpush.assert_called_once_with(topic, msgs)
|
||||
connector._redis_conn.pipeline().lpush.assert_called_once_with(topic, msgs)
|
||||
|
||||
if max_size:
|
||||
producer.r.pipeline().ltrim.assert_called_once_with(topic, 0, max_size)
|
||||
connector._redis_conn.pipeline().ltrim.assert_called_once_with(topic, 0, max_size)
|
||||
if expire:
|
||||
producer.r.pipeline().expire.assert_called_once_with(topic, expire)
|
||||
connector._redis_conn.pipeline().expire.assert_called_once_with(topic, expire)
|
||||
if not pipe:
|
||||
producer.r.pipeline().execute.assert_called_once()
|
||||
connector._redis_conn.pipeline().execute.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -194,68 +162,76 @@ def test_redis_producer_lpush(producer, topic, msgs, max_size, expire):
|
||||
["topic1", TestMessage("msgs"), None, 100],
|
||||
],
|
||||
)
|
||||
def test_redis_producer_lpush_BECMessage(producer, topic, msgs, max_size, expire):
|
||||
def test_redis_connector_lpush_BECMessage(connector, topic, msgs, max_size, expire):
|
||||
pipe = None
|
||||
producer.lpush(topic, msgs, pipe, max_size, expire)
|
||||
connector.lpush(topic, msgs, pipe, max_size, expire)
|
||||
|
||||
producer.r.pipeline().lpush.assert_called_once_with(topic, MsgpackSerialization.dumps(msgs))
|
||||
connector._redis_conn.pipeline().lpush.assert_called_once_with(
|
||||
topic, MsgpackSerialization.dumps(msgs)
|
||||
)
|
||||
|
||||
if max_size:
|
||||
producer.r.pipeline().ltrim.assert_called_once_with(topic, 0, max_size)
|
||||
connector._redis_conn.pipeline().ltrim.assert_called_once_with(topic, 0, max_size)
|
||||
if expire:
|
||||
producer.r.pipeline().expire.assert_called_once_with(topic, expire)
|
||||
connector._redis_conn.pipeline().expire.assert_called_once_with(topic, expire)
|
||||
if not pipe:
|
||||
producer.r.pipeline().execute.assert_called_once()
|
||||
connector._redis_conn.pipeline().execute.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"topic , index , msgs, use_pipe", [["topic1", 1, "msg1", True], ["topic2", 4, "msg2", False]]
|
||||
"topic , index , msgs, use_pipe",
|
||||
[["topic1", 1, "msg1", True], ["topic2", 4, "msg2", False]],
|
||||
)
|
||||
def test_redis_producer_lset(producer, topic, index, msgs, use_pipe):
|
||||
pipe = use_pipe_fcn(producer, use_pipe)
|
||||
def test_redis_connector_lset(connector, topic, index, msgs, use_pipe):
|
||||
pipe = use_pipe_fcn(connector, use_pipe)
|
||||
|
||||
ret = producer.lset(topic, index, msgs, pipe)
|
||||
ret = connector.lset(topic, index, msgs, pipe)
|
||||
|
||||
if pipe:
|
||||
producer.r.pipeline().lset.assert_called_once_with(topic, index, msgs)
|
||||
connector._redis_conn.pipeline().lset.assert_called_once_with(topic, index, msgs)
|
||||
assert ret == redis.Redis().pipeline().lset()
|
||||
else:
|
||||
producer.r.lset.assert_called_once_with(topic, index, msgs)
|
||||
connector._redis_conn.lset.assert_called_once_with(topic, index, msgs)
|
||||
assert ret == redis.Redis().lset()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"topic , index , msgs, use_pipe",
|
||||
[["topic1", 1, TestMessage("msg1"), True], ["topic2", 4, TestMessage("msg2"), False]],
|
||||
[
|
||||
["topic1", 1, TestMessage("msg1"), True],
|
||||
["topic2", 4, TestMessage("msg2"), False],
|
||||
],
|
||||
)
|
||||
def test_redis_producer_lset_BECMessage(producer, topic, index, msgs, use_pipe):
|
||||
pipe = use_pipe_fcn(producer, use_pipe)
|
||||
def test_redis_connector_lset_BECMessage(connector, topic, index, msgs, use_pipe):
|
||||
pipe = use_pipe_fcn(connector, use_pipe)
|
||||
|
||||
ret = producer.lset(topic, index, msgs, pipe)
|
||||
ret = connector.lset(topic, index, msgs, pipe)
|
||||
|
||||
if pipe:
|
||||
producer.r.pipeline().lset.assert_called_once_with(
|
||||
connector._redis_conn.pipeline().lset.assert_called_once_with(
|
||||
topic, index, MsgpackSerialization.dumps(msgs)
|
||||
)
|
||||
assert ret == redis.Redis().pipeline().lset()
|
||||
else:
|
||||
producer.r.lset.assert_called_once_with(topic, index, MsgpackSerialization.dumps(msgs))
|
||||
connector._redis_conn.lset.assert_called_once_with(
|
||||
topic, index, MsgpackSerialization.dumps(msgs)
|
||||
)
|
||||
assert ret == redis.Redis().lset()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"topic, msgs, use_pipe", [["topic1", "msg1", True], ["topic2", "msg2", False]]
|
||||
)
|
||||
def test_redis_producer_rpush(producer, topic, msgs, use_pipe):
|
||||
pipe = use_pipe_fcn(producer, use_pipe)
|
||||
def test_redis_connector_rpush(connector, topic, msgs, use_pipe):
|
||||
pipe = use_pipe_fcn(connector, use_pipe)
|
||||
|
||||
ret = producer.rpush(topic, msgs, pipe)
|
||||
ret = connector.rpush(topic, msgs, pipe)
|
||||
|
||||
if pipe:
|
||||
producer.r.pipeline().rpush.assert_called_once_with(topic, msgs)
|
||||
connector._redis_conn.pipeline().rpush.assert_called_once_with(topic, msgs)
|
||||
assert ret == redis.Redis().pipeline().rpush()
|
||||
else:
|
||||
producer.r.rpush.assert_called_once_with(topic, msgs)
|
||||
connector._redis_conn.rpush.assert_called_once_with(topic, msgs)
|
||||
assert ret == redis.Redis().rpush()
|
||||
|
||||
|
||||
@ -263,421 +239,349 @@ def test_redis_producer_rpush(producer, topic, msgs, use_pipe):
|
||||
"topic, msgs, use_pipe",
|
||||
[["topic1", TestMessage("msg1"), True], ["topic2", TestMessage("msg2"), False]],
|
||||
)
|
||||
def test_redis_producer_rpush_BECMessage(producer, topic, msgs, use_pipe):
|
||||
pipe = use_pipe_fcn(producer, use_pipe)
|
||||
def test_redis_connector_rpush_BECMessage(connector, topic, msgs, use_pipe):
|
||||
pipe = use_pipe_fcn(connector, use_pipe)
|
||||
|
||||
ret = producer.rpush(topic, msgs, pipe)
|
||||
ret = connector.rpush(topic, msgs, pipe)
|
||||
|
||||
if pipe:
|
||||
producer.r.pipeline().rpush.assert_called_once_with(topic, MsgpackSerialization.dumps(msgs))
|
||||
connector._redis_conn.pipeline().rpush.assert_called_once_with(
|
||||
topic, MsgpackSerialization.dumps(msgs)
|
||||
)
|
||||
assert ret == redis.Redis().pipeline().rpush()
|
||||
else:
|
||||
producer.r.rpush.assert_called_once_with(topic, MsgpackSerialization.dumps(msgs))
|
||||
connector._redis_conn.rpush.assert_called_once_with(topic, MsgpackSerialization.dumps(msgs))
|
||||
assert ret == redis.Redis().rpush()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"topic, start, end, use_pipe", [["topic1", 0, 4, True], ["topic2", 3, 7, False]]
|
||||
)
|
||||
def test_redis_producer_lrange(producer, topic, start, end, use_pipe):
|
||||
pipe = use_pipe_fcn(producer, use_pipe)
|
||||
def test_redis_connector_lrange(connector, topic, start, end, use_pipe):
|
||||
pipe = use_pipe_fcn(connector, use_pipe)
|
||||
|
||||
ret = producer.lrange(topic, start, end, pipe)
|
||||
ret = connector.lrange(topic, start, end, pipe)
|
||||
|
||||
if pipe:
|
||||
producer.r.pipeline().lrange.assert_called_once_with(topic, start, end)
|
||||
connector._redis_conn.pipeline().lrange.assert_called_once_with(topic, start, end)
|
||||
assert ret == redis.Redis().pipeline().lrange()
|
||||
else:
|
||||
producer.r.lrange.assert_called_once_with(topic, start, end)
|
||||
connector._redis_conn.lrange.assert_called_once_with(topic, start, end)
|
||||
assert ret == []
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"topic, msg, pipe, expire", [["topic1", "msg1", None, 400], ["topic2", "msg2", None, None]]
|
||||
"topic, msg, pipe, expire",
|
||||
[
|
||||
["topic1", TestMessage("msg1"), None, 400],
|
||||
["topic2", TestMessage("msg2"), None, None],
|
||||
["topic3", "msg3", None, None],
|
||||
],
|
||||
)
|
||||
def test_redis_producer_set_and_publish(producer, topic, msg, pipe, expire):
|
||||
producer.set_and_publish(topic, msg, pipe, expire)
|
||||
def test_redis_connector_set_and_publish(connector, topic, msg, pipe, expire):
|
||||
if not isinstance(msg, BECMessage):
|
||||
with pytest.raises(TypeError):
|
||||
connector.set_and_publish(topic, msg, pipe, expire)
|
||||
else:
|
||||
connector.set_and_publish(topic, msg, pipe, expire)
|
||||
|
||||
producer.r.pipeline().publish.assert_called_once_with(topic, msg)
|
||||
producer.r.pipeline().set.assert_called_once_with(topic, msg)
|
||||
if expire:
|
||||
producer.r.pipeline().expire.assert_called_once_with(topic, expire)
|
||||
if not pipe:
|
||||
producer.r.pipeline().execute.assert_called_once()
|
||||
connector._redis_conn.pipeline().publish.assert_called_once_with(
|
||||
topic, MsgpackSerialization.dumps(msg)
|
||||
)
|
||||
connector._redis_conn.pipeline().set.assert_called_once_with(
|
||||
topic, MsgpackSerialization.dumps(msg), ex=expire
|
||||
)
|
||||
if not pipe:
|
||||
connector._redis_conn.pipeline().execute.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("topic, msg, expire", [["topic1", "msg1", None], ["topic2", "msg2", 400]])
|
||||
def test_redis_producer_set(producer, topic, msg, expire):
|
||||
def test_redis_connector_set(connector, topic, msg, expire):
|
||||
pipe = None
|
||||
|
||||
producer.set(topic, msg, pipe, expire)
|
||||
connector.set(topic, msg, pipe, expire)
|
||||
|
||||
if pipe:
|
||||
producer.r.pipeline().set.assert_called_once_with(topic, msg, ex=expire)
|
||||
connector._redis_conn.pipeline().set.assert_called_once_with(topic, msg, ex=expire)
|
||||
else:
|
||||
producer.r.set.assert_called_once_with(topic, msg, ex=expire)
|
||||
connector._redis_conn.set.assert_called_once_with(topic, msg, ex=expire)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("pattern", ["samx", "samy"])
|
||||
def test_redis_producer_keys(producer, pattern):
|
||||
ret = producer.keys(pattern)
|
||||
producer.r.keys.assert_called_once_with(pattern)
|
||||
def test_redis_connector_keys(connector, pattern):
|
||||
ret = connector.keys(pattern)
|
||||
connector._redis_conn.keys.assert_called_once_with(pattern)
|
||||
assert ret == redis.Redis().keys()
|
||||
|
||||
|
||||
def test_redis_producer_pipeline(producer):
|
||||
ret = producer.pipeline()
|
||||
producer.r.pipeline.assert_called_once()
|
||||
def test_redis_connector_pipeline(connector):
|
||||
ret = connector.pipeline()
|
||||
connector._redis_conn.pipeline.assert_called_once()
|
||||
assert ret == redis.Redis().pipeline()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("topic,use_pipe", [["topic1", True], ["topic2", False]])
|
||||
def test_redis_producer_delete(producer, topic, use_pipe):
|
||||
pipe = use_pipe_fcn(producer, use_pipe)
|
||||
|
||||
producer.delete(topic, pipe)
|
||||
|
||||
if pipe:
|
||||
producer.pipeline().delete.assert_called_once_with(topic)
|
||||
else:
|
||||
producer.r.delete.assert_called_once_with(topic)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("topic, use_pipe", [["topic1", True], ["topic2", False]])
|
||||
def test_redis_producer_get(producer, topic, use_pipe):
|
||||
pipe = use_pipe_fcn(producer, use_pipe)
|
||||
|
||||
ret = producer.get(topic, pipe)
|
||||
if pipe:
|
||||
producer.pipeline().get.assert_called_once_with(topic)
|
||||
assert ret == redis.Redis().pipeline().get()
|
||||
else:
|
||||
producer.r.get.assert_called_once_with(topic)
|
||||
assert ret == redis.Redis().get()
|
||||
|
||||
|
||||
def use_pipe_fcn(producer, use_pipe):
|
||||
def use_pipe_fcn(connector, use_pipe):
|
||||
if use_pipe:
|
||||
return producer.pipeline()
|
||||
return connector.pipeline()
|
||||
return None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("topic,use_pipe", [["topic1", True], ["topic2", False]])
|
||||
def test_redis_connector_delete(connector, topic, use_pipe):
|
||||
pipe = use_pipe_fcn(connector, use_pipe)
|
||||
|
||||
connector.delete(topic, pipe)
|
||||
|
||||
if pipe:
|
||||
connector.pipeline().delete.assert_called_once_with(topic)
|
||||
else:
|
||||
connector._redis_conn.delete.assert_called_once_with(topic)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("topic, use_pipe", [["topic1", True], ["topic2", False]])
|
||||
def test_redis_connector_get(connector, topic, use_pipe):
|
||||
pipe = use_pipe_fcn(connector, use_pipe)
|
||||
|
||||
ret = connector.get(topic, pipe)
|
||||
if pipe:
|
||||
connector.pipeline().get.assert_called_once_with(topic)
|
||||
assert ret == redis.Redis().pipeline().get()
|
||||
else:
|
||||
connector._redis_conn.get.assert_called_once_with(topic)
|
||||
assert ret == redis.Redis().get()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"topics, pattern",
|
||||
"subscribed_topics, subscribed_patterns, msgs",
|
||||
[
|
||||
["topics1", None],
|
||||
[["topics1", "topics2"], None],
|
||||
[None, "pattern1"],
|
||||
[None, ["pattern1", "pattern2"]],
|
||||
["topics1", None, ["topics1"]],
|
||||
[["topics1", "topics2"], None, ["topics1", "topics2"]],
|
||||
[None, "pattern1", ["pattern1"]],
|
||||
[None, ["patt*", "top*"], ["pattern1", "topics1"]],
|
||||
],
|
||||
)
|
||||
def test_redis_consumer_init(consumer, topics, pattern):
|
||||
with mock.patch("bec_lib.redis_connector.redis.Redis"):
|
||||
consumer = RedisConsumer(
|
||||
"localhost", "1", topics, pattern, redis_cls=redis.Redis, cb=lambda *args, **kwargs: ...
|
||||
def test_redis_connector_register(
|
||||
redisdb, connected_connector, subscribed_topics, subscribed_patterns, msgs
|
||||
):
|
||||
connector = connected_connector
|
||||
test_msg = TestMessage("test")
|
||||
cb_mock = mock.Mock(spec=[]) # spec is here to remove all attributes
|
||||
if subscribed_topics:
|
||||
connector.register(
|
||||
subscribed_topics, subscribed_patterns, cb=cb_mock, start_thread=False, a=1
|
||||
)
|
||||
|
||||
if topics:
|
||||
if isinstance(topics, list):
|
||||
assert consumer.topics == topics
|
||||
else:
|
||||
assert consumer.topics == [topics]
|
||||
if pattern:
|
||||
if isinstance(pattern, list):
|
||||
assert consumer.pattern == pattern
|
||||
else:
|
||||
assert consumer.pattern == [pattern]
|
||||
|
||||
assert consumer.r == redis.Redis()
|
||||
assert consumer.pubsub == consumer.r.pubsub()
|
||||
assert consumer.host == "localhost"
|
||||
assert consumer.port == "1"
|
||||
for msg in msgs:
|
||||
connector.send(msg, TestMessage(msg))
|
||||
connector.poll_messages()
|
||||
msg_object = MessageObject(msg, TestMessage(msg))
|
||||
cb_mock.assert_called_with(msg_object, a=1)
|
||||
if subscribed_patterns:
|
||||
connector.register(
|
||||
subscribed_topics, subscribed_patterns, cb=cb_mock, start_thread=False, a=1
|
||||
)
|
||||
for msg in msgs:
|
||||
connector.send(msg, TestMessage(msg))
|
||||
connector.poll_messages()
|
||||
msg_object = MessageObject(msg, TestMessage(msg))
|
||||
cb_mock.assert_called_with(msg_object, a=1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("pattern, topics", [["pattern", "topics1"], [None, "topics2"]])
|
||||
def test_redis_consumer_initialize_connector(consumer, pattern, topics):
|
||||
consumer.pattern = pattern
|
||||
|
||||
consumer.topics = topics
|
||||
consumer.initialize_connector()
|
||||
|
||||
if consumer.pattern is not None:
|
||||
consumer.pubsub.psubscribe.assert_called_once_with(consumer.pattern)
|
||||
else:
|
||||
consumer.pubsub.subscribe.assert_called_with(consumer.topics)
|
||||
|
||||
|
||||
def test_redis_consumer_poll_messages(consumer):
|
||||
def test_redis_register_poll_messages(redisdb, connected_connector):
|
||||
connector = connected_connector
|
||||
cb_fcn_has_been_called = False
|
||||
|
||||
def cb_fcn(msg, **kwargs):
|
||||
nonlocal cb_fcn_has_been_called
|
||||
cb_fcn_has_been_called = True
|
||||
print(msg)
|
||||
|
||||
consumer.cb = cb_fcn
|
||||
assert kwargs["a"] == 1
|
||||
|
||||
test_msg = TestMessage("test")
|
||||
consumer.pubsub.get_message.return_value = {
|
||||
"channel": "",
|
||||
"data": MsgpackSerialization.dumps(test_msg),
|
||||
}
|
||||
ret = consumer.poll_messages()
|
||||
consumer.pubsub.get_message.assert_called_once_with(ignore_subscribe_messages=True)
|
||||
connector.register("test", cb=cb_fcn, a=1, start_thread=False)
|
||||
redisdb.publish("test", MsgpackSerialization.dumps(test_msg))
|
||||
|
||||
connector.poll_messages(timeout=1)
|
||||
|
||||
assert cb_fcn_has_been_called
|
||||
|
||||
|
||||
def test_redis_consumer_shutdown(consumer):
|
||||
consumer.shutdown()
|
||||
consumer.pubsub.close.assert_called_once()
|
||||
with pytest.raises(TimeoutError):
|
||||
connector.poll_messages(timeout=0.1)
|
||||
|
||||
|
||||
def test_redis_consumer_additional_kwargs(connector):
|
||||
cons = connector.consumer(topics="topic1", parent="here", cb=lambda *args, **kwargs: ...)
|
||||
assert "parent" in cons.kwargs
|
||||
def test_redis_connector_xadd(connector):
|
||||
connector.xadd("topic1", {"key": "value"})
|
||||
connector._redis_conn.xadd.assert_called_once_with("topic1", {"key": "value"})
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"topics, pattern",
|
||||
[
|
||||
["topics1", None],
|
||||
[["topics1", "topics2"], None],
|
||||
[None, "pattern1"],
|
||||
[None, ["pattern1", "pattern2"]],
|
||||
],
|
||||
)
|
||||
def test_mixin_init_topics_and_pattern(mixin, topics, pattern):
|
||||
ret_topics, ret_pattern = mixin._init_topics_and_pattern(mixin, topics, pattern)
|
||||
|
||||
if topics:
|
||||
if isinstance(topics, list):
|
||||
assert ret_topics == topics
|
||||
else:
|
||||
assert ret_topics == [topics]
|
||||
if pattern:
|
||||
if isinstance(pattern, list):
|
||||
assert ret_pattern == pattern
|
||||
else:
|
||||
assert ret_pattern == [pattern]
|
||||
def test_redis_connector_xadd_with_maxlen(connector):
|
||||
connector.xadd("topic1", {"key": "value"}, max_size=100)
|
||||
connector._redis_conn.xadd.assert_called_once_with("topic1", {"key": "value"}, maxlen=100)
|
||||
|
||||
|
||||
def test_mixin_init_redis_cls(mixin, consumer):
|
||||
mixin._init_redis_cls(consumer, None)
|
||||
assert consumer.r == redis.Redis(host="localhost", port=1)
|
||||
def test_redis_connector_xadd_with_expire(connector):
|
||||
connector.xadd("topic1", {"key": "value"}, expire=100)
|
||||
connector._redis_conn.pipeline().xadd.assert_called_once_with("topic1", {"key": "value"})
|
||||
connector._redis_conn.pipeline().expire.assert_called_once_with("topic1", 100)
|
||||
connector._redis_conn.pipeline().execute.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"topics, pattern",
|
||||
[
|
||||
["topics1", None],
|
||||
[["topics1", "topics2"], None],
|
||||
[None, "pattern1"],
|
||||
[None, ["pattern1", "pattern2"]],
|
||||
],
|
||||
)
|
||||
def test_redis_consumer_threaded_init(consumer_threaded, topics, pattern):
|
||||
with mock.patch("bec_lib.redis_connector.redis.Redis"):
|
||||
consumer_threaded = RedisConsumerThreaded(
|
||||
"localhost", "1", topics, pattern, redis_cls=redis.Redis, cb=lambda *args, **kwargs: ...
|
||||
)
|
||||
|
||||
if topics:
|
||||
if isinstance(topics, list):
|
||||
assert consumer_threaded.topics == topics
|
||||
else:
|
||||
assert consumer_threaded.topics == [topics]
|
||||
if pattern:
|
||||
if isinstance(pattern, list):
|
||||
assert consumer_threaded.pattern == pattern
|
||||
else:
|
||||
assert consumer_threaded.pattern == [pattern]
|
||||
|
||||
assert consumer_threaded.r == redis.Redis()
|
||||
assert consumer_threaded.pubsub == consumer_threaded.r.pubsub()
|
||||
assert consumer_threaded.host == "localhost"
|
||||
assert consumer_threaded.port == "1"
|
||||
assert consumer_threaded.sleep_times == [0.005, 0.1]
|
||||
assert consumer_threaded.last_received_msg == 0
|
||||
assert consumer_threaded.idle_time == 30
|
||||
def test_redis_connector_xread(connector):
|
||||
connector.xread("topic1", "id")
|
||||
connector._redis_conn.xread.assert_called_once_with({"topic1": "id"}, count=None, block=None)
|
||||
|
||||
|
||||
def test_redis_connector_xadd(producer):
|
||||
producer.xadd("topic1", {"key": "value"})
|
||||
producer.r.xadd.assert_called_once_with("topic1", {"key": MsgpackSerialization.dumps("value")})
|
||||
def test_redis_connector_xadd(connector):
|
||||
connector.xadd("topic1", {"key": "value"})
|
||||
connector._redis_conn.xadd.assert_called_once_with(
|
||||
"topic1", {"key": MsgpackSerialization.dumps("value")}
|
||||
)
|
||||
|
||||
test_msg = TestMessage("test")
|
||||
producer.xadd("topic1", {"data": test_msg})
|
||||
producer.r.xadd.assert_called_with("topic1", {"data": MsgpackSerialization.dumps(test_msg)})
|
||||
producer.r.xrevrange.return_value = [
|
||||
connector.xadd("topic1", {"data": test_msg})
|
||||
connector._redis_conn.xadd.assert_called_with(
|
||||
"topic1", {"data": MsgpackSerialization.dumps(test_msg)}
|
||||
)
|
||||
connector._redis_conn.xrevrange.return_value = [
|
||||
(b"1707391599960-0", {b"data": MsgpackSerialization.dumps(test_msg)})
|
||||
]
|
||||
msg = producer.get_last("topic1")
|
||||
msg = connector.get_last("topic1")
|
||||
assert msg == test_msg
|
||||
|
||||
|
||||
def test_redis_connector_xadd_with_maxlen(producer):
|
||||
producer.xadd("topic1", {"key": "value"}, max_size=100)
|
||||
producer.r.xadd.assert_called_once_with(
|
||||
def test_redis_connector_xadd_with_maxlen(connector):
|
||||
connector.xadd("topic1", {"key": "value"}, max_size=100)
|
||||
connector._redis_conn.xadd.assert_called_once_with(
|
||||
"topic1", {"key": MsgpackSerialization.dumps("value")}, maxlen=100
|
||||
)
|
||||
|
||||
|
||||
def test_redis_connector_xadd_with_expire(producer):
|
||||
producer.xadd("topic1", {"key": "value"}, expire=100)
|
||||
producer.r.pipeline().xadd.assert_called_once_with(
|
||||
def test_redis_connector_xadd_with_expire(connector):
|
||||
connector.xadd("topic1", {"key": "value"}, expire=100)
|
||||
connector._redis_conn.pipeline().xadd.assert_called_once_with(
|
||||
"topic1", {"key": MsgpackSerialization.dumps("value")}
|
||||
)
|
||||
producer.r.pipeline().expire.assert_called_once_with("topic1", 100)
|
||||
producer.r.pipeline().execute.assert_called_once()
|
||||
connector._redis_conn.pipeline().expire.assert_called_once_with("topic1", 100)
|
||||
connector._redis_conn.pipeline().execute.assert_called_once()
|
||||
|
||||
|
||||
def test_redis_connector_xread(producer):
|
||||
producer.xread("topic1", "id")
|
||||
producer.r.xread.assert_called_once_with({"topic1": "id"}, count=None, block=None)
|
||||
def test_redis_connector_xread(connector):
|
||||
connector.xread("topic1", "id")
|
||||
connector._redis_conn.xread.assert_called_once_with({"topic1": "id"}, count=None, block=None)
|
||||
|
||||
|
||||
def test_redis_connector_xread_without_id(producer):
|
||||
producer.xread("topic1", from_start=True)
|
||||
producer.r.xread.assert_called_once_with({"topic1": "0-0"}, count=None, block=None)
|
||||
producer.r.xread.reset_mock()
|
||||
def test_redis_connector_xread_without_id(connector):
|
||||
connector.xread("topic1", from_start=True)
|
||||
connector._redis_conn.xread.assert_called_once_with({"topic1": "0-0"}, count=None, block=None)
|
||||
connector._redis_conn.xread.reset_mock()
|
||||
|
||||
producer.stream_keys["topic1"] = "id"
|
||||
producer.xread("topic1")
|
||||
producer.r.xread.assert_called_once_with({"topic1": "id"}, count=None, block=None)
|
||||
connector.stream_keys["topic1"] = "id"
|
||||
connector.xread("topic1")
|
||||
connector._redis_conn.xread.assert_called_once_with({"topic1": "id"}, count=None, block=None)
|
||||
|
||||
|
||||
def test_redis_connector_xread_from_end(producer):
|
||||
producer.xread("topic1", from_start=False)
|
||||
producer.r.xrevrange.assert_called_once_with("topic1", "+", "-", count=1)
|
||||
def test_redis_connector_xread_from_end(connector):
|
||||
connector.xread("topic1", from_start=False)
|
||||
connector._redis_conn.xrevrange.assert_called_once_with("topic1", "+", "-", count=1)
|
||||
|
||||
|
||||
def test_redis_connector_get_last(producer):
|
||||
producer.r.xrevrange.return_value = [
|
||||
def test_redis_connector_get_last(connector):
|
||||
connector._redis_conn.xrevrange.return_value = [
|
||||
(b"1707391599960-0", {b"key": MsgpackSerialization.dumps("value")})
|
||||
]
|
||||
msg = producer.get_last("topic1")
|
||||
producer.r.xrevrange.assert_called_once_with("topic1", "+", "-", count=1)
|
||||
msg = connector.get_last("topic1")
|
||||
connector._redis_conn.xrevrange.assert_called_once_with("topic1", "+", "-", count=1)
|
||||
assert msg is None # no key given, default is b'data'
|
||||
assert producer.get_last("topic1", "key") == "value"
|
||||
assert producer.get_last("topic1", None) == {"key": "value"}
|
||||
assert connector.get_last("topic1", "key") == "value"
|
||||
assert connector.get_last("topic1", None) == {"key": "value"}
|
||||
|
||||
|
||||
def test_redis_xrange(producer):
|
||||
producer.xrange("topic1", "start", "end")
|
||||
producer.r.xrange.assert_called_once_with("topic1", "start", "end", count=None)
|
||||
def test_redis_connector_xread_without_id(connector):
|
||||
connector.xread("topic1", from_start=True)
|
||||
connector._redis_conn.xread.assert_called_once_with({"topic1": "0-0"}, count=None, block=None)
|
||||
connector._redis_conn.xread.reset_mock()
|
||||
|
||||
connector.stream_keys["topic1"] = "id"
|
||||
connector.xread("topic1")
|
||||
connector._redis_conn.xread.assert_called_once_with({"topic1": "id"}, count=None, block=None)
|
||||
|
||||
|
||||
def test_redis_xrange_topic_with_suffix(producer):
|
||||
producer.xrange("topic1", "start", "end")
|
||||
producer.r.xrange.assert_called_once_with("topic1", "start", "end", count=None)
|
||||
def test_redis_xrange(connector):
|
||||
connector.xrange("topic1", "start", "end")
|
||||
connector._redis_conn.xrange.assert_called_once_with("topic1", "start", "end", count=None)
|
||||
|
||||
|
||||
def test_redis_consumer_threaded_no_cb_without_messages(consumer_threaded):
|
||||
with mock.patch.object(consumer_threaded.pubsub, "get_message", return_value=None):
|
||||
consumer_threaded.cb = mock.MagicMock()
|
||||
consumer_threaded.poll_messages()
|
||||
consumer_threaded.cb.assert_not_called()
|
||||
def test_redis_xrange_topic_with_suffix(connector):
|
||||
connector.xrange("topic1", "start", "end")
|
||||
connector._redis_conn.xrange.assert_called_once_with("topic1", "start", "end", count=None)
|
||||
|
||||
|
||||
def test_redis_consumer_threaded_cb_called_with_messages(consumer_threaded):
|
||||
message = {"channel": b"topic1", "data": MsgpackSerialization.dumps(TestMessage("test"))}
|
||||
|
||||
with mock.patch.object(consumer_threaded.pubsub, "get_message", return_value=message):
|
||||
consumer_threaded.cb = mock.MagicMock()
|
||||
consumer_threaded.poll_messages()
|
||||
msg_object = MessageObject("topic1", TestMessage("test"))
|
||||
consumer_threaded.cb.assert_called_once_with(msg_object)
|
||||
# def test_redis_stream_register_threaded_get_id():
|
||||
# register = RedisStreamConsumerThreaded(
|
||||
# "localhost", "1", topics="topic1", cb=mock.MagicMock(), redis_cls=mock.MagicMock()
|
||||
# )
|
||||
# register.stream_keys["topic1"] = b"1691610882756-0"
|
||||
# assert register.get_id("topic1") == b"1691610882756-0"
|
||||
# assert register.get_id("doesnt_exist") == "0-0"
|
||||
|
||||
|
||||
def test_redis_consumer_threaded_shutdown(consumer_threaded):
|
||||
consumer_threaded.shutdown()
|
||||
consumer_threaded.pubsub.close.assert_called_once()
|
||||
# def test_redis_stream_register_threaded_poll_messages():
|
||||
# register = RedisStreamConsumerThreaded(
|
||||
# "localhost", "1", topics="topic1", cb=mock.MagicMock(), redis_cls=mock.MagicMock()
|
||||
# )
|
||||
# with mock.patch.object(
|
||||
# register, "get_newest_message", return_value=None
|
||||
# ) as mock_get_newest_message:
|
||||
# register.poll_messages()
|
||||
# mock_get_newest_message.assert_called_once()
|
||||
# register._redis_conn.xread.assert_not_called()
|
||||
|
||||
|
||||
def test_redis_stream_consumer_threaded_get_newest_message():
|
||||
consumer = RedisStreamConsumerThreaded(
|
||||
"localhost", "1", topics="topic1", cb=mock.MagicMock(), redis_cls=mock.MagicMock()
|
||||
)
|
||||
consumer.r.xrevrange.return_value = [(b"1691610882756-0", {b"data": b"msg"})]
|
||||
msgs = []
|
||||
consumer.get_newest_message(msgs)
|
||||
assert "topic1" in consumer.stream_keys
|
||||
assert consumer.stream_keys["topic1"] == b"1691610882756-0"
|
||||
# def test_redis_stream_register_threaded_poll_messages_newest_only():
|
||||
# register = RedisStreamConsumerThreaded(
|
||||
# "localhost",
|
||||
# "1",
|
||||
# topics="topic1",
|
||||
# cb=mock.MagicMock(),
|
||||
# redis_cls=mock.MagicMock(),
|
||||
# newest_only=True,
|
||||
# )
|
||||
#
|
||||
# register._redis_conn.xrevrange.return_value = [(b"1691610882756-0", {b"data": b"msg"})]
|
||||
# register.poll_messages()
|
||||
# register._redis_conn.xread.assert_not_called()
|
||||
# register.cb.assert_called_once_with(MessageObject(topic="topic1", value=b"msg"))
|
||||
|
||||
|
||||
def test_redis_stream_consumer_threaded_get_newest_message_no_msg():
|
||||
consumer = RedisStreamConsumerThreaded(
|
||||
"localhost", "1", topics="topic1", cb=mock.MagicMock(), redis_cls=mock.MagicMock()
|
||||
)
|
||||
consumer.r.xrevrange.return_value = []
|
||||
msgs = []
|
||||
consumer.get_newest_message(msgs)
|
||||
assert "topic1" in consumer.stream_keys
|
||||
assert consumer.stream_keys["topic1"] == "0-0"
|
||||
# def test_redis_stream_register_threaded_poll_messages_read():
|
||||
# register = RedisStreamConsumerThreaded(
|
||||
# "localhost",
|
||||
# "1",
|
||||
# topics="topic1",
|
||||
# cb=mock.MagicMock(),
|
||||
# redis_cls=mock.MagicMock(),
|
||||
# )
|
||||
# register.stream_keys["topic1"] = "0-0"
|
||||
#
|
||||
# msg = [[b"topic1", [(b"1691610714612-0", {b"data": b"msg"})]]]
|
||||
#
|
||||
# register._redis_conn.xread.return_value = msg
|
||||
# register.poll_messages()
|
||||
# register._redis_conn.xread.assert_called_once_with({"topic1": "0-0"}, count=1)
|
||||
# register.cb.assert_called_once_with(MessageObject(topic="topic1", value=b"msg"))
|
||||
|
||||
|
||||
def test_redis_stream_consumer_threaded_get_id():
|
||||
consumer = RedisStreamConsumerThreaded(
|
||||
"localhost", "1", topics="topic1", cb=mock.MagicMock(), redis_cls=mock.MagicMock()
|
||||
)
|
||||
consumer.stream_keys["topic1"] = b"1691610882756-0"
|
||||
assert consumer.get_id("topic1") == b"1691610882756-0"
|
||||
assert consumer.get_id("doesnt_exist") == "0-0"
|
||||
|
||||
|
||||
def test_redis_stream_consumer_threaded_poll_messages():
|
||||
consumer = RedisStreamConsumerThreaded(
|
||||
"localhost", "1", topics="topic1", cb=mock.MagicMock(), redis_cls=mock.MagicMock()
|
||||
)
|
||||
with mock.patch.object(
|
||||
consumer, "get_newest_message", return_value=None
|
||||
) as mock_get_newest_message:
|
||||
consumer.poll_messages()
|
||||
mock_get_newest_message.assert_called_once()
|
||||
consumer.r.xread.assert_not_called()
|
||||
|
||||
|
||||
def test_redis_stream_consumer_threaded_poll_messages_newest_only():
|
||||
consumer = RedisStreamConsumerThreaded(
|
||||
"localhost",
|
||||
"1",
|
||||
topics="topic1",
|
||||
cb=mock.MagicMock(),
|
||||
redis_cls=mock.MagicMock(),
|
||||
newest_only=True,
|
||||
)
|
||||
|
||||
consumer.r.xrevrange.return_value = [(b"1691610882756-0", {b"data": b"msg"})]
|
||||
consumer.poll_messages()
|
||||
consumer.r.xread.assert_not_called()
|
||||
consumer.cb.assert_called_once_with(MessageObject(topic="topic1", value=b"msg"))
|
||||
|
||||
|
||||
def test_redis_stream_consumer_threaded_poll_messages_read():
|
||||
consumer = RedisStreamConsumerThreaded(
|
||||
"localhost", "1", topics="topic1", cb=mock.MagicMock(), redis_cls=mock.MagicMock()
|
||||
)
|
||||
consumer.stream_keys["topic1"] = "0-0"
|
||||
|
||||
msg = [[b"topic1", [(b"1691610714612-0", {b"data": b"msg"})]]]
|
||||
|
||||
consumer.r.xread.return_value = msg
|
||||
consumer.poll_messages()
|
||||
consumer.r.xread.assert_called_once_with({"topic1": "0-0"}, count=1)
|
||||
consumer.cb.assert_called_once_with(MessageObject(topic="topic1", value=b"msg"))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"topics,expected",
|
||||
[
|
||||
("topic1", ["topic1"]),
|
||||
(["topic1"], ["topic1"]),
|
||||
(["topic1", "topic2"], ["topic1", "topic2"]),
|
||||
],
|
||||
)
|
||||
def test_redis_stream_consumer_threaded_init_topics(topics, expected):
|
||||
consumer = RedisStreamConsumerThreaded(
|
||||
"localhost", "1", topics=topics, cb=mock.MagicMock(), redis_cls=mock.MagicMock()
|
||||
)
|
||||
assert consumer.topics == expected
|
||||
# @pytest.mark.parametrize(
|
||||
# "topics,expected",
|
||||
# [
|
||||
# ("topic1", ["topic1"]),
|
||||
# (["topic1"], ["topic1"]),
|
||||
# (["topic1", "topic2"], ["topic1", "topic2"]),
|
||||
# ],
|
||||
# )
|
||||
# def test_redis_stream_register_threaded_init_topics(topics, expected):
|
||||
# register = RedisStreamConsumerThreaded(
|
||||
# "localhost",
|
||||
# "1",
|
||||
# topics=topics,
|
||||
# cb=mock.MagicMock(),
|
||||
# redis_cls=mock.MagicMock(),
|
||||
# )
|
||||
# assert register.topics == expected
|
||||
|
@ -63,7 +63,7 @@ from bec_lib.tests.utils import ConnectorMock
|
||||
)
|
||||
def test_update_with_queue_status(queue_msg):
|
||||
scan_manager = ScanManager(ConnectorMock(""))
|
||||
scan_manager.producer._get_buffer[MessageEndpoints.scan_queue_status()] = queue_msg
|
||||
scan_manager.connector._get_buffer[MessageEndpoints.scan_queue_status()] = queue_msg
|
||||
scan_manager.update_with_queue_status(queue_msg)
|
||||
assert (
|
||||
scan_manager.scan_storage.find_scan_by_ID("bfa582aa-f9cd-4258-ab5d-3e5d54d3dde5")
|
||||
|
@ -105,6 +105,6 @@ def test_scan_report_get_mv_status(scan_report, lrange_return, expected):
|
||||
scan_report.request.request = messages.ScanQueueMessage(
|
||||
scan_type="mv", parameter={"args": {"samx": [5], "samy": [5]}}
|
||||
)
|
||||
with mock.patch.object(scan_report._client.device_manager.producer, "lrange") as mock_lrange:
|
||||
with mock.patch.object(scan_report._client.device_manager.connector, "lrange") as mock_lrange:
|
||||
mock_lrange.return_value = lrange_return
|
||||
assert scan_report._get_mv_status() == expected
|
||||
|
@ -14,7 +14,6 @@ parser.add_argument("--redis", default="localhost:6379", help="redis host and po
|
||||
|
||||
clargs = parser.parse_args()
|
||||
connector = RedisConnector(clargs.redis)
|
||||
producer = connector.producer()
|
||||
|
||||
with open(clargs.config, "r", encoding="utf-8") as stream:
|
||||
data = yaml.safe_load(stream)
|
||||
@ -22,4 +21,4 @@ for name, device in data.items():
|
||||
device["name"] = name
|
||||
config_data = list(data.values())
|
||||
msg = messages.AvailableResourceMessage(resource=config_data)
|
||||
producer.set(MessageEndpoints.device_config(), msg)
|
||||
connector.set(MessageEndpoints.device_config(), msg)
|
||||
|
@ -11,7 +11,6 @@ class DAPServiceManager:
|
||||
|
||||
def __init__(self, services: list) -> None:
|
||||
self.connector = None
|
||||
self.producer = None
|
||||
self._started = False
|
||||
self.client = None
|
||||
self._dap_request_thread = None
|
||||
@ -24,13 +23,11 @@ class DAPServiceManager:
|
||||
"""
|
||||
Start the dap request consumer.
|
||||
"""
|
||||
self._dap_request_thread = self.connector.consumer(
|
||||
topics=MessageEndpoints.dap_request(), cb=self._dap_request_callback, parent=self
|
||||
self.connector.register(
|
||||
topics=MessageEndpoints.dap_request(), cb=self._dap_request_callback
|
||||
)
|
||||
self._dap_request_thread.start()
|
||||
|
||||
@staticmethod
|
||||
def _dap_request_callback(msg: MessageObject, *, parent: DAPServiceManager) -> None:
|
||||
def _dap_request_callback(self, msg: MessageObject) -> None:
|
||||
"""
|
||||
Callback function for dap request consumer.
|
||||
|
||||
@ -41,7 +38,7 @@ class DAPServiceManager:
|
||||
dap_request_msg = messages.DAPRequestMessage.loads(msg.value)
|
||||
if not dap_request_msg:
|
||||
return
|
||||
parent.process_dap_request(dap_request_msg)
|
||||
self.process_dap_request(dap_request_msg)
|
||||
|
||||
def process_dap_request(self, dap_request_msg: messages.DAPRequestMessage) -> None:
|
||||
"""
|
||||
@ -153,7 +150,7 @@ class DAPServiceManager:
|
||||
dap_response_msg = messages.DAPResponseMessage(
|
||||
success=success, data=data, error=error, dap_request=dap_request_msg, metadata=metadata
|
||||
)
|
||||
self.producer.set_and_publish(
|
||||
self.connector.set_and_publish(
|
||||
MessageEndpoints.dap_response(metadata.get("RID")), dap_response_msg, expire=60
|
||||
)
|
||||
|
||||
@ -168,7 +165,6 @@ class DAPServiceManager:
|
||||
return
|
||||
self.client = client
|
||||
self.connector = client.connector
|
||||
self.producer = self.connector.producer()
|
||||
self._start_dap_request_consumer()
|
||||
self.update_available_dap_services()
|
||||
self.publish_available_services()
|
||||
@ -264,12 +260,12 @@ class DAPServiceManager:
|
||||
"""send all available dap services to the broker"""
|
||||
msg = messages.AvailableResourceMessage(resource=self.available_dap_services)
|
||||
# pylint: disable=protected-access
|
||||
self.producer.set(
|
||||
self.connector.set(
|
||||
MessageEndpoints.dap_available_plugins(f"DAPServer/{self.client._service_id}"), msg
|
||||
)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
if not self._started:
|
||||
return
|
||||
self._dap_request_thread.stop()
|
||||
self.connector.shutdown()
|
||||
self._started = False
|
||||
|
@ -148,7 +148,7 @@ class LmfitService1D(DAPServiceBase):
|
||||
out = self.process()
|
||||
if out:
|
||||
stream_output, metadata = out
|
||||
self.client.producer.xadd(
|
||||
self.client.connector.xadd(
|
||||
MessageEndpoints.processed_data(self.model.__class__.__name__),
|
||||
msg={
|
||||
"data": MsgpackSerialization.dumps(
|
||||
|
@ -72,7 +72,7 @@ def test_DAPServiceManager_init(service_manager):
|
||||
def test_DAPServiceManager_request_callback(service_manager, msg, process_called):
|
||||
msg_obj = MessageObject(value=msg, topic="topic")
|
||||
with mock.patch.object(service_manager, "process_dap_request") as mock_process_dap_request:
|
||||
service_manager._dap_request_callback(msg_obj, parent=service_manager)
|
||||
service_manager._dap_request_callback(msg_obj)
|
||||
if process_called:
|
||||
mock_process_dap_request.assert_called_once_with(msg)
|
||||
|
||||
|
@ -134,7 +134,7 @@ def test_LmfitService1D_process_until_finished(lmfit_service):
|
||||
lmfit_service.process_until_finished(event)
|
||||
assert get_data.call_count == 2
|
||||
assert process.call_count == 2
|
||||
assert lmfit_service.client.producer.xadd.call_count == 2
|
||||
assert lmfit_service.client.connector.xadd.call_count == 2
|
||||
|
||||
|
||||
def test_LmfitService1D_configure(lmfit_service):
|
||||
|
@ -18,7 +18,7 @@ from device_server.rpc_mixin import RPCMixin
|
||||
|
||||
logger = bec_logger.logger
|
||||
|
||||
consumer_stop = threading.Event()
|
||||
register_stop = threading.Event()
|
||||
|
||||
|
||||
class DisabledDeviceError(Exception):
|
||||
@ -38,14 +38,10 @@ class DeviceServer(RPCMixin, BECService):
|
||||
super().__init__(config, connector_cls, unique_service=True)
|
||||
self._tasks = []
|
||||
self.device_manager = None
|
||||
self.threads = []
|
||||
self.sig_thread = None
|
||||
self.sig_thread = self.connector.consumer(
|
||||
self.connector.register(
|
||||
MessageEndpoints.scan_queue_modification(),
|
||||
cb=self.consumer_interception_callback,
|
||||
parent=self,
|
||||
cb=self.register_interception_callback,
|
||||
)
|
||||
self.sig_thread.start()
|
||||
self.executor = ThreadPoolExecutor(max_workers=4)
|
||||
self._start_device_manager()
|
||||
|
||||
@ -55,19 +51,16 @@ class DeviceServer(RPCMixin, BECService):
|
||||
|
||||
def start(self) -> None:
|
||||
"""start the device server"""
|
||||
if consumer_stop.is_set():
|
||||
consumer_stop.clear()
|
||||
if register_stop.is_set():
|
||||
register_stop.clear()
|
||||
|
||||
self.connector.register(
|
||||
MessageEndpoints.device_instructions(),
|
||||
event=register_stop,
|
||||
cb=self.instructions_callback,
|
||||
parent=self,
|
||||
)
|
||||
|
||||
self.threads = [
|
||||
self.connector.consumer(
|
||||
MessageEndpoints.device_instructions(),
|
||||
event=consumer_stop,
|
||||
cb=self.instructions_callback,
|
||||
parent=self,
|
||||
)
|
||||
]
|
||||
for thread in self.threads:
|
||||
thread.start()
|
||||
self.status = BECStatus.RUNNING
|
||||
|
||||
def update_status(self, status: BECStatus):
|
||||
@ -76,17 +69,13 @@ class DeviceServer(RPCMixin, BECService):
|
||||
|
||||
def stop(self) -> None:
|
||||
"""stop the device server"""
|
||||
consumer_stop.set()
|
||||
for thread in self.threads:
|
||||
thread.join()
|
||||
register_stop.set()
|
||||
self.status = BECStatus.IDLE
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""shutdown the device server"""
|
||||
super().shutdown()
|
||||
self.stop()
|
||||
self.sig_thread.signal_event.set()
|
||||
self.sig_thread.join()
|
||||
self.device_manager.shutdown()
|
||||
|
||||
def _update_device_metadata(self, instr) -> None:
|
||||
@ -97,8 +86,7 @@ class DeviceServer(RPCMixin, BECService):
|
||||
device_root = dev.split(".")[0]
|
||||
self.device_manager.devices.get(device_root).metadata = instr.metadata
|
||||
|
||||
@staticmethod
|
||||
def consumer_interception_callback(msg, *, parent, **_kwargs) -> None:
|
||||
def register_interception_callback(self, msg, **_kwargs) -> None:
|
||||
"""callback for receiving scan modifications / interceptions"""
|
||||
mvalue = msg.value
|
||||
if mvalue is None:
|
||||
@ -106,7 +94,7 @@ class DeviceServer(RPCMixin, BECService):
|
||||
return
|
||||
logger.info(f"Receiving: {mvalue.content}")
|
||||
if mvalue.content.get("action") in ["pause", "abort", "halt"]:
|
||||
parent.stop_devices()
|
||||
self.stop_devices()
|
||||
|
||||
def stop_devices(self) -> None:
|
||||
"""stop all enabled devices"""
|
||||
@ -279,7 +267,7 @@ class DeviceServer(RPCMixin, BECService):
|
||||
devices = instr.content["device"]
|
||||
if not isinstance(devices, list):
|
||||
devices = [devices]
|
||||
pipe = self.producer.pipeline()
|
||||
pipe = self.connector.pipeline()
|
||||
for dev in devices:
|
||||
obj = self.device_manager.devices.get(dev)
|
||||
obj.metadata = instr.metadata
|
||||
@ -288,11 +276,11 @@ class DeviceServer(RPCMixin, BECService):
|
||||
dev_msg = messages.DeviceReqStatusMessage(
|
||||
device=dev, success=True, metadata=instr.metadata
|
||||
)
|
||||
self.producer.set_and_publish(MessageEndpoints.device_req_status(dev), dev_msg, pipe)
|
||||
self.connector.set_and_publish(MessageEndpoints.device_req_status(dev), dev_msg, pipe)
|
||||
pipe.execute()
|
||||
|
||||
def _status_callback(self, status):
|
||||
pipe = self.producer.pipeline()
|
||||
pipe = self.connector.pipeline()
|
||||
if hasattr(status, "device"):
|
||||
obj = status.device
|
||||
else:
|
||||
@ -302,12 +290,12 @@ class DeviceServer(RPCMixin, BECService):
|
||||
device=device_name, success=status.success, metadata=status.instruction.metadata
|
||||
)
|
||||
logger.debug(f"req status for device {device_name}: {status.success}")
|
||||
self.producer.set_and_publish(
|
||||
self.connector.set_and_publish(
|
||||
MessageEndpoints.device_req_status(device_name), dev_msg, pipe
|
||||
)
|
||||
response = status.instruction.metadata.get("response")
|
||||
if response:
|
||||
self.producer.lpush(
|
||||
self.connector.lpush(
|
||||
MessageEndpoints.device_req_status(status.instruction.metadata["RID"]),
|
||||
dev_msg,
|
||||
pipe,
|
||||
@ -328,7 +316,7 @@ class DeviceServer(RPCMixin, BECService):
|
||||
dev_config_msg = messages.DeviceMessage(
|
||||
signals=obj.root.read_configuration(), metadata=metadata
|
||||
)
|
||||
self.producer.set_and_publish(
|
||||
self.connector.set_and_publish(
|
||||
MessageEndpoints.device_read_configuration(obj.root.name), dev_config_msg, pipe
|
||||
)
|
||||
|
||||
@ -342,7 +330,7 @@ class DeviceServer(RPCMixin, BECService):
|
||||
|
||||
def _read_and_update_devices(self, devices: list[str], metadata: dict) -> list:
|
||||
start = time.time()
|
||||
pipe = self.producer.pipeline()
|
||||
pipe = self.connector.pipeline()
|
||||
signal_container = []
|
||||
for dev in devices:
|
||||
device_root = dev.split(".")[0]
|
||||
@ -354,17 +342,17 @@ class DeviceServer(RPCMixin, BECService):
|
||||
except Exception as exc:
|
||||
signals = self._retry_obj_method(dev, obj, "read", exc)
|
||||
|
||||
self.producer.set_and_publish(
|
||||
self.connector.set_and_publish(
|
||||
MessageEndpoints.device_read(device_root),
|
||||
messages.DeviceMessage(signals=signals, metadata=metadata),
|
||||
pipe,
|
||||
)
|
||||
self.producer.set_and_publish(
|
||||
self.connector.set_and_publish(
|
||||
MessageEndpoints.device_readback(device_root),
|
||||
messages.DeviceMessage(signals=signals, metadata=metadata),
|
||||
pipe,
|
||||
)
|
||||
self.producer.set(
|
||||
self.connector.set(
|
||||
MessageEndpoints.device_status(device_root),
|
||||
messages.DeviceStatusMessage(device=device_root, status=0, metadata=metadata),
|
||||
pipe,
|
||||
@ -377,7 +365,7 @@ class DeviceServer(RPCMixin, BECService):
|
||||
|
||||
def _read_config_and_update_devices(self, devices: list[str], metadata: dict) -> list:
|
||||
start = time.time()
|
||||
pipe = self.producer.pipeline()
|
||||
pipe = self.connector.pipeline()
|
||||
signal_container = []
|
||||
for dev in devices:
|
||||
self.device_manager.devices.get(dev).metadata = metadata
|
||||
@ -387,7 +375,7 @@ class DeviceServer(RPCMixin, BECService):
|
||||
signal_container.append(signals)
|
||||
except Exception as exc:
|
||||
signals = self._retry_obj_method(dev, obj, "read_configuration", exc)
|
||||
self.producer.set_and_publish(
|
||||
self.connector.set_and_publish(
|
||||
MessageEndpoints.device_read_configuration(dev),
|
||||
messages.DeviceMessage(signals=signals, metadata=metadata),
|
||||
pipe,
|
||||
@ -420,9 +408,11 @@ class DeviceServer(RPCMixin, BECService):
|
||||
f"Failed to run {method} on device {device_root}. Trying to load an old value."
|
||||
)
|
||||
if method == "read":
|
||||
old_msg = self.producer.get(MessageEndpoints.device_read(device_root))
|
||||
old_msg = self.connector.get(MessageEndpoints.device_read(device_root))
|
||||
elif method == "read_configuration":
|
||||
old_msg = self.producer.get(MessageEndpoints.device_read_configuration(device_root))
|
||||
old_msg = self.connector.get(
|
||||
MessageEndpoints.device_read_configuration(device_root)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown method {method}.")
|
||||
if not old_msg:
|
||||
@ -435,7 +425,7 @@ class DeviceServer(RPCMixin, BECService):
|
||||
if not isinstance(devices, list):
|
||||
devices = [devices]
|
||||
|
||||
pipe = self.producer.pipeline()
|
||||
pipe = self.connector.pipeline()
|
||||
for dev in devices:
|
||||
obj = self.device_manager.devices[dev].obj
|
||||
if hasattr(obj, "_staged"):
|
||||
@ -444,7 +434,7 @@ class DeviceServer(RPCMixin, BECService):
|
||||
logger.info(f"Device {obj.name} was already staged and will be first unstaged.")
|
||||
self.device_manager.devices[dev].obj.unstage()
|
||||
self.device_manager.devices[dev].obj.stage()
|
||||
self.producer.set(
|
||||
self.connector.set(
|
||||
MessageEndpoints.device_staged(dev),
|
||||
messages.DeviceStatusMessage(device=dev, status=1, metadata=instr.metadata),
|
||||
pipe,
|
||||
@ -456,7 +446,7 @@ class DeviceServer(RPCMixin, BECService):
|
||||
if not isinstance(devices, list):
|
||||
devices = [devices]
|
||||
|
||||
pipe = self.producer.pipeline()
|
||||
pipe = self.connector.pipeline()
|
||||
for dev in devices:
|
||||
obj = self.device_manager.devices[dev].obj
|
||||
if hasattr(obj, "_staged"):
|
||||
@ -465,7 +455,7 @@ class DeviceServer(RPCMixin, BECService):
|
||||
self.device_manager.devices[dev].obj.unstage()
|
||||
else:
|
||||
logger.debug(f"Device {obj.name} was already unstaged.")
|
||||
self.producer.set(
|
||||
self.connector.set(
|
||||
MessageEndpoints.device_staged(dev),
|
||||
messages.DeviceStatusMessage(device=dev, status=0, metadata=instr.metadata),
|
||||
pipe,
|
||||
|
@ -16,17 +16,11 @@ class ConfigUpdateHandler:
|
||||
def __init__(self, device_manager: DeviceManagerDS) -> None:
|
||||
self.device_manager = device_manager
|
||||
self.connector = self.device_manager.connector
|
||||
self._config_request_handler = None
|
||||
|
||||
self._start_config_handler()
|
||||
|
||||
def _start_config_handler(self) -> None:
|
||||
self._config_request_handler = self.connector.consumer(
|
||||
self.connector.register(
|
||||
MessageEndpoints.device_server_config_request(),
|
||||
cb=self._device_config_callback,
|
||||
parent=self,
|
||||
)
|
||||
self._config_request_handler.start()
|
||||
|
||||
@staticmethod
|
||||
def _device_config_callback(msg, *, parent, **_kwargs) -> None:
|
||||
@ -74,7 +68,7 @@ class ConfigUpdateHandler:
|
||||
accepted=accepted, message=error_msg, metadata=metadata
|
||||
)
|
||||
RID = metadata.get("RID")
|
||||
self.device_manager.producer.set(
|
||||
self.device_manager.connector.set(
|
||||
MessageEndpoints.device_config_request_response(RID), msg, expire=60
|
||||
)
|
||||
|
||||
@ -97,7 +91,7 @@ class ConfigUpdateHandler:
|
||||
"low": device.obj.low_limit_travel.get(),
|
||||
"high": device.obj.high_limit_travel.get(),
|
||||
}
|
||||
self.device_manager.producer.set_and_publish(
|
||||
self.device_manager.connector.set_and_publish(
|
||||
MessageEndpoints.device_limits(device.name),
|
||||
messages.DeviceMessage(signals=limits),
|
||||
)
|
||||
|
@ -51,7 +51,7 @@ class DSDevice(DeviceBase):
|
||||
self.metadata = {}
|
||||
self.initialized = False
|
||||
|
||||
def initialize_device_buffer(self, producer):
|
||||
def initialize_device_buffer(self, connector):
|
||||
"""initialize the device read and readback buffer on redis with a new reading"""
|
||||
dev_msg = messages.DeviceMessage(signals=self.obj.read(), metadata={})
|
||||
dev_config_msg = messages.DeviceMessage(signals=self.obj.read_configuration(), metadata={})
|
||||
@ -62,14 +62,14 @@ class DSDevice(DeviceBase):
|
||||
}
|
||||
else:
|
||||
limits = None
|
||||
pipe = producer.pipeline()
|
||||
producer.set_and_publish(MessageEndpoints.device_readback(self.name), dev_msg, pipe=pipe)
|
||||
producer.set(topic=MessageEndpoints.device_read(self.name), msg=dev_msg, pipe=pipe)
|
||||
producer.set_and_publish(
|
||||
pipe = connector.pipeline()
|
||||
connector.set_and_publish(MessageEndpoints.device_readback(self.name), dev_msg, pipe=pipe)
|
||||
connector.set(topic=MessageEndpoints.device_read(self.name), msg=dev_msg, pipe=pipe)
|
||||
connector.set_and_publish(
|
||||
MessageEndpoints.device_read_configuration(self.name), dev_config_msg, pipe=pipe
|
||||
)
|
||||
if limits is not None:
|
||||
producer.set_and_publish(
|
||||
connector.set_and_publish(
|
||||
MessageEndpoints.device_limits(self.name),
|
||||
messages.DeviceMessage(signals=limits),
|
||||
pipe=pipe,
|
||||
@ -318,7 +318,7 @@ class DeviceManagerDS(DeviceManagerBase):
|
||||
self.update_config(obj, config)
|
||||
|
||||
# refresh the device info
|
||||
pipe = self.producer.pipeline()
|
||||
pipe = self.connector.pipeline()
|
||||
self.reset_device_data(obj, pipe)
|
||||
self.publish_device_info(obj, pipe)
|
||||
pipe.execute()
|
||||
@ -369,7 +369,7 @@ class DeviceManagerDS(DeviceManagerBase):
|
||||
def initialize_enabled_device(self, opaas_obj):
|
||||
"""connect to an enabled device and initialize the device buffer"""
|
||||
self.connect_device(opaas_obj.obj)
|
||||
opaas_obj.initialize_device_buffer(self.producer)
|
||||
opaas_obj.initialize_device_buffer(self.connector)
|
||||
|
||||
@staticmethod
|
||||
def disconnect_device(obj):
|
||||
@ -420,7 +420,7 @@ class DeviceManagerDS(DeviceManagerBase):
|
||||
"""
|
||||
|
||||
interface = get_device_info(obj, {})
|
||||
self.producer.set(
|
||||
self.connector.set(
|
||||
MessageEndpoints.device_info(obj.name),
|
||||
messages.DeviceInfoMessage(device=obj.name, info=interface),
|
||||
pipe,
|
||||
@ -428,9 +428,9 @@ class DeviceManagerDS(DeviceManagerBase):
|
||||
|
||||
def reset_device_data(self, obj: OphydObject, pipe=None) -> None:
|
||||
"""delete all device data and device info"""
|
||||
self.producer.delete(MessageEndpoints.device_status(obj.name), pipe)
|
||||
self.producer.delete(MessageEndpoints.device_read(obj.name), pipe)
|
||||
self.producer.delete(MessageEndpoints.device_info(obj.name), pipe)
|
||||
self.connector.delete(MessageEndpoints.device_status(obj.name), pipe)
|
||||
self.connector.delete(MessageEndpoints.device_read(obj.name), pipe)
|
||||
self.connector.delete(MessageEndpoints.device_info(obj.name), pipe)
|
||||
|
||||
def _obj_callback_readback(self, *_args, obj: OphydObject, **kwargs):
|
||||
if obj.connected:
|
||||
@ -438,8 +438,8 @@ class DeviceManagerDS(DeviceManagerBase):
|
||||
signals = obj.read()
|
||||
metadata = self.devices.get(obj.root.name).metadata
|
||||
dev_msg = messages.DeviceMessage(signals=signals, metadata=metadata)
|
||||
pipe = self.producer.pipeline()
|
||||
self.producer.set_and_publish(MessageEndpoints.device_readback(name), dev_msg, pipe)
|
||||
pipe = self.connector.pipeline()
|
||||
self.connector.set_and_publish(MessageEndpoints.device_readback(name), dev_msg, pipe)
|
||||
pipe.execute()
|
||||
|
||||
@typechecked
|
||||
@ -466,7 +466,7 @@ class DeviceManagerDS(DeviceManagerBase):
|
||||
metadata = self.devices[name].metadata
|
||||
msg = messages.DeviceMonitorMessage(device=name, data=value, metadata=metadata)
|
||||
stream_msg = {"data": msg}
|
||||
self.producer.xadd(
|
||||
self.connector.xadd(
|
||||
MessageEndpoints.device_monitor(name),
|
||||
stream_msg,
|
||||
max_size=min(100, int(max_size // dsize)),
|
||||
@ -476,7 +476,7 @@ class DeviceManagerDS(DeviceManagerBase):
|
||||
device = kwargs["obj"].root.name
|
||||
status = 0
|
||||
metadata = self.devices[device].metadata
|
||||
self.producer.send(
|
||||
self.connector.send(
|
||||
MessageEndpoints.device_status(device),
|
||||
messages.DeviceStatusMessage(device=device, status=status, metadata=metadata),
|
||||
)
|
||||
@ -489,8 +489,8 @@ class DeviceManagerDS(DeviceManagerBase):
|
||||
device = kwargs["obj"].root.name
|
||||
status = int(kwargs.get("value"))
|
||||
metadata = self.devices[device].metadata
|
||||
self.producer.set(
|
||||
MessageEndpoints.device_status(kwargs["obj"].root.name),
|
||||
self.connector.set(
|
||||
MessageEndpoints.device_status(device),
|
||||
messages.DeviceStatusMessage(device=device, status=status, metadata=metadata),
|
||||
)
|
||||
|
||||
@ -521,12 +521,12 @@ class DeviceManagerDS(DeviceManagerBase):
|
||||
)
|
||||
)
|
||||
ds_obj.emitted_points[metadata["scanID"]] = max_points
|
||||
pipe = self.producer.pipeline()
|
||||
self.producer.send(MessageEndpoints.device_read(obj.root.name), bundle, pipe=pipe)
|
||||
pipe = self.connector.pipeline()
|
||||
self.connector.send(MessageEndpoints.device_read(obj.root.name), bundle, pipe=pipe)
|
||||
msg = messages.DeviceStatusMessage(
|
||||
device=obj.root.name, status=max_points, metadata=metadata
|
||||
)
|
||||
self.producer.set_and_publish(
|
||||
self.connector.set_and_publish(
|
||||
MessageEndpoints.device_progress(obj.root.name), msg, pipe=pipe
|
||||
)
|
||||
pipe.execute()
|
||||
@ -536,4 +536,4 @@ class DeviceManagerDS(DeviceManagerBase):
|
||||
msg = messages.ProgressMessage(
|
||||
value=value, max_value=max_value, done=done, metadata=metadata
|
||||
)
|
||||
self.producer.set_and_publish(MessageEndpoints.device_progress(obj.root.name), msg)
|
||||
self.connector.set_and_publish(MessageEndpoints.device_progress(obj.root.name), msg)
|
||||
|
@ -70,7 +70,7 @@ class RPCMixin:
|
||||
def _send_rpc_result_to_client(
|
||||
self, device: str, instr_params: dict, res: Any, result: StringIO
|
||||
):
|
||||
self.producer.set(
|
||||
self.connector.set(
|
||||
MessageEndpoints.device_rpc(instr_params.get("rpc_id")),
|
||||
messages.DeviceRPCMessage(
|
||||
device=device, return_val=res, out=result.getvalue(), success=True
|
||||
@ -175,7 +175,7 @@ class RPCMixin:
|
||||
}
|
||||
logger.info(f"Received exception: {exc_formatted}, {exc}")
|
||||
instr_params = instr.content.get("parameter")
|
||||
self.producer.set(
|
||||
self.connector.set(
|
||||
MessageEndpoints.device_rpc(instr_params.get("rpc_id")),
|
||||
messages.DeviceRPCMessage(
|
||||
device=instr.content["device"], return_val=None, out=exc_formatted, success=False
|
||||
|
@ -6,7 +6,7 @@ import numpy as np
|
||||
import pytest
|
||||
import yaml
|
||||
from bec_lib import MessageEndpoints, messages
|
||||
from bec_lib.tests.utils import ConnectorMock, ProducerMock, create_session_from_config
|
||||
from bec_lib.tests.utils import ConnectorMock, create_session_from_config
|
||||
|
||||
from device_server.devices.devicemanager import DeviceManagerDS
|
||||
|
||||
@ -52,7 +52,7 @@ def load_device_manager():
|
||||
service_mock = mock.MagicMock()
|
||||
service_mock.connector = ConnectorMock("", store_data=False)
|
||||
device_manager = DeviceManagerDS(service_mock, "")
|
||||
device_manager.producer = service_mock.connector.producer()
|
||||
device_manager.connector = service_mock.connector
|
||||
device_manager.config_update_handler = mock.MagicMock()
|
||||
with open(f"{dir_path}/tests/test_config.yaml", "r") as session_file:
|
||||
device_manager._session = create_session_from_config(yaml.safe_load(session_file))
|
||||
@ -133,10 +133,10 @@ def test_flyer_event_callback():
|
||||
device_manager._obj_flyer_callback(
|
||||
obj=samx.obj, value={"data": {"idata": np.random.rand(20), "edata": np.random.rand(20)}}
|
||||
)
|
||||
pipe = device_manager.producer.pipeline()
|
||||
pipe = device_manager.connector.pipeline()
|
||||
bundle, progress = pipe._pipe_buffer[-2:]
|
||||
|
||||
# check producer method
|
||||
# check connector method
|
||||
assert bundle[0] == "send"
|
||||
assert progress[0] == "set_and_publish"
|
||||
|
||||
@ -157,9 +157,9 @@ def test_obj_progress_callback():
|
||||
samx = device_manager.devices.samx
|
||||
samx.metadata = {"scanID": "12345"}
|
||||
|
||||
with mock.patch.object(device_manager, "producer") as mock_producer:
|
||||
with mock.patch.object(device_manager, "connector") as mock_connector:
|
||||
device_manager._obj_progress_callback(obj=samx.obj, value=1, max_value=2, done=False)
|
||||
mock_producer.set_and_publish.assert_called_once_with(
|
||||
mock_connector.set_and_publish.assert_called_once_with(
|
||||
MessageEndpoints.device_progress("samx"),
|
||||
messages.ProgressMessage(
|
||||
value=1, max_value=2, done=False, metadata={"scanID": "12345"}
|
||||
@ -176,9 +176,9 @@ def test_obj_monitor_callback(value):
|
||||
eiger.metadata = {"scanID": "12345"}
|
||||
value_size = len(value.tobytes()) / 1e6 # MB
|
||||
max_size = 100
|
||||
with mock.patch.object(device_manager, "producer") as mock_producer:
|
||||
with mock.patch.object(device_manager, "connector") as mock_connector:
|
||||
device_manager._obj_callback_monitor(obj=eiger.obj, value=value)
|
||||
mock_producer.xadd.assert_called_once_with(
|
||||
mock_connector.xadd.assert_called_once_with(
|
||||
MessageEndpoints.device_monitor(eiger.name),
|
||||
{
|
||||
"data": messages.DeviceMonitorMessage(
|
||||
|
@ -7,7 +7,7 @@ from bec_lib import Alarms, MessageEndpoints, ServiceConfig, messages
|
||||
from bec_lib.device import OnFailure
|
||||
from bec_lib.messages import BECStatus
|
||||
from bec_lib.redis_connector import MessageObject
|
||||
from bec_lib.tests.utils import ConnectorMock, ConsumerMock
|
||||
from bec_lib.tests.utils import ConnectorMock
|
||||
from ophyd import Staged
|
||||
from ophyd.utils import errors as ophyd_errors
|
||||
from test_device_manager_ds import device_manager, load_device_manager
|
||||
@ -54,8 +54,6 @@ def test_start(device_server_mock):
|
||||
|
||||
device_server.start()
|
||||
|
||||
assert device_server.threads
|
||||
assert isinstance(device_server.threads[0], ConsumerMock)
|
||||
assert device_server.status == BECStatus.RUNNING
|
||||
|
||||
|
||||
@ -187,10 +185,10 @@ def test_stop_devices(device_server_mock):
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_consumer_interception_callback(device_server_mock, msg, stop_called):
|
||||
def test_register_interception_callback(device_server_mock, msg, stop_called):
|
||||
device_server = device_server_mock
|
||||
with mock.patch.object(device_server, "stop_devices") as stop:
|
||||
device_server.consumer_interception_callback(msg, parent=device_server)
|
||||
device_server.register_interception_callback(msg, parent=device_server)
|
||||
if stop_called:
|
||||
stop.assert_called_once()
|
||||
else:
|
||||
@ -640,7 +638,7 @@ def test_set_device(device_server_mock, instr):
|
||||
while True:
|
||||
res = [
|
||||
msg
|
||||
for msg in device_server.producer.message_sent
|
||||
for msg in device_server.connector.message_sent
|
||||
if msg["queue"] == MessageEndpoints.device_req_status("samx")
|
||||
]
|
||||
if res:
|
||||
@ -676,7 +674,7 @@ def test_read_device(device_server_mock, instr):
|
||||
for device in devices:
|
||||
res = [
|
||||
msg
|
||||
for msg in device_server.producer.message_sent
|
||||
for msg in device_server.connector.message_sent
|
||||
if msg["queue"] == MessageEndpoints.device_read(device)
|
||||
]
|
||||
assert res[-1]["msg"].metadata["RID"] == instr.metadata["RID"]
|
||||
@ -690,7 +688,7 @@ def test_read_config_and_update_devices(device_server_mock, devices):
|
||||
for device in devices:
|
||||
res = [
|
||||
msg
|
||||
for msg in device_server.producer.message_sent
|
||||
for msg in device_server.connector.message_sent
|
||||
if msg["queue"] == MessageEndpoints.device_read_configuration(device)
|
||||
]
|
||||
config = device_server.device_manager.devices[device].obj.read_configuration()
|
||||
@ -755,8 +753,8 @@ def test_retry_obj_method_buffer(device_server_mock, instr):
|
||||
return
|
||||
|
||||
signals_before = getattr(samx.obj, instr)()
|
||||
device_server.producer = mock.MagicMock()
|
||||
device_server.producer.get.return_value = messages.DeviceMessage(
|
||||
device_server.connector = mock.MagicMock()
|
||||
device_server.connector.get.return_value = messages.DeviceMessage(
|
||||
signals=signals_before, metadata={"RID": "test", "stream": "primary"}
|
||||
)
|
||||
|
||||
|
@ -13,7 +13,7 @@ from device_server.rpc_mixin import RPCMixin
|
||||
def rpc_cls():
|
||||
rpc_mixin = RPCMixin()
|
||||
rpc_mixin.connector = mock.MagicMock()
|
||||
rpc_mixin.producer = mock.MagicMock()
|
||||
rpc_mixin.connector = mock.MagicMock()
|
||||
rpc_mixin.device_manager = mock.MagicMock()
|
||||
yield rpc_mixin
|
||||
|
||||
@ -93,7 +93,7 @@ def test_get_result_from_rpc_list_from_stage(rpc_cls):
|
||||
|
||||
def test_send_rpc_exception(rpc_cls, instr):
|
||||
rpc_cls._send_rpc_exception(Exception(), instr)
|
||||
rpc_cls.producer.set.assert_called_once_with(
|
||||
rpc_cls.connector.set.assert_called_once_with(
|
||||
MessageEndpoints.device_rpc("rpc_id"),
|
||||
messages.DeviceRPCMessage(
|
||||
device="device",
|
||||
@ -108,7 +108,7 @@ def test_send_rpc_result_to_client(rpc_cls):
|
||||
result = mock.MagicMock()
|
||||
result.getvalue.return_value = "result"
|
||||
rpc_cls._send_rpc_result_to_client("device", {"rpc_id": "rpc_id"}, 1, result)
|
||||
rpc_cls.producer.set.assert_called_once_with(
|
||||
rpc_cls.connector.set.assert_called_once_with(
|
||||
MessageEndpoints.device_rpc("rpc_id"),
|
||||
messages.DeviceRPCMessage(device="device", return_val=1, out="result", success=True),
|
||||
expire=1800,
|
||||
|
@ -325,7 +325,7 @@ class NexusFileWriter(FileWriter):
|
||||
file_data[key] = val if not isinstance(val, list) else merge_dicts(val)
|
||||
msg_data = {"file_path": file_path, "data": file_data}
|
||||
msg = messages.FileContentMessage(**msg_data)
|
||||
self.file_writer_manager.producer.set_and_publish(MessageEndpoints.file_content(), msg)
|
||||
self.file_writer_manager.connector.set_and_publish(MessageEndpoints.file_content(), msg)
|
||||
|
||||
with h5py.File(file_path, "w") as file:
|
||||
HDF5StorageWriter.write(writer_storage._storage, device_storage, file)
|
||||
|
@ -80,10 +80,13 @@ class FileWriterManager(BECService):
|
||||
self._lock = threading.RLock()
|
||||
self.file_writer_config = self._service_config.service_config.get("file_writer")
|
||||
self.writer_mixin = FileWriterMixin(self.file_writer_config)
|
||||
self.producer = self.connector.producer()
|
||||
self._start_device_manager()
|
||||
self._start_scan_segment_consumer()
|
||||
self._start_scan_status_consumer()
|
||||
self.connector.register(
|
||||
patterns=MessageEndpoints.scan_segment(), cb=self._scan_segment_callback, parent=self
|
||||
)
|
||||
self.connector.register(
|
||||
MessageEndpoints.scan_status(), cb=self._scan_status_callback, parent=self
|
||||
)
|
||||
self.scan_storage = {}
|
||||
self.file_writer = NexusFileWriter(self)
|
||||
|
||||
@ -92,20 +95,7 @@ class FileWriterManager(BECService):
|
||||
self.device_manager = DeviceManagerBase(self)
|
||||
self.device_manager.initialize([self.bootstrap_server])
|
||||
|
||||
def _start_scan_segment_consumer(self):
|
||||
self._scan_segment_consumer = self.connector.consumer(
|
||||
pattern=MessageEndpoints.scan_segment(), cb=self._scan_segment_callback, parent=self
|
||||
)
|
||||
self._scan_segment_consumer.start()
|
||||
|
||||
def _start_scan_status_consumer(self):
|
||||
self._scan_status_consumer = self.connector.consumer(
|
||||
MessageEndpoints.scan_status(), cb=self._scan_status_callback, parent=self
|
||||
)
|
||||
self._scan_status_consumer.start()
|
||||
|
||||
@staticmethod
|
||||
def _scan_segment_callback(msg: MessageObject, *, parent: FileWriterManager):
|
||||
def _scan_segment_callback(self, msg: MessageObject, *, parent: FileWriterManager):
|
||||
msgs = msg.value
|
||||
for scan_msg in msgs:
|
||||
parent.insert_to_scan_storage(scan_msg)
|
||||
@ -188,7 +178,7 @@ class FileWriterManager(BECService):
|
||||
return
|
||||
if self.scan_storage[scanID].baseline:
|
||||
return
|
||||
baseline = self.producer.get(MessageEndpoints.public_scan_baseline(scanID))
|
||||
baseline = self.connector.get(MessageEndpoints.public_scan_baseline(scanID))
|
||||
if not baseline:
|
||||
return
|
||||
self.scan_storage[scanID].baseline = baseline.content["data"]
|
||||
@ -205,13 +195,13 @@ class FileWriterManager(BECService):
|
||||
"""
|
||||
if not self.scan_storage.get(scanID):
|
||||
return
|
||||
msgs = self.producer.keys(MessageEndpoints.public_file(scanID, "*"))
|
||||
msgs = self.connector.keys(MessageEndpoints.public_file(scanID, "*"))
|
||||
if not msgs:
|
||||
return
|
||||
|
||||
# extract name from 'public/<scanID>/file/<name>'
|
||||
names = [msg.decode().split("/")[-1] for msg in msgs]
|
||||
file_msgs = [self.producer.get(msg.decode()) for msg in msgs]
|
||||
file_msgs = [self.connector.get(msg.decode()) for msg in msgs]
|
||||
if not file_msgs:
|
||||
return
|
||||
for name, file_msg in zip(names, file_msgs):
|
||||
@ -236,7 +226,7 @@ class FileWriterManager(BECService):
|
||||
if not self.scan_storage.get(scanID):
|
||||
return
|
||||
# get all async devices
|
||||
async_device_keys = self.producer.keys(MessageEndpoints.device_async_readback(scanID, "*"))
|
||||
async_device_keys = self.connector.keys(MessageEndpoints.device_async_readback(scanID, "*"))
|
||||
if not async_device_keys:
|
||||
return
|
||||
for device_key in async_device_keys:
|
||||
@ -244,7 +234,7 @@ class FileWriterManager(BECService):
|
||||
device_name = key.split(MessageEndpoints.device_async_readback(scanID, ""))[-1].split(
|
||||
":"
|
||||
)[0]
|
||||
msgs = self.producer.xrange(key, min="-", max="+")
|
||||
msgs = self.connector.xrange(key, min="-", max="+")
|
||||
if not msgs:
|
||||
continue
|
||||
self._process_async_data(msgs, scanID, device_name)
|
||||
@ -298,7 +288,7 @@ class FileWriterManager(BECService):
|
||||
|
||||
try:
|
||||
file_path = self.writer_mixin.compile_full_filename(scan, file_suffix)
|
||||
self.producer.set_and_publish(
|
||||
self.connector.set_and_publish(
|
||||
MessageEndpoints.public_file(scanID, "master"),
|
||||
messages.FileMessage(file_path=file_path, done=False),
|
||||
)
|
||||
@ -319,7 +309,7 @@ class FileWriterManager(BECService):
|
||||
)
|
||||
successful = False
|
||||
self.scan_storage.pop(scanID)
|
||||
self.producer.set_and_publish(
|
||||
self.connector.set_and_publish(
|
||||
MessageEndpoints.public_file(scanID, "master"),
|
||||
messages.FileMessage(file_path=file_path, successful=successful),
|
||||
)
|
||||
|
@ -25,7 +25,7 @@ def load_FileWriter():
|
||||
service_mock = mock.MagicMock()
|
||||
service_mock.connector = ConnectorMock("")
|
||||
device_manager = DeviceManagerBase(service_mock, "")
|
||||
device_manager.producer = service_mock.connector.producer()
|
||||
device_manager.connector = service_mock.connector
|
||||
with open(f"{dir_path}/tests/test_config.yaml", "r") as session_file:
|
||||
device_manager._session = create_session_from_config(yaml.safe_load(session_file))
|
||||
device_manager._load_session()
|
||||
@ -152,13 +152,13 @@ def test_write_file_raises_alarm_on_error():
|
||||
def test_update_baseline_reading():
|
||||
file_manager = load_FileWriter()
|
||||
file_manager.scan_storage["scanID"] = ScanStorage(10, "scanID")
|
||||
with mock.patch.object(file_manager, "producer") as mock_producer:
|
||||
mock_producer.get.return_value = messages.ScanBaselineMessage(
|
||||
with mock.patch.object(file_manager, "connector") as mock_connector:
|
||||
mock_connector.get.return_value = messages.ScanBaselineMessage(
|
||||
scanID="scanID", data={"data": "data"}
|
||||
)
|
||||
file_manager.update_baseline_reading("scanID")
|
||||
assert file_manager.scan_storage["scanID"].baseline == {"data": "data"}
|
||||
mock_producer.get.assert_called_once_with(MessageEndpoints.public_scan_baseline("scanID"))
|
||||
mock_connector.get.assert_called_once_with(MessageEndpoints.public_scan_baseline("scanID"))
|
||||
|
||||
|
||||
def test_scan_storage_append():
|
||||
@ -178,30 +178,30 @@ def test_scan_storage_ready_to_write():
|
||||
|
||||
def test_update_file_references():
|
||||
file_manager = load_FileWriter()
|
||||
with mock.patch.object(file_manager, "producer") as mock_producer:
|
||||
with mock.patch.object(file_manager, "connector") as mock_connector:
|
||||
file_manager.update_file_references("scanID")
|
||||
mock_producer.keys.assert_not_called()
|
||||
mock_connector.keys.assert_not_called()
|
||||
|
||||
|
||||
def test_update_file_references_gets_keys():
|
||||
file_manager = load_FileWriter()
|
||||
file_manager.scan_storage["scanID"] = ScanStorage(10, "scanID")
|
||||
with mock.patch.object(file_manager, "producer") as mock_producer:
|
||||
with mock.patch.object(file_manager, "connector") as mock_connector:
|
||||
file_manager.update_file_references("scanID")
|
||||
mock_producer.keys.assert_called_once_with(MessageEndpoints.public_file("scanID", "*"))
|
||||
mock_connector.keys.assert_called_once_with(MessageEndpoints.public_file("scanID", "*"))
|
||||
|
||||
|
||||
def test_update_async_data():
|
||||
file_manager = load_FileWriter()
|
||||
file_manager.scan_storage["scanID"] = ScanStorage(10, "scanID")
|
||||
with mock.patch.object(file_manager, "producer") as mock_producer:
|
||||
with mock.patch.object(file_manager, "connector") as mock_connector:
|
||||
with mock.patch.object(file_manager, "_process_async_data") as mock_process:
|
||||
key = MessageEndpoints.device_async_readback("scanID", "dev1")
|
||||
mock_producer.keys.return_value = [key.encode()]
|
||||
mock_connector.keys.return_value = [key.encode()]
|
||||
data = [(b"0-0", b'{"data": "data"}')]
|
||||
mock_producer.xrange.return_value = data
|
||||
mock_connector.xrange.return_value = data
|
||||
file_manager.update_async_data("scanID")
|
||||
mock_producer.xrange.assert_called_once_with(key, min="-", max="+")
|
||||
mock_connector.xrange.assert_called_once_with(key, min="-", max="+")
|
||||
mock_process.assert_called_once_with(data, "scanID", "dev1")
|
||||
|
||||
|
||||
|
@ -14,7 +14,7 @@ if TYPE_CHECKING:
|
||||
|
||||
class BECEmitter(EmitterBase):
|
||||
def __init__(self, scan_bundler: ScanBundler) -> None:
|
||||
super().__init__(scan_bundler.producer)
|
||||
super().__init__(scan_bundler.connector)
|
||||
self.scan_bundler = scan_bundler
|
||||
|
||||
def on_scan_point_emit(self, scanID: str, pointID: int):
|
||||
@ -46,9 +46,16 @@ class BECEmitter(EmitterBase):
|
||||
data=sb.sync_storage[scanID]["baseline"],
|
||||
metadata=sb.sync_storage[scanID]["info"],
|
||||
)
|
||||
pipe = sb.producer.pipeline()
|
||||
sb.producer.set(
|
||||
MessageEndpoints.public_scan_baseline(scanID=scanID), msg, expire=1800, pipe=pipe
|
||||
pipe = sb.connector.pipeline()
|
||||
sb.connector.set(
|
||||
MessageEndpoints.public_scan_baseline(scanID=scanID),
|
||||
msg,
|
||||
expire=1800,
|
||||
pipe=pipe,
|
||||
)
|
||||
sb.connector.set_and_publish(
|
||||
MessageEndpoints.scan_baseline(),
|
||||
msg,
|
||||
pipe=pipe,
|
||||
)
|
||||
sb.producer.set_and_publish(MessageEndpoints.scan_baseline(), msg, pipe=pipe)
|
||||
pipe.execute()
|
||||
|
@ -17,7 +17,7 @@ if TYPE_CHECKING:
|
||||
|
||||
class BlueskyEmitter(EmitterBase):
|
||||
def __init__(self, scan_bundler: ScanBundler) -> None:
|
||||
super().__init__(scan_bundler.producer)
|
||||
super().__init__(scan_bundler.connector)
|
||||
self.scan_bundler = scan_bundler
|
||||
self.bluesky_metadata = {}
|
||||
|
||||
@ -27,7 +27,7 @@ class BlueskyEmitter(EmitterBase):
|
||||
self.bluesky_metadata[scanID] = {}
|
||||
doc = self._get_run_start_document(scanID)
|
||||
self.bluesky_metadata[scanID]["start"] = doc
|
||||
self.producer.raw_send(MessageEndpoints.bluesky_events(), msgpack.dumps(("start", doc)))
|
||||
self.connector.raw_send(MessageEndpoints.bluesky_events(), msgpack.dumps(("start", doc)))
|
||||
self.send_descriptor_document(scanID)
|
||||
|
||||
def _get_run_start_document(self, scanID) -> dict:
|
||||
@ -71,7 +71,7 @@ class BlueskyEmitter(EmitterBase):
|
||||
"""Bluesky only: send descriptor document"""
|
||||
doc = self._get_descriptor_document(scanID)
|
||||
self.bluesky_metadata[scanID]["descriptor"] = doc
|
||||
self.producer.raw_send(
|
||||
self.connector.raw_send(
|
||||
MessageEndpoints.bluesky_events(), msgpack.dumps(("descriptor", doc))
|
||||
)
|
||||
|
||||
@ -85,7 +85,7 @@ class BlueskyEmitter(EmitterBase):
|
||||
logger.warning(f"Failed to remove {scanID} from {storage}.")
|
||||
|
||||
def send_bluesky_scan_point(self, scanID, pointID) -> None:
|
||||
self.producer.raw_send(
|
||||
self.connector.raw_send(
|
||||
MessageEndpoints.bluesky_events(),
|
||||
msgpack.dumps(("event", self._prepare_bluesky_event_data(scanID, pointID))),
|
||||
)
|
||||
|
@ -6,16 +6,16 @@ from bec_lib import messages
|
||||
|
||||
|
||||
class EmitterBase:
|
||||
def __init__(self, producer) -> None:
|
||||
def __init__(self, connector) -> None:
|
||||
self._send_buffer = Queue()
|
||||
self.producer = producer
|
||||
self._start_buffered_producer()
|
||||
self.connector = connector
|
||||
self._start_buffered_connector()
|
||||
|
||||
def _start_buffered_producer(self):
|
||||
self._buffered_producer_thread = threading.Thread(
|
||||
def _start_buffered_connector(self):
|
||||
self._buffered_connector_thread = threading.Thread(
|
||||
target=self._buffered_publish, daemon=True, name="buffered_publisher"
|
||||
)
|
||||
self._buffered_producer_thread.start()
|
||||
self._buffered_connector_thread.start()
|
||||
|
||||
def add_message(self, msg: messages.BECMessage, endpoint: str, public: str = None):
|
||||
self._send_buffer.put((msg, endpoint, public))
|
||||
@ -37,20 +37,20 @@ class EmitterBase:
|
||||
time.sleep(0.1)
|
||||
return
|
||||
|
||||
pipe = self.producer.pipeline()
|
||||
pipe = self.connector.pipeline()
|
||||
msgs = messages.BundleMessage()
|
||||
_, endpoint, _ = msgs_to_send[0]
|
||||
for msg, endpoint, public in msgs_to_send:
|
||||
msg_dump = msg
|
||||
msgs.append(msg_dump)
|
||||
if public:
|
||||
self.producer.set(
|
||||
self.connector.set(
|
||||
public,
|
||||
msg_dump,
|
||||
pipe=pipe,
|
||||
expire=1800,
|
||||
)
|
||||
self.producer.send(endpoint, msgs, pipe=pipe)
|
||||
self.connector.send(endpoint, msgs, pipe=pipe)
|
||||
pipe.execute()
|
||||
|
||||
def on_init(self, scanID: str):
|
||||
|
@ -20,9 +20,23 @@ class ScanBundler(BECService):
|
||||
|
||||
self.device_manager = None
|
||||
self._start_device_manager()
|
||||
self._start_device_read_consumer()
|
||||
self._start_scan_queue_consumer()
|
||||
self._start_scan_status_consumer()
|
||||
self.connector.register(
|
||||
patterns=MessageEndpoints.device_read("*"),
|
||||
cb=self._device_read_callback,
|
||||
name="device_read_register",
|
||||
)
|
||||
self.connector.register(
|
||||
MessageEndpoints.scan_queue_status(),
|
||||
cb=self._scan_queue_callback,
|
||||
group_id="scan_bundler",
|
||||
name="scan_queue_register",
|
||||
)
|
||||
self.connector.register(
|
||||
MessageEndpoints.scan_status(),
|
||||
cb=self._scan_status_callback,
|
||||
group_id="scan_bundler",
|
||||
name="scan_status_register",
|
||||
)
|
||||
|
||||
self.sync_storage = {}
|
||||
self.monitored_devices = {}
|
||||
@ -56,56 +70,24 @@ class ScanBundler(BECService):
|
||||
self.device_manager = DeviceManagerBase(self)
|
||||
self.device_manager.initialize(self.bootstrap_server)
|
||||
|
||||
def _start_device_read_consumer(self):
|
||||
self._device_read_consumer = self.connector.consumer(
|
||||
pattern=MessageEndpoints.device_read("*"),
|
||||
cb=self._device_read_callback,
|
||||
parent=self,
|
||||
name="device_read_consumer",
|
||||
)
|
||||
self._device_read_consumer.start()
|
||||
|
||||
def _start_scan_queue_consumer(self):
|
||||
self._scan_queue_consumer = self.connector.consumer(
|
||||
MessageEndpoints.scan_queue_status(),
|
||||
cb=self._scan_queue_callback,
|
||||
group_id="scan_bundler",
|
||||
parent=self,
|
||||
name="scan_queue_consumer",
|
||||
)
|
||||
self._scan_queue_consumer.start()
|
||||
|
||||
def _start_scan_status_consumer(self):
|
||||
self._scan_status_consumer = self.connector.consumer(
|
||||
MessageEndpoints.scan_status(),
|
||||
cb=self._scan_status_callback,
|
||||
group_id="scan_bundler",
|
||||
parent=self,
|
||||
name="scan_status_consumer",
|
||||
)
|
||||
self._scan_status_consumer.start()
|
||||
|
||||
@staticmethod
|
||||
def _device_read_callback(msg, parent, **_kwargs):
|
||||
def _device_read_callback(self, msg, **_kwargs):
|
||||
# pylint: disable=protected-access
|
||||
dev = msg.topic.split(MessageEndpoints._device_read + "/")[-1]
|
||||
msgs = msg.value
|
||||
logger.debug(f"Received reading from device {dev}")
|
||||
if not isinstance(msgs, list):
|
||||
msgs = [msgs]
|
||||
task = parent.executor.submit(parent._add_device_to_storage, msgs, dev)
|
||||
parent.executor_tasks.append(task)
|
||||
task = self.executor.submit(self._add_device_to_storage, msgs, dev)
|
||||
self.executor_tasks.append(task)
|
||||
|
||||
@staticmethod
|
||||
def _scan_queue_callback(msg, parent, **_kwargs):
|
||||
def _scan_queue_callback(self, msg, **_kwargs):
|
||||
msg = msg.value
|
||||
logger.trace(msg)
|
||||
parent.current_queue = msg.content["queue"]["primary"].get("info")
|
||||
self.current_queue = msg.content["queue"]["primary"].get("info")
|
||||
|
||||
@staticmethod
|
||||
def _scan_status_callback(msg, parent, **_kwargs):
|
||||
def _scan_status_callback(self, msg, **_kwargs):
|
||||
msg = msg.value
|
||||
parent.handle_scan_status_message(msg)
|
||||
self.handle_scan_status_message(msg)
|
||||
|
||||
def handle_scan_status_message(self, msg: messages.ScanStatusMessage) -> None:
|
||||
"""handle scan status messages"""
|
||||
@ -270,7 +252,7 @@ class ScanBundler(BECService):
|
||||
}
|
||||
|
||||
def _get_scan_status_history(self, length):
|
||||
return self.producer.lrange(MessageEndpoints.scan_status() + "_list", length * -1, -1)
|
||||
return self.connector.lrange(MessageEndpoints.scan_status() + "_list", length * -1, -1)
|
||||
|
||||
def _wait_for_scanID(self, scanID, timeout_time=10):
|
||||
elapsed_time = 0
|
||||
@ -344,10 +326,10 @@ class ScanBundler(BECService):
|
||||
self.sync_storage[scanID][pointID][dev.name] = read
|
||||
|
||||
def _get_last_device_readback(self, devices: list) -> list:
|
||||
pipe = self.producer.pipeline()
|
||||
pipe = self.connector.pipeline()
|
||||
for dev in devices:
|
||||
self.producer.get(MessageEndpoints.device_readback(dev.name), pipe)
|
||||
return [msg.content["signals"] for msg in self.producer.execute_pipeline(pipe)]
|
||||
self.connector.get(MessageEndpoints.device_readback(dev.name), pipe)
|
||||
return [msg.content["signals"] for msg in self.connector.execute_pipeline(pipe)]
|
||||
|
||||
def cleanup_storage(self):
|
||||
"""remove old scanIDs to free memory"""
|
||||
|
@ -57,16 +57,16 @@ def test_send_baseline_BEC():
|
||||
sb.sync_storage[scanID] = {"info": {}, "status": "open", "sent": set()}
|
||||
sb.sync_storage[scanID]["baseline"] = {}
|
||||
msg = messages.ScanBaselineMessage(scanID=scanID, data=sb.sync_storage[scanID]["baseline"])
|
||||
with mock.patch.object(sb, "producer") as producer:
|
||||
with mock.patch.object(sb, "connector") as connector:
|
||||
bec_emitter._send_baseline(scanID)
|
||||
pipe = producer.pipeline()
|
||||
producer.set.assert_called_once_with(
|
||||
pipe = connector.pipeline()
|
||||
connector.set.assert_called_once_with(
|
||||
MessageEndpoints.public_scan_baseline(scanID),
|
||||
msg,
|
||||
expire=1800,
|
||||
pipe=pipe,
|
||||
)
|
||||
producer.set_and_publish.assert_called_once_with(
|
||||
connector.set_and_publish.assert_called_once_with(
|
||||
MessageEndpoints.scan_baseline(),
|
||||
msg,
|
||||
pipe=pipe,
|
||||
|
@ -13,7 +13,7 @@ from scan_bundler.bluesky_emitter import BlueskyEmitter
|
||||
def test_run_start_document(scanID):
|
||||
sb = load_ScanBundlerMock()
|
||||
bls_emitter = BlueskyEmitter(sb)
|
||||
with mock.patch.object(bls_emitter.producer, "raw_send") as send:
|
||||
with mock.patch.object(bls_emitter.connector, "raw_send") as send:
|
||||
with mock.patch.object(bls_emitter, "send_descriptor_document") as send_descr:
|
||||
with mock.patch.object(
|
||||
bls_emitter, "_get_run_start_document", return_value={}
|
||||
@ -45,7 +45,7 @@ def test_send_descriptor_document():
|
||||
bls_emitter = BlueskyEmitter(sb)
|
||||
scanID = "lkajsdl"
|
||||
bls_emitter.bluesky_metadata[scanID] = {}
|
||||
with mock.patch.object(bls_emitter.producer, "raw_send") as send:
|
||||
with mock.patch.object(bls_emitter.connector, "raw_send") as send:
|
||||
with mock.patch.object(
|
||||
bls_emitter, "_get_descriptor_document", return_value={}
|
||||
) as get_descr:
|
||||
|
@ -51,30 +51,30 @@ from scan_bundler.emitter import EmitterBase
|
||||
],
|
||||
)
|
||||
def test_publish_data(msgs):
|
||||
producer = mock.MagicMock()
|
||||
with mock.patch.object(EmitterBase, "_start_buffered_producer") as start:
|
||||
emitter = EmitterBase(producer)
|
||||
connector = mock.MagicMock()
|
||||
with mock.patch.object(EmitterBase, "_start_buffered_connector") as start:
|
||||
emitter = EmitterBase(connector)
|
||||
start.assert_called_once()
|
||||
with mock.patch.object(emitter, "_get_messages_from_buffer", return_value=msgs) as get_msgs:
|
||||
emitter._publish_data()
|
||||
get_msgs.assert_called_once()
|
||||
|
||||
if not msgs:
|
||||
producer.send.assert_not_called()
|
||||
connector.send.assert_not_called()
|
||||
return
|
||||
|
||||
pipe = producer.pipeline()
|
||||
pipe = connector.pipeline()
|
||||
msgs_bundle = messages.BundleMessage()
|
||||
_, endpoint, _ = msgs[0]
|
||||
for msg, endpoint, public in msgs:
|
||||
msg_dump = msg
|
||||
msgs_bundle.append(msg_dump)
|
||||
if public:
|
||||
producer.set.assert_has_calls(
|
||||
producer.set(public, msg_dump, pipe=pipe, expire=1800)
|
||||
connector.set.assert_has_calls(
|
||||
connector.set(public, msg_dump, pipe=pipe, expire=1800)
|
||||
)
|
||||
|
||||
producer.send.assert_called_with(endpoint, msgs_bundle, pipe=pipe)
|
||||
connector.send.assert_called_with(endpoint, msgs_bundle, pipe=pipe)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -93,8 +93,8 @@ def test_publish_data(msgs):
|
||||
],
|
||||
)
|
||||
def test_add_message(msg, endpoint, public):
|
||||
producer = mock.MagicMock()
|
||||
emitter = EmitterBase(producer)
|
||||
connector = mock.MagicMock()
|
||||
emitter = EmitterBase(connector)
|
||||
emitter.add_message(msg, endpoint, public)
|
||||
msgs = emitter._get_messages_from_buffer()
|
||||
out_msg, out_endpoint, out_public = msgs[0]
|
||||
|
@ -36,7 +36,7 @@ def load_ScanBundlerMock():
|
||||
service_mock = mock.MagicMock()
|
||||
service_mock.connector = ConnectorMock("")
|
||||
device_manager = ScanBundlerDeviceManagerMock(service_mock, "")
|
||||
device_manager.producer = service_mock.connector.producer()
|
||||
device_manager.connector = service_mock.connector
|
||||
with open(f"{dir_path}/tests/test_config.yaml", "r") as session_file:
|
||||
device_manager._session = create_session_from_config(yaml.safe_load(session_file))
|
||||
device_manager._load_session()
|
||||
@ -74,7 +74,7 @@ def test_device_read_callback():
|
||||
msg.topic = MessageEndpoints.device_read("samx")
|
||||
|
||||
with mock.patch.object(scan_bundler, "_add_device_to_storage") as add_dev:
|
||||
scan_bundler._device_read_callback(msg, scan_bundler)
|
||||
scan_bundler._device_read_callback(msg)
|
||||
add_dev.assert_called_once_with([dev_msg], "samx")
|
||||
|
||||
|
||||
@ -157,7 +157,7 @@ def test_wait_for_scanID(scanID, storageID, scan_msg):
|
||||
)
|
||||
def test_get_scan_status_history(msgs):
|
||||
sb = load_ScanBundlerMock()
|
||||
with mock.patch.object(sb.producer, "lrange", return_value=[msg for msg in msgs]) as lrange:
|
||||
with mock.patch.object(sb.connector, "lrange", return_value=[msg for msg in msgs]) as lrange:
|
||||
res = sb._get_scan_status_history(5)
|
||||
lrange.assert_called_once_with(MessageEndpoints.scan_status() + "_list", -5, -1)
|
||||
assert res == msgs
|
||||
@ -371,7 +371,7 @@ def test_scan_queue_callback(queue_msg):
|
||||
sb = load_ScanBundlerMock()
|
||||
msg = MessageMock()
|
||||
msg.value = queue_msg
|
||||
sb._scan_queue_callback(msg, sb)
|
||||
sb._scan_queue_callback(msg)
|
||||
assert sb.current_queue == queue_msg.content["queue"]["primary"].get("info")
|
||||
|
||||
|
||||
@ -399,7 +399,7 @@ def test_scan_status_callback(scan_msg):
|
||||
msg.value = scan_msg
|
||||
|
||||
with mock.patch.object(sb, "handle_scan_status_message") as handle_scan_status_message_mock:
|
||||
sb._scan_status_callback(msg, sb)
|
||||
sb._scan_status_callback(msg)
|
||||
handle_scan_status_message_mock.assert_called_once_with(scan_msg)
|
||||
|
||||
|
||||
@ -744,10 +744,10 @@ def test_get_last_device_readback():
|
||||
signals={"samx": {"samx": 0.51, "setpoint": 0.5, "motor_is_moving": 0}},
|
||||
metadata={"scanID": "laksjd", "readout_priority": "monitored"},
|
||||
)
|
||||
with mock.patch.object(sb, "producer") as producer_mock:
|
||||
producer_mock.execute_pipeline.return_value = [dev_msg]
|
||||
with mock.patch.object(sb, "connector") as connector_mock:
|
||||
connector_mock.execute_pipeline.return_value = [dev_msg]
|
||||
ret = sb._get_last_device_readback([sb.device_manager.devices.samx])
|
||||
assert producer_mock.get.mock_calls == [
|
||||
mock.call(MessageEndpoints.device_readback("samx"), producer_mock.pipeline())
|
||||
assert connector_mock.get.mock_calls == [
|
||||
mock.call(MessageEndpoints.device_readback("samx"), connector_mock.pipeline())
|
||||
]
|
||||
assert ret == [dev_msg.content["signals"]]
|
||||
|
@ -458,7 +458,7 @@ class LamNIFermatScan(ScanBase, LamNIMixin):
|
||||
yield from self.stubs.kickoff(device="rtx")
|
||||
while True:
|
||||
yield from self.stubs.read_and_wait(group="primary", wait_group="readout_primary")
|
||||
msg = self.device_manager.producer.get(MessageEndpoints.device_status("rt_scan"))
|
||||
msg = self.device_manager.connector.get(MessageEndpoints.device_status("rt_scan"))
|
||||
if msg:
|
||||
status = msg
|
||||
status_id = status.content.get("status", 1)
|
||||
|
@ -163,7 +163,7 @@ class OwisGrid(AsyncFlyScanBase):
|
||||
|
||||
def scan_progress(self) -> int:
|
||||
"""Timeout of the progress bar. This gets updated in the frequency of scan segments"""
|
||||
msg = self.device_manager.producer.get(MessageEndpoints.device_progress("mcs"))
|
||||
msg = self.device_manager.connector.get(MessageEndpoints.device_progress("mcs"))
|
||||
if not msg:
|
||||
self.timeout_progress += 1
|
||||
return self.timeout_progress
|
||||
|
@ -106,7 +106,7 @@ class SgalilGrid(AsyncFlyScanBase):
|
||||
|
||||
def scan_progress(self) -> int:
|
||||
"""Timeout of the progress bar. This gets updated in the frequency of scan segments"""
|
||||
msg = self.device_manager.producer.get(MessageEndpoints.device_progress("mcs"))
|
||||
msg = self.device_manager.connector.get(MessageEndpoints.device_progress("mcs"))
|
||||
if not msg:
|
||||
self.timeout_progress += 1
|
||||
return self.timeout_progress
|
||||
|
@ -10,8 +10,8 @@ class DeviceValidation:
|
||||
Mixin class for validation methods
|
||||
"""
|
||||
|
||||
def __init__(self, producer, worker):
|
||||
self.producer = producer
|
||||
def __init__(self, connector, worker):
|
||||
self.connector = connector
|
||||
self.worker = worker
|
||||
|
||||
def get_device_status(self, endpoint: MessageEndpoints, devices: list) -> list:
|
||||
@ -25,10 +25,10 @@ class DeviceValidation:
|
||||
Returns:
|
||||
list: List of BECMessage objects
|
||||
"""
|
||||
pipe = self.producer.pipeline()
|
||||
pipe = self.connector.pipeline()
|
||||
for dev in devices:
|
||||
self.producer.get(endpoint(dev), pipe)
|
||||
return self.producer.execute_pipeline(pipe)
|
||||
self.connector.get(endpoint(dev), pipe)
|
||||
return self.connector.execute_pipeline(pipe)
|
||||
|
||||
def devices_are_ready(
|
||||
self,
|
||||
|
@ -27,25 +27,19 @@ class ScanGuard:
|
||||
self.parent = parent
|
||||
self.device_manager = self.parent.device_manager
|
||||
self.connector = self.parent.connector
|
||||
self.producer = self.connector.producer()
|
||||
self._start_scan_queue_request_consumer()
|
||||
|
||||
def _start_scan_queue_request_consumer(self):
|
||||
self._scan_queue_request_consumer = self.connector.consumer(
|
||||
self.connector.register(
|
||||
MessageEndpoints.scan_queue_request(),
|
||||
cb=self._scan_queue_request_callback,
|
||||
parent=self,
|
||||
)
|
||||
|
||||
self._scan_queue_modification_request_consumer = self.connector.consumer(
|
||||
self.connector.register(
|
||||
MessageEndpoints.scan_queue_modification_request(),
|
||||
cb=self._scan_queue_modification_request_callback,
|
||||
parent=self,
|
||||
)
|
||||
|
||||
self._scan_queue_request_consumer.start()
|
||||
self._scan_queue_modification_request_consumer.start()
|
||||
|
||||
def _is_valid_scan_request(self, request) -> ScanStatus:
|
||||
try:
|
||||
self._check_valid_request(request)
|
||||
@ -63,7 +57,7 @@ class ScanGuard:
|
||||
raise ScanRejection("Invalid request.")
|
||||
|
||||
def _check_valid_scan(self, request) -> None:
|
||||
avail_scans = self.producer.get(MessageEndpoints.available_scans())
|
||||
avail_scans = self.connector.get(MessageEndpoints.available_scans())
|
||||
scan_type = request.content.get("scan_type")
|
||||
if scan_type not in avail_scans.resource:
|
||||
raise ScanRejection(f"Unknown scan type {scan_type}.")
|
||||
@ -140,7 +134,7 @@ class ScanGuard:
|
||||
message=scan_status.message,
|
||||
metadata=metadata,
|
||||
)
|
||||
self.device_manager.producer.send(sqrr, rrm)
|
||||
self.device_manager.connector.send(sqrr, rrm)
|
||||
|
||||
def _handle_scan_request(self, msg):
|
||||
"""
|
||||
@ -181,10 +175,10 @@ class ScanGuard:
|
||||
self._send_scan_request_response(ScanStatus(), mod_msg.metadata)
|
||||
|
||||
sqm = MessageEndpoints.scan_queue_modification()
|
||||
self.device_manager.producer.send(sqm, mod_msg)
|
||||
self.device_manager.connector.send(sqm, mod_msg)
|
||||
|
||||
def _append_to_scan_queue(self, msg):
|
||||
logger.info("Appending new scan to queue")
|
||||
msg = msg
|
||||
sqi = MessageEndpoints.scan_queue_insert()
|
||||
self.device_manager.producer.send(sqi, msg)
|
||||
self.device_manager.connector.send(sqi, msg)
|
||||
|
@ -101,7 +101,7 @@ class ScanManager:
|
||||
|
||||
def publish_available_scans(self):
|
||||
"""send all available scans to the broker"""
|
||||
self.parent.producer.set(
|
||||
self.parent.connector.set(
|
||||
MessageEndpoints.available_scans(),
|
||||
AvailableResourceMessage(resource=self.available_scans),
|
||||
)
|
||||
|
@ -51,11 +51,10 @@ class QueueManager:
|
||||
def __init__(self, parent) -> None:
|
||||
self.parent = parent
|
||||
self.connector = parent.connector
|
||||
self.producer = parent.producer
|
||||
self.num_queues = 1
|
||||
self.key = ""
|
||||
self.queues = {}
|
||||
self._start_scan_queue_consumer()
|
||||
self._start_scan_queue_register()
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def add_to_queue(self, scan_queue: str, msg: messages.ScanQueueMessage, position=-1) -> None:
|
||||
@ -91,17 +90,15 @@ class QueueManager:
|
||||
self.queues[queue_name] = ScanQueue(self, queue_name=queue_name)
|
||||
self.queues[queue_name].start_worker()
|
||||
|
||||
def _start_scan_queue_consumer(self) -> None:
|
||||
self._scan_queue_consumer = self.connector.consumer(
|
||||
def _start_scan_queue_register(self) -> None:
|
||||
self.connector.register(
|
||||
MessageEndpoints.scan_queue_insert(), cb=self._scan_queue_callback, parent=self
|
||||
)
|
||||
self._scan_queue_modification_consumer = self.connector.consumer(
|
||||
self.connector.register(
|
||||
MessageEndpoints.scan_queue_modification(),
|
||||
cb=self._scan_queue_modification_callback,
|
||||
parent=self,
|
||||
)
|
||||
self._scan_queue_consumer.start()
|
||||
self._scan_queue_modification_consumer.start()
|
||||
|
||||
@staticmethod
|
||||
def _scan_queue_callback(msg, parent, **_kwargs) -> None:
|
||||
@ -233,7 +230,7 @@ class QueueManager:
|
||||
logger.info("New scan queue:")
|
||||
for queue in self.describe_queue():
|
||||
logger.info(f"\n {queue}")
|
||||
self.producer.set_and_publish(
|
||||
self.connector.set_and_publish(
|
||||
MessageEndpoints.scan_queue_status(),
|
||||
messages.ScanQueueStatusMessage(queue=queue_export),
|
||||
)
|
||||
@ -685,7 +682,7 @@ class InstructionQueueItem:
|
||||
self.instructions = []
|
||||
self.parent = parent
|
||||
self.queue = RequestBlockQueue(instruction_queue=self, assembler=assembler)
|
||||
self.producer = self.parent.queue_manager.producer
|
||||
self.connector = self.parent.queue_manager.connector
|
||||
self._is_scan = False
|
||||
self.is_active = False # set to true while a worker is processing the instructions
|
||||
self.completed = False
|
||||
@ -790,7 +787,7 @@ class InstructionQueueItem:
|
||||
msg = messages.ScanQueueHistoryMessage(
|
||||
status=self.status.name, queueID=self.queue_id, info=self.describe()
|
||||
)
|
||||
self.parent.queue_manager.producer.lpush(
|
||||
self.parent.queue_manager.connector.lpush(
|
||||
MessageEndpoints.scan_queue_history(), msg, max_size=100
|
||||
)
|
||||
|
||||
|
@ -23,7 +23,6 @@ class ScanServer(BECService):
|
||||
|
||||
def __init__(self, config: ServiceConfig, connector_cls: ConnectorBase):
|
||||
super().__init__(config, connector_cls, unique_service=True)
|
||||
self.producer = self.connector.producer()
|
||||
self._start_scan_manager()
|
||||
self._start_queue_manager()
|
||||
self._start_device_manager()
|
||||
@ -52,15 +51,12 @@ class ScanServer(BECService):
|
||||
self.scan_guard = ScanGuard(parent=self)
|
||||
|
||||
def _start_alarm_handler(self):
|
||||
self._alarm_consumer = self.connector.consumer(
|
||||
MessageEndpoints.alarm(), cb=self._alarm_callback, parent=self
|
||||
)
|
||||
self._alarm_consumer.start()
|
||||
self.connector.register(MessageEndpoints.alarm(), cb=self._alarm_callback, parent=self)
|
||||
|
||||
def _reset_scan_number(self):
|
||||
if self.producer.get(MessageEndpoints.scan_number()) is None:
|
||||
if self.connector.get(MessageEndpoints.scan_number()) is None:
|
||||
self.scan_number = 1
|
||||
if self.producer.get(MessageEndpoints.dataset_number()) is None:
|
||||
if self.connector.get(MessageEndpoints.dataset_number()) is None:
|
||||
self.dataset_number = 1
|
||||
|
||||
@staticmethod
|
||||
@ -74,25 +70,24 @@ class ScanServer(BECService):
|
||||
@property
|
||||
def scan_number(self) -> int:
|
||||
"""get the current scan number"""
|
||||
return int(self.producer.get(MessageEndpoints.scan_number()))
|
||||
return int(self.connector.get(MessageEndpoints.scan_number()))
|
||||
|
||||
@scan_number.setter
|
||||
def scan_number(self, val: int):
|
||||
"""set the current scan number"""
|
||||
self.producer.set(MessageEndpoints.scan_number(), val)
|
||||
self.connector.set(MessageEndpoints.scan_number(), val)
|
||||
|
||||
@property
|
||||
def dataset_number(self) -> int:
|
||||
"""get the current dataset number"""
|
||||
return int(self.producer.get(MessageEndpoints.dataset_number()))
|
||||
return int(self.connector.get(MessageEndpoints.dataset_number()))
|
||||
|
||||
@dataset_number.setter
|
||||
def dataset_number(self, val: int):
|
||||
"""set the current dataset number"""
|
||||
self.producer.set(MessageEndpoints.dataset_number(), val)
|
||||
self.connector.set(MessageEndpoints.dataset_number(), val)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""shutdown the scan server"""
|
||||
|
||||
self.device_manager.shutdown()
|
||||
self.queue_manager.shutdown()
|
||||
|
@ -5,7 +5,9 @@ import uuid
|
||||
from collections.abc import Callable
|
||||
|
||||
import numpy as np
|
||||
from bec_lib import MessageEndpoints, ProducerConnector, Status, bec_logger, messages
|
||||
|
||||
from bec_lib import MessageEndpoints, Status, bec_logger, messages
|
||||
from bec_lib.connector import ConnectorBase
|
||||
|
||||
from .errors import DeviceMessageError, ScanAbortion
|
||||
|
||||
@ -13,8 +15,8 @@ logger = bec_logger.logger
|
||||
|
||||
|
||||
class ScanStubs:
|
||||
def __init__(self, producer: ProducerConnector, device_msg_callback: Callable = None) -> None:
|
||||
self.producer = producer
|
||||
def __init__(self, connector: ConnectorBase, device_msg_callback: Callable = None) -> None:
|
||||
self.connector = connector
|
||||
self.device_msg_metadata = (
|
||||
device_msg_callback if device_msg_callback is not None else lambda: {}
|
||||
)
|
||||
@ -62,7 +64,7 @@ class ScanStubs:
|
||||
|
||||
def _get_from_rpc(self, rpc_id):
|
||||
while True:
|
||||
msg = self.producer.get(MessageEndpoints.device_rpc(rpc_id))
|
||||
msg = self.connector.get(MessageEndpoints.device_rpc(rpc_id))
|
||||
if msg:
|
||||
break
|
||||
time.sleep(0.001)
|
||||
@ -81,7 +83,7 @@ class ScanStubs:
|
||||
if not isinstance(return_val, dict):
|
||||
return return_val
|
||||
if return_val.get("type") == "status" and return_val.get("RID"):
|
||||
return Status(self.producer, return_val.get("RID"))
|
||||
return Status(self.connector, return_val.get("RID"))
|
||||
return return_val
|
||||
|
||||
def set_and_wait(self, *, device: list[str], positions: list | np.ndarray):
|
||||
@ -182,7 +184,7 @@ class ScanStubs:
|
||||
DIID (int): device instruction ID
|
||||
|
||||
"""
|
||||
msg = self.producer.get(MessageEndpoints.device_req_status(device))
|
||||
msg = self.connector.get(MessageEndpoints.device_req_status(device))
|
||||
if not msg:
|
||||
return 0
|
||||
matching_RID = msg.metadata.get("RID") == RID
|
||||
@ -199,7 +201,7 @@ class ScanStubs:
|
||||
RID (str): request ID
|
||||
|
||||
"""
|
||||
msg = self.producer.get(MessageEndpoints.device_progress(device))
|
||||
msg = self.connector.get(MessageEndpoints.device_progress(device))
|
||||
if not msg:
|
||||
return None
|
||||
matching_RID = msg.metadata.get("RID") == RID
|
||||
|
@ -41,7 +41,7 @@ class ScanWorker(threading.Thread):
|
||||
self._groups = {}
|
||||
self.interception_msg = None
|
||||
self.reset()
|
||||
self.validate = DeviceValidation(self.device_manager.producer, self)
|
||||
self.validate = DeviceValidation(self.device_manager.connector, self)
|
||||
|
||||
def open_scan(self, instr: messages.DeviceInstructionMessage) -> None:
|
||||
"""
|
||||
@ -138,7 +138,7 @@ class ScanWorker(threading.Thread):
|
||||
"""
|
||||
devices = [dev.name for dev in self.device_manager.devices.get_software_triggered_devices()]
|
||||
self._last_trigger = instr
|
||||
self.device_manager.producer.send(
|
||||
self.device_manager.connector.send(
|
||||
MessageEndpoints.device_instructions(),
|
||||
messages.DeviceInstructionMessage(
|
||||
device=devices,
|
||||
@ -157,7 +157,7 @@ class ScanWorker(threading.Thread):
|
||||
"""
|
||||
|
||||
# send instruction
|
||||
self.device_manager.producer.send(MessageEndpoints.device_instructions(), instr)
|
||||
self.device_manager.connector.send(MessageEndpoints.device_instructions(), instr)
|
||||
|
||||
def read_devices(self, instr: messages.DeviceInstructionMessage) -> None:
|
||||
"""
|
||||
@ -171,7 +171,7 @@ class ScanWorker(threading.Thread):
|
||||
self._publish_readback(instr)
|
||||
return
|
||||
|
||||
producer = self.device_manager.producer
|
||||
connector = self.device_manager.connector
|
||||
|
||||
devices = instr.content.get("device")
|
||||
if devices is None:
|
||||
@ -181,7 +181,7 @@ class ScanWorker(threading.Thread):
|
||||
readout_priority=self.readout_priority
|
||||
)
|
||||
]
|
||||
producer.send(
|
||||
connector.send(
|
||||
MessageEndpoints.device_instructions(),
|
||||
messages.DeviceInstructionMessage(
|
||||
device=devices,
|
||||
@ -201,7 +201,7 @@ class ScanWorker(threading.Thread):
|
||||
|
||||
"""
|
||||
# logger.info("kickoff")
|
||||
self.device_manager.producer.send(
|
||||
self.device_manager.connector.send(
|
||||
MessageEndpoints.device_instructions(),
|
||||
messages.DeviceInstructionMessage(
|
||||
device=instr.content.get("device"),
|
||||
@ -225,7 +225,7 @@ class ScanWorker(threading.Thread):
|
||||
devices = instr.content.get("device")
|
||||
if not isinstance(devices, list):
|
||||
devices = [devices]
|
||||
self.device_manager.producer.send(
|
||||
self.device_manager.connector.send(
|
||||
MessageEndpoints.device_instructions(),
|
||||
messages.DeviceInstructionMessage(
|
||||
device=devices,
|
||||
@ -251,7 +251,7 @@ class ScanWorker(threading.Thread):
|
||||
)
|
||||
]
|
||||
params = instr.content["parameter"]
|
||||
self.device_manager.producer.send(
|
||||
self.device_manager.connector.send(
|
||||
MessageEndpoints.device_instructions(),
|
||||
messages.DeviceInstructionMessage(
|
||||
device=baseline_devices, action="read", parameter=params, metadata=instr.metadata
|
||||
@ -266,7 +266,7 @@ class ScanWorker(threading.Thread):
|
||||
instr (DeviceInstructionMessage): Device instruction received from the scan assembler
|
||||
"""
|
||||
devices = [dev.name for dev in self.device_manager.devices.enabled_devices]
|
||||
self.device_manager.producer.send(
|
||||
self.device_manager.connector.send(
|
||||
MessageEndpoints.device_instructions(),
|
||||
messages.DeviceInstructionMessage(
|
||||
device=devices,
|
||||
@ -285,7 +285,7 @@ class ScanWorker(threading.Thread):
|
||||
Args:
|
||||
instr (DeviceInstructionMessage): Device instruction received from the scan assembler
|
||||
"""
|
||||
producer = self.device_manager.producer
|
||||
connector = self.device_manager.connector
|
||||
data = instr.content["parameter"]["data"]
|
||||
devices = instr.content["device"]
|
||||
if not isinstance(devices, list):
|
||||
@ -294,7 +294,7 @@ class ScanWorker(threading.Thread):
|
||||
data = [data]
|
||||
for device, dev_data in zip(devices, data):
|
||||
msg = messages.DeviceMessage(signals=dev_data, metadata=instr.metadata)
|
||||
producer.set_and_publish(MessageEndpoints.device_read(device), msg)
|
||||
connector.set_and_publish(MessageEndpoints.device_read(device), msg)
|
||||
|
||||
def send_rpc(self, instr: messages.DeviceInstructionMessage) -> None:
|
||||
"""
|
||||
@ -304,7 +304,7 @@ class ScanWorker(threading.Thread):
|
||||
instr (DeviceInstructionMessage): Device instruction received from the scan assembler
|
||||
|
||||
"""
|
||||
self.device_manager.producer.send(MessageEndpoints.device_instructions(), instr)
|
||||
self.device_manager.connector.send(MessageEndpoints.device_instructions(), instr)
|
||||
|
||||
def process_scan_report_instruction(self, instr):
|
||||
"""
|
||||
@ -333,7 +333,7 @@ class ScanWorker(threading.Thread):
|
||||
if dev.name not in async_devices
|
||||
]
|
||||
for det in async_devices:
|
||||
self.device_manager.producer.send(
|
||||
self.device_manager.connector.send(
|
||||
MessageEndpoints.device_instructions(),
|
||||
messages.DeviceInstructionMessage(
|
||||
device=det,
|
||||
@ -344,7 +344,7 @@ class ScanWorker(threading.Thread):
|
||||
)
|
||||
self._staged_devices.update(async_devices)
|
||||
|
||||
self.device_manager.producer.send(
|
||||
self.device_manager.connector.send(
|
||||
MessageEndpoints.device_instructions(),
|
||||
messages.DeviceInstructionMessage(
|
||||
device=devices,
|
||||
@ -375,7 +375,7 @@ class ScanWorker(threading.Thread):
|
||||
parameter = {} if not instr else instr.content["parameter"]
|
||||
metadata = {} if not instr else instr.metadata
|
||||
self._staged_devices.difference_update(devices)
|
||||
self.device_manager.producer.send(
|
||||
self.device_manager.connector.send(
|
||||
MessageEndpoints.device_instructions(),
|
||||
messages.DeviceInstructionMessage(
|
||||
device=devices, action="unstage", parameter=parameter, metadata=metadata
|
||||
@ -459,7 +459,7 @@ class ScanWorker(threading.Thread):
|
||||
matching_DIID = device_status[ind].metadata.get("DIID") >= devices[ind][1]
|
||||
matching_RID = device_status[ind].metadata.get("RID") == instr.metadata["RID"]
|
||||
if matching_DIID and matching_RID:
|
||||
last_pos_msg = self.device_manager.producer.get(
|
||||
last_pos_msg = self.device_manager.connector.get(
|
||||
MessageEndpoints.device_readback(failed_device[0])
|
||||
)
|
||||
last_pos = last_pos_msg.content["signals"][failed_device[0]]["value"]
|
||||
@ -603,25 +603,25 @@ class ScanWorker(threading.Thread):
|
||||
def _publish_readback(
|
||||
self, instr: messages.DeviceInstructionMessage, devices: list = None
|
||||
) -> None:
|
||||
producer = self.device_manager.producer
|
||||
connector = self.device_manager.connector
|
||||
if not devices:
|
||||
devices = instr.content.get("device")
|
||||
|
||||
# cached readout
|
||||
readouts = self._get_readback(devices)
|
||||
pipe = producer.pipeline()
|
||||
pipe = connector.pipeline()
|
||||
for readout, device in zip(readouts, devices):
|
||||
msg = messages.DeviceMessage(signals=readout, metadata=instr.metadata)
|
||||
producer.set_and_publish(MessageEndpoints.device_read(device), msg, pipe)
|
||||
connector.set_and_publish(MessageEndpoints.device_read(device), msg, pipe)
|
||||
return pipe.execute()
|
||||
|
||||
def _get_readback(self, devices: list) -> list:
|
||||
producer = self.device_manager.producer
|
||||
connector = self.device_manager.connector
|
||||
# cached readout
|
||||
pipe = producer.pipeline()
|
||||
pipe = connector.pipeline()
|
||||
for dev in devices:
|
||||
producer.get(MessageEndpoints.device_readback(dev), pipe=pipe)
|
||||
return producer.execute_pipeline(pipe)
|
||||
connector.get(MessageEndpoints.device_readback(dev), pipe=pipe)
|
||||
return connector.execute_pipeline(pipe)
|
||||
|
||||
def _check_for_interruption(self) -> None:
|
||||
if self.status == InstructionQueueStatus.PAUSED:
|
||||
@ -700,11 +700,13 @@ class ScanWorker(threading.Thread):
|
||||
scanID=self.current_scanID, status=status, info=self.current_scan_info
|
||||
)
|
||||
expire = None if status in ["open", "paused"] else 1800
|
||||
pipe = self.device_manager.producer.pipeline()
|
||||
self.device_manager.producer.set(
|
||||
pipe = self.device_manager.connector.pipeline()
|
||||
self.device_manager.connector.set(
|
||||
MessageEndpoints.public_scan_info(self.current_scanID), msg, pipe=pipe, expire=expire
|
||||
)
|
||||
self.device_manager.producer.set_and_publish(MessageEndpoints.scan_status(), msg, pipe=pipe)
|
||||
self.device_manager.connector.set_and_publish(
|
||||
MessageEndpoints.scan_status(), msg, pipe=pipe
|
||||
)
|
||||
pipe.execute()
|
||||
|
||||
def _process_instructions(self, queue: InstructionQueueItem) -> None:
|
||||
|
@ -213,7 +213,7 @@ class RequestBase(ABC):
|
||||
if metadata is None:
|
||||
self.metadata = {}
|
||||
self.stubs = ScanStubs(
|
||||
producer=self.device_manager.producer, device_msg_callback=self.device_msg_metadata
|
||||
connector=self.device_manager.connector, device_msg_callback=self.device_msg_metadata
|
||||
)
|
||||
|
||||
@property
|
||||
@ -239,7 +239,7 @@ class RequestBase(ABC):
|
||||
|
||||
def run_pre_scan_macros(self):
|
||||
"""run pre scan macros if any"""
|
||||
macros = self.device_manager.producer.lrange(MessageEndpoints.pre_scan_macros(), 0, -1)
|
||||
macros = self.device_manager.connector.lrange(MessageEndpoints.pre_scan_macros(), 0, -1)
|
||||
for macro in macros:
|
||||
macro = macro.decode().strip()
|
||||
func_name = self._get_func_name_from_macro(macro)
|
||||
@ -558,12 +558,12 @@ class SyncFlyScanBase(ScanBase, ABC):
|
||||
|
||||
def _get_flyer_status(self) -> list:
|
||||
flyer = self.scan_motors[0]
|
||||
producer = self.device_manager.producer
|
||||
connector = self.device_manager.connector
|
||||
|
||||
pipe = producer.pipeline()
|
||||
producer.lrange(MessageEndpoints.device_req_status(self.metadata["RID"]), 0, -1, pipe)
|
||||
producer.get(MessageEndpoints.device_readback(flyer), pipe)
|
||||
return producer.execute_pipeline(pipe)
|
||||
pipe = connector.pipeline()
|
||||
connector.lrange(MessageEndpoints.device_req_status(self.metadata["RID"]), 0, -1, pipe)
|
||||
connector.get(MessageEndpoints.device_readback(flyer), pipe)
|
||||
return connector.execute_pipeline(pipe)
|
||||
|
||||
@abstractmethod
|
||||
def scan_core(self):
|
||||
@ -1098,7 +1098,7 @@ class RoundScanFlySim(SyncFlyScanBase):
|
||||
|
||||
while True:
|
||||
yield from self.stubs.read_and_wait(group="primary", wait_group="readout_primary")
|
||||
status = self.device_manager.producer.get(MessageEndpoints.device_status(self.flyer))
|
||||
status = self.device_manager.connector.get(MessageEndpoints.device_status(self.flyer))
|
||||
if status:
|
||||
device_is_idle = status.content.get("status", 1) == 0
|
||||
matching_RID = self.metadata.get("RID") == status.metadata.get("RID")
|
||||
@ -1318,12 +1318,12 @@ class MonitorScan(ScanBase):
|
||||
self._check_limits()
|
||||
|
||||
def _get_flyer_status(self) -> list:
|
||||
producer = self.device_manager.producer
|
||||
connector = self.device_manager.connector
|
||||
|
||||
pipe = producer.pipeline()
|
||||
producer.lrange(MessageEndpoints.device_req_status(self.metadata["RID"]), 0, -1, pipe)
|
||||
producer.get(MessageEndpoints.device_readback(self.flyer), pipe)
|
||||
return producer.execute_pipeline(pipe)
|
||||
pipe = connector.pipeline()
|
||||
connector.lrange(MessageEndpoints.device_req_status(self.metadata["RID"]), 0, -1, pipe)
|
||||
connector.get(MessageEndpoints.device_readback(self.flyer), pipe)
|
||||
return connector.execute_pipeline(pipe)
|
||||
|
||||
def scan_core(self):
|
||||
yield from self.stubs.set(
|
||||
|
@ -12,7 +12,7 @@ from scan_server.scan_guard import ScanGuard, ScanRejection, ScanStatus
|
||||
@pytest.fixture
|
||||
def scan_guard_mock(scan_server_mock):
|
||||
sg = ScanGuard(parent=scan_server_mock)
|
||||
sg.device_manager.producer = mock.MagicMock()
|
||||
sg.device_manager.connector = mock.MagicMock()
|
||||
yield sg
|
||||
|
||||
|
||||
@ -113,8 +113,8 @@ def test_valid_request(scan_server_mock, scan_queue_msg, valid):
|
||||
|
||||
def test_check_valid_scan_raises_for_unknown_scan(scan_guard_mock):
|
||||
sg = scan_guard_mock
|
||||
sg.producer = mock.MagicMock()
|
||||
sg.producer.get.return_value = messages.AvailableResourceMessage(
|
||||
sg.connector = mock.MagicMock()
|
||||
sg.connector.get.return_value = messages.AvailableResourceMessage(
|
||||
resource={"fermat_scan": "fermat_scan"}
|
||||
)
|
||||
|
||||
@ -130,8 +130,8 @@ def test_check_valid_scan_raises_for_unknown_scan(scan_guard_mock):
|
||||
|
||||
def test_check_valid_scan_accepts_known_scan(scan_guard_mock):
|
||||
sg = scan_guard_mock
|
||||
sg.producer = mock.MagicMock()
|
||||
sg.producer.get.return_value = messages.AvailableResourceMessage(
|
||||
sg.connector = mock.MagicMock()
|
||||
sg.connector.get.return_value = messages.AvailableResourceMessage(
|
||||
resource={"fermat_scan": "fermat_scan"}
|
||||
)
|
||||
|
||||
@ -146,8 +146,8 @@ def test_check_valid_scan_accepts_known_scan(scan_guard_mock):
|
||||
|
||||
def test_check_valid_scan_device_rpc(scan_guard_mock):
|
||||
sg = scan_guard_mock
|
||||
sg.producer = mock.MagicMock()
|
||||
sg.producer.get.return_value = messages.AvailableResourceMessage(
|
||||
sg.connector = mock.MagicMock()
|
||||
sg.connector.get.return_value = messages.AvailableResourceMessage(
|
||||
resource={"device_rpc": "device_rpc"}
|
||||
)
|
||||
request = messages.ScanQueueMessage(
|
||||
@ -162,8 +162,8 @@ def test_check_valid_scan_device_rpc(scan_guard_mock):
|
||||
|
||||
def test_check_valid_scan_device_rpc_raises(scan_guard_mock):
|
||||
sg = scan_guard_mock
|
||||
sg.producer = mock.MagicMock()
|
||||
sg.producer.get.return_value = messages.AvailableResourceMessage(
|
||||
sg.connector = mock.MagicMock()
|
||||
sg.connector.get.return_value = messages.AvailableResourceMessage(
|
||||
resource={"device_rpc": "device_rpc"}
|
||||
)
|
||||
request = messages.ScanQueueMessage(
|
||||
@ -184,7 +184,7 @@ def test_handle_scan_modification_request(scan_guard_mock):
|
||||
msg = messages.ScanQueueModificationMessage(
|
||||
scanID="scanID", action="abort", parameter={}, metadata={"RID": "RID"}
|
||||
)
|
||||
with mock.patch.object(sg.device_manager.producer, "send") as send:
|
||||
with mock.patch.object(sg.device_manager.connector, "send") as send:
|
||||
sg._handle_scan_modification_request(msg)
|
||||
send.assert_called_once_with(MessageEndpoints.scan_queue_modification(), msg)
|
||||
|
||||
@ -207,7 +207,7 @@ def test_append_to_scan_queue(scan_guard_mock):
|
||||
parameter={"args": {"samx": (-5, 5), "samy": (-5, 5)}, "kwargs": {"step": 3}},
|
||||
queue="primary",
|
||||
)
|
||||
with mock.patch.object(sg.device_manager.producer, "send") as send:
|
||||
with mock.patch.object(sg.device_manager.connector, "send") as send:
|
||||
sg._append_to_scan_queue(msg)
|
||||
send.assert_called_once_with(MessageEndpoints.scan_queue_insert(), msg)
|
||||
|
||||
@ -251,7 +251,7 @@ def test_scan_queue_modification_request_callback(scan_guard_mock):
|
||||
|
||||
def test_send_scan_request_response(scan_guard_mock):
|
||||
sg = scan_guard_mock
|
||||
with mock.patch.object(sg.device_manager.producer, "send") as send:
|
||||
with mock.patch.object(sg.device_manager.connector, "send") as send:
|
||||
sg._send_scan_request_response(ScanStatus(), {"RID": "RID"})
|
||||
send.assert_called_once_with(
|
||||
MessageEndpoints.scan_queue_request_response(),
|
||||
|
@ -166,40 +166,40 @@ def test_set_halt_disables_return_to_start(queuemanager_mock):
|
||||
|
||||
def test_set_pause(queuemanager_mock):
|
||||
queue_manager = queuemanager_mock()
|
||||
queue_manager.producer.message_sent = []
|
||||
queue_manager.connector.message_sent = []
|
||||
queue_manager.set_pause(queue="primary")
|
||||
assert queue_manager.queues["primary"].status == ScanQueueStatus.PAUSED
|
||||
assert len(queue_manager.producer.message_sent) == 1
|
||||
assert len(queue_manager.connector.message_sent) == 1
|
||||
assert (
|
||||
queue_manager.producer.message_sent[0].get("queue") == MessageEndpoints.scan_queue_status()
|
||||
queue_manager.connector.message_sent[0].get("queue") == MessageEndpoints.scan_queue_status()
|
||||
)
|
||||
|
||||
|
||||
def test_set_deferred_pause(queuemanager_mock):
|
||||
queue_manager = queuemanager_mock()
|
||||
queue_manager.producer.message_sent = []
|
||||
queue_manager.connector.message_sent = []
|
||||
queue_manager.set_deferred_pause(queue="primary")
|
||||
assert queue_manager.queues["primary"].status == ScanQueueStatus.PAUSED
|
||||
assert len(queue_manager.producer.message_sent) == 1
|
||||
assert len(queue_manager.connector.message_sent) == 1
|
||||
assert (
|
||||
queue_manager.producer.message_sent[0].get("queue") == MessageEndpoints.scan_queue_status()
|
||||
queue_manager.connector.message_sent[0].get("queue") == MessageEndpoints.scan_queue_status()
|
||||
)
|
||||
|
||||
|
||||
def test_set_continue(queuemanager_mock):
|
||||
queue_manager = queuemanager_mock()
|
||||
queue_manager.producer.message_sent = []
|
||||
queue_manager.connector.message_sent = []
|
||||
queue_manager.set_continue(queue="primary")
|
||||
assert queue_manager.queues["primary"].status == ScanQueueStatus.RUNNING
|
||||
assert len(queue_manager.producer.message_sent) == 1
|
||||
assert len(queue_manager.connector.message_sent) == 1
|
||||
assert (
|
||||
queue_manager.producer.message_sent[0].get("queue") == MessageEndpoints.scan_queue_status()
|
||||
queue_manager.connector.message_sent[0].get("queue") == MessageEndpoints.scan_queue_status()
|
||||
)
|
||||
|
||||
|
||||
def test_set_abort(queuemanager_mock):
|
||||
queue_manager = queuemanager_mock()
|
||||
queue_manager.producer.message_sent = []
|
||||
queue_manager.connector.message_sent = []
|
||||
msg = messages.ScanQueueMessage(
|
||||
scan_type="mv",
|
||||
parameter={"args": {"samx": (1,)}, "kwargs": {}},
|
||||
@ -210,23 +210,23 @@ def test_set_abort(queuemanager_mock):
|
||||
queue_manager.add_to_queue(scan_queue="primary", msg=msg)
|
||||
queue_manager.set_abort(queue="primary")
|
||||
assert queue_manager.queues["primary"].status == ScanQueueStatus.PAUSED
|
||||
assert len(queue_manager.producer.message_sent) == 2
|
||||
assert len(queue_manager.connector.message_sent) == 2
|
||||
assert (
|
||||
queue_manager.producer.message_sent[0].get("queue") == MessageEndpoints.scan_queue_status()
|
||||
queue_manager.connector.message_sent[0].get("queue") == MessageEndpoints.scan_queue_status()
|
||||
)
|
||||
|
||||
|
||||
def test_set_abort_with_empty_queue(queuemanager_mock):
|
||||
queue_manager = queuemanager_mock()
|
||||
queue_manager.producer.message_sent = []
|
||||
queue_manager.connector.message_sent = []
|
||||
queue_manager.set_abort(queue="primary")
|
||||
assert queue_manager.queues["primary"].status == ScanQueueStatus.RUNNING
|
||||
assert len(queue_manager.producer.message_sent) == 0
|
||||
assert len(queue_manager.connector.message_sent) == 0
|
||||
|
||||
|
||||
def test_set_clear_sends_message(queuemanager_mock):
|
||||
queue_manager = queuemanager_mock()
|
||||
queue_manager.producer.message_sent = []
|
||||
queue_manager.connector.message_sent = []
|
||||
setter_mock = mock.Mock(wraps=ScanQueue.worker_status.fset)
|
||||
# pylint: disable=assignment-from-no-return
|
||||
# pylint: disable=too-many-function-args
|
||||
@ -238,9 +238,9 @@ def test_set_clear_sends_message(queuemanager_mock):
|
||||
mock_property.fset.assert_called_once_with(
|
||||
queue_manager.queues["primary"], InstructionQueueStatus.STOPPED
|
||||
)
|
||||
assert len(queue_manager.producer.message_sent) == 1
|
||||
assert len(queue_manager.connector.message_sent) == 1
|
||||
assert (
|
||||
queue_manager.producer.message_sent[0].get("queue")
|
||||
queue_manager.connector.message_sent[0].get("queue")
|
||||
== MessageEndpoints.scan_queue_status()
|
||||
)
|
||||
|
||||
|
@ -11,7 +11,7 @@ from scan_server.scan_stubs import ScanAbortion, ScanStubs
|
||||
@pytest.fixture
|
||||
def stubs():
|
||||
connector = ConnectorMock("")
|
||||
yield ScanStubs(connector.producer())
|
||||
yield ScanStubs(connector)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -36,7 +36,11 @@ def stubs():
|
||||
device="rtx",
|
||||
action="kickoff",
|
||||
parameter={
|
||||
"configure": {"num_pos": 5, "positions": [1, 2, 3, 4, 5], "exp_time": 2},
|
||||
"configure": {
|
||||
"num_pos": 5,
|
||||
"positions": [1, 2, 3, 4, 5],
|
||||
"exp_time": 2,
|
||||
},
|
||||
"wait_group": "kickoff",
|
||||
},
|
||||
metadata={},
|
||||
@ -45,6 +49,8 @@ def stubs():
|
||||
],
|
||||
)
|
||||
def test_kickoff(stubs, device, parameter, metadata, reference_msg):
|
||||
connector = ConnectorMock("")
|
||||
stubs = ScanStubs(connector)
|
||||
msg = list(stubs.kickoff(device=device, parameter=parameter, metadata=metadata))
|
||||
assert msg[0] == reference_msg
|
||||
|
||||
@ -52,12 +58,19 @@ def test_kickoff(stubs, device, parameter, metadata, reference_msg):
|
||||
@pytest.mark.parametrize(
|
||||
"msg,raised_error",
|
||||
[
|
||||
(messages.DeviceRPCMessage(device="samx", return_val="", out="", success=True), None),
|
||||
(
|
||||
messages.DeviceRPCMessage(device="samx", return_val="", out="", success=True),
|
||||
None,
|
||||
),
|
||||
(
|
||||
messages.DeviceRPCMessage(
|
||||
device="samx",
|
||||
return_val="",
|
||||
out={"error": "TypeError", "msg": "some weird error", "traceback": "traceback"},
|
||||
out={
|
||||
"error": "TypeError",
|
||||
"msg": "some weird error",
|
||||
"traceback": "traceback",
|
||||
},
|
||||
success=False,
|
||||
),
|
||||
ScanAbortion,
|
||||
@ -69,8 +82,7 @@ def test_kickoff(stubs, device, parameter, metadata, reference_msg):
|
||||
],
|
||||
)
|
||||
def test_rpc_raises_scan_abortion(stubs, msg, raised_error):
|
||||
msg = msg
|
||||
with mock.patch.object(stubs.producer, "get", return_value=msg) as prod_get:
|
||||
with mock.patch.object(stubs.connector, "get", return_value=msg) as prod_get:
|
||||
if raised_error is None:
|
||||
stubs._get_from_rpc("rpc-id")
|
||||
else:
|
||||
@ -106,8 +118,8 @@ def test_rpc_raises_scan_abortion(stubs, msg, raised_error):
|
||||
def test_device_progress(stubs, msg, ret_value, raised_error):
|
||||
if raised_error:
|
||||
with pytest.raises(DeviceMessageError):
|
||||
with mock.patch.object(stubs.producer, "get", return_value=msg):
|
||||
with mock.patch.object(stubs.connector, "get", return_value=msg):
|
||||
assert stubs.get_device_progress(device="samx", RID="rid") == ret_value
|
||||
return
|
||||
with mock.patch.object(stubs.producer, "get", return_value=msg):
|
||||
with mock.patch.object(stubs.connector, "get", return_value=msg):
|
||||
assert stubs.get_device_progress(device="samx", RID="rid") == ret_value
|
||||
|
@ -4,7 +4,7 @@ from unittest import mock
|
||||
|
||||
import pytest
|
||||
from bec_lib import MessageEndpoints, messages
|
||||
from bec_lib.tests.utils import ProducerMock, dm, dm_with_devices
|
||||
from bec_lib.tests.utils import ConnectorMock, dm, dm_with_devices
|
||||
from utils import scan_server_mock
|
||||
|
||||
from scan_server.errors import DeviceMessageError, ScanAbortion
|
||||
@ -22,7 +22,7 @@ from scan_server.scan_worker import ScanWorker
|
||||
|
||||
@pytest.fixture
|
||||
def scan_worker_mock(scan_server_mock) -> ScanWorker:
|
||||
scan_server_mock.device_manager.producer = mock.MagicMock()
|
||||
scan_server_mock.device_manager.connector = mock.MagicMock()
|
||||
scan_worker = ScanWorker(parent=scan_server_mock)
|
||||
yield scan_worker
|
||||
|
||||
@ -295,7 +295,7 @@ def test_wait_for_devices(scan_worker_mock, instructions, wait_type):
|
||||
def test_complete_devices(scan_worker_mock, instructions):
|
||||
worker = scan_worker_mock
|
||||
with mock.patch.object(worker, "_wait_for_status") as wait_for_status_mock:
|
||||
with mock.patch.object(worker.device_manager.producer, "send") as send_mock:
|
||||
with mock.patch.object(worker.device_manager.connector, "send") as send_mock:
|
||||
worker.complete_devices(instructions)
|
||||
if instructions.content["device"]:
|
||||
devices = instructions.content["device"]
|
||||
@ -328,7 +328,7 @@ def test_complete_devices(scan_worker_mock, instructions):
|
||||
)
|
||||
def test_pre_scan(scan_worker_mock, instructions):
|
||||
worker = scan_worker_mock
|
||||
with mock.patch.object(worker.device_manager.producer, "send") as send_mock:
|
||||
with mock.patch.object(worker.device_manager.connector, "send") as send_mock:
|
||||
with mock.patch.object(worker, "_wait_for_status") as wait_for_status_mock:
|
||||
worker.pre_scan(instructions)
|
||||
devices = [dev.name for dev in worker.device_manager.devices.enabled_devices]
|
||||
@ -457,12 +457,12 @@ def test_pre_scan(scan_worker_mock, instructions):
|
||||
)
|
||||
def test_check_for_failed_movements(scan_worker_mock, device_status, devices, instr, abort):
|
||||
worker = scan_worker_mock
|
||||
worker.device_manager.producer = ProducerMock()
|
||||
worker.device_manager.connector = ConnectorMock()
|
||||
if abort:
|
||||
with pytest.raises(ScanAbortion):
|
||||
worker.device_manager.producer._get_buffer[MessageEndpoints.device_readback("samx")] = (
|
||||
messages.DeviceMessage(signals={"samx": {"value": 4}}, metadata={})
|
||||
)
|
||||
worker.device_manager.connector._get_buffer[
|
||||
MessageEndpoints.device_readback("samx")
|
||||
] = messages.DeviceMessage(signals={"samx": {"value": 4}}, metadata={})
|
||||
worker._check_for_failed_movements(device_status, devices, instr)
|
||||
else:
|
||||
worker._check_for_failed_movements(device_status, devices, instr)
|
||||
@ -577,12 +577,12 @@ def test_check_for_failed_movements(scan_worker_mock, device_status, devices, in
|
||||
)
|
||||
def test_wait_for_idle(scan_worker_mock, msg1, msg2, req_msg: messages.DeviceReqStatusMessage):
|
||||
worker = scan_worker_mock
|
||||
worker.device_manager.producer = ProducerMock()
|
||||
worker.device_manager.connector = ConnectorMock()
|
||||
|
||||
with mock.patch.object(
|
||||
worker.validate, "get_device_status", return_value=[req_msg]
|
||||
) as device_status:
|
||||
worker.device_manager.producer._get_buffer[MessageEndpoints.device_readback("samx")] = (
|
||||
worker.device_manager.connector._get_buffer[MessageEndpoints.device_readback("samx")] = (
|
||||
messages.DeviceMessage(signals={"samx": {"value": 4}}, metadata={})
|
||||
)
|
||||
|
||||
@ -635,7 +635,7 @@ def test_wait_for_idle(scan_worker_mock, msg1, msg2, req_msg: messages.DeviceReq
|
||||
)
|
||||
def test_wait_for_read(scan_worker_mock, msg1, msg2, req_msg: messages.DeviceReqStatusMessage):
|
||||
worker = scan_worker_mock
|
||||
worker.device_manager.producer = ProducerMock()
|
||||
worker.device_manager.connector = ConnectorMock()
|
||||
|
||||
with mock.patch.object(
|
||||
worker.validate, "get_device_status", return_value=[req_msg]
|
||||
@ -643,9 +643,9 @@ def test_wait_for_read(scan_worker_mock, msg1, msg2, req_msg: messages.DeviceReq
|
||||
with mock.patch.object(worker, "_check_for_interruption") as interruption_mock:
|
||||
assert worker._groups == {}
|
||||
worker._groups["scan_motor"] = {"samx": 3, "samy": 4}
|
||||
worker.device_manager.producer._get_buffer[MessageEndpoints.device_readback("samx")] = (
|
||||
messages.DeviceMessage(signals={"samx": {"value": 4}}, metadata={})
|
||||
)
|
||||
worker.device_manager.connector._get_buffer[
|
||||
MessageEndpoints.device_readback("samx")
|
||||
] = messages.DeviceMessage(signals={"samx": {"value": 4}}, metadata={})
|
||||
worker._add_wait_group(msg1)
|
||||
worker._wait_for_read(msg2)
|
||||
assert worker._groups == {"scan_motor": {"samy": 4}}
|
||||
@ -730,7 +730,7 @@ def test_wait_for_device_server(scan_worker_mock):
|
||||
)
|
||||
def test_set_devices(scan_worker_mock, instr):
|
||||
worker = scan_worker_mock
|
||||
with mock.patch.object(worker.device_manager.producer, "send") as send_mock:
|
||||
with mock.patch.object(worker.device_manager.connector, "send") as send_mock:
|
||||
worker.set_devices(instr)
|
||||
send_mock.assert_called_once_with(MessageEndpoints.device_instructions(), instr)
|
||||
|
||||
@ -755,7 +755,7 @@ def test_set_devices(scan_worker_mock, instr):
|
||||
)
|
||||
def test_trigger_devices(scan_worker_mock, instr):
|
||||
worker = scan_worker_mock
|
||||
with mock.patch.object(worker.device_manager.producer, "send") as send_mock:
|
||||
with mock.patch.object(worker.device_manager.connector, "send") as send_mock:
|
||||
worker.trigger_devices(instr)
|
||||
devices = [
|
||||
dev.name for dev in worker.device_manager.devices.get_software_triggered_devices()
|
||||
@ -797,7 +797,7 @@ def test_trigger_devices(scan_worker_mock, instr):
|
||||
)
|
||||
def test_send_rpc(scan_worker_mock, instr):
|
||||
worker = scan_worker_mock
|
||||
with mock.patch.object(worker.device_manager.producer, "send") as send_mock:
|
||||
with mock.patch.object(worker.device_manager.connector, "send") as send_mock:
|
||||
worker.send_rpc(instr)
|
||||
send_mock.assert_called_once_with(MessageEndpoints.device_instructions(), instr)
|
||||
|
||||
@ -840,7 +840,7 @@ def test_read_devices(scan_worker_mock, instr):
|
||||
instr_devices = []
|
||||
worker.readout_priority.update({"monitored": instr_devices})
|
||||
devices = [dev.name for dev in worker._get_devices_from_instruction(instr)]
|
||||
with mock.patch.object(worker.device_manager.producer, "send") as send_mock:
|
||||
with mock.patch.object(worker.device_manager.connector, "send") as send_mock:
|
||||
worker.read_devices(instr)
|
||||
|
||||
if instr.content.get("device"):
|
||||
@ -888,7 +888,7 @@ def test_read_devices(scan_worker_mock, instr):
|
||||
)
|
||||
def test_kickoff_devices(scan_worker_mock, instr, devices, parameter, metadata):
|
||||
worker = scan_worker_mock
|
||||
with mock.patch.object(worker.device_manager.producer, "send") as send_mock:
|
||||
with mock.patch.object(worker.device_manager.connector, "send") as send_mock:
|
||||
worker.kickoff_devices(instr)
|
||||
send_mock.assert_called_once_with(
|
||||
MessageEndpoints.device_instructions(),
|
||||
@ -920,29 +920,27 @@ def test_kickoff_devices(scan_worker_mock, instr, devices, parameter, metadata):
|
||||
def test_publish_readback(scan_worker_mock, instr, devices):
|
||||
worker = scan_worker_mock
|
||||
with mock.patch.object(worker, "_get_readback", return_value=[{}]) as get_readback:
|
||||
with mock.patch.object(worker.device_manager, "producer") as producer_mock:
|
||||
with mock.patch.object(worker.device_manager, "connector") as connector_mock:
|
||||
worker._publish_readback(instr)
|
||||
|
||||
get_readback.assert_called_once_with(["samx"])
|
||||
pipe = producer_mock.pipeline()
|
||||
pipe = connector_mock.pipeline()
|
||||
msg = messages.DeviceMessage(signals={}, metadata=instr.metadata)
|
||||
|
||||
producer_mock.set_and_publish.assert_called_once_with(
|
||||
connector_mock.set_and_publish.assert_called_once_with(
|
||||
MessageEndpoints.device_read("samx"), msg, pipe
|
||||
)
|
||||
pipe.execute.assert_called_once()
|
||||
|
||||
|
||||
def test_get_readback(scan_worker_mock):
|
||||
worker = scan_worker_mock
|
||||
devices = ["samx"]
|
||||
with mock.patch.object(worker.device_manager, "producer") as producer_mock:
|
||||
with mock.patch.object(worker.device_manager, "connector") as connector_mock:
|
||||
worker._get_readback(devices)
|
||||
pipe = producer_mock.pipeline()
|
||||
producer_mock.get.assert_called_once_with(
|
||||
pipe = connector_mock.pipeline()
|
||||
connector_mock.get.assert_called_once_with(
|
||||
MessageEndpoints.device_readback("samx"), pipe=pipe
|
||||
)
|
||||
producer_mock.execute_pipeline.assert_called_once()
|
||||
connector_mock.execute_pipeline.assert_called_once()
|
||||
|
||||
|
||||
def test_publish_data_as_read(scan_worker_mock):
|
||||
@ -958,12 +956,12 @@ def test_publish_data_as_read(scan_worker_mock):
|
||||
"RID": "requestID",
|
||||
},
|
||||
)
|
||||
with mock.patch.object(worker.device_manager, "producer") as producer_mock:
|
||||
with mock.patch.object(worker.device_manager, "connector") as connector_mock:
|
||||
worker.publish_data_as_read(instr)
|
||||
msg = messages.DeviceMessage(
|
||||
signals=instr.content["parameter"]["data"], metadata=instr.metadata
|
||||
)
|
||||
producer_mock.set_and_publish.assert_called_once_with(
|
||||
connector_mock.set_and_publish.assert_called_once_with(
|
||||
MessageEndpoints.device_read("samx"), msg
|
||||
)
|
||||
|
||||
@ -983,13 +981,13 @@ def test_publish_data_as_read_multiple(scan_worker_mock):
|
||||
"RID": "requestID",
|
||||
},
|
||||
)
|
||||
with mock.patch.object(worker.device_manager, "producer") as producer_mock:
|
||||
with mock.patch.object(worker.device_manager, "connector") as connector_mock:
|
||||
worker.publish_data_as_read(instr)
|
||||
mock_calls = []
|
||||
for device, dev_data in zip(devices, data):
|
||||
msg = messages.DeviceMessage(signals=dev_data, metadata=instr.metadata)
|
||||
mock_calls.append(mock.call(MessageEndpoints.device_read(device), msg))
|
||||
assert producer_mock.set_and_publish.mock_calls == mock_calls
|
||||
assert connector_mock.set_and_publish.mock_calls == mock_calls
|
||||
|
||||
|
||||
def test_check_for_interruption(scan_worker_mock):
|
||||
@ -1048,7 +1046,7 @@ def test_open_scan(scan_worker_mock, instr, corr_num_points, scan_id):
|
||||
if "pointID" in instr.metadata:
|
||||
worker.max_point_id = instr.metadata["pointID"]
|
||||
|
||||
assert worker.parent.producer.get(MessageEndpoints.scan_number()) == None
|
||||
assert worker.parent.connector.get(MessageEndpoints.scan_number()) == None
|
||||
|
||||
with mock.patch.object(worker, "current_instruction_queue_item") as queue_mock:
|
||||
with mock.patch.object(worker, "_initialize_scan_info") as init_mock:
|
||||
@ -1181,7 +1179,7 @@ def test_stage_device(scan_worker_mock, msg):
|
||||
worker.device_manager.devices["eiger"]._config["readoutPriority"] = "async"
|
||||
|
||||
with mock.patch.object(worker, "_wait_for_stage") as wait_mock:
|
||||
with mock.patch.object(worker.device_manager.producer, "send") as send_mock:
|
||||
with mock.patch.object(worker.device_manager.connector, "send") as send_mock:
|
||||
worker.stage_devices(msg)
|
||||
async_devices = [dev.name for dev in worker.device_manager.devices.async_devices()]
|
||||
devices = [
|
||||
@ -1251,7 +1249,7 @@ def test_unstage_device(scan_worker_mock, msg, devices, parameter, metadata, cle
|
||||
if not devices:
|
||||
devices = [dev.name for dev in worker.device_manager.devices.enabled_devices]
|
||||
|
||||
with mock.patch.object(worker.device_manager.producer, "send") as send_mock:
|
||||
with mock.patch.object(worker.device_manager.connector, "send") as send_mock:
|
||||
with mock.patch.object(worker, "_wait_for_stage") as wait_mock:
|
||||
worker.unstage_devices(msg, devices, cleanup)
|
||||
|
||||
@ -1270,12 +1268,12 @@ def test_unstage_device(scan_worker_mock, msg, devices, parameter, metadata, cle
|
||||
@pytest.mark.parametrize("status,expire", [("open", None), ("closed", 1800), ("aborted", 1800)])
|
||||
def test_send_scan_status(scan_worker_mock, status, expire):
|
||||
worker = scan_worker_mock
|
||||
worker.device_manager.producer = ProducerMock()
|
||||
worker.device_manager.connector = ConnectorMock()
|
||||
worker.current_scanID = str(uuid.uuid4())
|
||||
worker._send_scan_status(status)
|
||||
scan_info_msgs = [
|
||||
msg
|
||||
for msg in worker.device_manager.producer.message_sent
|
||||
for msg in worker.device_manager.connector.message_sent
|
||||
if msg["queue"] == MessageEndpoints.public_scan_info(scanID=worker.current_scanID)
|
||||
]
|
||||
assert len(scan_info_msgs) == 1
|
||||
|
@ -6,7 +6,7 @@ import numpy as np
|
||||
import pytest
|
||||
from bec_lib import messages
|
||||
from bec_lib.devicemanager import DeviceContainer
|
||||
from bec_lib.tests.utils import ProducerMock
|
||||
from bec_lib.tests.utils import ConnectorMock
|
||||
|
||||
from scan_plugins.LamNIFermatScan import LamNIFermatScan
|
||||
from scan_plugins.otf_scan import OTFScan
|
||||
@ -80,7 +80,7 @@ class DeviceMock:
|
||||
|
||||
class DMMock:
|
||||
devices = DeviceContainer()
|
||||
producer = ProducerMock()
|
||||
connector = ConnectorMock()
|
||||
|
||||
def add_device(self, name):
|
||||
self.devices[name] = DeviceMock(name)
|
||||
@ -1099,7 +1099,7 @@ def test_pre_scan_macro():
|
||||
device_manager=device_manager, parameter=scan_msg.content["parameter"]
|
||||
)
|
||||
with mock.patch.object(
|
||||
request.device_manager.producer,
|
||||
request.device_manager.connector,
|
||||
"lrange",
|
||||
new_callable=mock.PropertyMock,
|
||||
return_value=macros,
|
||||
|
@ -26,9 +26,9 @@ dir_path = os.path.abspath(os.path.join(os.path.dirname(bec_lib.__file__), "./co
|
||||
class ConfigHandler:
|
||||
def __init__(self, scibec_connector: SciBecConnector, connector: ConnectorBase) -> None:
|
||||
self.scibec_connector = scibec_connector
|
||||
self.connector = connector
|
||||
self.device_manager = DeviceManager(self.scibec_connector.scihub)
|
||||
self.device_manager.initialize(scibec_connector.config.redis)
|
||||
self.producer = connector.producer()
|
||||
self.validator = SciBecValidator(os.path.join(dir_path, "openapi_schema.json"))
|
||||
|
||||
def parse_config_request(self, msg: messages.DeviceConfigMessage) -> None:
|
||||
@ -53,7 +53,7 @@ class ConfigHandler:
|
||||
|
||||
def send_config(self, msg: messages.DeviceConfigMessage) -> None:
|
||||
"""broadcast a new config"""
|
||||
self.producer.send(MessageEndpoints.device_config_update(), msg)
|
||||
self.connector.send(MessageEndpoints.device_config_update(), msg)
|
||||
|
||||
def send_config_request_reply(self, accepted, error_msg, metadata):
|
||||
"""send a config request reply"""
|
||||
@ -61,7 +61,7 @@ class ConfigHandler:
|
||||
accepted=accepted, message=error_msg, metadata=metadata
|
||||
)
|
||||
RID = metadata.get("RID")
|
||||
self.producer.set(MessageEndpoints.device_config_request_response(RID), msg, expire=60)
|
||||
self.connector.set(MessageEndpoints.device_config_request_response(RID), msg, expire=60)
|
||||
|
||||
def _set_config(self, msg: messages.DeviceConfigMessage):
|
||||
config = msg.content["config"]
|
||||
@ -127,14 +127,14 @@ class ConfigHandler:
|
||||
|
||||
def _update_device_server(self, RID: str, config: dict, action="update") -> None:
|
||||
msg = messages.DeviceConfigMessage(action=action, config=config, metadata={"RID": RID})
|
||||
self.producer.send(MessageEndpoints.device_server_config_request(), msg)
|
||||
self.connector.send(MessageEndpoints.device_server_config_request(), msg)
|
||||
|
||||
def _wait_for_device_server_update(self, RID: str, timeout_time=10) -> bool:
|
||||
timeout = timeout_time
|
||||
time_step = 0.05
|
||||
elapsed_time = 0
|
||||
while True:
|
||||
msg = self.producer.get(MessageEndpoints.device_config_request_response(RID))
|
||||
msg = self.connector.get(MessageEndpoints.device_config_request_response(RID))
|
||||
if msg:
|
||||
return msg.content["accepted"], msg
|
||||
|
||||
@ -188,11 +188,11 @@ class ConfigHandler:
|
||||
self.validator.validate_device_patch(update)
|
||||
|
||||
def update_config_in_redis(self, device):
|
||||
config = self.device_manager.producer.get(MessageEndpoints.device_config())
|
||||
config = self.device_manager.connector.get(MessageEndpoints.device_config())
|
||||
config = config.content["resource"]
|
||||
index = next(
|
||||
index for index, dev_conf in enumerate(config) if dev_conf["name"] == device.name
|
||||
)
|
||||
config[index] = device._config
|
||||
msg = messages.AvailableResourceMessage(resource=config)
|
||||
self.device_manager.producer.set(MessageEndpoints.device_config(), msg)
|
||||
self.device_manager.connector.set(MessageEndpoints.device_config(), msg)
|
||||
|
@ -31,7 +31,6 @@ class SciBecConnector:
|
||||
def __init__(self, scihub: SciHub, connector: ConnectorBase) -> None:
|
||||
self.scihub = scihub
|
||||
self.connector = connector
|
||||
self.producer = connector.producer()
|
||||
self.scibec = None
|
||||
self.host = None
|
||||
self.target_bl = None
|
||||
@ -132,25 +131,24 @@ class SciBecConnector:
|
||||
"""
|
||||
Set the scibec account in redis
|
||||
"""
|
||||
self.producer.set(
|
||||
self.connector.set(
|
||||
MessageEndpoints.scibec(),
|
||||
messages.CredentialsMessage(credentials={"url": self.host, "token": f"Bearer {token}"}),
|
||||
)
|
||||
|
||||
def set_redis_config(self, config):
|
||||
msg = messages.AvailableResourceMessage(resource=config)
|
||||
self.producer.set(MessageEndpoints.device_config(), msg)
|
||||
self.connector.set(MessageEndpoints.device_config(), msg)
|
||||
|
||||
def _start_metadata_handler(self) -> None:
|
||||
self._metadata_handler = SciBecMetadataHandler(self)
|
||||
|
||||
def _start_config_request_handler(self) -> None:
|
||||
self._config_request_handler = self.connector.consumer(
|
||||
self._config_request_handler = self.connector.register(
|
||||
MessageEndpoints.device_config_request(),
|
||||
cb=self._device_config_request_callback,
|
||||
parent=self,
|
||||
)
|
||||
self._config_request_handler.start()
|
||||
|
||||
@staticmethod
|
||||
def _device_config_request_callback(msg, *, parent, **_kwargs) -> None:
|
||||
@ -159,7 +157,7 @@ class SciBecConnector:
|
||||
|
||||
def connect_to_scibec(self):
|
||||
"""
|
||||
Connect to SciBec and set the producer to the write account
|
||||
Connect to SciBec and set the connector to the write account
|
||||
"""
|
||||
self._load_environment()
|
||||
if not self._env_configured:
|
||||
@ -205,7 +203,7 @@ class SciBecConnector:
|
||||
write_account = self.scibec_info["activeExperiment"]["writeAccount"]
|
||||
if write_account[0] == "p":
|
||||
write_account = write_account.replace("p", "e")
|
||||
self.producer.set(MessageEndpoints.account(), write_account.encode())
|
||||
self.connector.set(MessageEndpoints.account(), write_account.encode())
|
||||
|
||||
def shutdown(self):
|
||||
"""
|
||||
|
@ -15,22 +15,20 @@ if TYPE_CHECKING:
|
||||
class SciBecMetadataHandler:
|
||||
def __init__(self, scibec_connector: SciBecConnector) -> None:
|
||||
self.scibec_connector = scibec_connector
|
||||
self._scan_status_consumer = None
|
||||
self._scan_status_register = None
|
||||
self._start_scan_subscription()
|
||||
self._file_subscription = None
|
||||
self._start_file_subscription()
|
||||
|
||||
def _start_scan_subscription(self):
|
||||
self._scan_status_consumer = self.scibec_connector.connector.consumer(
|
||||
self._scan_status_register = self.scibec_connector.connector.register(
|
||||
MessageEndpoints.scan_status(), cb=self._handle_scan_status, parent=self
|
||||
)
|
||||
self._scan_status_consumer.start()
|
||||
|
||||
def _start_file_subscription(self):
|
||||
self._file_subscription = self.scibec_connector.connector.consumer(
|
||||
self._file_subscription = self.scibec_connector.connector.register(
|
||||
MessageEndpoints.file_content(), cb=self._handle_file_content, parent=self
|
||||
)
|
||||
self._file_subscription.start()
|
||||
|
||||
@staticmethod
|
||||
def _handle_scan_status(msg, *, parent, **_kwargs) -> None:
|
||||
@ -171,7 +169,7 @@ class SciBecMetadataHandler:
|
||||
"""
|
||||
Shutdown the metadata handler
|
||||
"""
|
||||
if self._scan_status_consumer:
|
||||
self._scan_status_consumer.shutdown()
|
||||
if self._scan_status_register:
|
||||
self._scan_status_register.shutdown()
|
||||
if self._file_subscription:
|
||||
self._file_subscription.shutdown()
|
||||
|
@ -22,7 +22,6 @@ class SciLogConnector:
|
||||
def __init__(self, scihub: SciHub, connector: RedisConnector) -> None:
|
||||
self.scihub = scihub
|
||||
self.connector = connector
|
||||
self.producer = self.connector.producer()
|
||||
self.host = None
|
||||
self.user = None
|
||||
self.user_secret = None
|
||||
@ -44,7 +43,7 @@ class SciLogConnector:
|
||||
|
||||
def set_bec_token(self, token: str) -> None:
|
||||
"""set the scilog token in redis"""
|
||||
self.producer.set(
|
||||
self.connector.set(
|
||||
MessageEndpoints.logbook(),
|
||||
msgpack.dumps({"url": self.host, "user": self.user, "token": f"Bearer {token}"}),
|
||||
)
|
||||
|
@ -334,7 +334,7 @@ def test_config_handler_update_device_config_available_keys(config_handler, avai
|
||||
|
||||
def test_config_handler_wait_for_device_server_update(config_handler):
|
||||
RID = "12345"
|
||||
with mock.patch.object(config_handler.producer, "get") as mock_get:
|
||||
with mock.patch.object(config_handler.connector, "get") as mock_get:
|
||||
mock_get.side_effect = [
|
||||
None,
|
||||
None,
|
||||
@ -346,7 +346,7 @@ def test_config_handler_wait_for_device_server_update(config_handler):
|
||||
|
||||
def test_config_handler_wait_for_device_server_update_timeout(config_handler):
|
||||
RID = "12345"
|
||||
with mock.patch.object(config_handler.producer, "get", return_value=None) as mock_get:
|
||||
with mock.patch.object(config_handler.connector, "get", return_value=None) as mock_get:
|
||||
with pytest.raises(TimeoutError):
|
||||
config_handler._wait_for_device_server_update(RID, timeout_time=0.1)
|
||||
mock_get.assert_called()
|
||||
|
@ -138,6 +138,6 @@ def test_scibec_update_experiment_info(SciBecMock):
|
||||
|
||||
def test_update_eaccount_in_redis(SciBecMock):
|
||||
SciBecMock.scibec_info = {"activeExperiment": {"writeAccount": "p12345"}}
|
||||
with mock.patch.object(SciBecMock, "producer") as mock_producer:
|
||||
with mock.patch.object(SciBecMock, "connector") as mock_connector:
|
||||
SciBecMock._update_eaccount_in_redis()
|
||||
mock_producer.set.assert_called_once_with(MessageEndpoints.account(), b"e12345")
|
||||
mock_connector.set.assert_called_once_with(MessageEndpoints.account(), b"e12345")
|
||||
|
Loading…
x
Reference in New Issue
Block a user