feat: added stream consumer

This commit is contained in:
wakonig_k 2023-08-09 20:44:11 +02:00 committed by wakonig_k
parent b0467a86aa
commit b4043e970a

View File

@ -91,6 +91,53 @@ class RedisConnector(ConnectorBase):
**kwargs, **kwargs,
) )
def stream_consumer(
self,
topics=None,
pattern=None,
group_id=None,
event=None,
cb=None,
from_start=False,
newest_only=False,
**kwargs,
):
"""
Threaded stream consumer for redis streams.
Args:
topics (str, list): topics to subscribe to
pattern (str, list): pattern to subscribe to
group_id (str): group id
event (threading.Event): event to stop the consumer
cb (function): callback function
from_start (bool): read from start. Defaults to False.
newest_only (bool): read only the newest message. Defaults to False.
"""
if cb is None:
raise ValueError("The callback function must be specified.")
if pattern:
raise ValueError("Pattern is currently not supported for stream consumer.")
if topics is None and pattern is None:
raise ValueError("Topics must be set for stream consumer.")
listener = RedisStreamConsumerThreaded(
self.host,
self.port,
topics,
pattern,
group_id,
event,
cb,
redis_cls=self.redis_cls,
from_start=from_start,
newest_only=newest_only,
**kwargs,
)
self._threads.append(listener)
return listener
@catch_connection_error @catch_connection_error
def log_warning(self, msg): def log_warning(self, msg):
"""send a warning""" """send a warning"""
@ -343,6 +390,7 @@ class RedisConsumerMixin:
else: else:
self.r = redis.Redis(host=self.host, port=self.port) self.r = redis.Redis(host=self.host, port=self.port)
@catch_connection_error
def initialize_connector(self) -> None: def initialize_connector(self) -> None:
if self.pattern is not None: if self.pattern is not None:
self.pubsub.psubscribe(self.pattern) self.pubsub.psubscribe(self.pattern)
@ -378,11 +426,9 @@ class RedisConsumer(RedisConsumerMixin, ConsumerConnector):
cb=cb, cb=cb,
**kwargs, **kwargs,
) )
self.error_message_sent = False
self._init_redis_cls(redis_cls) self._init_redis_cls(redis_cls)
self.pubsub = self.r.pubsub() self.pubsub = self.r.pubsub()
self.initialize_connector() self.initialize_connector()
@catch_connection_error @catch_connection_error
@ -391,12 +437,19 @@ class RedisConsumer(RedisConsumerMixin, ConsumerConnector):
Poll messages from self.connector and call the callback function self.cb Poll messages from self.connector and call the callback function self.cb
""" """
try:
messages = self.pubsub.get_message(ignore_subscribe_messages=True) messages = self.pubsub.get_message(ignore_subscribe_messages=True)
if messages is not None: if messages is not None:
msg = MessageObject(topic=messages["channel"], value=messages["data"]) msg = MessageObject(topic=messages["channel"], value=messages["data"])
return self.cb(msg, **self.kwargs) return self.cb(msg, **self.kwargs)
time.sleep(0.01) time.sleep(0.01)
self.error_message_sent = False
except redis.exceptions.ConnectionError:
if not self.error_message_sent:
print("Failed to connect to redis. Is the server running?")
self.error_message_sent = True
time.sleep(1)
return None return None
def shutdown(self): def shutdown(self):
@ -404,6 +457,141 @@ class RedisConsumer(RedisConsumerMixin, ConsumerConnector):
self.pubsub.close() self.pubsub.close()
class RedisStreamConsumerThreaded(RedisConsumerMixin, ConsumerConnectorThreaded):
# pylint: disable=too-many-arguments
def __init__(
self,
host,
port,
topics=None,
pattern=None,
group_id=None,
event=None,
cb=None,
redis_cls=None,
from_start=False,
newest_only=False,
**kwargs,
):
self.host = host
self.port = port
self.from_start = from_start
self.newest_only = newest_only
bootstrap_server = "".join([host, ":", port])
topics, pattern = self._init_topics_and_pattern(topics, pattern)
super().__init__(
bootstrap_server=bootstrap_server,
topics=topics,
pattern=pattern,
group_id=group_id,
event=event,
cb=cb,
**kwargs,
)
self._init_redis_cls(redis_cls)
self.pubsub = self.r.pubsub()
self.sleep_times = [0.005, 0.1]
self.last_received_msg = 0
self.idle_time = 30
self.error_message_sent = False
self.stream_keys = {}
def initialize_connector(self) -> None:
pass
def _init_topics_and_pattern(self, topics, pattern):
if topics:
if isinstance(topics, list):
topics = [f"{topic}:stream" for topic in topics]
else:
topics = [f"{topics}:stream"]
if pattern:
if isinstance(pattern, list):
pattern = [f"{pat}:stream" for pat in pattern]
else:
pattern = [f"{pattern}:stream"]
return topics, pattern
def get_id(self, topic: str) -> str:
"""
Get the stream key for the given topic.
Args:
topic (str): topic to get the stream key for
"""
if topic not in self.stream_keys:
return "0-0"
return self.stream_keys.get(topic)
def get_newest_message(self, container: list, append=True) -> None:
"""
Get the newest message from the stream and update the stream key. If
append is True, append the message to the container.
Args:
container (list): container to append the message to
append (bool, optional): append to container. Defaults to True.
"""
for topic in self.topics:
msg = self.r.xrevrange(topic, "+", "-", count=1)
if msg:
if append:
container.append((topic, msg[0][1]))
self.stream_keys[topic] = msg[0][0]
else:
self.stream_keys[topic] = "0-0"
def poll_messages(self) -> None:
"""
Poll messages from self.connector and call the callback function self.cb
"""
try:
if self.pattern is not None:
keys = self.r.keys(self.pattern)
topics = [key.decode() for key in keys if key.decode().endswith(":stream")]
else:
topics = self.topics
messages = []
if self.newest_only:
self.get_newest_message(messages)
elif not self.from_start and not self.stream_keys:
self.get_newest_message(messages, append=False)
else:
streams = {f"{topic}": self.get_id(topic) for topic in topics}
read_msgs = self.r.xread(streams, count=1)
if read_msgs:
for msg in read_msgs:
topic = msg[0].decode()
messages.append((topic, msg[1][0][1]))
self.stream_keys[topic] = msg[1][-1][0]
if messages:
if MessageEndpoints.log() not in topics:
# no need to update the update frequency just for logs
self.last_received_msg = time.time()
for topic, msg in messages:
msg_obj = MessageObject(topic=topic, value=msg[b"data"])
self.cb(msg_obj, **self.kwargs)
else:
sleep_time = int(bool(time.time() - self.last_received_msg > self.idle_time))
if self.sleep_times[sleep_time]:
time.sleep(self.sleep_times[sleep_time])
self.error_message_sent = False
except redis.exceptions.ConnectionError:
if not self.error_message_sent:
print("Failed to connect to redis. Is the server running?")
self.error_message_sent = True
time.sleep(1)
def shutdown(self):
super().shutdown()
self.pubsub.close()
class RedisConsumerThreaded(RedisConsumerMixin, ConsumerConnectorThreaded): class RedisConsumerThreaded(RedisConsumerMixin, ConsumerConnectorThreaded):
# pylint: disable=too-many-arguments # pylint: disable=too-many-arguments
def __init__( def __init__(
@ -439,6 +627,7 @@ class RedisConsumerThreaded(RedisConsumerMixin, ConsumerConnectorThreaded):
self.sleep_times = [0.005, 0.1] self.sleep_times = [0.005, 0.1]
self.last_received_msg = 0 self.last_received_msg = 0
self.idle_time = 30 self.idle_time = 30
self.error_message_sent = False
def poll_messages(self) -> None: def poll_messages(self) -> None:
""" """
@ -457,8 +646,11 @@ class RedisConsumerThreaded(RedisConsumerMixin, ConsumerConnectorThreaded):
sleep_time = int(bool(time.time() - self.last_received_msg > self.idle_time)) sleep_time = int(bool(time.time() - self.last_received_msg > self.idle_time))
if self.sleep_times[sleep_time]: if self.sleep_times[sleep_time]:
time.sleep(self.sleep_times[sleep_time]) time.sleep(self.sleep_times[sleep_time])
self.error_message_sent = False
except redis.exceptions.ConnectionError: except redis.exceptions.ConnectionError:
if not self.error_message_sent:
print("Failed to connect to redis. Is the server running?") print("Failed to connect to redis. Is the server running?")
self.error_message_sent = True
time.sleep(1) time.sleep(1)
def shutdown(self): def shutdown(self):