refactor: stream subs in redisconnector

- clearer organisation: class to handle subscription tasks
- more correct: no multiple subscriptions, test logic
This commit is contained in:
2026-03-13 13:45:42 +01:00
committed by David Perl
parent 751f66ef7b
commit caa906de20
3 changed files with 267 additions and 209 deletions
@@ -114,7 +114,7 @@ def test_mv_scan_nested_device(capsys, bec_ipython_client_fixture):
bec.metadata.update({"unit_test": "test_mv_scan_nested_device"})
dev = bec.device_manager.devices
scans.mv(dev.hexapod.x, 10, dev.hexapod.y, 20, relative=False).wait()
if not bec.connector._messages_queue.empty():
if not bec.connector._message_callbacks_queue.empty():
print("Waiting for messages to be processed")
time.sleep(0.5)
current_pos_hexapod_x = dev.hexapod.x.read(cached=True)["hexapod_x"]["value"]
@@ -126,7 +126,7 @@ def test_mv_scan_nested_device(capsys, bec_ipython_client_fixture):
current_pos_hexapod_y, 20, atol=dev.hexapod._config["deviceConfig"].get("tolerance", 0.5)
)
scans.umv(dev.hexapod.x, 10, dev.hexapod.y, 20, relative=False)
if not bec.connector._messages_queue.empty():
if not bec.connector._message_callbacks_queue.empty():
print("Waiting for messages to be processed")
time.sleep(0.5)
current_pos_hexapod_x = dev.hexapod.x.read(cached=True)["hexapod_x"]["value"]
+255 -199
View File
@@ -17,6 +17,7 @@ import threading
import time
import traceback
import warnings
from collections import defaultdict
from collections.abc import MutableMapping, Sequence
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
@@ -31,6 +32,7 @@ from typing import (
Generator,
Iterable,
Literal,
NamedTuple,
ParamSpec,
TypedDict,
TypeVar,
@@ -40,6 +42,7 @@ from typing import (
import louie
import redis.client
import redis.exceptions
from astroid.nodes import Unknown
from redis.backoff import ExponentialBackoff
from redis.client import Pipeline, Redis
from redis.retry import Retry
@@ -57,6 +60,7 @@ from bec_lib.messages import (
)
from bec_lib.serialization import MsgpackSerialization
logger = bec_logger.logger
if TYPE_CHECKING: # pragma: no cover
from concurrent.futures import Future
@@ -84,6 +88,11 @@ class InvalidItemForOperation(ValueError): ...
class WrongArguments(ValueError): ...
def _error_log_with_context(msg: str):
context = "".join(traceback.format_stack(limit=5)[:-1])
logger.error(msg + f" Context:\n{context}")
def _raise_incompatible_message(msg, endpoint):
raise IncompatibleMessageForEndpoint(
f"Message type {type(msg)} is not compatible with endpoint {endpoint}. Expected {endpoint.message_type}"
@@ -198,34 +207,167 @@ class GeneratorExecution:
@dataclass
class StreamSubscriptionInfo:
id: str
topic: str
newest_only: bool
from_start: bool
class StreamSubInfo:
cb_ref: Callable
kwargs: dict
kwargs: dict[str, Unknown]
def __eq__(self, other):
if not isinstance(other, StreamSubscriptionInfo):
if not isinstance(other, StreamSubInfo):
return False
return (
self.topic == other.topic
and self.cb_ref == other.cb_ref
and self.from_start == other.from_start
)
return self.cb_ref == other.cb_ref
def __hash__(self) -> int:
return self.cb_ref.__hash__()
@dataclass
class DirectReadingStreamSubscriptionInfo(StreamSubscriptionInfo):
class DirectReadStreamSubInfo(StreamSubInfo):
stop_event: threading.Event
thread: threading.Thread | None = None
thread: threading.Thread
def __hash__(self) -> int:
return self.cb_ref.__hash__()
@dataclass
class StreamMessage:
msg: dict
callbacks: Iterable[tuple[Callable, dict]]
callbacks: Iterable[tuple[Callable, dict[str, Unknown]]]
class StreamSubsEntry(NamedTuple):
read_id: str
subs: set[StreamSubInfo]
StreamResponseList = list[tuple[bytes, list[tuple[bytes, dict[bytes, bytes]]]]]
StreamSubsRegistry = dict[str, StreamSubsEntry]
class StreamSubs:
def __init__(self) -> None:
"""Manager for stream subscriptions. Since operations often need to be combined,
use the lock directly at point of call, it is generally not used in the methods."""
self.lock = threading.RLock()
self._subs: StreamSubsRegistry = {}
self._direct_read_subs: dict[
str, dict[DirectReadStreamSubInfo, DirectReadStreamSubInfo]
] = {}
self.from_start_subs: dict[str, set[StreamSubInfo]] = {}
@property
def normal_subs(self):
return {t: s.subs for t, s in self._subs.items()}
@property
def all_topics(self):
with self.lock:
from_start_keys = [k for k in self.from_start_subs if self.from_start_subs[k] != set()]
dr_sub_keys = [k for k in self._direct_read_subs if self._direct_read_subs[k] != set()]
return list(set((*self._subs.keys(), *dr_sub_keys, *from_start_keys)))
def topic_ids(self) -> dict[str, str]:
"""Get Redis read Ids for active subscriptions"""
return {topic: infos.read_id for topic, infos in self._subs.items()}
def update_normal_ids(self, updated_ids: dict[str, str]):
for topic, id in updated_ids.items():
if topic in self._subs:
self._subs[topic] = StreamSubsEntry(id, self._subs[topic].subs)
def from_start_topics(self) -> set[str]:
"""Get topics for new `from_start` subscriptions which haven't been read yet"""
return set(self.from_start_subs.keys())
def end_id(self, topic: str):
"""Return the last read id for a given topic if given, or "+" """
return self._subs[topic].read_id if topic in self._subs else "+"
def move_from_start_to_normal(self, topics_and_end_ids: dict[str, str]):
if topics_and_end_ids.keys() != self.from_start_subs.keys():
_error_log_with_context(
f"Mismatch of subs to move! {topics_and_end_ids.keys()=}, {self.from_start_subs.keys()=} Was a lock forgotten?"
)
for topic in topics_and_end_ids:
if topic in self._subs:
if topics_and_end_ids[topic] != self._subs[topic].read_id:
_error_log_with_context(f"Mismatch of ID! Was a lock forgotten?")
for sub in self.from_start_subs.pop(topic):
self._subs[topic].subs.add(sub) # type: ignore
else:
self._subs[topic] = StreamSubsEntry(
read_id=topics_and_end_ids[topic], subs=self.from_start_subs.pop(topic)
)
def _check_registered(self, topic, new_sub: StreamSubInfo):
if (
(topic in self.from_start_subs and new_sub in self.from_start_subs[topic])
or (topic in self._direct_read_subs and new_sub in self._direct_read_subs[topic])
or (topic in self._subs and new_sub in self._subs[topic].subs)
):
raise ValueError(f"Received duplicate subscription for {new_sub=}.")
def add_direct_listener(self, topic: str, new_sub: DirectReadStreamSubInfo):
self._check_registered(topic, new_sub)
if not topic in self._direct_read_subs:
self._direct_read_subs[topic] = {}
self._direct_read_subs[topic][new_sub] = new_sub
new_sub.thread.start()
def add(self, from_start: bool, last_id: str, topic: str, new_sub: StreamSubInfo):
self._check_registered(topic, new_sub)
if from_start:
if topic in self.from_start_subs:
subs = self.from_start_subs[topic]
else:
subs = set()
self.from_start_subs[topic] = subs
else:
if not topic in self._subs:
subs = set()
self._subs[topic] = StreamSubsEntry(read_id=last_id, subs=subs)
else:
subs = self._subs[topic].subs
subs.add(new_sub)
@staticmethod
def _kill_direct_stream(sub: DirectReadStreamSubInfo, topic: str):
sub.stop_event.set()
sub.thread.join(timeout=1)
if sub.thread.is_alive():
_error_log_with_context(
f"RedisConnector direct stream callback thread for {topic=}, {sub.cb_ref=} failed to shutdown"
)
def remove(self, topic: str, cb: Callable | None = None) -> bool:
removed = False
if cb is None: # Remove all subs for the given topic
removed |= bool(self.from_start_subs.pop(topic, False))
removed |= bool(self._subs.pop(topic, False))
if (subs := self._direct_read_subs.pop(topic, None)) is not None:
for sub in subs:
self._kill_direct_stream(sub, topic)
removed = True
return removed
test_subinfo = StreamSubInfo(louie.saferef.safe_ref(cb), {})
if topic in self.from_start_subs and test_subinfo in self.from_start_subs[topic]:
self.from_start_subs[topic].remove(test_subinfo)
removed = True
if len(self.from_start_subs[topic]) == 0:
del self.from_start_subs[topic]
if topic in self._direct_read_subs and test_subinfo in self._direct_read_subs[topic]:
sub = self._direct_read_subs[topic].pop(test_subinfo) # type: ignore # hash is the same
self._kill_direct_stream(sub, topic)
removed = True
if len(self._direct_read_subs[topic]) == 0:
del self._direct_read_subs[topic]
if topic in self._subs and test_subinfo in self._subs[topic].subs:
self._subs[topic].subs.remove(test_subinfo)
removed = True
if len(self._subs[topic].subs) == 0:
del self._subs[topic]
return removed
class RedisConnector:
@@ -276,13 +418,12 @@ class RedisConnector:
collections.defaultdict(list)
)
self._topics_cb_lock = threading.Lock()
self._stream_topics_subscription = collections.defaultdict(list)
self._stream_topics_subscription_lock = threading.Lock()
self._stream_subs = StreamSubs()
self._events_listener_thread: threading.Thread | None = None
self._stream_events_listener_thread: threading.Thread | None = None
self._events_dispatcher_thread: threading.Thread | None = None
self._messages_queue = queue.Queue()
self._message_callbacks_queue = queue.Queue()
self._stop_events_listener_thread = threading.Event()
self._stop_stream_events_listener_thread = threading.Event()
self.stream_keys: dict[str, str] = {}
@@ -392,12 +533,12 @@ class RedisConnector:
self._stream_events_listener_thread.join(timeout=per_thread_timeout_s)
self._stream_events_listener_thread = None
if self._events_dispatcher_thread:
self._messages_queue.put(StopIteration)
self._message_callbacks_queue.put(StopIteration)
self._events_dispatcher_thread.join(timeout=per_thread_timeout_s)
self._events_dispatcher_thread = None
# this will take care of shutting down direct listening threads
self._unregister_stream(self._stream_topics_subscription)
self._unregister_stream(self._stream_subs.all_topics)
# release all connections
self._pubsub_conn.close()
@@ -463,7 +604,7 @@ class RedisConnector:
>>> connector.raise_alarm(
severity=Alarms.WARNING,
info=ErrorInfo(
id=str(uuid.uuid4()),
id=str(uuid.uuid4()),_stream_topic_subscriptions
error_message="ValueError",
compact_error_message="test alarm",
exception_type="ValueError",
@@ -648,7 +789,7 @@ class RedisConnector:
self._topics_cb[topic].append(item)
self._start_events_dispatcher_thread(start_thread)
def _add_direct_stream_listener(self, topic, cb_ref, **kwargs):
def _create_direct_stream_listener(self, topic, cb_ref, kwargs):
"""
Add a direct listener for a topic. This is used when newest_only is True.
@@ -658,123 +799,95 @@ class RedisConnector:
kwargs (dict): additional keyword arguments to be transmitted to the callback
Returns:
None
DirectReadStreamSubInfo with an unstarted thread
"""
info = DirectReadingStreamSubscriptionInfo(
id="-",
topic=topic,
newest_only=True,
from_start=False,
cb_ref=cb_ref,
kwargs=kwargs,
stop_event=threading.Event(),
stop_event = threading.Event()
thread = threading.Thread(
target=self._direct_stream_listener, args=(topic, stop_event, cb_ref, kwargs)
)
if info in self._stream_topics_subscription[topic]:
raise RuntimeError("Already registered stream topic with the same callback")
return DirectReadStreamSubInfo(cb_ref, kwargs, stop_event, thread)
info.thread = threading.Thread(target=self._direct_stream_listener, args=(info,))
with self._stream_topics_subscription_lock:
self._stream_topics_subscription[topic].append(info)
info.thread.start()
def _direct_stream_listener(self, info: DirectReadingStreamSubscriptionInfo):
stop_event = info.stop_event
cb_ref = info.cb_ref
kwargs = info.kwargs
topic = info.topic
def _direct_stream_listener(self, topic: str, stop_event: threading.Event, cb_ref, kwargs):
read_id = "-"
while not stop_event.is_set():
ret = self._redis_conn.xrevrange(topic, "+", info.id, count=1)
if not ret:
time.sleep(0.1)
if not (response := self._redis_conn.xrevrange(topic, "+", read_id, count=1)):
stop_event.wait(timeout=0.1)
continue
redis_id, msg_dict = ret[0] # type: ignore : we are using Redis synchronously
redis_id, msg_dict = response[0] # type: ignore : we are using Redis synchronously
timestamp, _, ind = redis_id.partition(b"-")
info.id = f"{timestamp.decode()}-{int(ind.decode())+1}"
read_id = f"{timestamp.decode()}-{int(ind.decode())+1}"
stream_msg = StreamMessage(
{key.decode(): MsgpackSerialization.loads(val) for key, val in msg_dict.items()},
((cb_ref, kwargs),),
)
self._messages_queue.put(stream_msg)
self._message_callbacks_queue.put(stream_msg)
def _get_stream_topics_id(self) -> tuple[dict, dict]:
stream_topics_id = {}
from_start_stream_topics_id = {}
with self._stream_topics_subscription_lock:
for topic, subscription_info_list in self._stream_topics_subscription.items():
for info in subscription_info_list:
if isinstance(info, DirectReadingStreamSubscriptionInfo):
continue
if info.from_start:
from_start_stream_topics_id[topic] = info.id
else:
stream_topics_id[topic] = info.id
return from_start_stream_topics_id, stream_topics_id
def _handle_stream_msg_list(self, msg_list, from_start=False):
for topic, msgs in msg_list:
subscription_info_list = self._stream_topics_subscription[topic.decode()]
for index, record in msgs:
callbacks = []
for info in subscription_info_list:
info.id = index.decode()
if from_start and not info.from_start:
continue
callbacks.append((info.cb_ref, info.kwargs))
if callbacks:
def _handle_stream_msg_list(
self, redis_response: StreamResponseList, subs: dict[str, set[StreamSubInfo]]
):
new_ids = {}
for btopic, msgs in redis_response:
for read_id, record in msgs:
topic: str = btopic.decode() if isinstance(btopic, bytes) else btopic # type: ignore
if callbacks := subs.get(topic):
msg_dict = {
k.decode(): MsgpackSerialization.loads(msg) for k, msg in record.items()
}
msg = StreamMessage(msg_dict, callbacks)
self._messages_queue.put(msg)
for info in subscription_info_list:
info.from_start = False
msg = StreamMessage(msg_dict, [(cb.cb_ref, cb.kwargs) for cb in callbacks])
self._message_callbacks_queue.put(msg)
new_ids[topic] = read_id.decode()
return new_ids
def _try_read_streams(self, topics_ids: dict[str, str], from_start: bool = False):
try:
if from_start:
return [(t, self._redis_conn.xrange(t, "-", end)) for t, end in topics_ids.items()]
else:
return self._redis_conn.xread(topics_ids, block=200) or [] # type: ignore strs are fine key and id types
except redis.exceptions.ConnectionError:
logger.error("Failed to connect to redis. Is the server running?")
except redis.exceptions.NoPermissionError:
logger.error(f"Permission denied for stream topics: {set(topics_ids.keys())}")
# pylint: disable=broad-except
except Exception:
sys.excepthook(*sys.exc_info()) # type: ignore # inside except
def _read_from_start_streams_and_migrate(self) -> bool:
"""Returns whether there was an error"""
with self._stream_subs.lock:
if from_start_topics := self._stream_subs.from_start_topics():
topics_and_end_ids = {t: self._stream_subs.end_id(t) for t in from_start_topics}
response = self._try_read_streams(topics_and_end_ids, from_start=True)
if response is not None:
updated_end_ids = self._handle_stream_msg_list(
response, self._stream_subs.from_start_subs
)
new_end_ids = {t: "0-0" for t in from_start_topics}
new_end_ids.update(updated_end_ids)
self._stream_subs.move_from_start_to_normal(new_end_ids)
else:
return True
return False
def _get_stream_messages_loop(self) -> None:
"""
Get stream messages loop. This method is run in a separate thread and listens
for messages from the redis server.
"""
error = False
while not self._stop_stream_events_listener_thread.is_set():
try:
from_start_stream_topics_id, stream_topics_id = self._get_stream_topics_id()
if not any((stream_topics_id, from_start_stream_topics_id)):
self._stop_stream_events_listener_thread.wait(timeout=0.1)
continue
msg_list = []
from_start_msg_list = []
# first handle the 'from_start' streams ;
# in the case of reading from start what is expected is to call the
# callbacks for existing items, without waiting for a new element to be added
# to the stream
if from_start_stream_topics_id:
# read the streams contents from beginning
from_start_msg_list = self._redis_conn.xread(
from_start_stream_topics_id, block=200
)
if stream_topics_id:
msg_list = self._redis_conn.xread(stream_topics_id, block=200)
except redis.exceptions.ConnectionError:
if not error:
error = True
bec_logger.logger.error("Failed to connect to redis. Is the server running?")
# First read the "from_start" streams, up until any id which is already in the normal
# subs, then all those them to the normal streams
error = self._read_from_start_streams_and_migrate()
# Then read all the normal streams
with self._stream_subs.lock:
normal_topics = self._stream_subs.topic_ids()
normal_subs = self._stream_subs.normal_subs
if normal_topics and (response := self._try_read_streams(normal_topics)) is not None:
updated_ids = self._handle_stream_msg_list(response, normal_subs)
with self._stream_subs.lock:
self._stream_subs.update_normal_ids(updated_ids)
if error: # Encountered an error on xread, wait a while without the lock
self._stop_stream_events_listener_thread.wait(timeout=1)
except redis.exceptions.NoPermissionError:
bec_logger.logger.error(
f"Permission denied for stream topics: \n Topics id: {from_start_stream_topics_id}, Stream topics id: {stream_topics_id}"
)
if not error:
error = True
self._stop_stream_events_listener_thread.wait(timeout=1)
# pylint: disable=broad-except
except Exception:
sys.excepthook(*sys.exc_info()) # type: ignore # inside except
else:
error = False
with self._stream_topics_subscription_lock:
self._handle_stream_msg_list(from_start_msg_list, from_start=True)
self._handle_stream_msg_list(msg_list)
def _register_stream(
self,
@@ -805,50 +918,27 @@ class RedisConnector:
cb_ref = louie.saferef.safe_ref(cb)
self._start_events_dispatcher_thread(start_thread)
if newest_only:
# if newest_only is True, we need to provide a separate callback for each topic,
# directly calling the callback. This is because we need to have a backpressure
# mechanism in place, and we cannot rely on the dispatcher thread to handle it.
with self._stream_subs.lock:
for topic in topics:
self._add_direct_stream_listener(topic, cb_ref, **kwargs)
else:
with self._stream_topics_subscription_lock:
for topic in topics:
if newest_only:
new_sub = self._create_direct_stream_listener(topic, cb_ref, kwargs)
self._stream_subs.add_direct_listener(topic, new_sub)
else:
new_sub = StreamSubInfo(cb_ref, kwargs)
try:
stream_info = self._redis_conn.xinfo_stream(topic)
except redis.exceptions.ResponseError:
# no such key
last_id = "0-0"
last_id = "0-0" # no such key
else:
last_id = stream_info["last-entry"][0].decode() # type: ignore # we are using the sync Redis client
new_subscription = StreamSubscriptionInfo(
id="0-0" if from_start else last_id,
topic=topic,
newest_only=newest_only,
from_start=from_start,
cb_ref=cb_ref,
kwargs=kwargs,
)
subscriptions = self._stream_topics_subscription[topic]
if new_subscription in subscriptions:
# raise an error if attempted to register a stream with the same callback,
# whereas it has already been registered as a 'direct reading' stream with
# newest_only=True ; it is clearly an error case that would produce weird results
index = subscriptions.index(new_subscription)
if isinstance(subscriptions[index], DirectReadingStreamSubscriptionInfo):
raise RuntimeError(
"Already registered stream topic with the same callback with 'newest_only=True'"
)
else:
subscriptions.append(new_subscription)
self._stream_subs.add(from_start, last_id, topic, new_sub)
if self._stream_events_listener_thread is None:
# create the thread that will get all messages for this connector
self._stream_events_listener_thread = threading.Thread(
target=self._get_stream_messages_loop
)
self._stream_events_listener_thread.start()
if self._stream_events_listener_thread is None:
# create the thread that will get all messages for this connector
self._stream_events_listener_thread = threading.Thread(
target=self._get_stream_messages_loop
)
self._stream_events_listener_thread.start()
def _filter_topics_cb(self, topics: list, cb: Callable | None):
unsubscribe_list = []
@@ -875,9 +965,7 @@ class RedisConnector:
patterns = self._normalize_patterns(patterns)
# see if registered streams can be unregistered
for pattern in patterns:
self._unregister_stream(
fnmatch.filter(self._stream_topics_subscription, pattern), cb
)
self._unregister_stream(fnmatch.filter(self._stream_subs.all_topics, pattern), cb)
pubsub_unsubscribe_list = self._filter_topics_cb(patterns, cb)
if pubsub_unsubscribe_list:
self._pubsub_conn.punsubscribe(pubsub_unsubscribe_list)
@@ -889,41 +977,9 @@ class RedisConnector:
self._pubsub_conn.unsubscribe(unsubscribe_list)
def _unregister_stream(self, topics: list[str], cb: Callable | None = None) -> bool:
"""
Unregister a stream listener.
Args:
topics (list[str]): list of stream topics
Returns:
bool: True if the stream listener has been removed, False otherwise
"""
unsubscribe_list = []
with self._stream_topics_subscription_lock:
for topic in topics:
subscription_infos = self._stream_topics_subscription[topic]
# remove from list if callback corresponds
self._stream_topics_subscription[topic] = list(
filter(lambda sub_info: cb and sub_info.cb_ref() is not cb, subscription_infos)
)
if not self._stream_topics_subscription[topic]:
# no callbacks left, unsubscribe
unsubscribe_list += subscription_infos
# clean the topics that have been unsubscribed
for subscription_info in unsubscribe_list:
if isinstance(subscription_info, DirectReadingStreamSubscriptionInfo):
subscription_info.stop_event.set()
if subscription_info.thread:
subscription_info.thread.join()
# it is possible to register the same stream multiple times with different
# callbacks, in this case when unregistering with cb=None (unregister all)
# the topic can be deleted multiple times, hence try...except in code below
try:
del self._stream_topics_subscription[subscription_info.topic]
except KeyError:
pass
return len(unsubscribe_list) > 0
"""Unregister callbacks from a list of topics. Returns true if any were removed"""
with self._stream_subs.lock:
return any([self._stream_subs.remove(topic, cb) for topic in topics])
def _get_messages_loop(self) -> None:
"""
@@ -948,7 +1004,7 @@ class RedisConnector:
else:
error = False
if msg is not None:
self._messages_queue.put(msg)
self._message_callbacks_queue.put(msg)
def _execute_callback(self, cb, msg, kwargs):
try:
@@ -959,13 +1015,13 @@ class RedisConnector:
else:
if inspect.isgenerator(g):
# reschedule execution to delineate the generator
self._messages_queue.put(g)
self._message_callbacks_queue.put(g)
def _handle_message(self, msg: StreamMessage | GeneratorExecution | PubSubMessage):
if inspect.isgenerator(msg):
g = msg
fut = self._generator_executor.submit(next, g)
self._messages_queue.put(GeneratorExecution(fut, g))
self._message_callbacks_queue.put(GeneratorExecution(fut, g))
elif isinstance(msg, StreamMessage):
for cb_ref, kwargs in msg.callbacks:
cb = cb_ref()
@@ -980,9 +1036,9 @@ class RedisConnector:
pass
else:
fut = self._generator_executor.submit(g.send, res)
self._messages_queue.put(GeneratorExecution(fut, g))
self._message_callbacks_queue.put(GeneratorExecution(fut, g))
else:
self._messages_queue.put(GeneratorExecution(fut, g))
self._message_callbacks_queue.put(GeneratorExecution(fut, g))
else:
channel = msg["channel"].decode()
with self._topics_cb_lock:
@@ -1011,7 +1067,7 @@ class RedisConnector:
while True:
try:
# wait for a message and return it before timeout expires
msg = self._messages_queue.get(timeout=remaining_timeout, block=True)
msg = self._message_callbacks_queue.get(timeout=remaining_timeout, block=True)
except queue.Empty as exc:
remaining_timeout = cast(float, remaining_timeout)
timeout = cast(float, timeout)
@@ -1032,7 +1088,7 @@ class RedisConnector:
bec_logger.logger.error(f"Error handling message {msg}:\n{content}")
if timeout is None:
if self._messages_queue.empty():
if self._message_callbacks_queue.empty():
# no message to process
return True
else:
@@ -369,20 +369,22 @@ def test_redis_connector_register_stream(connected_connector):
connector.poll_messages()
cb_mock1.assert_not_called()
cb_mock2.assert_called_once_with({"data": 2}, a=2)
assert "test" in connector._stream_subs.all_topics
connector.unregister("test")
assert connector._stream_topics_subscription["test"] == []
assert connector._stream_subs.all_topics == []
@pytest.mark.timeout(10)
def test_redis_connector_register_stream_identical(connected_connector):
connector = connected_connector
received_event1 = mock.Mock(spec=[])
received_event2 = mock.Mock(spec=[])
connector.register(TestStreamEndpoint, cb=received_event1, start_thread=False)
connector.register(TestStreamEndpoint, cb=received_event1, start_thread=False)
connector.register(TestStreamEndpoint, cb=received_event2, start_thread=False)
connector.register(TestStreamEndpoint2, cb=received_event1, start_thread=False)
connector.register(TestStreamEndpoint2, cb=received_event2, start_thread=False)
connector.xadd("test", {"data": 1})
connector.poll_messages(timeout=1)
assert received_event1.call_count == 1
@@ -392,14 +394,11 @@ def test_redis_connector_register_stream_identical(connected_connector):
assert received_event1.call_count == 2
try:
with pytest.raises(RuntimeError):
with pytest.raises(ValueError):
connector.register(
TestStreamEndpoint2, cb=received_event1, newest_only=True, start_thread=False
)
connector.register(
TestStreamEndpoint2, cb=received_event2, newest_only=True, start_thread=False
)
with pytest.raises(RuntimeError):
with pytest.raises(ValueError):
connector.register(TestStreamEndpoint2, cb=received_event2, start_thread=False)
finally:
connector.unregister(TestStreamEndpoint2)
@@ -427,7 +426,8 @@ def test_redis_connector_register_stream_list(connected_connector, endpoint):
connector.poll_messages()
assert mock.call({"data": 2}, a=1) in cb_mock.mock_calls
connector.unregister(endpoint)
assert len(connector._stream_topics_subscription) == 0
all_topics = connector._stream_subs.all_topics
assert len(all_topics) == 0
@pytest.mark.timeout(10)
@@ -448,12 +448,14 @@ def test_redis_connector_register_stream_from_start(connected_connector):
cb_mock1.assert_called_once_with({"data": 3}, a=1)
cb_mock2.assert_called_once_with({"data": 3}, a=2)
cb_mock1.reset_mock()
connector.unregister(TestStreamEndpoint, cb=cb_mock1)
connector.register(TestStreamEndpoint, cb=cb_mock1, from_start=True, start_thread=False, a=3)
connector.poll_messages(timeout=1)
cb_mock1.assert_has_calls(
[mock.call({"data": 1}, a=3), mock.call({"data": 2}, a=3), mock.call({"data": 3}, a=3)]
)
cb_mock1.reset_mock()
connector.unregister(TestStreamEndpoint, cb=cb_mock1)
connector.register(TestStreamEndpoint, cb=cb_mock1, start_thread=False, a=4)
with pytest.raises(TimeoutError):
connector.poll_messages(timeout=1)