bec/bec_lib/bec_lib/redis_connector.py

1163 lines
46 KiB
Python

"""
This module provides a connector to a redis server. It is a wrapper around the
redis library providing a simple interface to send and receive messages from a
redis server.
"""
from __future__ import annotations
import collections
import inspect
import itertools
import queue
import sys
import threading
import time
import traceback
import typing
import warnings
from collections.abc import MutableMapping, Sequence
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from functools import wraps
from glob import fnmatch
from typing import TYPE_CHECKING, Any, Awaitable, Generator, Literal
import louie
import redis
import redis.client
import redis.exceptions
from bec_lib.connector import ConnectorBase, MessageObject
from bec_lib.endpoints import EndpointInfo, MessageEndpoints
from bec_lib.logger import bec_logger
from bec_lib.messages import AlarmMessage, BECMessage, BundleMessage, ClientInfoMessage
from bec_lib.serialization import MsgpackSerialization
if TYPE_CHECKING:
from typing import Union
from bec_lib.alarm_handler import Alarms
def validate_endpoint(endpoint_arg):
def decorator(func):
argspec = inspect.getfullargspec(func)
argument_index = argspec.args.index(endpoint_arg)
@wraps(func)
def wrapper(*args, **kwargs):
try:
endpoint = args[argument_index]
arg = list(args)
except IndexError:
endpoint = kwargs[endpoint_arg]
arg = kwargs
if isinstance(endpoint, str):
warnings.warn(
"RedisConnector methods with a string topic are deprecated and should not be used anymore. Use RedisConnector methods with an EndpointInfo instead.",
DeprecationWarning,
)
return func(*args, **kwargs)
if not isinstance(endpoint, EndpointInfo):
raise TypeError(f"Endpoint {endpoint} is not EndpointInfo")
if func.__name__ not in endpoint.message_op:
raise ValueError(
f"Endpoint {endpoint} is not compatible with {func.__name__} method"
)
for val in list(args) + list(kwargs.values()):
if isinstance(val, BECMessage) and endpoint.message_type == Any:
continue
if isinstance(val, BECMessage) and not isinstance(val, endpoint.message_type):
if not isinstance(val, BundleMessage):
raise TypeError(
f"Message type {type(val)} is not compatible with endpoint {endpoint}. Expected {endpoint.message_type}"
)
for msg in val.messages:
if not isinstance(msg, endpoint.message_type):
raise TypeError(
f"Message type {type(msg)} is not compatible with endpoint {endpoint}. Expected {endpoint.message_type}"
)
if isinstance(val, dict):
for sub_val in val.values():
if isinstance(sub_val, BECMessage) and endpoint.message_type == Any:
continue
if isinstance(sub_val, BECMessage) and not isinstance(
sub_val, endpoint.message_type
):
raise TypeError(
f"Message type {type(sub_val)} is not compatible with endpoint {endpoint}. Expected {endpoint.message_type}"
)
if isinstance(val, list):
for sub_val in val:
if isinstance(sub_val, BECMessage) and endpoint.message_type == Any:
continue
if isinstance(sub_val, BECMessage) and not isinstance(
sub_val, endpoint.message_type
):
raise TypeError(
f"Message type {type(sub_val)} is not compatible with endpoint {endpoint}. Expected {endpoint.message_type}"
)
if isinstance(val, tuple):
for sub_val in val:
if isinstance(sub_val, BECMessage) and endpoint.message_type == typing.Any:
continue
if isinstance(sub_val, BECMessage) and not isinstance(
sub_val, endpoint.message_type
):
raise TypeError(
f"Message type {type(sub_val)} is not compatible with endpoint {endpoint}. Expected {endpoint.message_type}"
)
if isinstance(arg, list):
arg[argument_index] = endpoint.endpoint
return func(*arg, **kwargs)
arg[endpoint_arg] = endpoint.endpoint
return func(*args, **arg)
return wrapper
return decorator
@dataclass
class GeneratorExecution:
fut: Awaitable
g: Generator
@dataclass
class StreamSubscriptionInfo:
id: str
topic: str
newest_only: bool
from_start: bool
cb_ref: callable
kwargs: dict
def __eq__(self, other):
if not isinstance(other, StreamSubscriptionInfo):
return False
return (
self.topic == other.topic
and self.cb_ref == other.cb_ref
and self.from_start == other.from_start
)
@dataclass
class DirectReadingStreamSubscriptionInfo(StreamSubscriptionInfo):
thread = None
stop_event = None
@dataclass
class StreamMessage:
msg: dict
callbacks: list
class RedisConnector(ConnectorBase):
"""
Redis connector class. This class is a wrapper around the redis library providing
a simple interface to send and receive messages from a redis server.
"""
def __init__(self, bootstrap: list, redis_cls=None, **kwargs):
"""
Initialize the connector
Args:
bootstrap (list): list of strings in the form "host:port"
redis_cls (redis.client, optional): redis client class. Defaults to None.
"""
super().__init__(bootstrap)
self.host, self.port = (
bootstrap[0].split(":") if isinstance(bootstrap, list) else bootstrap.split(":")
)
if redis_cls:
self._redis_conn = redis_cls(host=self.host, port=self.port)
else:
self._redis_conn = redis.Redis(host=self.host, port=self.port, **kwargs)
# main pubsub connection
self._pubsub_conn = self._redis_conn.pubsub()
self._pubsub_conn.ignore_subscribe_messages = True
# keep track of topics and callbacks
self._topics_cb = collections.defaultdict(list)
self._topics_cb_lock = threading.Lock()
self._stream_topics_subscription = collections.defaultdict(list)
self._stream_topics_subscription_lock = threading.Lock()
self._events_listener_thread = None
self._stream_events_listener_thread = None
self._events_dispatcher_thread = None
self._messages_queue = queue.Queue()
self._stop_events_listener_thread = threading.Event()
self._stop_stream_events_listener_thread = threading.Event()
self.stream_keys = {}
self._generator_executor = ThreadPoolExecutor()
def authenticate(self, password: str, username: str = "user"):
"""
Authenticate to the redis server
Args:
password (str): password
username (str, optional): username. Defaults to "default".
"""
self._redis_conn.connection_pool.connection_kwargs["username"] = username
self._redis_conn.connection_pool.connection_kwargs["password"] = password
def shutdown(self):
"""
Shutdown the connector
"""
super().shutdown()
self._generator_executor.shutdown()
if self._events_listener_thread:
self._stop_events_listener_thread.set()
self._events_listener_thread.join()
self._events_listener_thread = None
if self._stream_events_listener_thread:
self._stop_stream_events_listener_thread.set()
self._stream_events_listener_thread.join()
self._stream_events_listener_thread = None
if self._events_dispatcher_thread:
self._messages_queue.put(StopIteration)
self._events_dispatcher_thread.join()
self._events_dispatcher_thread = None
# this will take care of shutting down direct listening threads
self._unregister_stream(self._stream_topics_subscription)
# release all connections
self._pubsub_conn.close()
self._redis_conn.close()
def send_client_info(
self,
message: str,
show_asap: bool = False,
source: Literal[
"bec_ipython_client",
"scan_server",
"device_server",
"scan_bundler",
"file_writer",
"scihub",
"dap",
None,
] = None,
severity: int = 0,
scope: str = None,
rid: str = None,
metadata: dict = None,
):
"""
Send a message to the client
Args:
msg (str): message
show_asap (bool, optional): show asap. Defaults to False.
source (Literal[str], optional): Any of the services: "bec_ipython_client", "scan_server", "device_server", "scan_bundler", "file_writer", "scihub", "dap". Defaults to None.
severity (int, optional): severity. Defaults to 0.
rid (str, optional): request ID. Defaults to None.
scope (str, optional): scope. Defaults to None.
metadata (dict, optional): metadata. Defaults to None.
"""
client_msg = ClientInfoMessage(
message=message,
source=source,
severity=severity,
show_asap=show_asap,
scope=scope,
RID=rid,
metadata=metadata,
)
self.xadd(MessageEndpoints.client_info(), msg_dict={"data": client_msg}, max_size=100)
def raise_alarm(self, severity: Alarms, alarm_type: str, source: str, msg: str, metadata: dict):
"""
Raise an alarm
Args:
severity (Alarms): alarm severity
alarm_type (str): alarm type
source (str): source
msg (str): message
metadata (dict): metadata
"""
alarm_msg = AlarmMessage(
severity=severity, alarm_type=alarm_type, source=source, msg=msg, metadata=metadata
)
self.set_and_publish(MessageEndpoints.alarm(), alarm_msg)
def pipeline(self) -> redis.client.Pipeline:
"""Create a new pipeline"""
return self._redis_conn.pipeline()
def execute_pipeline(self, pipeline) -> list:
"""
Execute a pipeline and return the results
Args:
pipeline (Pipeline): redis pipeline
Returns:
list: list of results
"""
if not isinstance(pipeline, redis.client.Pipeline):
raise TypeError(f"Expected a redis Pipeline, got {type(pipeline)}")
ret = []
results = pipeline.execute()
for res in results:
try:
ret.append(MsgpackSerialization.loads(res))
except RuntimeError:
ret.append(res)
return ret
def raw_send(self, topic: str, msg: bytes, pipe=None):
"""
Send a message to a topic. This is the raw version of send, it does not
check the message type. Use this method if you want to send a message
that is not a BECMessage.
Args:
topic (str): topic
msg (bytes): message
pipe (Pipeline, optional): redis pipe. Defaults to None.
"""
client = pipe if pipe is not None else self._redis_conn
client.publish(topic, msg)
@validate_endpoint("topic")
def send(self, topic: EndpointInfo, msg: BECMessage, pipe=None) -> None:
"""
Send a message to a topic
Args:
topic (str): topic
msg (BECMessage): message
pipe (Pipeline, optional): redis pipe. Defaults to None.
"""
if isinstance(msg, BECMessage):
msg = MsgpackSerialization.dumps(msg)
self.raw_send(topic, msg, pipe)
def _start_events_dispatcher_thread(self, start_thread):
if start_thread and self._events_dispatcher_thread is None:
# start dispatcher thread
started_event = threading.Event()
self._events_dispatcher_thread = threading.Thread(
target=self._dispatch_events, args=(started_event,)
)
self._events_dispatcher_thread.start()
started_event.wait() # synchronization of thread start
def _convert_endpointinfo(self, endpoint, check_message_op=True):
if isinstance(endpoint, EndpointInfo):
return [endpoint.endpoint], endpoint.message_op.name
if isinstance(endpoint, str):
return [endpoint], ""
# Support list of endpoints or dict with endpoints as keys
if isinstance(endpoint, (Sequence, MutableMapping)):
endpoints_str = []
ref_message_op = None
for e in endpoint:
e_str, message_op = self._convert_endpointinfo(e, check_message_op=check_message_op)
if check_message_op:
if ref_message_op is None:
ref_message_op = message_op
else:
if message_op != ref_message_op:
raise ValueError(
f"All endpoints do not have the same type: {ref_message_op}"
)
endpoints_str.append(e_str)
return list(itertools.chain(*endpoints_str)), ref_message_op or ""
raise ValueError(f"Invalid endpoint {endpoint}")
def _normalize_patterns(self, patterns):
patterns, _ = self._convert_endpointinfo(patterns)
if isinstance(patterns, str):
return [patterns]
elif isinstance(patterns, list):
if not all(isinstance(p, str) for p in patterns):
raise ValueError("register: patterns must be a string or a list of strings")
else:
raise ValueError("register: patterns must be a string or a list of strings")
return patterns
def register(
self,
topics: str | list[str] | EndpointInfo | list[EndpointInfo] = None,
patterns: str | list[str] = None,
cb: callable = None,
start_thread: bool = True,
from_start: bool = False,
newest_only: bool = False,
**kwargs,
):
"""
Register a callback for a topic or a pattern
Args:
topics (str, list, EndpointInfo, list[EndpointInfo], optional): topic or list of topics. Defaults to None. The topic should be a valid message endpoint in BEC and can be a string or an EndpointInfo object.
patterns (str, list, optional): pattern or list of patterns. Defaults to None. In contrast to topics, patterns may contain "*" wildcards. The evaluated patterns should be a valid pub/sub message endpoint in BEC
cb (callable, optional): callback. Defaults to None.
start_thread (bool, optional): start the dispatcher thread. Defaults to True.
from_start (bool, optional): for streams only: return data from start on first reading. Defaults to False.
newest_only (bool, optional): for streams only: return newest data only. Defaults to False.
**kwargs: additional keyword arguments to be transmitted to the callback
Examples:
>>> def my_callback(msg, **kwargs):
... print(msg)
...
>>> connector.register("test", my_callback)
>>> connector.register(topics="test", cb=my_callback)
>>> connector.register(patterns="test:*", cb=my_callback)
>>> connector.register(patterns="test:*", cb=my_callback, start_thread=False)
>>> connector.register(patterns="test:*", cb=my_callback, start_thread=False, my_arg="test")
"""
if cb is None:
raise ValueError("Callback cb cannot be None")
if topics is None and patterns is None:
raise ValueError("topics and patterns cannot be both None")
# make a weakref from the callable, using louie;
# it can create safe refs for simple functions as well as methods
cb_ref = louie.saferef.safe_ref(cb)
item = (cb_ref, kwargs)
if self._events_listener_thread is None:
# create the thread that will get all messages for this connector;
self._events_listener_thread = threading.Thread(target=self._get_messages_loop)
self._events_listener_thread.start()
if patterns is not None:
patterns = self._normalize_patterns(patterns)
self._pubsub_conn.psubscribe(patterns)
with self._topics_cb_lock:
for pattern in patterns:
if item not in self._topics_cb[pattern]:
self._topics_cb[pattern].append(item)
else:
topics, message_op = self._convert_endpointinfo(topics)
if message_op == "STREAM":
return self._register_stream(
topics=topics,
cb=cb,
from_start=from_start,
newest_only=newest_only,
start_thread=start_thread,
**kwargs,
)
self._pubsub_conn.subscribe(topics)
with self._topics_cb_lock:
for topic in topics:
if item not in self._topics_cb[topic]:
self._topics_cb[topic].append(item)
self._start_events_dispatcher_thread(start_thread)
def _add_direct_stream_listener(self, topic, cb_ref, **kwargs) -> int:
"""
Add a direct listener for a topic. This is used when newest_only is True.
Args:
topic (str): topic
cb (callable): weakref to callback
kwargs (dict): additional keyword arguments to be transmitted to the callback
Returns:
int: stream id
"""
info = DirectReadingStreamSubscriptionInfo(
id="-", topic=topic, newest_only=True, from_start=False, cb_ref=cb_ref, kwargs=kwargs
)
if info in self._stream_topics_subscription[topic]:
raise RuntimeError("Already registered stream topic with the same callback")
info.stop_event = threading.Event()
info.thread = threading.Thread(target=self._direct_stream_listener, args=(info,))
with self._stream_topics_subscription_lock:
self._stream_topics_subscription[topic].append(info)
info.thread.start()
def _direct_stream_listener(self, info: DirectReadingStreamSubscriptionInfo):
stop_event = info.stop_event
cb_ref = info.cb_ref
kwargs = info.kwargs
topic = info.topic
while not stop_event.is_set():
ret = self._redis_conn.xrevrange(topic, "+", info.id, count=1)
if not ret:
time.sleep(0.1)
continue
redis_id, msg_dict = ret[0]
timestamp, _, ind = redis_id.partition(b"-")
info.id = f"{timestamp.decode()}-{int(ind.decode())+1}"
stream_msg = StreamMessage(
{key.decode(): MsgpackSerialization.loads(val) for key, val in msg_dict.items()},
((cb_ref, kwargs),),
)
self._messages_queue.put(stream_msg)
def _get_stream_topics_id(self) -> dict:
stream_topics_id = {}
from_start_stream_topics_id = {}
with self._stream_topics_subscription_lock:
for topic, subscription_info_list in self._stream_topics_subscription.items():
for info in subscription_info_list:
if isinstance(info, DirectReadingStreamSubscriptionInfo):
continue
if info.from_start:
from_start_stream_topics_id[topic] = info.id
else:
stream_topics_id[topic] = info.id
return from_start_stream_topics_id, stream_topics_id
def _handle_stream_msg_list(self, msg_list, from_start=False):
for topic, msgs in msg_list:
subscription_info_list = self._stream_topics_subscription[topic.decode()]
for index, record in msgs:
callbacks = []
for info in subscription_info_list:
info.id = index.decode()
if from_start and not info.from_start:
continue
callbacks.append((info.cb_ref, info.kwargs))
if callbacks:
msg_dict = {
k.decode(): MsgpackSerialization.loads(msg) for k, msg in record.items()
}
msg = StreamMessage(msg_dict, callbacks)
self._messages_queue.put(msg)
for info in subscription_info_list:
info.from_start = False
def _get_stream_messages_loop(self) -> None:
"""
Get stream messages loop. This method is run in a separate thread and listens
for messages from the redis server.
"""
error = False
while not self._stop_stream_events_listener_thread.is_set():
try:
from_start_stream_topics_id, stream_topics_id = self._get_stream_topics_id()
if not any((stream_topics_id, from_start_stream_topics_id)):
time.sleep(0.1)
continue
msg_list = []
from_start_msg_list = []
# first handle the 'from_start' streams ;
# in the case of reading from start what is expected is to call the
# callbacks for existing items, without waiting for a new element to be added
# to the stream
if from_start_stream_topics_id:
# read the streams contents from beginning, not blocking
from_start_msg_list = self._redis_conn.xread(from_start_stream_topics_id)
if stream_topics_id:
msg_list = self._redis_conn.xread(stream_topics_id, block=200)
except redis.exceptions.ConnectionError:
if not error:
error = True
bec_logger.logger.error("Failed to connect to redis. Is the server running?")
time.sleep(1)
# pylint: disable=broad-except
except Exception:
sys.excepthook(*sys.exc_info())
else:
error = False
with self._stream_topics_subscription_lock:
self._handle_stream_msg_list(from_start_msg_list, from_start=True)
self._handle_stream_msg_list(msg_list)
return True
def _register_stream(
self,
topics: list[str] = None,
cb: callable = None,
from_start: bool = False,
newest_only: bool = False,
start_thread: bool = True,
**kwargs,
) -> None:
"""
Register a callback for a stream topic or pattern
Args:
topic (str, optional): Topic. This should be a valid message endpoint string.
cb (callable, optional): callback. Defaults to None.
from_start (bool, optional): read from start. Defaults to False.
newest_only (bool, optional): read newest only. Defaults to False.
start_thread (bool, optional): start the dispatcher thread. Defaults to True.
**kwargs: additional keyword arguments to be transmitted to the callback
"""
if newest_only and from_start:
raise ValueError("newest_only and from_start cannot be both True")
# make a weakref from the callable, using louie;
# it can create safe refs for simple functions as well as methods
cb_ref = louie.saferef.safe_ref(cb)
self._start_events_dispatcher_thread(start_thread)
if newest_only:
# if newest_only is True, we need to provide a separate callback for each topic,
# directly calling the callback. This is because we need to have a backpressure
# mechanism in place, and we cannot rely on the dispatcher thread to handle it.
for topic in topics:
self._add_direct_stream_listener(topic, cb_ref, **kwargs)
else:
with self._stream_topics_subscription_lock:
for topic in topics:
try:
stream_info = self._redis_conn.xinfo_stream(topic)
except redis.exceptions.ResponseError:
# no such key
last_id = "0-0"
else:
last_id = stream_info["last-entry"][0].decode()
new_subscription = StreamSubscriptionInfo(
id="0-0" if from_start else last_id,
topic=topic,
newest_only=newest_only,
from_start=from_start,
cb_ref=cb_ref,
kwargs=kwargs,
)
subscriptions = self._stream_topics_subscription[topic]
if new_subscription in subscriptions:
# raise an error if attempted to register a stream with the same callback,
# whereas it has already been registered as a 'direct reading' stream with
# newest_only=True ; it is clearly an error case that would produce weird results
index = subscriptions.index(new_subscription)
if isinstance(subscriptions[index], DirectReadingStreamSubscriptionInfo):
raise RuntimeError(
"Already registered stream topic with the same callback with 'newest_only=True'"
)
else:
subscriptions.append(new_subscription)
if self._stream_events_listener_thread is None:
# create the thread that will get all messages for this connector
self._stream_events_listener_thread = threading.Thread(
target=self._get_stream_messages_loop
)
self._stream_events_listener_thread.start()
def _filter_topics_cb(self, topics: list, cb: Union[callable, None]):
unsubscribe_list = []
with self._topics_cb_lock:
for topic in topics:
topics_cb = self._topics_cb[topic]
# remove callback from list
self._topics_cb[topic] = list(
filter(lambda item: cb and item[0]() is not cb, topics_cb)
)
if not self._topics_cb[topic]:
# no callbacks left, unsubscribe
unsubscribe_list.append(topic)
# clean the topics that have been unsubscribed
for topic in unsubscribe_list:
del self._topics_cb[topic]
return unsubscribe_list
def unregister(self, topics=None, patterns=None, cb=None):
if self._events_listener_thread is None:
return
if patterns is not None:
patterns = self._normalize_patterns(patterns)
# see if registered streams can be unregistered
for pattern in patterns:
self._unregister_stream(
fnmatch.filter(self._stream_topics_subscription, pattern), cb
)
pubsub_unsubscribe_list = self._filter_topics_cb(patterns, cb)
if pubsub_unsubscribe_list:
self._pubsub_conn.punsubscribe(pubsub_unsubscribe_list)
else:
topics, _ = self._convert_endpointinfo(topics, check_message_op=False)
if not self._unregister_stream(topics, cb):
unsubscribe_list = self._filter_topics_cb(topics, cb)
if unsubscribe_list:
self._pubsub_conn.unsubscribe(unsubscribe_list)
def _unregister_stream(self, topics: list[str], cb: callable = None) -> bool:
"""
Unregister a stream listener.
Args:
topics (list[str]): list of stream topics
Returns:
bool: True if the stream listener has been removed, False otherwise
"""
unsubscribe_list = []
with self._stream_topics_subscription_lock:
for topic in topics:
subscription_infos = self._stream_topics_subscription[topic]
# remove from list if callback corresponds
self._stream_topics_subscription[topic] = list(
filter(lambda sub_info: cb and sub_info.cb_ref() is not cb, subscription_infos)
)
if not self._stream_topics_subscription[topic]:
# no callbacks left, unsubscribe
unsubscribe_list += subscription_infos
# clean the topics that have been unsubscribed
for subscription_info in unsubscribe_list:
if isinstance(subscription_info, DirectReadingStreamSubscriptionInfo):
subscription_info.stop_event.set()
subscription_info.thread.join()
# it is possible to register the same stream multiple times with different
# callbacks, in this case when unregistering with cb=None (unregister all)
# the topic can be deleted multiple times, hence try...except in code below
try:
del self._stream_topics_subscription[subscription_info.topic]
except KeyError:
pass
return len(unsubscribe_list) > 0
def _get_messages_loop(self) -> None:
"""
Get messages loop. This method is run in a separate thread and listens
for messages from the redis server.
Args:
pubsub (redis.client.PubSub): pubsub object
"""
error = False
while not self._stop_events_listener_thread.is_set():
try:
msg = self._pubsub_conn.get_message(timeout=1)
except redis.exceptions.ConnectionError:
if not error:
error = True
bec_logger.logger.error("Failed to connect to redis. Is the server running?")
time.sleep(1)
# pylint: disable=broad-except
except Exception:
sys.excepthook(*sys.exc_info())
else:
error = False
if msg is not None:
self._messages_queue.put(msg)
def _execute_callback(self, cb, msg, kwargs):
try:
g = cb(msg, **kwargs)
# pylint: disable=broad-except
except Exception:
sys.excepthook(*sys.exc_info())
else:
if inspect.isgenerator(g):
# reschedule execution to delineate the generator
self._messages_queue.put(g)
def _handle_message(self, msg):
if inspect.isgenerator(msg):
g = msg
fut = self._generator_executor.submit(next, g)
self._messages_queue.put(GeneratorExecution(fut, g))
elif isinstance(msg, StreamMessage):
for cb_ref, kwargs in msg.callbacks:
cb = cb_ref()
if cb:
self._execute_callback(cb, msg.msg, kwargs)
elif isinstance(msg, GeneratorExecution):
fut, g = msg.fut, msg.g
if fut.done():
try:
res = fut.result()
except StopIteration:
pass
else:
fut = self._generator_executor.submit(g.send, res)
self._messages_queue.put(GeneratorExecution(fut, g))
else:
self._messages_queue.put(GeneratorExecution(fut, g))
else:
channel = msg["channel"].decode()
with self._topics_cb_lock:
if msg["pattern"] is not None:
callbacks = self._topics_cb[msg["pattern"].decode()]
else:
callbacks = self._topics_cb[channel]
msg = MessageObject(topic=channel, value=MsgpackSerialization.loads(msg["data"]))
for cb_ref, kwargs in callbacks:
cb = cb_ref()
if cb:
self._execute_callback(cb, msg, kwargs)
def poll_messages(self, timeout=None) -> None:
"""Poll messages from the messages queue
If timeout is None, wait for at least one message. Processes until queue is empty,
or until timeout is reached.
Args:
timeout (float): timeout in seconds
"""
start_time = time.perf_counter()
remaining_timeout = timeout
while True:
try:
# wait for a message and return it before timeout expires
msg = self._messages_queue.get(timeout=remaining_timeout, block=True)
except queue.Empty as exc:
if remaining_timeout < timeout:
# at least one message has been processed, so we do not raise
# the timeout error
return True
raise TimeoutError(f"{self}: timeout waiting for messages") from exc
else:
if msg is StopIteration:
return False
try:
self._handle_message(msg)
# pylint: disable=broad-except
except Exception:
content = traceback.format_exc()
bec_logger.logger.error(f"Error handling message {msg}:\n{content}")
if timeout is None:
if self._messages_queue.empty():
# no message to process
return True
else:
# calculate how much time remains and retry getting a message
remaining_timeout = timeout - (time.perf_counter() - start_time)
if remaining_timeout <= 0:
return True
def _dispatch_events(self, started_event):
started_event.set()
while self.poll_messages():
...
@validate_endpoint("topic")
def lpush(
self, topic: EndpointInfo, msg: str, pipe=None, max_size: int = None, expire: int = None
) -> None:
"""Time complexity: O(1) for each element added, so O(N) to
add N elements when the command is called with multiple arguments.
Insert all the specified values at the head of the list stored at key.
If key does not exist, it is created as empty list before
performing the push operations. When key holds a value that
is not a list, an error is returned."""
client = pipe if pipe is not None else self.pipeline()
if isinstance(msg, BECMessage):
msg = MsgpackSerialization.dumps(msg)
client.lpush(topic, msg)
if max_size:
client.ltrim(topic, 0, max_size)
if expire:
client.expire(topic, expire)
if not pipe:
client.execute()
@validate_endpoint("topic")
def lset(self, topic: EndpointInfo, index: int, msg: str, pipe=None) -> None:
client = pipe if pipe is not None else self._redis_conn
if isinstance(msg, BECMessage):
msg = MsgpackSerialization.dumps(msg)
return client.lset(topic, index, msg)
@validate_endpoint("topic")
def rpush(self, topic: EndpointInfo, msg: str, pipe=None) -> int:
"""O(1) for each element added, so O(N) to add N elements when the
command is called with multiple arguments. Insert all the specified
values at the tail of the list stored at key. If key does not exist,
it is created as empty list before performing the push operation. When
key holds a value that is not a list, an error is returned."""
client = pipe if pipe is not None else self._redis_conn
if isinstance(msg, BECMessage):
msg = MsgpackSerialization.dumps(msg)
return client.rpush(topic, msg)
@validate_endpoint("topic")
def lrange(self, topic: EndpointInfo, start: int, end: int, pipe=None):
"""O(S+N) where S is the distance of start offset from HEAD for small
lists, from nearest end (HEAD or TAIL) for large lists; and N is the
number of elements in the specified range. Returns the specified elements
of the list stored at key. The offsets start and stop are zero-based indexes,
with 0 being the first element of the list (the head of the list), 1 being
the next element and so on."""
client = pipe if pipe is not None else self._redis_conn
cmd_result = client.lrange(topic, start, end)
if pipe:
return cmd_result
# in case of command executed in a pipe, use 'execute_pipeline' method
ret = []
for msg in cmd_result:
try:
ret.append(MsgpackSerialization.loads(msg))
except RuntimeError:
ret.append(msg)
return ret
@validate_endpoint("topic")
def set_and_publish(self, topic: EndpointInfo, msg, pipe=None, expire: int = None) -> None:
"""piped combination of self.publish and self.set"""
client = pipe if pipe is not None else self.pipeline()
if isinstance(msg, BECMessage):
msg = MsgpackSerialization.dumps(msg)
client.set(topic, msg, ex=expire)
self.raw_send(topic, msg, pipe=client)
if not pipe:
client.execute()
@validate_endpoint("topic")
def set(self, topic: EndpointInfo, msg, pipe=None, expire: int = None) -> None:
"""set redis value"""
client = pipe if pipe is not None else self._redis_conn
if isinstance(msg, BECMessage):
msg = MsgpackSerialization.dumps(msg)
client.set(topic, msg, ex=expire)
@validate_endpoint("pattern")
def keys(self, pattern: EndpointInfo) -> list:
"""returns all keys matching a pattern"""
return self._redis_conn.keys(pattern)
@validate_endpoint("topic")
def delete(self, topic: EndpointInfo, pipe=None):
"""delete topic"""
client = pipe if pipe is not None else self._redis_conn
client.delete(topic)
@validate_endpoint("topic")
def get(self, topic: EndpointInfo, pipe=None):
"""retrieve entry, either via hgetall or get"""
client = pipe if pipe is not None else self._redis_conn
data = client.get(topic)
if pipe:
return data
else:
try:
return MsgpackSerialization.loads(data)
except RuntimeError:
return data
def mget(self, topics: list[str], pipe=None):
"""retrieve multiple entries"""
client = pipe if pipe is not None else self._redis_conn
data = client.mget(topics)
if pipe:
return data
return [MsgpackSerialization.loads(d) if d else None for d in data]
@validate_endpoint("topic")
def xadd(
self, topic: EndpointInfo, msg_dict: dict, max_size=None, pipe=None, expire: int = None
):
"""
add to stream
Args:
topic (str): redis topic
msg_dict (dict): message to add
max_size (int, optional): max size of stream. Defaults to None.
pipe (Pipeline, optional): redis pipe. Defaults to None.
expire (int, optional): expire time. Defaults to None.
Examples:
>>> redis.xadd("test", {"test": "test"})
>>> redis.xadd("test", {"test": "test"}, max_size=10)
"""
if pipe:
client = pipe
elif expire:
client = self.pipeline()
else:
client = self._redis_conn
msg_dict = {key: MsgpackSerialization.dumps(val) for key, val in msg_dict.items()}
if max_size:
client.xadd(topic, msg_dict, maxlen=max_size)
else:
client.xadd(topic, msg_dict)
if expire:
client.expire(topic, expire)
if not pipe and expire:
client.execute()
@validate_endpoint("topic")
def get_last(self, topic: EndpointInfo, key=None, count=1):
"""
Get last message from stream. Repeated calls will return
the same message until a new message is added to the stream.
Args:
topic (str): redis topic
key (str, optional): key to retrieve. Defaults to None. If None, the whole message is returned.
count (int, optional): number of last elements to retrieve
"""
if count <= 0:
return None
ret = []
client = self._redis_conn
try:
res = client.xrevrange(topic, "+", "-", count=count)
if not res:
return None
for _, msg_dict in reversed(res):
ret.append(
{k.decode(): MsgpackSerialization.loads(msg) for k, msg in msg_dict.items()}
if key is None
else MsgpackSerialization.loads(msg_dict[key.encode()])
)
except TypeError:
return None
if count > 1:
return ret
else:
return ret[0]
@validate_endpoint("topic")
def xread(
self,
topic: EndpointInfo,
id: str = None,
count: int = None,
block: int = None,
from_start=False,
) -> list | None:
"""
read from stream
Args:
topic (str): redis topic
id (str, optional): id to read from. Defaults to None.
count (int, optional): number of messages to read. Defaults to None, which means all.
block (int, optional): block for x milliseconds. Defaults to None.
from_start (bool, optional): read from start. Defaults to False.
Returns:
[list]: list of messages
Examples:
>>> redis.xread("test", "0-0")
>>> redis.xread("test", "0-0", count=1)
# read one message at a time
>>> key = 0
>>> msg = redis.xread("test", key, count=1)
>>> key = msg[0][1][0][0]
>>> next_msg = redis.xread("test", key, count=1)
"""
client = self._redis_conn
if from_start:
self.stream_keys[topic] = "0-0"
if topic not in self.stream_keys:
if id is None:
try:
msg = client.xrevrange(topic, "+", "-", count=1)
if msg:
self.stream_keys[topic] = msg[0][0].decode()
out = {}
for key, val in msg[0][1].items():
out[key.decode()] = MsgpackSerialization.loads(val)
return [out]
self.stream_keys[topic] = "0-0"
except redis.exceptions.ResponseError:
self.stream_keys[topic] = "0-0"
if id is None:
id = self.stream_keys[topic]
msg = client.xread({topic: id}, count=count, block=block)
return self._decode_stream_messages_xread(msg)
def _decode_stream_messages_xread(self, msg):
out = []
for topic, msgs in msg:
for index, record in msgs:
out.append(
{k.decode(): MsgpackSerialization.loads(msg) for k, msg in record.items()}
)
self.stream_keys[topic.decode()] = index
return out if out else None
@validate_endpoint("topic")
def xrange(self, topic: EndpointInfo, min: str, max: str, count: int = None):
"""
read a range from stream
Args:
topic (str): redis topic
min (str): min id. Use "-" to read from start
max (str): max id. Use "+" to read to end
count (int, optional): number of messages to read. Defaults to None.
Returns:
[list]: list of messages or None
"""
client = self._redis_conn
msgs = []
for reading in client.xrange(topic, min, max, count=count):
_, msg_dict = reading
msgs.append(
{k.decode(): MsgpackSerialization.loads(msg) for k, msg in msg_dict.items()}
)
return msgs if msgs else None
def producer(self):
"""Return itself as a producer, to be compatible with old code"""
warnings.warn(
"RedisConnector.producer() is deprecated and should not be used anymore. A Connector is a producer now, just use the connector object.",
FutureWarning,
)
return self
def consumer(
self,
topics=None,
patterns=None,
group_id=None,
event=None,
cb=None,
threaded=True,
name=None,
**kwargs,
):
"""Return a fake thread object to be compatible with old code
In order to keep this fail-safe and simple it uses 'mock'...
"""
from unittest.mock import ( # import is done here, to not pollute the file with something normally in tests
Mock,
)
warnings.warn(
"RedisConnector.consumer() is deprecated and should not be used anymore. Use RedisConnector.register() with 'topics', 'patterns', 'cb' or 'start_thread' instead. Additional keyword args are transmitted to the callback. For the caller, the main difference with RedisConnector.register() is that it does not return a new thread.",
FutureWarning,
)
dummy_thread = Mock(spec=threading.Thread)
dummy_thread.start.side_effet = lambda: self.register(
topics, patterns, cb, threaded, **kwargs
)
return dummy_thread