mirror of
https://github.com/bec-project/bec.git
synced 2026-06-02 08:18:31 +02:00
refactor: stream subs in redisconnector
- clearer organisation: class to handle subscription tasks - more correct: no multiple subscriptions, test logic
This commit is contained in:
@@ -114,7 +114,7 @@ def test_mv_scan_nested_device(capsys, bec_ipython_client_fixture):
|
||||
bec.metadata.update({"unit_test": "test_mv_scan_nested_device"})
|
||||
dev = bec.device_manager.devices
|
||||
scans.mv(dev.hexapod.x, 10, dev.hexapod.y, 20, relative=False).wait()
|
||||
if not bec.connector._messages_queue.empty():
|
||||
if not bec.connector._message_callbacks_queue.empty():
|
||||
print("Waiting for messages to be processed")
|
||||
time.sleep(0.5)
|
||||
current_pos_hexapod_x = dev.hexapod.x.read(cached=True)["hexapod_x"]["value"]
|
||||
@@ -126,7 +126,7 @@ def test_mv_scan_nested_device(capsys, bec_ipython_client_fixture):
|
||||
current_pos_hexapod_y, 20, atol=dev.hexapod._config["deviceConfig"].get("tolerance", 0.5)
|
||||
)
|
||||
scans.umv(dev.hexapod.x, 10, dev.hexapod.y, 20, relative=False)
|
||||
if not bec.connector._messages_queue.empty():
|
||||
if not bec.connector._message_callbacks_queue.empty():
|
||||
print("Waiting for messages to be processed")
|
||||
time.sleep(0.5)
|
||||
current_pos_hexapod_x = dev.hexapod.x.read(cached=True)["hexapod_x"]["value"]
|
||||
|
||||
+255
-199
@@ -17,6 +17,7 @@ import threading
|
||||
import time
|
||||
import traceback
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from collections.abc import MutableMapping, Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
@@ -31,6 +32,7 @@ from typing import (
|
||||
Generator,
|
||||
Iterable,
|
||||
Literal,
|
||||
NamedTuple,
|
||||
ParamSpec,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
@@ -40,6 +42,7 @@ from typing import (
|
||||
import louie
|
||||
import redis.client
|
||||
import redis.exceptions
|
||||
from astroid.nodes import Unknown
|
||||
from redis.backoff import ExponentialBackoff
|
||||
from redis.client import Pipeline, Redis
|
||||
from redis.retry import Retry
|
||||
@@ -57,6 +60,7 @@ from bec_lib.messages import (
|
||||
)
|
||||
from bec_lib.serialization import MsgpackSerialization
|
||||
|
||||
logger = bec_logger.logger
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from concurrent.futures import Future
|
||||
|
||||
@@ -84,6 +88,11 @@ class InvalidItemForOperation(ValueError): ...
|
||||
class WrongArguments(ValueError): ...
|
||||
|
||||
|
||||
def _error_log_with_context(msg: str):
|
||||
context = "".join(traceback.format_stack(limit=5)[:-1])
|
||||
logger.error(msg + f" Context:\n{context}")
|
||||
|
||||
|
||||
def _raise_incompatible_message(msg, endpoint):
|
||||
raise IncompatibleMessageForEndpoint(
|
||||
f"Message type {type(msg)} is not compatible with endpoint {endpoint}. Expected {endpoint.message_type}"
|
||||
@@ -198,34 +207,167 @@ class GeneratorExecution:
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamSubscriptionInfo:
|
||||
id: str
|
||||
topic: str
|
||||
newest_only: bool
|
||||
from_start: bool
|
||||
class StreamSubInfo:
|
||||
cb_ref: Callable
|
||||
kwargs: dict
|
||||
kwargs: dict[str, Unknown]
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, StreamSubscriptionInfo):
|
||||
if not isinstance(other, StreamSubInfo):
|
||||
return False
|
||||
return (
|
||||
self.topic == other.topic
|
||||
and self.cb_ref == other.cb_ref
|
||||
and self.from_start == other.from_start
|
||||
)
|
||||
return self.cb_ref == other.cb_ref
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return self.cb_ref.__hash__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class DirectReadingStreamSubscriptionInfo(StreamSubscriptionInfo):
|
||||
class DirectReadStreamSubInfo(StreamSubInfo):
|
||||
stop_event: threading.Event
|
||||
thread: threading.Thread | None = None
|
||||
thread: threading.Thread
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return self.cb_ref.__hash__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamMessage:
|
||||
msg: dict
|
||||
callbacks: Iterable[tuple[Callable, dict]]
|
||||
callbacks: Iterable[tuple[Callable, dict[str, Unknown]]]
|
||||
|
||||
|
||||
class StreamSubsEntry(NamedTuple):
|
||||
read_id: str
|
||||
subs: set[StreamSubInfo]
|
||||
|
||||
|
||||
StreamResponseList = list[tuple[bytes, list[tuple[bytes, dict[bytes, bytes]]]]]
|
||||
StreamSubsRegistry = dict[str, StreamSubsEntry]
|
||||
|
||||
|
||||
class StreamSubs:
|
||||
def __init__(self) -> None:
|
||||
"""Manager for stream subscriptions. Since operations often need to be combined,
|
||||
use the lock directly at point of call, it is generally not used in the methods."""
|
||||
self.lock = threading.RLock()
|
||||
|
||||
self._subs: StreamSubsRegistry = {}
|
||||
self._direct_read_subs: dict[
|
||||
str, dict[DirectReadStreamSubInfo, DirectReadStreamSubInfo]
|
||||
] = {}
|
||||
self.from_start_subs: dict[str, set[StreamSubInfo]] = {}
|
||||
|
||||
@property
|
||||
def normal_subs(self):
|
||||
return {t: s.subs for t, s in self._subs.items()}
|
||||
|
||||
@property
|
||||
def all_topics(self):
|
||||
with self.lock:
|
||||
from_start_keys = [k for k in self.from_start_subs if self.from_start_subs[k] != set()]
|
||||
dr_sub_keys = [k for k in self._direct_read_subs if self._direct_read_subs[k] != set()]
|
||||
return list(set((*self._subs.keys(), *dr_sub_keys, *from_start_keys)))
|
||||
|
||||
def topic_ids(self) -> dict[str, str]:
|
||||
"""Get Redis read Ids for active subscriptions"""
|
||||
return {topic: infos.read_id for topic, infos in self._subs.items()}
|
||||
|
||||
def update_normal_ids(self, updated_ids: dict[str, str]):
|
||||
for topic, id in updated_ids.items():
|
||||
if topic in self._subs:
|
||||
self._subs[topic] = StreamSubsEntry(id, self._subs[topic].subs)
|
||||
|
||||
def from_start_topics(self) -> set[str]:
|
||||
"""Get topics for new `from_start` subscriptions which haven't been read yet"""
|
||||
return set(self.from_start_subs.keys())
|
||||
|
||||
def end_id(self, topic: str):
|
||||
"""Return the last read id for a given topic if given, or "+" """
|
||||
return self._subs[topic].read_id if topic in self._subs else "+"
|
||||
|
||||
def move_from_start_to_normal(self, topics_and_end_ids: dict[str, str]):
|
||||
if topics_and_end_ids.keys() != self.from_start_subs.keys():
|
||||
_error_log_with_context(
|
||||
f"Mismatch of subs to move! {topics_and_end_ids.keys()=}, {self.from_start_subs.keys()=} Was a lock forgotten?"
|
||||
)
|
||||
for topic in topics_and_end_ids:
|
||||
if topic in self._subs:
|
||||
if topics_and_end_ids[topic] != self._subs[topic].read_id:
|
||||
_error_log_with_context(f"Mismatch of ID! Was a lock forgotten?")
|
||||
for sub in self.from_start_subs.pop(topic):
|
||||
self._subs[topic].subs.add(sub) # type: ignore
|
||||
else:
|
||||
self._subs[topic] = StreamSubsEntry(
|
||||
read_id=topics_and_end_ids[topic], subs=self.from_start_subs.pop(topic)
|
||||
)
|
||||
|
||||
def _check_registered(self, topic, new_sub: StreamSubInfo):
|
||||
if (
|
||||
(topic in self.from_start_subs and new_sub in self.from_start_subs[topic])
|
||||
or (topic in self._direct_read_subs and new_sub in self._direct_read_subs[topic])
|
||||
or (topic in self._subs and new_sub in self._subs[topic].subs)
|
||||
):
|
||||
raise ValueError(f"Received duplicate subscription for {new_sub=}.")
|
||||
|
||||
def add_direct_listener(self, topic: str, new_sub: DirectReadStreamSubInfo):
|
||||
self._check_registered(topic, new_sub)
|
||||
if not topic in self._direct_read_subs:
|
||||
self._direct_read_subs[topic] = {}
|
||||
self._direct_read_subs[topic][new_sub] = new_sub
|
||||
new_sub.thread.start()
|
||||
|
||||
def add(self, from_start: bool, last_id: str, topic: str, new_sub: StreamSubInfo):
|
||||
self._check_registered(topic, new_sub)
|
||||
if from_start:
|
||||
if topic in self.from_start_subs:
|
||||
subs = self.from_start_subs[topic]
|
||||
else:
|
||||
subs = set()
|
||||
self.from_start_subs[topic] = subs
|
||||
else:
|
||||
if not topic in self._subs:
|
||||
subs = set()
|
||||
self._subs[topic] = StreamSubsEntry(read_id=last_id, subs=subs)
|
||||
else:
|
||||
subs = self._subs[topic].subs
|
||||
subs.add(new_sub)
|
||||
|
||||
@staticmethod
|
||||
def _kill_direct_stream(sub: DirectReadStreamSubInfo, topic: str):
|
||||
sub.stop_event.set()
|
||||
sub.thread.join(timeout=1)
|
||||
if sub.thread.is_alive():
|
||||
_error_log_with_context(
|
||||
f"RedisConnector direct stream callback thread for {topic=}, {sub.cb_ref=} failed to shutdown"
|
||||
)
|
||||
|
||||
def remove(self, topic: str, cb: Callable | None = None) -> bool:
|
||||
removed = False
|
||||
if cb is None: # Remove all subs for the given topic
|
||||
removed |= bool(self.from_start_subs.pop(topic, False))
|
||||
removed |= bool(self._subs.pop(topic, False))
|
||||
if (subs := self._direct_read_subs.pop(topic, None)) is not None:
|
||||
for sub in subs:
|
||||
self._kill_direct_stream(sub, topic)
|
||||
removed = True
|
||||
return removed
|
||||
test_subinfo = StreamSubInfo(louie.saferef.safe_ref(cb), {})
|
||||
if topic in self.from_start_subs and test_subinfo in self.from_start_subs[topic]:
|
||||
self.from_start_subs[topic].remove(test_subinfo)
|
||||
removed = True
|
||||
if len(self.from_start_subs[topic]) == 0:
|
||||
del self.from_start_subs[topic]
|
||||
if topic in self._direct_read_subs and test_subinfo in self._direct_read_subs[topic]:
|
||||
sub = self._direct_read_subs[topic].pop(test_subinfo) # type: ignore # hash is the same
|
||||
self._kill_direct_stream(sub, topic)
|
||||
removed = True
|
||||
if len(self._direct_read_subs[topic]) == 0:
|
||||
del self._direct_read_subs[topic]
|
||||
if topic in self._subs and test_subinfo in self._subs[topic].subs:
|
||||
self._subs[topic].subs.remove(test_subinfo)
|
||||
removed = True
|
||||
if len(self._subs[topic].subs) == 0:
|
||||
del self._subs[topic]
|
||||
return removed
|
||||
|
||||
|
||||
class RedisConnector:
|
||||
@@ -276,13 +418,12 @@ class RedisConnector:
|
||||
collections.defaultdict(list)
|
||||
)
|
||||
self._topics_cb_lock = threading.Lock()
|
||||
self._stream_topics_subscription = collections.defaultdict(list)
|
||||
self._stream_topics_subscription_lock = threading.Lock()
|
||||
self._stream_subs = StreamSubs()
|
||||
|
||||
self._events_listener_thread: threading.Thread | None = None
|
||||
self._stream_events_listener_thread: threading.Thread | None = None
|
||||
self._events_dispatcher_thread: threading.Thread | None = None
|
||||
self._messages_queue = queue.Queue()
|
||||
self._message_callbacks_queue = queue.Queue()
|
||||
self._stop_events_listener_thread = threading.Event()
|
||||
self._stop_stream_events_listener_thread = threading.Event()
|
||||
self.stream_keys: dict[str, str] = {}
|
||||
@@ -392,12 +533,12 @@ class RedisConnector:
|
||||
self._stream_events_listener_thread.join(timeout=per_thread_timeout_s)
|
||||
self._stream_events_listener_thread = None
|
||||
if self._events_dispatcher_thread:
|
||||
self._messages_queue.put(StopIteration)
|
||||
self._message_callbacks_queue.put(StopIteration)
|
||||
self._events_dispatcher_thread.join(timeout=per_thread_timeout_s)
|
||||
self._events_dispatcher_thread = None
|
||||
|
||||
# this will take care of shutting down direct listening threads
|
||||
self._unregister_stream(self._stream_topics_subscription)
|
||||
self._unregister_stream(self._stream_subs.all_topics)
|
||||
|
||||
# release all connections
|
||||
self._pubsub_conn.close()
|
||||
@@ -463,7 +604,7 @@ class RedisConnector:
|
||||
>>> connector.raise_alarm(
|
||||
severity=Alarms.WARNING,
|
||||
info=ErrorInfo(
|
||||
id=str(uuid.uuid4()),
|
||||
id=str(uuid.uuid4()),_stream_topic_subscriptions
|
||||
error_message="ValueError",
|
||||
compact_error_message="test alarm",
|
||||
exception_type="ValueError",
|
||||
@@ -648,7 +789,7 @@ class RedisConnector:
|
||||
self._topics_cb[topic].append(item)
|
||||
self._start_events_dispatcher_thread(start_thread)
|
||||
|
||||
def _add_direct_stream_listener(self, topic, cb_ref, **kwargs):
|
||||
def _create_direct_stream_listener(self, topic, cb_ref, kwargs):
|
||||
"""
|
||||
Add a direct listener for a topic. This is used when newest_only is True.
|
||||
|
||||
@@ -658,123 +799,95 @@ class RedisConnector:
|
||||
kwargs (dict): additional keyword arguments to be transmitted to the callback
|
||||
|
||||
Returns:
|
||||
None
|
||||
DirectReadStreamSubInfo with an unstarted thread
|
||||
"""
|
||||
info = DirectReadingStreamSubscriptionInfo(
|
||||
id="-",
|
||||
topic=topic,
|
||||
newest_only=True,
|
||||
from_start=False,
|
||||
cb_ref=cb_ref,
|
||||
kwargs=kwargs,
|
||||
stop_event=threading.Event(),
|
||||
stop_event = threading.Event()
|
||||
thread = threading.Thread(
|
||||
target=self._direct_stream_listener, args=(topic, stop_event, cb_ref, kwargs)
|
||||
)
|
||||
if info in self._stream_topics_subscription[topic]:
|
||||
raise RuntimeError("Already registered stream topic with the same callback")
|
||||
return DirectReadStreamSubInfo(cb_ref, kwargs, stop_event, thread)
|
||||
|
||||
info.thread = threading.Thread(target=self._direct_stream_listener, args=(info,))
|
||||
with self._stream_topics_subscription_lock:
|
||||
self._stream_topics_subscription[topic].append(info)
|
||||
info.thread.start()
|
||||
|
||||
def _direct_stream_listener(self, info: DirectReadingStreamSubscriptionInfo):
|
||||
stop_event = info.stop_event
|
||||
cb_ref = info.cb_ref
|
||||
kwargs = info.kwargs
|
||||
topic = info.topic
|
||||
def _direct_stream_listener(self, topic: str, stop_event: threading.Event, cb_ref, kwargs):
|
||||
read_id = "-"
|
||||
while not stop_event.is_set():
|
||||
ret = self._redis_conn.xrevrange(topic, "+", info.id, count=1)
|
||||
if not ret:
|
||||
time.sleep(0.1)
|
||||
if not (response := self._redis_conn.xrevrange(topic, "+", read_id, count=1)):
|
||||
stop_event.wait(timeout=0.1)
|
||||
continue
|
||||
redis_id, msg_dict = ret[0] # type: ignore : we are using Redis synchronously
|
||||
redis_id, msg_dict = response[0] # type: ignore : we are using Redis synchronously
|
||||
timestamp, _, ind = redis_id.partition(b"-")
|
||||
info.id = f"{timestamp.decode()}-{int(ind.decode())+1}"
|
||||
read_id = f"{timestamp.decode()}-{int(ind.decode())+1}"
|
||||
stream_msg = StreamMessage(
|
||||
{key.decode(): MsgpackSerialization.loads(val) for key, val in msg_dict.items()},
|
||||
((cb_ref, kwargs),),
|
||||
)
|
||||
self._messages_queue.put(stream_msg)
|
||||
self._message_callbacks_queue.put(stream_msg)
|
||||
|
||||
def _get_stream_topics_id(self) -> tuple[dict, dict]:
|
||||
stream_topics_id = {}
|
||||
from_start_stream_topics_id = {}
|
||||
with self._stream_topics_subscription_lock:
|
||||
for topic, subscription_info_list in self._stream_topics_subscription.items():
|
||||
for info in subscription_info_list:
|
||||
if isinstance(info, DirectReadingStreamSubscriptionInfo):
|
||||
continue
|
||||
if info.from_start:
|
||||
from_start_stream_topics_id[topic] = info.id
|
||||
else:
|
||||
stream_topics_id[topic] = info.id
|
||||
return from_start_stream_topics_id, stream_topics_id
|
||||
|
||||
def _handle_stream_msg_list(self, msg_list, from_start=False):
|
||||
for topic, msgs in msg_list:
|
||||
subscription_info_list = self._stream_topics_subscription[topic.decode()]
|
||||
for index, record in msgs:
|
||||
callbacks = []
|
||||
for info in subscription_info_list:
|
||||
info.id = index.decode()
|
||||
if from_start and not info.from_start:
|
||||
continue
|
||||
callbacks.append((info.cb_ref, info.kwargs))
|
||||
if callbacks:
|
||||
def _handle_stream_msg_list(
|
||||
self, redis_response: StreamResponseList, subs: dict[str, set[StreamSubInfo]]
|
||||
):
|
||||
new_ids = {}
|
||||
for btopic, msgs in redis_response:
|
||||
for read_id, record in msgs:
|
||||
topic: str = btopic.decode() if isinstance(btopic, bytes) else btopic # type: ignore
|
||||
if callbacks := subs.get(topic):
|
||||
msg_dict = {
|
||||
k.decode(): MsgpackSerialization.loads(msg) for k, msg in record.items()
|
||||
}
|
||||
msg = StreamMessage(msg_dict, callbacks)
|
||||
self._messages_queue.put(msg)
|
||||
for info in subscription_info_list:
|
||||
info.from_start = False
|
||||
msg = StreamMessage(msg_dict, [(cb.cb_ref, cb.kwargs) for cb in callbacks])
|
||||
self._message_callbacks_queue.put(msg)
|
||||
new_ids[topic] = read_id.decode()
|
||||
return new_ids
|
||||
|
||||
def _try_read_streams(self, topics_ids: dict[str, str], from_start: bool = False):
|
||||
try:
|
||||
if from_start:
|
||||
return [(t, self._redis_conn.xrange(t, "-", end)) for t, end in topics_ids.items()]
|
||||
else:
|
||||
return self._redis_conn.xread(topics_ids, block=200) or [] # type: ignore strs are fine key and id types
|
||||
except redis.exceptions.ConnectionError:
|
||||
logger.error("Failed to connect to redis. Is the server running?")
|
||||
except redis.exceptions.NoPermissionError:
|
||||
logger.error(f"Permission denied for stream topics: {set(topics_ids.keys())}")
|
||||
# pylint: disable=broad-except
|
||||
except Exception:
|
||||
sys.excepthook(*sys.exc_info()) # type: ignore # inside except
|
||||
|
||||
def _read_from_start_streams_and_migrate(self) -> bool:
|
||||
"""Returns whether there was an error"""
|
||||
with self._stream_subs.lock:
|
||||
if from_start_topics := self._stream_subs.from_start_topics():
|
||||
topics_and_end_ids = {t: self._stream_subs.end_id(t) for t in from_start_topics}
|
||||
response = self._try_read_streams(topics_and_end_ids, from_start=True)
|
||||
if response is not None:
|
||||
updated_end_ids = self._handle_stream_msg_list(
|
||||
response, self._stream_subs.from_start_subs
|
||||
)
|
||||
new_end_ids = {t: "0-0" for t in from_start_topics}
|
||||
new_end_ids.update(updated_end_ids)
|
||||
self._stream_subs.move_from_start_to_normal(new_end_ids)
|
||||
else:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _get_stream_messages_loop(self) -> None:
|
||||
"""
|
||||
Get stream messages loop. This method is run in a separate thread and listens
|
||||
for messages from the redis server.
|
||||
"""
|
||||
error = False
|
||||
|
||||
while not self._stop_stream_events_listener_thread.is_set():
|
||||
try:
|
||||
from_start_stream_topics_id, stream_topics_id = self._get_stream_topics_id()
|
||||
if not any((stream_topics_id, from_start_stream_topics_id)):
|
||||
self._stop_stream_events_listener_thread.wait(timeout=0.1)
|
||||
continue
|
||||
msg_list = []
|
||||
from_start_msg_list = []
|
||||
# first handle the 'from_start' streams ;
|
||||
# in the case of reading from start what is expected is to call the
|
||||
# callbacks for existing items, without waiting for a new element to be added
|
||||
# to the stream
|
||||
if from_start_stream_topics_id:
|
||||
# read the streams contents from beginning
|
||||
from_start_msg_list = self._redis_conn.xread(
|
||||
from_start_stream_topics_id, block=200
|
||||
)
|
||||
if stream_topics_id:
|
||||
msg_list = self._redis_conn.xread(stream_topics_id, block=200)
|
||||
except redis.exceptions.ConnectionError:
|
||||
if not error:
|
||||
error = True
|
||||
bec_logger.logger.error("Failed to connect to redis. Is the server running?")
|
||||
# First read the "from_start" streams, up until any id which is already in the normal
|
||||
# subs, then all those them to the normal streams
|
||||
error = self._read_from_start_streams_and_migrate()
|
||||
# Then read all the normal streams
|
||||
with self._stream_subs.lock:
|
||||
normal_topics = self._stream_subs.topic_ids()
|
||||
normal_subs = self._stream_subs.normal_subs
|
||||
if normal_topics and (response := self._try_read_streams(normal_topics)) is not None:
|
||||
updated_ids = self._handle_stream_msg_list(response, normal_subs)
|
||||
with self._stream_subs.lock:
|
||||
self._stream_subs.update_normal_ids(updated_ids)
|
||||
if error: # Encountered an error on xread, wait a while without the lock
|
||||
self._stop_stream_events_listener_thread.wait(timeout=1)
|
||||
except redis.exceptions.NoPermissionError:
|
||||
bec_logger.logger.error(
|
||||
f"Permission denied for stream topics: \n Topics id: {from_start_stream_topics_id}, Stream topics id: {stream_topics_id}"
|
||||
)
|
||||
if not error:
|
||||
error = True
|
||||
self._stop_stream_events_listener_thread.wait(timeout=1)
|
||||
# pylint: disable=broad-except
|
||||
except Exception:
|
||||
sys.excepthook(*sys.exc_info()) # type: ignore # inside except
|
||||
else:
|
||||
error = False
|
||||
with self._stream_topics_subscription_lock:
|
||||
self._handle_stream_msg_list(from_start_msg_list, from_start=True)
|
||||
self._handle_stream_msg_list(msg_list)
|
||||
|
||||
def _register_stream(
|
||||
self,
|
||||
@@ -805,50 +918,27 @@ class RedisConnector:
|
||||
cb_ref = louie.saferef.safe_ref(cb)
|
||||
|
||||
self._start_events_dispatcher_thread(start_thread)
|
||||
|
||||
if newest_only:
|
||||
# if newest_only is True, we need to provide a separate callback for each topic,
|
||||
# directly calling the callback. This is because we need to have a backpressure
|
||||
# mechanism in place, and we cannot rely on the dispatcher thread to handle it.
|
||||
with self._stream_subs.lock:
|
||||
for topic in topics:
|
||||
self._add_direct_stream_listener(topic, cb_ref, **kwargs)
|
||||
else:
|
||||
with self._stream_topics_subscription_lock:
|
||||
for topic in topics:
|
||||
if newest_only:
|
||||
new_sub = self._create_direct_stream_listener(topic, cb_ref, kwargs)
|
||||
self._stream_subs.add_direct_listener(topic, new_sub)
|
||||
else:
|
||||
new_sub = StreamSubInfo(cb_ref, kwargs)
|
||||
try:
|
||||
stream_info = self._redis_conn.xinfo_stream(topic)
|
||||
except redis.exceptions.ResponseError:
|
||||
# no such key
|
||||
last_id = "0-0"
|
||||
last_id = "0-0" # no such key
|
||||
else:
|
||||
last_id = stream_info["last-entry"][0].decode() # type: ignore # we are using the sync Redis client
|
||||
new_subscription = StreamSubscriptionInfo(
|
||||
id="0-0" if from_start else last_id,
|
||||
topic=topic,
|
||||
newest_only=newest_only,
|
||||
from_start=from_start,
|
||||
cb_ref=cb_ref,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
subscriptions = self._stream_topics_subscription[topic]
|
||||
if new_subscription in subscriptions:
|
||||
# raise an error if attempted to register a stream with the same callback,
|
||||
# whereas it has already been registered as a 'direct reading' stream with
|
||||
# newest_only=True ; it is clearly an error case that would produce weird results
|
||||
index = subscriptions.index(new_subscription)
|
||||
if isinstance(subscriptions[index], DirectReadingStreamSubscriptionInfo):
|
||||
raise RuntimeError(
|
||||
"Already registered stream topic with the same callback with 'newest_only=True'"
|
||||
)
|
||||
else:
|
||||
subscriptions.append(new_subscription)
|
||||
self._stream_subs.add(from_start, last_id, topic, new_sub)
|
||||
|
||||
if self._stream_events_listener_thread is None:
|
||||
# create the thread that will get all messages for this connector
|
||||
self._stream_events_listener_thread = threading.Thread(
|
||||
target=self._get_stream_messages_loop
|
||||
)
|
||||
self._stream_events_listener_thread.start()
|
||||
if self._stream_events_listener_thread is None:
|
||||
# create the thread that will get all messages for this connector
|
||||
self._stream_events_listener_thread = threading.Thread(
|
||||
target=self._get_stream_messages_loop
|
||||
)
|
||||
self._stream_events_listener_thread.start()
|
||||
|
||||
def _filter_topics_cb(self, topics: list, cb: Callable | None):
|
||||
unsubscribe_list = []
|
||||
@@ -875,9 +965,7 @@ class RedisConnector:
|
||||
patterns = self._normalize_patterns(patterns)
|
||||
# see if registered streams can be unregistered
|
||||
for pattern in patterns:
|
||||
self._unregister_stream(
|
||||
fnmatch.filter(self._stream_topics_subscription, pattern), cb
|
||||
)
|
||||
self._unregister_stream(fnmatch.filter(self._stream_subs.all_topics, pattern), cb)
|
||||
pubsub_unsubscribe_list = self._filter_topics_cb(patterns, cb)
|
||||
if pubsub_unsubscribe_list:
|
||||
self._pubsub_conn.punsubscribe(pubsub_unsubscribe_list)
|
||||
@@ -889,41 +977,9 @@ class RedisConnector:
|
||||
self._pubsub_conn.unsubscribe(unsubscribe_list)
|
||||
|
||||
def _unregister_stream(self, topics: list[str], cb: Callable | None = None) -> bool:
|
||||
"""
|
||||
Unregister a stream listener.
|
||||
|
||||
Args:
|
||||
topics (list[str]): list of stream topics
|
||||
|
||||
Returns:
|
||||
bool: True if the stream listener has been removed, False otherwise
|
||||
"""
|
||||
unsubscribe_list = []
|
||||
with self._stream_topics_subscription_lock:
|
||||
for topic in topics:
|
||||
subscription_infos = self._stream_topics_subscription[topic]
|
||||
# remove from list if callback corresponds
|
||||
self._stream_topics_subscription[topic] = list(
|
||||
filter(lambda sub_info: cb and sub_info.cb_ref() is not cb, subscription_infos)
|
||||
)
|
||||
if not self._stream_topics_subscription[topic]:
|
||||
# no callbacks left, unsubscribe
|
||||
unsubscribe_list += subscription_infos
|
||||
# clean the topics that have been unsubscribed
|
||||
for subscription_info in unsubscribe_list:
|
||||
if isinstance(subscription_info, DirectReadingStreamSubscriptionInfo):
|
||||
subscription_info.stop_event.set()
|
||||
if subscription_info.thread:
|
||||
subscription_info.thread.join()
|
||||
# it is possible to register the same stream multiple times with different
|
||||
# callbacks, in this case when unregistering with cb=None (unregister all)
|
||||
# the topic can be deleted multiple times, hence try...except in code below
|
||||
try:
|
||||
del self._stream_topics_subscription[subscription_info.topic]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
return len(unsubscribe_list) > 0
|
||||
"""Unregister callbacks from a list of topics. Returns true if any were removed"""
|
||||
with self._stream_subs.lock:
|
||||
return any([self._stream_subs.remove(topic, cb) for topic in topics])
|
||||
|
||||
def _get_messages_loop(self) -> None:
|
||||
"""
|
||||
@@ -948,7 +1004,7 @@ class RedisConnector:
|
||||
else:
|
||||
error = False
|
||||
if msg is not None:
|
||||
self._messages_queue.put(msg)
|
||||
self._message_callbacks_queue.put(msg)
|
||||
|
||||
def _execute_callback(self, cb, msg, kwargs):
|
||||
try:
|
||||
@@ -959,13 +1015,13 @@ class RedisConnector:
|
||||
else:
|
||||
if inspect.isgenerator(g):
|
||||
# reschedule execution to delineate the generator
|
||||
self._messages_queue.put(g)
|
||||
self._message_callbacks_queue.put(g)
|
||||
|
||||
def _handle_message(self, msg: StreamMessage | GeneratorExecution | PubSubMessage):
|
||||
if inspect.isgenerator(msg):
|
||||
g = msg
|
||||
fut = self._generator_executor.submit(next, g)
|
||||
self._messages_queue.put(GeneratorExecution(fut, g))
|
||||
self._message_callbacks_queue.put(GeneratorExecution(fut, g))
|
||||
elif isinstance(msg, StreamMessage):
|
||||
for cb_ref, kwargs in msg.callbacks:
|
||||
cb = cb_ref()
|
||||
@@ -980,9 +1036,9 @@ class RedisConnector:
|
||||
pass
|
||||
else:
|
||||
fut = self._generator_executor.submit(g.send, res)
|
||||
self._messages_queue.put(GeneratorExecution(fut, g))
|
||||
self._message_callbacks_queue.put(GeneratorExecution(fut, g))
|
||||
else:
|
||||
self._messages_queue.put(GeneratorExecution(fut, g))
|
||||
self._message_callbacks_queue.put(GeneratorExecution(fut, g))
|
||||
else:
|
||||
channel = msg["channel"].decode()
|
||||
with self._topics_cb_lock:
|
||||
@@ -1011,7 +1067,7 @@ class RedisConnector:
|
||||
while True:
|
||||
try:
|
||||
# wait for a message and return it before timeout expires
|
||||
msg = self._messages_queue.get(timeout=remaining_timeout, block=True)
|
||||
msg = self._message_callbacks_queue.get(timeout=remaining_timeout, block=True)
|
||||
except queue.Empty as exc:
|
||||
remaining_timeout = cast(float, remaining_timeout)
|
||||
timeout = cast(float, timeout)
|
||||
@@ -1032,7 +1088,7 @@ class RedisConnector:
|
||||
bec_logger.logger.error(f"Error handling message {msg}:\n{content}")
|
||||
|
||||
if timeout is None:
|
||||
if self._messages_queue.empty():
|
||||
if self._message_callbacks_queue.empty():
|
||||
# no message to process
|
||||
return True
|
||||
else:
|
||||
|
||||
@@ -369,20 +369,22 @@ def test_redis_connector_register_stream(connected_connector):
|
||||
connector.poll_messages()
|
||||
cb_mock1.assert_not_called()
|
||||
cb_mock2.assert_called_once_with({"data": 2}, a=2)
|
||||
assert "test" in connector._stream_subs.all_topics
|
||||
connector.unregister("test")
|
||||
assert connector._stream_topics_subscription["test"] == []
|
||||
assert connector._stream_subs.all_topics == []
|
||||
|
||||
|
||||
@pytest.mark.timeout(10)
|
||||
def test_redis_connector_register_stream_identical(connected_connector):
|
||||
connector = connected_connector
|
||||
|
||||
received_event1 = mock.Mock(spec=[])
|
||||
received_event2 = mock.Mock(spec=[])
|
||||
|
||||
connector.register(TestStreamEndpoint, cb=received_event1, start_thread=False)
|
||||
connector.register(TestStreamEndpoint, cb=received_event1, start_thread=False)
|
||||
connector.register(TestStreamEndpoint, cb=received_event2, start_thread=False)
|
||||
connector.register(TestStreamEndpoint2, cb=received_event1, start_thread=False)
|
||||
connector.register(TestStreamEndpoint2, cb=received_event2, start_thread=False)
|
||||
connector.xadd("test", {"data": 1})
|
||||
connector.poll_messages(timeout=1)
|
||||
assert received_event1.call_count == 1
|
||||
@@ -392,14 +394,11 @@ def test_redis_connector_register_stream_identical(connected_connector):
|
||||
assert received_event1.call_count == 2
|
||||
|
||||
try:
|
||||
with pytest.raises(RuntimeError):
|
||||
with pytest.raises(ValueError):
|
||||
connector.register(
|
||||
TestStreamEndpoint2, cb=received_event1, newest_only=True, start_thread=False
|
||||
)
|
||||
connector.register(
|
||||
TestStreamEndpoint2, cb=received_event2, newest_only=True, start_thread=False
|
||||
)
|
||||
with pytest.raises(RuntimeError):
|
||||
with pytest.raises(ValueError):
|
||||
connector.register(TestStreamEndpoint2, cb=received_event2, start_thread=False)
|
||||
finally:
|
||||
connector.unregister(TestStreamEndpoint2)
|
||||
@@ -427,7 +426,8 @@ def test_redis_connector_register_stream_list(connected_connector, endpoint):
|
||||
connector.poll_messages()
|
||||
assert mock.call({"data": 2}, a=1) in cb_mock.mock_calls
|
||||
connector.unregister(endpoint)
|
||||
assert len(connector._stream_topics_subscription) == 0
|
||||
all_topics = connector._stream_subs.all_topics
|
||||
assert len(all_topics) == 0
|
||||
|
||||
|
||||
@pytest.mark.timeout(10)
|
||||
@@ -448,12 +448,14 @@ def test_redis_connector_register_stream_from_start(connected_connector):
|
||||
cb_mock1.assert_called_once_with({"data": 3}, a=1)
|
||||
cb_mock2.assert_called_once_with({"data": 3}, a=2)
|
||||
cb_mock1.reset_mock()
|
||||
connector.unregister(TestStreamEndpoint, cb=cb_mock1)
|
||||
connector.register(TestStreamEndpoint, cb=cb_mock1, from_start=True, start_thread=False, a=3)
|
||||
connector.poll_messages(timeout=1)
|
||||
cb_mock1.assert_has_calls(
|
||||
[mock.call({"data": 1}, a=3), mock.call({"data": 2}, a=3), mock.call({"data": 3}, a=3)]
|
||||
)
|
||||
cb_mock1.reset_mock()
|
||||
connector.unregister(TestStreamEndpoint, cb=cb_mock1)
|
||||
connector.register(TestStreamEndpoint, cb=cb_mock1, start_thread=False, a=4)
|
||||
with pytest.raises(TimeoutError):
|
||||
connector.poll_messages(timeout=1)
|
||||
|
||||
Reference in New Issue
Block a user