diff --git a/bec_lib/bec_lib/core/redis_connector.py b/bec_lib/bec_lib/core/redis_connector.py index 5f44dfd5..75dcbb05 100644 --- a/bec_lib/bec_lib/core/redis_connector.py +++ b/bec_lib/bec_lib/core/redis_connector.py @@ -91,6 +91,53 @@ class RedisConnector(ConnectorBase): **kwargs, ) + def stream_consumer( + self, + topics=None, + pattern=None, + group_id=None, + event=None, + cb=None, + from_start=False, + newest_only=False, + **kwargs, + ): + """ + Threaded stream consumer for redis streams. + + Args: + topics (str, list): topics to subscribe to + pattern (str, list): pattern to subscribe to + group_id (str): group id + event (threading.Event): event to stop the consumer + cb (function): callback function + from_start (bool): read from start. Defaults to False. + newest_only (bool): read only the newest message. Defaults to False. + """ + if cb is None: + raise ValueError("The callback function must be specified.") + + if pattern: + raise ValueError("Pattern is currently not supported for stream consumer.") + + if topics is None and pattern is None: + raise ValueError("Topics must be set for stream consumer.") + listener = RedisStreamConsumerThreaded( + self.host, + self.port, + topics, + pattern, + group_id, + event, + cb, + redis_cls=self.redis_cls, + from_start=from_start, + newest_only=newest_only, + **kwargs, + ) + self._threads.append(listener) + return listener + @catch_connection_error def log_warning(self, msg): """send a warning""" @@ -343,6 +390,7 @@ class RedisConsumerMixin: else: self.r = redis.Redis(host=self.host, port=self.port) + @catch_connection_error def initialize_connector(self) -> None: if self.pattern is not None: self.pubsub.psubscribe(self.pattern) @@ -378,11 +426,9 @@ class RedisConsumer(RedisConsumerMixin, ConsumerConnector): cb=cb, **kwargs, ) - + self.error_message_sent = False self._init_redis_cls(redis_cls) - self.pubsub = self.r.pubsub() - self.initialize_connector() @catch_connection_error @@ -391,12 +437,19 @@ class RedisConsumer(RedisConsumerMixin, ConsumerConnector): Poll messages from self.connector and call the callback function self.cb """ - messages = self.pubsub.get_message(ignore_subscribe_messages=True) - if messages is not None: - msg = MessageObject(topic=messages["channel"], value=messages["data"]) - return self.cb(msg, **self.kwargs) + try: + messages = self.pubsub.get_message(ignore_subscribe_messages=True) + if messages is not None: + msg = MessageObject(topic=messages["channel"], value=messages["data"]) + return self.cb(msg, **self.kwargs) - time.sleep(0.01) + time.sleep(0.01) + self.error_message_sent = False + except redis.exceptions.ConnectionError: + if not self.error_message_sent: + print("Failed to connect to redis. Is the server running?") + self.error_message_sent = True + time.sleep(1) return None def shutdown(self): @@ -404,6 +457,141 @@ class RedisConsumer(RedisConsumerMixin, ConsumerConnector): self.pubsub.close() +class RedisStreamConsumerThreaded(RedisConsumerMixin, ConsumerConnectorThreaded): + # pylint: disable=too-many-arguments + def __init__( + self, + host, + port, + topics=None, + pattern=None, + group_id=None, + event=None, + cb=None, + redis_cls=None, + from_start=False, + newest_only=False, + **kwargs, + ): + self.host = host + self.port = port + self.from_start = from_start + self.newest_only = newest_only + + bootstrap_server = "".join([host, ":", port]) + topics, pattern = self._init_topics_and_pattern(topics, pattern) + super().__init__( + bootstrap_server=bootstrap_server, + topics=topics, + pattern=pattern, + group_id=group_id, + event=event, + cb=cb, + **kwargs, + ) + + self._init_redis_cls(redis_cls) + self.pubsub = self.r.pubsub() + + self.sleep_times = [0.005, 0.1] + self.last_received_msg = 0 + self.idle_time = 30 + self.error_message_sent = False + self.stream_keys = {} + + def initialize_connector(self) -> None: + pass + + def _init_topics_and_pattern(self, topics, pattern): + if topics: + if isinstance(topics, list): + topics = [f"{topic}:stream" for topic in topics] + else: + topics = [f"{topics}:stream"] + if pattern: + if isinstance(pattern, list): + pattern = [f"{pat}:stream" for pat in pattern] + else: + pattern = [f"{pattern}:stream"] + return topics, pattern + + def get_id(self, topic: str) -> str: + """ + Get the stream key for the given topic. + + Args: + topic (str): topic to get the stream key for + """ + if topic not in self.stream_keys: + return "0-0" + return self.stream_keys.get(topic) + + def get_newest_message(self, container: list, append=True) -> None: + """ + Get the newest message from the stream and update the stream key. If + append is True, append the message to the container. + + Args: + container (list): container to append the message to + append (bool, optional): append to container. Defaults to True. + """ + for topic in self.topics: + msg = self.r.xrevrange(topic, "+", "-", count=1) + if msg: + if append: + container.append((topic, msg[0][1])) + self.stream_keys[topic] = msg[0][0] + else: + self.stream_keys[topic] = "0-0" + + def poll_messages(self) -> None: + """ + Poll messages from self.connector and call the callback function self.cb + + """ + try: + if self.pattern is not None: + keys = self.r.keys(self.pattern) + topics = [key.decode() for key in keys if key.decode().endswith(":stream")] + else: + topics = self.topics + messages = [] + if self.newest_only: + self.get_newest_message(messages) + elif not self.from_start and not self.stream_keys: + self.get_newest_message(messages, append=False) + else: + streams = {f"{topic}": self.get_id(topic) for topic in topics} + read_msgs = self.r.xread(streams, count=1) + if read_msgs: + for msg in read_msgs: + topic = msg[0].decode() + messages.append((topic, msg[1][0][1])) + self.stream_keys[topic] = msg[1][-1][0] + + if messages: + if MessageEndpoints.log() not in topics: + # no need to update the update frequency just for logs + self.last_received_msg = time.time() + for topic, msg in messages: + msg_obj = MessageObject(topic=topic, value=msg[b"data"]) + self.cb(msg_obj, **self.kwargs) + else: + sleep_time = int(bool(time.time() - self.last_received_msg > self.idle_time)) + if self.sleep_times[sleep_time]: + time.sleep(self.sleep_times[sleep_time]) + self.error_message_sent = False + except redis.exceptions.ConnectionError: + if not self.error_message_sent: + print("Failed to connect to redis. Is the server running?") + self.error_message_sent = True + time.sleep(1) + + def shutdown(self): + super().shutdown() + self.pubsub.close() + + class RedisConsumerThreaded(RedisConsumerMixin, ConsumerConnectorThreaded): # pylint: disable=too-many-arguments def __init__( @@ -439,6 +627,7 @@ class RedisConsumerThreaded(RedisConsumerMixin, ConsumerConnectorThreaded): self.sleep_times = [0.005, 0.1] self.last_received_msg = 0 self.idle_time = 30 + self.error_message_sent = False def poll_messages(self) -> None: """ @@ -457,8 +646,11 @@ class RedisConsumerThreaded(RedisConsumerMixin, ConsumerConnectorThreaded): sleep_time = int(bool(time.time() - self.last_received_msg > self.idle_time)) if self.sleep_times[sleep_time]: time.sleep(self.sleep_times[sleep_time]) + self.error_message_sent = False except redis.exceptions.ConnectionError: - print("Failed to connect to redis. Is the server running?") + if not self.error_message_sent: + print("Failed to connect to redis. Is the server running?") + self.error_message_sent = True time.sleep(1) def shutdown(self):