mirror of
https://github.com/ivan-usov-org/bec.git
synced 2025-04-20 01:40:02 +02:00
refactor: endpoints return EndpointInfo object instead of string
This commit is contained in:
parent
4ac0bbca16
commit
a4adb64f5f
@ -31,9 +31,9 @@ class AsyncDataHandler:
|
||||
async_data = {}
|
||||
for device_key in async_device_keys:
|
||||
key = device_key.decode()
|
||||
device_name = key.split(MessageEndpoints.device_async_readback(scan_id, ""))[-1].split(
|
||||
":"
|
||||
)[0]
|
||||
device_name = key.split(MessageEndpoints.device_async_readback(scan_id, "").endpoint)[
|
||||
-1
|
||||
].split(":")[0]
|
||||
data = self.get_async_data_for_device(scan_id, device_name)
|
||||
if not data:
|
||||
continue
|
||||
|
@ -109,7 +109,7 @@ class BECService:
|
||||
)
|
||||
|
||||
def _update_existing_services(self):
|
||||
service_keys = self.connector.keys(MessageEndpoints.service_status("*"))
|
||||
service_keys = self.connector.keys(MessageEndpoints.service_status("*").endpoint)
|
||||
if not service_keys:
|
||||
return
|
||||
services = [service.decode().split(":", maxsplit=1)[0] for service in service_keys]
|
||||
|
@ -92,7 +92,7 @@ class Status:
|
||||
|
||||
while True:
|
||||
request_status = self._connector.lrange(
|
||||
MessageEndpoints.device_req_status(self._RID), 0, -1
|
||||
MessageEndpoints.device_req_status_container(self._RID), 0, -1
|
||||
)
|
||||
if request_status:
|
||||
break
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,3 +1,7 @@
|
||||
"""
|
||||
BECMessage classes for communication between BEC components.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
@ -83,8 +87,7 @@ class BundleMessage(BECMessage):
|
||||
"""append a new BECMessage to the bundle"""
|
||||
if not isinstance(msg, BECMessage):
|
||||
raise AttributeError(f"Cannot append message of type {msg.__class__.__name__}")
|
||||
else:
|
||||
self.messages.append(msg)
|
||||
self.messages.append(msg)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.messages)
|
||||
|
@ -14,9 +14,9 @@ import redis
|
||||
import redis.client
|
||||
import redis.exceptions
|
||||
|
||||
from bec_lib.logger import bec_logger
|
||||
from bec_lib.connector import ConnectorBase, MessageObject
|
||||
from bec_lib.endpoints import MessageEndpoints
|
||||
from bec_lib.endpoints import EndpointInfo, MessageEndpoints
|
||||
from bec_lib.logger import bec_logger
|
||||
from bec_lib.messages import AlarmMessage, BECMessage, LogMessage
|
||||
from bec_lib.serialization import MsgpackSerialization
|
||||
|
||||
@ -24,6 +24,30 @@ if TYPE_CHECKING:
|
||||
from bec_lib.alarm_handler import Alarms
|
||||
|
||||
|
||||
def _validate_endpoint(func, endpoint):
|
||||
if not isinstance(endpoint, EndpointInfo):
|
||||
return
|
||||
if func.__name__ not in endpoint.message_op:
|
||||
raise ValueError(f"Endpoint {endpoint} is not compatible with {func.__name__} method")
|
||||
|
||||
|
||||
def check_topic(func):
|
||||
@wraps(func)
|
||||
def wrapper(self, topic, *args, **kwargs):
|
||||
if isinstance(topic, str):
|
||||
warnings.warn(
|
||||
"RedisConnector methods with a string topic are deprecated and should not be used anymore. Use RedisConnector methods with an EndpointInfo instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
return func(self, topic, *args, **kwargs)
|
||||
if isinstance(topic, EndpointInfo):
|
||||
_validate_endpoint(func, topic)
|
||||
return func(self, topic.endpoint, *args, **kwargs)
|
||||
return func(self, topic, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class RedisConnector(ConnectorBase):
|
||||
def __init__(self, bootstrap: list, redis_cls=None):
|
||||
super().__init__(bootstrap)
|
||||
@ -74,21 +98,10 @@ class RedisConnector(ConnectorBase):
|
||||
"""send an error as log"""
|
||||
self.send(MessageEndpoints.log(), LogMessage(log_type="error", log_msg=msg))
|
||||
|
||||
def raise_alarm(
|
||||
self,
|
||||
severity: Alarms,
|
||||
alarm_type: str,
|
||||
source: str,
|
||||
msg: str,
|
||||
metadata: dict,
|
||||
):
|
||||
def raise_alarm(self, severity: Alarms, alarm_type: str, source: str, msg: str, metadata: dict):
|
||||
"""raise an alarm"""
|
||||
alarm_msg = AlarmMessage(
|
||||
severity=severity,
|
||||
alarm_type=alarm_type,
|
||||
source=source,
|
||||
msg=msg,
|
||||
metadata=metadata,
|
||||
severity=severity, alarm_type=alarm_type, source=source, msg=msg, metadata=metadata
|
||||
)
|
||||
self.set_and_publish(MessageEndpoints.alarm(), alarm_msg)
|
||||
|
||||
@ -112,6 +125,7 @@ class RedisConnector(ConnectorBase):
|
||||
client = pipe if pipe is not None else self._redis_conn
|
||||
client.publish(topic, msg)
|
||||
|
||||
@check_topic
|
||||
def send(self, topic: str, msg: BECMessage, pipe=None) -> None:
|
||||
"""send to redis"""
|
||||
if not isinstance(msg, BECMessage):
|
||||
@ -124,8 +138,7 @@ class RedisConnector(ConnectorBase):
|
||||
# under the hood, it uses asyncio - this lets the possibility to stop
|
||||
# the loop on demand
|
||||
self._events_listener_thread = threading.Thread(
|
||||
target=self._get_messages_loop,
|
||||
args=(self._pubsub_conn,),
|
||||
target=self._get_messages_loop, args=(self._pubsub_conn,)
|
||||
)
|
||||
self._events_listener_thread.start()
|
||||
# make a weakref from the callable, using louie;
|
||||
@ -135,6 +148,9 @@ class RedisConnector(ConnectorBase):
|
||||
if patterns is not None:
|
||||
if isinstance(patterns, str):
|
||||
patterns = [patterns]
|
||||
elif isinstance(patterns, EndpointInfo):
|
||||
_validate_endpoint(self.register, patterns)
|
||||
patterns = [patterns.endpoint]
|
||||
|
||||
self._pubsub_conn.psubscribe(patterns)
|
||||
for pattern in patterns:
|
||||
@ -142,6 +158,9 @@ class RedisConnector(ConnectorBase):
|
||||
else:
|
||||
if isinstance(topics, str):
|
||||
topics = [topics]
|
||||
elif isinstance(topics, EndpointInfo):
|
||||
_validate_endpoint(self.register, topics)
|
||||
topics = [topics.endpoint]
|
||||
|
||||
self._pubsub_conn.subscribe(topics)
|
||||
for topic in topics:
|
||||
@ -181,10 +200,7 @@ class RedisConnector(ConnectorBase):
|
||||
callbacks = self._topics_cb[msg["pattern"].decode()]
|
||||
else:
|
||||
callbacks = self._topics_cb[channel]
|
||||
msg = MessageObject(
|
||||
topic=channel,
|
||||
value=MsgpackSerialization.loads(msg["data"]),
|
||||
)
|
||||
msg = MessageObject(topic=channel, value=MsgpackSerialization.loads(msg["data"]))
|
||||
for cb_ref, kwargs in callbacks:
|
||||
cb = cb_ref()
|
||||
if cb:
|
||||
@ -214,6 +230,7 @@ class RedisConnector(ConnectorBase):
|
||||
while self.poll_messages():
|
||||
...
|
||||
|
||||
@check_topic
|
||||
def lpush(
|
||||
self, topic: str, msg: str, pipe=None, max_size: int = None, expire: int = None
|
||||
) -> None:
|
||||
@ -234,12 +251,14 @@ class RedisConnector(ConnectorBase):
|
||||
if not pipe:
|
||||
client.execute()
|
||||
|
||||
@check_topic
|
||||
def lset(self, topic: str, index: int, msg: str, pipe=None) -> None:
|
||||
client = pipe if pipe is not None else self._redis_conn
|
||||
if isinstance(msg, BECMessage):
|
||||
msg = MsgpackSerialization.dumps(msg)
|
||||
return client.lset(topic, index, msg)
|
||||
|
||||
@check_topic
|
||||
def rpush(self, topic: str, msg: str, pipe=None) -> int:
|
||||
"""O(1) for each element added, so O(N) to add N elements when the
|
||||
command is called with multiple arguments. Insert all the specified
|
||||
@ -251,6 +270,7 @@ class RedisConnector(ConnectorBase):
|
||||
msg = MsgpackSerialization.dumps(msg)
|
||||
return client.rpush(topic, msg)
|
||||
|
||||
@check_topic
|
||||
def lrange(self, topic: str, start: int, end: int, pipe=None):
|
||||
"""O(S+N) where S is the distance of start offset from HEAD for small
|
||||
lists, from nearest end (HEAD or TAIL) for large lists; and N is the
|
||||
@ -262,16 +282,17 @@ class RedisConnector(ConnectorBase):
|
||||
cmd_result = client.lrange(topic, start, end)
|
||||
if pipe:
|
||||
return cmd_result
|
||||
else:
|
||||
# in case of command executed in a pipe, use 'execute_pipeline' method
|
||||
ret = []
|
||||
for msg in cmd_result:
|
||||
try:
|
||||
ret.append(MsgpackSerialization.loads(msg))
|
||||
except RuntimeError:
|
||||
ret.append(msg)
|
||||
return ret
|
||||
|
||||
# in case of command executed in a pipe, use 'execute_pipeline' method
|
||||
ret = []
|
||||
for msg in cmd_result:
|
||||
try:
|
||||
ret.append(MsgpackSerialization.loads(msg))
|
||||
except RuntimeError:
|
||||
ret.append(msg)
|
||||
return ret
|
||||
|
||||
@check_topic
|
||||
def set_and_publish(self, topic: str, msg, pipe=None, expire: int = None) -> None:
|
||||
"""piped combination of self.publish and self.set"""
|
||||
client = pipe if pipe is not None else self.pipeline()
|
||||
@ -283,6 +304,7 @@ class RedisConnector(ConnectorBase):
|
||||
if not pipe:
|
||||
client.execute()
|
||||
|
||||
@check_topic
|
||||
def set(self, topic: str, msg, pipe=None, expire: int = None) -> None:
|
||||
"""set redis value"""
|
||||
client = pipe if pipe is not None else self._redis_conn
|
||||
@ -292,13 +314,18 @@ class RedisConnector(ConnectorBase):
|
||||
|
||||
def keys(self, pattern: str) -> list:
|
||||
"""returns all keys matching a pattern"""
|
||||
if isinstance(pattern, EndpointInfo):
|
||||
_validate_endpoint(self.keys, pattern)
|
||||
pattern = pattern.endpoint
|
||||
return self._redis_conn.keys(pattern)
|
||||
|
||||
@check_topic
|
||||
def delete(self, topic, pipe=None):
|
||||
"""delete topic"""
|
||||
client = pipe if pipe is not None else self._redis_conn
|
||||
client.delete(topic)
|
||||
|
||||
@check_topic
|
||||
def get(self, topic: str, pipe=None):
|
||||
"""retrieve entry, either via hgetall or get"""
|
||||
client = pipe if pipe is not None else self._redis_conn
|
||||
@ -311,6 +338,7 @@ class RedisConnector(ConnectorBase):
|
||||
except RuntimeError:
|
||||
return data
|
||||
|
||||
@check_topic
|
||||
def xadd(self, topic: str, msg_dict: dict, max_size=None, pipe=None, expire: int = None):
|
||||
"""
|
||||
add to stream
|
||||
@ -345,6 +373,7 @@ class RedisConnector(ConnectorBase):
|
||||
if not pipe and expire:
|
||||
client.execute()
|
||||
|
||||
@check_topic
|
||||
def get_last(self, topic: str, key="data"):
|
||||
"""retrieve last entry from stream"""
|
||||
client = self._redis_conn
|
||||
@ -359,13 +388,9 @@ class RedisConnector(ConnectorBase):
|
||||
return msg_dict
|
||||
return msg_dict.get(key)
|
||||
|
||||
@check_topic
|
||||
def xread(
|
||||
self,
|
||||
topic: str,
|
||||
id: str = None,
|
||||
count: int = None,
|
||||
block: int = None,
|
||||
from_start=False,
|
||||
self, topic: str, id: str = None, count: int = None, block: int = None, from_start=False
|
||||
) -> list:
|
||||
"""
|
||||
read from stream
|
||||
@ -422,6 +447,7 @@ class RedisConnector(ConnectorBase):
|
||||
self.stream_keys[topic] = index
|
||||
return out if out else None
|
||||
|
||||
@check_topic
|
||||
def xrange(self, topic: str, min: str, max: str, count: int = None):
|
||||
"""
|
||||
read a range from stream
|
||||
@ -464,9 +490,9 @@ class RedisConnector(ConnectorBase):
|
||||
|
||||
In order to keep this fail-safe and simple it uses 'mock'...
|
||||
"""
|
||||
from unittest.mock import (
|
||||
from unittest.mock import ( # import is done here, to not pollute the file with something normally in tests
|
||||
Mock,
|
||||
) # import is done here, to not pollute the file with something normally in tests
|
||||
)
|
||||
|
||||
warnings.warn(
|
||||
"RedisConnector.consumer() is deprecated and should not be used anymore. Use RedisConnector.register() with 'topics', 'patterns', 'cb' or 'start_thread' instead. Additional keyword args are transmitted to the callback. For the caller, the main difference with RedisConnector.register() is that it does not return a new thread.",
|
||||
|
@ -90,7 +90,7 @@ class ScanReport:
|
||||
"""get the status of a move request"""
|
||||
motors = list(self.request.request.content["parameter"]["args"].keys())
|
||||
request_status = self._client.device_manager.connector.lrange(
|
||||
MessageEndpoints.device_req_status(self.request.requestID), 0, -1
|
||||
MessageEndpoints.device_req_status_container(self.request.requestID), 0, -1
|
||||
)
|
||||
if len(request_status) == len(motors):
|
||||
return True
|
||||
|
@ -13,7 +13,7 @@ import yaml
|
||||
from bec_lib import BECClient, messages
|
||||
from bec_lib.connector import ConnectorBase
|
||||
from bec_lib.devicemanager import DeviceManagerBase
|
||||
from bec_lib.endpoints import MessageEndpoints
|
||||
from bec_lib.endpoints import EndpointInfo, MessageEndpoints
|
||||
from bec_lib.logger import bec_logger
|
||||
from bec_lib.scans import Scans
|
||||
from bec_lib.service_config import ServiceConfig
|
||||
@ -542,24 +542,18 @@ class ConnectorMock(ConnectorBase): # pragma: no cover
|
||||
def register(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def set(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def set_and_publish(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def keys(self, *args, **kwargs):
|
||||
return []
|
||||
|
||||
def set(self, topic, msg, pipe=None, expire: int = None):
|
||||
if pipe:
|
||||
pipe._pipe_buffer.append(("set", (topic, msg), {"expire": expire}))
|
||||
pipe._pipe_buffer.append(("set", (topic.endpoint, msg), {"expire": expire}))
|
||||
return
|
||||
self.message_sent.append({"queue": topic, "msg": msg, "expire": expire})
|
||||
|
||||
def raw_send(self, topic, msg, pipe=None):
|
||||
if pipe:
|
||||
pipe._pipe_buffer.append(("send", (topic, msg), {}))
|
||||
pipe._pipe_buffer.append(("send", (topic.endpoint, msg), {}))
|
||||
return
|
||||
self.message_sent.append({"queue": topic, "msg": msg})
|
||||
|
||||
@ -570,7 +564,7 @@ class ConnectorMock(ConnectorBase): # pragma: no cover
|
||||
|
||||
def set_and_publish(self, topic, msg, pipe=None, expire: int = None):
|
||||
if pipe:
|
||||
pipe._pipe_buffer.append(("set_and_publish", (topic, msg), {"expire": expire}))
|
||||
pipe._pipe_buffer.append(("set_and_publish", (topic.endpoint, msg), {"expire": expire}))
|
||||
return
|
||||
self.message_sent.append({"queue": topic, "msg": msg, "expire": expire})
|
||||
|
||||
@ -592,6 +586,8 @@ class ConnectorMock(ConnectorBase): # pragma: no cover
|
||||
return []
|
||||
|
||||
def get(self, topic, pipe=None):
|
||||
if isinstance(topic, EndpointInfo):
|
||||
topic = topic.endpoint
|
||||
if pipe:
|
||||
pipe._pipe_buffer.append(("get", (topic,), {}))
|
||||
return
|
||||
|
@ -89,8 +89,8 @@ def test_get_async_data_for_scan():
|
||||
producer = mock.MagicMock()
|
||||
async_data = AsyncDataHandler(producer)
|
||||
producer.keys.return_value = [
|
||||
MessageEndpoints.device_async_readback("scanID", "samx").encode(),
|
||||
MessageEndpoints.device_async_readback("scanID", "samy").encode(),
|
||||
MessageEndpoints.device_async_readback("scanID", "samx").endpoint.encode(),
|
||||
MessageEndpoints.device_async_readback("scanID", "samy").endpoint.encode(),
|
||||
]
|
||||
with mock.patch.object(async_data, "get_async_data_for_device") as mock_get:
|
||||
async_data.get_async_data_for_scan("scanID")
|
||||
|
@ -116,8 +116,8 @@ def test_bec_service_service_status():
|
||||
|
||||
def test_bec_service_update_existing_services():
|
||||
service_keys = [
|
||||
MessageEndpoints.service_status("service1").encode(),
|
||||
MessageEndpoints.service_status("service2").encode(),
|
||||
MessageEndpoints.service_status("service1").endpoint.encode(),
|
||||
MessageEndpoints.service_status("service2").endpoint.encode(),
|
||||
]
|
||||
service_msgs = [
|
||||
messages.StatusMessage(name="service1", status=BECStatus.RUNNING, info={}, metadata={}),
|
||||
@ -136,8 +136,8 @@ def test_bec_service_update_existing_services():
|
||||
|
||||
def test_bec_service_update_existing_services_ignores_wrong_msgs():
|
||||
service_keys = [
|
||||
MessageEndpoints.service_status("service1").encode(),
|
||||
MessageEndpoints.service_status("service2").encode(),
|
||||
MessageEndpoints.service_status("service1").endpoint.encode(),
|
||||
MessageEndpoints.service_status("service2").endpoint.encode(),
|
||||
]
|
||||
service_msgs = [
|
||||
messages.StatusMessage(name="service1", status=BECStatus.RUNNING, info={}, metadata={}),
|
||||
|
@ -63,7 +63,7 @@ from bec_lib.tests.utils import ConnectorMock
|
||||
)
|
||||
def test_update_with_queue_status(queue_msg):
|
||||
scan_manager = ScanManager(ConnectorMock(""))
|
||||
scan_manager.connector._get_buffer[MessageEndpoints.scan_queue_status()] = queue_msg
|
||||
scan_manager.connector._get_buffer[MessageEndpoints.scan_queue_status().endpoint] = queue_msg
|
||||
scan_manager.update_with_queue_status(queue_msg)
|
||||
assert (
|
||||
scan_manager.scan_storage.find_scan_by_ID("bfa582aa-f9cd-4258-ab5d-3e5d54d3dde5")
|
||||
|
@ -5,13 +5,13 @@ import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import ophyd
|
||||
from ophyd import Kind, OphydObject, Staged
|
||||
from ophyd.utils import errors as ophyd_errors
|
||||
|
||||
from bec_lib import Alarms, BECService, MessageEndpoints, bec_logger, messages
|
||||
from bec_lib.connector import ConnectorBase
|
||||
from bec_lib.device import OnFailure
|
||||
from bec_lib.messages import BECStatus
|
||||
from ophyd import Kind, OphydObject, Staged
|
||||
from ophyd.utils import errors as ophyd_errors
|
||||
|
||||
from device_server.devices import rgetattr
|
||||
from device_server.devices.devicemanager import DeviceManagerDS
|
||||
from device_server.rpc_mixin import RPCMixin
|
||||
@ -39,8 +39,7 @@ class DeviceServer(RPCMixin, BECService):
|
||||
self._tasks = []
|
||||
self.device_manager = None
|
||||
self.connector.register(
|
||||
MessageEndpoints.scan_queue_modification(),
|
||||
cb=self.register_interception_callback,
|
||||
MessageEndpoints.scan_queue_modification(), cb=self.register_interception_callback
|
||||
)
|
||||
self.executor = ThreadPoolExecutor(max_workers=4)
|
||||
self._start_device_manager()
|
||||
@ -296,7 +295,7 @@ class DeviceServer(RPCMixin, BECService):
|
||||
response = status.instruction.metadata.get("response")
|
||||
if response:
|
||||
self.connector.lpush(
|
||||
MessageEndpoints.device_req_status(status.instruction.metadata["RID"]),
|
||||
MessageEndpoints.device_req_status_container(status.instruction.metadata["RID"]),
|
||||
dev_msg,
|
||||
pipe,
|
||||
expire=18000,
|
||||
|
@ -10,6 +10,10 @@ import numpy as np
|
||||
import ophyd
|
||||
import ophyd.sim as ops
|
||||
import ophyd_devices as opd
|
||||
from ophyd.ophydobj import OphydObject
|
||||
from ophyd.signal import EpicsSignalBase
|
||||
from typeguard import typechecked
|
||||
|
||||
from bec_lib import (
|
||||
BECService,
|
||||
DeviceBase,
|
||||
@ -22,9 +26,6 @@ from bec_lib import (
|
||||
from bec_lib.connector import ConnectorBase
|
||||
from device_server.devices.config_update_handler import ConfigUpdateHandler
|
||||
from device_server.devices.device_serializer import get_device_info
|
||||
from ophyd.ophydobj import OphydObject
|
||||
from ophyd.signal import EpicsSignalBase
|
||||
from typeguard import typechecked
|
||||
|
||||
try:
|
||||
from bec_plugins import devices as plugin_devices
|
||||
@ -64,7 +65,9 @@ class DSDevice(DeviceBase):
|
||||
limits = None
|
||||
pipe = connector.pipeline()
|
||||
connector.set_and_publish(MessageEndpoints.device_readback(self.name), dev_msg, pipe=pipe)
|
||||
connector.set(topic=MessageEndpoints.device_read(self.name), msg=dev_msg, pipe=pipe)
|
||||
connector.set_and_publish(
|
||||
topic=MessageEndpoints.device_read(self.name), msg=dev_msg, pipe=pipe
|
||||
)
|
||||
connector.set_and_publish(
|
||||
MessageEndpoints.device_read_configuration(self.name), dev_config_msg, pipe=pipe
|
||||
)
|
||||
@ -476,7 +479,7 @@ class DeviceManagerDS(DeviceManagerBase):
|
||||
device = kwargs["obj"].root.name
|
||||
status = 0
|
||||
metadata = self.devices[device].metadata
|
||||
self.connector.send(
|
||||
self.connector.set(
|
||||
MessageEndpoints.device_status(device),
|
||||
messages.DeviceStatusMessage(device=device, status=status, metadata=metadata),
|
||||
)
|
||||
|
@ -141,8 +141,8 @@ def test_flyer_event_callback():
|
||||
assert progress[0] == "set_and_publish"
|
||||
|
||||
# check endpoint
|
||||
assert bundle[1][0] == MessageEndpoints.device_read("samx")
|
||||
assert progress[1][0] == MessageEndpoints.device_progress("samx")
|
||||
assert bundle[1][0] == MessageEndpoints.device_read("samx").endpoint
|
||||
assert progress[1][0] == MessageEndpoints.device_progress("samx").endpoint
|
||||
|
||||
# check message
|
||||
bundle_msg = bundle[1][1]
|
||||
|
@ -621,6 +621,7 @@ def test_kickoff_device(device_server_mock, instr):
|
||||
kickoff.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.timeout(5)
|
||||
@pytest.mark.parametrize(
|
||||
"instr",
|
||||
[
|
||||
@ -639,7 +640,7 @@ def test_set_device(device_server_mock, instr):
|
||||
res = [
|
||||
msg
|
||||
for msg in device_server.connector.message_sent
|
||||
if msg["queue"] == MessageEndpoints.device_req_status("samx")
|
||||
if msg["queue"] == MessageEndpoints.device_req_status("samx").endpoint
|
||||
]
|
||||
if res:
|
||||
break
|
||||
@ -675,7 +676,7 @@ def test_read_device(device_server_mock, instr):
|
||||
res = [
|
||||
msg
|
||||
for msg in device_server.connector.message_sent
|
||||
if msg["queue"] == MessageEndpoints.device_read(device)
|
||||
if msg["queue"] == MessageEndpoints.device_read(device).endpoint
|
||||
]
|
||||
assert res[-1]["msg"].metadata["RID"] == instr.metadata["RID"]
|
||||
assert res[-1]["msg"].metadata["stream"] == "primary"
|
||||
@ -689,12 +690,12 @@ def test_read_config_and_update_devices(device_server_mock, devices):
|
||||
res = [
|
||||
msg
|
||||
for msg in device_server.connector.message_sent
|
||||
if msg["queue"] == MessageEndpoints.device_read_configuration(device)
|
||||
if msg["queue"] == MessageEndpoints.device_read_configuration(device).endpoint
|
||||
]
|
||||
config = device_server.device_manager.devices[device].obj.read_configuration()
|
||||
msg = res[-1]["msg"]
|
||||
assert msg.content["signals"].keys() == config.keys()
|
||||
assert res[-1]["queue"] == MessageEndpoints.device_read_configuration(device)
|
||||
assert res[-1]["queue"] == MessageEndpoints.device_read_configuration(device).endpoint
|
||||
|
||||
|
||||
def test_read_and_update_devices_exception(device_server_mock):
|
||||
|
@ -16,6 +16,7 @@ from bec_lib.alarm_handler import Alarms
|
||||
from bec_lib.async_data import AsyncDataHandler
|
||||
from bec_lib.file_utils import FileWriterMixin
|
||||
from bec_lib.redis_connector import MessageObject, RedisConnector
|
||||
|
||||
from file_writer.file_writer import NexusFileWriter
|
||||
|
||||
logger = bec_logger.logger
|
||||
@ -229,9 +230,9 @@ class FileWriterManager(BECService):
|
||||
return
|
||||
for device_key in async_device_keys:
|
||||
key = device_key.decode()
|
||||
device_name = key.split(MessageEndpoints.device_async_readback(scanID, ""))[-1].split(
|
||||
":"
|
||||
)[0]
|
||||
device_name = key.split(MessageEndpoints.device_async_readback(scanID, "").endpoint)[
|
||||
-1
|
||||
].split(":")[0]
|
||||
msgs = self.connector.xrange(key, min="-", max="+")
|
||||
if not msgs:
|
||||
continue
|
||||
|
@ -196,7 +196,7 @@ def test_update_async_data():
|
||||
file_manager.scan_storage["scanID"] = ScanStorage(10, "scanID")
|
||||
with mock.patch.object(file_manager, "connector") as mock_connector:
|
||||
with mock.patch.object(file_manager, "_process_async_data") as mock_process:
|
||||
key = MessageEndpoints.device_async_readback("scanID", "dev1")
|
||||
key = MessageEndpoints.device_async_readback("scanID", "dev1").endpoint
|
||||
mock_connector.keys.return_value = [key.encode()]
|
||||
data = [(b"0-0", b'{"data": "data"}')]
|
||||
mock_connector.xrange.return_value = data
|
||||
|
@ -44,12 +44,7 @@ class EmitterBase:
|
||||
msg_dump = msg
|
||||
msgs.append(msg_dump)
|
||||
if public:
|
||||
self.connector.set(
|
||||
public,
|
||||
msg_dump,
|
||||
pipe=pipe,
|
||||
expire=1800,
|
||||
)
|
||||
self.connector.set(public, msg_dump, pipe=pipe, expire=1800)
|
||||
self.connector.send(endpoint, msgs, pipe=pipe)
|
||||
pipe.execute()
|
||||
|
||||
|
@ -72,7 +72,7 @@ class ScanBundler(BECService):
|
||||
|
||||
def _device_read_callback(self, msg, **_kwargs):
|
||||
# pylint: disable=protected-access
|
||||
dev = msg.topic.split(MessageEndpoints._device_read + "/")[-1]
|
||||
dev = msg.topic.split(MessageEndpoints.device_read("").endpoint)[-1]
|
||||
msgs = msg.value
|
||||
logger.debug(f"Received reading from device {dev}")
|
||||
if not isinstance(msgs, list):
|
||||
@ -251,16 +251,12 @@ class ScanBundler(BECService):
|
||||
)
|
||||
}
|
||||
|
||||
def _get_scan_status_history(self, length):
|
||||
return self.connector.lrange(MessageEndpoints.scan_status() + "_list", length * -1, -1)
|
||||
|
||||
def _wait_for_scanID(self, scanID, timeout_time=10):
|
||||
elapsed_time = 0
|
||||
while not scanID in self.storage_initialized:
|
||||
msgs = self._get_scan_status_history(5)
|
||||
for msg in msgs:
|
||||
if msg.content["scanID"] == scanID:
|
||||
self.handle_scan_status_message(msg)
|
||||
msg = self.connector.get(MessageEndpoints.public_scan_info(scanID))
|
||||
if msg and msg.content["scanID"] == scanID:
|
||||
self.handle_scan_status_message(msg)
|
||||
if scanID in self.sync_storage:
|
||||
if self.sync_storage[scanID]["status"] in ["closed", "aborted"]:
|
||||
logger.info(
|
||||
|
@ -71,7 +71,7 @@ def test_device_read_callback():
|
||||
metadata={"scanID": "laksjd", "readout_priority": "monitored"},
|
||||
)
|
||||
msg.value = dev_msg
|
||||
msg.topic = MessageEndpoints.device_read("samx")
|
||||
msg.topic = MessageEndpoints.device_read("samx").endpoint
|
||||
|
||||
with mock.patch.object(scan_bundler, "_add_device_to_storage") as add_dev:
|
||||
scan_bundler._device_read_callback(msg)
|
||||
@ -81,55 +81,43 @@ def test_device_read_callback():
|
||||
@pytest.mark.parametrize(
|
||||
"scanID,storageID,scan_msg",
|
||||
[
|
||||
("adlk-jalskdj", None, []),
|
||||
("adlk-jalskdj", None, None),
|
||||
(
|
||||
"adlk-jalskdjs",
|
||||
"adlk-jalskdjs",
|
||||
[
|
||||
messages.ScanStatusMessage(
|
||||
scanID="adlk-jalskdjs",
|
||||
status="open",
|
||||
info={
|
||||
"scan_motors": ["samx"],
|
||||
"readout_priority": {
|
||||
"monitored": ["samx"],
|
||||
"baseline": [],
|
||||
"on_request": [],
|
||||
},
|
||||
"queueID": "my-queue-ID",
|
||||
"scan_number": 5,
|
||||
"scan_type": "step",
|
||||
},
|
||||
)
|
||||
],
|
||||
messages.ScanStatusMessage(
|
||||
scanID="adlk-jalskdjs",
|
||||
status="open",
|
||||
info={
|
||||
"scan_motors": ["samx"],
|
||||
"readout_priority": {"monitored": ["samx"], "baseline": [], "on_request": []},
|
||||
"queueID": "my-queue-ID",
|
||||
"scan_number": 5,
|
||||
"scan_type": "step",
|
||||
},
|
||||
),
|
||||
),
|
||||
(
|
||||
"adlk-jalskdjs",
|
||||
"",
|
||||
[
|
||||
messages.ScanStatusMessage(
|
||||
scanID="adlk-jalskdjs",
|
||||
status="open",
|
||||
info={
|
||||
"scan_motors": ["samx"],
|
||||
"readout_priority": {
|
||||
"monitored": ["samx"],
|
||||
"baseline": [],
|
||||
"on_request": [],
|
||||
},
|
||||
"queueID": "my-queue-ID",
|
||||
"scan_number": 5,
|
||||
"scan_type": "step",
|
||||
},
|
||||
)
|
||||
],
|
||||
messages.ScanStatusMessage(
|
||||
scanID="adlk-jalskdjs",
|
||||
status="open",
|
||||
info={
|
||||
"scan_motors": ["samx"],
|
||||
"readout_priority": {"monitored": ["samx"], "baseline": [], "on_request": []},
|
||||
"queueID": "my-queue-ID",
|
||||
"scan_number": 5,
|
||||
"scan_type": "step",
|
||||
},
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_wait_for_scanID(scanID, storageID, scan_msg):
|
||||
sb = load_ScanBundlerMock()
|
||||
sb.storage_initialized.add(storageID)
|
||||
with mock.patch.object(sb, "_get_scan_status_history", return_value=scan_msg) as get_scan_msgs:
|
||||
with mock.patch.object(sb.connector, "get", return_value=scan_msg) as get_scan_msgs:
|
||||
if not storageID and not scan_msg:
|
||||
with pytest.raises(TimeoutError):
|
||||
sb._wait_for_scanID(scanID, 1)
|
||||
@ -137,32 +125,6 @@ def test_wait_for_scanID(scanID, storageID, scan_msg):
|
||||
sb._wait_for_scanID(scanID)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"msgs",
|
||||
[
|
||||
[
|
||||
messages.ScanStatusMessage(
|
||||
scanID="scanID",
|
||||
status="open",
|
||||
info={
|
||||
"primary": ["samx"],
|
||||
"queueID": "my-queue-ID",
|
||||
"scan_number": 5,
|
||||
"scan_type": "step",
|
||||
},
|
||||
)
|
||||
],
|
||||
[],
|
||||
],
|
||||
)
|
||||
def test_get_scan_status_history(msgs):
|
||||
sb = load_ScanBundlerMock()
|
||||
with mock.patch.object(sb.connector, "lrange", return_value=[msg for msg in msgs]) as lrange:
|
||||
res = sb._get_scan_status_history(5)
|
||||
lrange.assert_called_once_with(MessageEndpoints.scan_status() + "_list", -5, -1)
|
||||
assert res == msgs
|
||||
|
||||
|
||||
def test_add_device_to_storage_returns_without_scanID():
|
||||
msg = messages.DeviceMessage(
|
||||
signals={"samx": {"samx": 0.51, "setpoint": 0.5, "motor_is_moving": 0}},
|
||||
|
@ -5,6 +5,7 @@ from abc import ABC, abstractmethod
|
||||
from typing import Any, Literal
|
||||
|
||||
import numpy as np
|
||||
|
||||
from bec_lib import DeviceManagerBase, MessageEndpoints, bec_logger, messages
|
||||
|
||||
from .errors import LimitError, ScanAbortion
|
||||
@ -561,7 +562,9 @@ class SyncFlyScanBase(ScanBase, ABC):
|
||||
connector = self.device_manager.connector
|
||||
|
||||
pipe = connector.pipeline()
|
||||
connector.lrange(MessageEndpoints.device_req_status(self.metadata["RID"]), 0, -1, pipe)
|
||||
connector.lrange(
|
||||
MessageEndpoints.device_req_status_container(self.metadata["RID"]), 0, -1, pipe
|
||||
)
|
||||
connector.get(MessageEndpoints.device_readback(flyer), pipe)
|
||||
return connector.execute_pipeline(pipe)
|
||||
|
||||
@ -1321,7 +1324,9 @@ class MonitorScan(ScanBase):
|
||||
connector = self.device_manager.connector
|
||||
|
||||
pipe = connector.pipeline()
|
||||
connector.lrange(MessageEndpoints.device_req_status(self.metadata["RID"]), 0, -1, pipe)
|
||||
connector.lrange(
|
||||
MessageEndpoints.device_req_status_container(self.metadata["RID"]), 0, -1, pipe
|
||||
)
|
||||
connector.get(MessageEndpoints.device_readback(self.flyer), pipe)
|
||||
return connector.execute_pipeline(pipe)
|
||||
|
||||
|
@ -461,7 +461,7 @@ def test_check_for_failed_movements(scan_worker_mock, device_status, devices, in
|
||||
if abort:
|
||||
with pytest.raises(ScanAbortion):
|
||||
worker.device_manager.connector._get_buffer[
|
||||
MessageEndpoints.device_readback("samx")
|
||||
MessageEndpoints.device_readback("samx").endpoint
|
||||
] = messages.DeviceMessage(signals={"samx": {"value": 4}}, metadata={})
|
||||
worker._check_for_failed_movements(device_status, devices, instr)
|
||||
else:
|
||||
@ -582,9 +582,9 @@ def test_wait_for_idle(scan_worker_mock, msg1, msg2, req_msg: messages.DeviceReq
|
||||
with mock.patch.object(
|
||||
worker.validate, "get_device_status", return_value=[req_msg]
|
||||
) as device_status:
|
||||
worker.device_manager.connector._get_buffer[MessageEndpoints.device_readback("samx")] = (
|
||||
messages.DeviceMessage(signals={"samx": {"value": 4}}, metadata={})
|
||||
)
|
||||
worker.device_manager.connector._get_buffer[
|
||||
MessageEndpoints.device_readback("samx").endpoint
|
||||
] = messages.DeviceMessage(signals={"samx": {"value": 4}}, metadata={})
|
||||
|
||||
worker._add_wait_group(msg1)
|
||||
if req_msg.content["success"]:
|
||||
@ -644,7 +644,7 @@ def test_wait_for_read(scan_worker_mock, msg1, msg2, req_msg: messages.DeviceReq
|
||||
assert worker._groups == {}
|
||||
worker._groups["scan_motor"] = {"samx": 3, "samy": 4}
|
||||
worker.device_manager.connector._get_buffer[
|
||||
MessageEndpoints.device_readback("samx")
|
||||
MessageEndpoints.device_readback("samx").endpoint
|
||||
] = messages.DeviceMessage(signals={"samx": {"value": 4}}, metadata={})
|
||||
worker._add_wait_group(msg1)
|
||||
worker._wait_for_read(msg2)
|
||||
@ -1289,7 +1289,7 @@ def test_send_scan_status(scan_worker_mock, status, expire):
|
||||
scan_info_msgs = [
|
||||
msg
|
||||
for msg in worker.device_manager.connector.message_sent
|
||||
if msg["queue"] == MessageEndpoints.public_scan_info(scanID=worker.current_scanID)
|
||||
if msg["queue"] == MessageEndpoints.public_scan_info(scanID=worker.current_scanID).endpoint
|
||||
]
|
||||
assert len(scan_info_msgs) == 1
|
||||
assert scan_info_msgs[0]["expire"] == expire
|
||||
|
Loading…
x
Reference in New Issue
Block a user