bec/bec_lib/tests/test_redis_connector.py

689 lines
23 KiB
Python

from dataclasses import dataclass, field
from unittest import mock
import pytest
import redis
import bec_lib.messages as bec_messages
from bec_lib.alarm_handler import Alarms
from bec_lib.connector import ConsumerConnectorError
from bec_lib.endpoints import MessageEndpoints
from bec_lib.messages import AlarmMessage, BECMessage, LogMessage
from bec_lib.redis_connector import (
MessageObject,
RedisConnector,
RedisConsumer,
RedisConsumerMixin,
RedisConsumerThreaded,
RedisProducer,
RedisStreamConsumerThreaded,
)
from bec_lib.serialization import MsgpackSerialization
@pytest.fixture
def producer():
with mock.patch("bec_lib.redis_connector.redis.Redis"):
prod = RedisProducer("localhost", 1)
yield prod
@pytest.fixture
def connector():
with mock.patch("bec_lib.redis_connector.redis.Redis"):
connector = RedisConnector("localhost:1")
yield connector
@pytest.fixture
def consumer():
with mock.patch("bec_lib.redis_connector.redis.Redis"):
consumer = RedisConsumer("localhost", "1", topics="topics")
yield consumer
@pytest.fixture
def consumer_threaded():
with mock.patch("bec_lib.redis_connector.redis.Redis"):
consumer_threaded = RedisConsumerThreaded("localhost", "1", topics="topics")
yield consumer_threaded
@pytest.fixture
def mixin():
with mock.patch("bec_lib.redis_connector.redis.Redis"):
mixin = RedisConsumerMixin
yield mixin
def test_redis_connector_producer(connector):
ret = connector.producer()
assert isinstance(ret, RedisProducer)
@pytest.mark.parametrize(
"topics, threaded", [["topics", True], ["topics", False], [None, True], [None, False]]
)
def test_redis_connector_consumer(connector, threaded, topics):
pattern = None
len_of_threads = len(connector._threads)
if threaded:
if topics is None and pattern is None:
with pytest.raises(ValueError) as exc_info:
ret = connector.consumer(
topics=topics, threaded=threaded, cb=lambda *args, **kwargs: ...
)
assert exc_info.value.args[0] == "Topics must be set for threaded consumer"
else:
ret = connector.consumer(
topics=topics, threaded=threaded, cb=lambda *args, **kwargs: ...
)
assert len(connector._threads) == len_of_threads + 1
assert isinstance(ret, RedisConsumerThreaded)
else:
if not topics:
with pytest.raises(ConsumerConnectorError):
ret = connector.consumer(
topics=topics, threaded=threaded, cb=lambda *args, **kwargs: ...
)
return
ret = connector.consumer(topics=topics, threaded=threaded, cb=lambda *args, **kwargs: ...)
assert isinstance(ret, RedisConsumer)
def test_redis_connector_log_warning(connector):
connector._notifications_producer.send = mock.MagicMock()
connector.log_warning("msg")
connector._notifications_producer.send.assert_called_once_with(
MessageEndpoints.log(), LogMessage(log_type="warning", log_msg="msg")
)
def test_redis_connector_log_message(connector):
connector._notifications_producer.send = mock.MagicMock()
connector.log_message("msg")
connector._notifications_producer.send.assert_called_once_with(
MessageEndpoints.log(), LogMessage(log_type="log", log_msg="msg")
)
def test_redis_connector_log_error(connector):
connector._notifications_producer.send = mock.MagicMock()
connector.log_error("msg")
connector._notifications_producer.send.assert_called_once_with(
MessageEndpoints.log(), LogMessage(log_type="error", log_msg="msg")
)
@pytest.mark.parametrize(
"severity, alarm_type, source, msg, metadata",
[
[Alarms.MAJOR, "alarm", "source", "content1", {"metadata": "metadata1"}],
[Alarms.MINOR, "alarm", "source", "content1", {"metadata": "metadata1"}],
[Alarms.WARNING, "alarm", "source", "content1", {"metadata": "metadata1"}],
],
)
def test_redis_connector_raise_alarm(connector, severity, alarm_type, source, msg, metadata):
connector._notifications_producer.set_and_publish = mock.MagicMock()
connector.raise_alarm(severity, alarm_type, source, msg, metadata)
connector._notifications_producer.set_and_publish.assert_called_once_with(
MessageEndpoints.alarm(),
AlarmMessage(
severity=severity, alarm_type=alarm_type, source=source, msg=msg, metadata=metadata
),
)
@dataclass(eq=False)
class TestMessage(BECMessage):
msg_type = "test_message"
msg: str
# have to add this field here,
# could be inherited but it requires Python 3.10
# and 'kw_only=True'
metadata: dict = field(default_factory=lambda: {})
# register at BEC messages module level, to be able to
# find it when using "loads()"
bec_messages.TestMessage = TestMessage
@pytest.mark.parametrize(
"topic , msg", [["topic1", TestMessage("msg1")], ["topic2", TestMessage("msg2")]]
)
def test_redis_producer_send(producer, topic, msg):
producer.send(topic, msg)
producer.r.publish.assert_called_once_with(topic, MsgpackSerialization.dumps(msg))
producer.send(topic, msg, pipe=producer.pipeline())
producer.r.pipeline().publish.assert_called_once_with(topic, MsgpackSerialization.dumps(msg))
@pytest.mark.parametrize(
"topic, msgs, max_size, expire",
[["topic1", "msgs", None, None], ["topic1", "msgs", 10, None], ["topic1", "msgs", None, 100]],
)
def test_redis_producer_lpush(producer, topic, msgs, max_size, expire):
pipe = None
producer.lpush(topic, msgs, pipe, max_size, expire)
producer.r.pipeline().lpush.assert_called_once_with(topic, msgs)
if max_size:
producer.r.pipeline().ltrim.assert_called_once_with(topic, 0, max_size)
if expire:
producer.r.pipeline().expire.assert_called_once_with(topic, expire)
if not pipe:
producer.r.pipeline().execute.assert_called_once()
@pytest.mark.parametrize(
"topic, msgs, max_size, expire",
[
["topic1", TestMessage("msgs"), None, None],
["topic1", TestMessage("msgs"), 10, None],
["topic1", TestMessage("msgs"), None, 100],
],
)
def test_redis_producer_lpush_BECMessage(producer, topic, msgs, max_size, expire):
pipe = None
producer.lpush(topic, msgs, pipe, max_size, expire)
producer.r.pipeline().lpush.assert_called_once_with(topic, MsgpackSerialization.dumps(msgs))
if max_size:
producer.r.pipeline().ltrim.assert_called_once_with(topic, 0, max_size)
if expire:
producer.r.pipeline().expire.assert_called_once_with(topic, expire)
if not pipe:
producer.r.pipeline().execute.assert_called_once()
@pytest.mark.parametrize(
"topic , index , msgs, use_pipe", [["topic1", 1, "msg1", True], ["topic2", 4, "msg2", False]]
)
def test_redis_producer_lset(producer, topic, index, msgs, use_pipe):
pipe = use_pipe_fcn(producer, use_pipe)
ret = producer.lset(topic, index, msgs, pipe)
if pipe:
producer.r.pipeline().lset.assert_called_once_with(topic, index, msgs)
assert ret == redis.Redis().pipeline().lset()
else:
producer.r.lset.assert_called_once_with(topic, index, msgs)
assert ret == redis.Redis().lset()
@pytest.mark.parametrize(
"topic , index , msgs, use_pipe",
[["topic1", 1, TestMessage("msg1"), True], ["topic2", 4, TestMessage("msg2"), False]],
)
def test_redis_producer_lset_BECMessage(producer, topic, index, msgs, use_pipe):
pipe = use_pipe_fcn(producer, use_pipe)
ret = producer.lset(topic, index, msgs, pipe)
if pipe:
producer.r.pipeline().lset.assert_called_once_with(
topic, index, MsgpackSerialization.dumps(msgs)
)
assert ret == redis.Redis().pipeline().lset()
else:
producer.r.lset.assert_called_once_with(topic, index, MsgpackSerialization.dumps(msgs))
assert ret == redis.Redis().lset()
@pytest.mark.parametrize(
"topic, msgs, use_pipe", [["topic1", "msg1", True], ["topic2", "msg2", False]]
)
def test_redis_producer_rpush(producer, topic, msgs, use_pipe):
pipe = use_pipe_fcn(producer, use_pipe)
ret = producer.rpush(topic, msgs, pipe)
if pipe:
producer.r.pipeline().rpush.assert_called_once_with(topic, msgs)
assert ret == redis.Redis().pipeline().rpush()
else:
producer.r.rpush.assert_called_once_with(topic, msgs)
assert ret == redis.Redis().rpush()
@pytest.mark.parametrize(
"topic, msgs, use_pipe",
[["topic1", TestMessage("msg1"), True], ["topic2", TestMessage("msg2"), False]],
)
def test_redis_producer_rpush_BECMessage(producer, topic, msgs, use_pipe):
pipe = use_pipe_fcn(producer, use_pipe)
ret = producer.rpush(topic, msgs, pipe)
if pipe:
producer.r.pipeline().rpush.assert_called_once_with(topic, MsgpackSerialization.dumps(msgs))
assert ret == redis.Redis().pipeline().rpush()
else:
producer.r.rpush.assert_called_once_with(topic, MsgpackSerialization.dumps(msgs))
assert ret == redis.Redis().rpush()
@pytest.mark.parametrize(
"topic, start, end, use_pipe", [["topic1", 0, 4, True], ["topic2", 3, 7, False]]
)
def test_redis_producer_lrange(producer, topic, start, end, use_pipe):
pipe = use_pipe_fcn(producer, use_pipe)
ret = producer.lrange(topic, start, end, pipe)
if pipe:
producer.r.pipeline().lrange.assert_called_once_with(topic, start, end)
assert ret == redis.Redis().pipeline().lrange()
else:
producer.r.lrange.assert_called_once_with(topic, start, end)
assert ret == []
@pytest.mark.parametrize(
"topic, msg, pipe, expire", [["topic1", "msg1", None, 400], ["topic2", "msg2", None, None]]
)
def test_redis_producer_set_and_publish(producer, topic, msg, pipe, expire):
producer.set_and_publish(topic, msg, pipe, expire)
producer.r.pipeline().publish.assert_called_once_with(topic, msg)
producer.r.pipeline().set.assert_called_once_with(topic, msg)
if expire:
producer.r.pipeline().expire.assert_called_once_with(topic, expire)
if not pipe:
producer.r.pipeline().execute.assert_called_once()
@pytest.mark.parametrize("topic, msg, expire", [["topic1", "msg1", None], ["topic2", "msg2", 400]])
def test_redis_producer_set(producer, topic, msg, expire):
pipe = None
producer.set(topic, msg, pipe, expire)
if pipe:
producer.r.pipeline().set.assert_called_once_with(topic, msg, ex=expire)
else:
producer.r.set.assert_called_once_with(topic, msg, ex=expire)
@pytest.mark.parametrize("pattern", ["samx", "samy"])
def test_redis_producer_keys(producer, pattern):
ret = producer.keys(pattern)
producer.r.keys.assert_called_once_with(pattern)
assert ret == redis.Redis().keys()
def test_redis_producer_pipeline(producer):
ret = producer.pipeline()
producer.r.pipeline.assert_called_once()
assert ret == redis.Redis().pipeline()
@pytest.mark.parametrize("topic,use_pipe", [["topic1", True], ["topic2", False]])
def test_redis_producer_delete(producer, topic, use_pipe):
pipe = use_pipe_fcn(producer, use_pipe)
producer.delete(topic, pipe)
if pipe:
producer.pipeline().delete.assert_called_once_with(topic)
else:
producer.r.delete.assert_called_once_with(topic)
@pytest.mark.parametrize("topic, use_pipe", [["topic1", True], ["topic2", False]])
def test_redis_producer_get(producer, topic, use_pipe):
pipe = use_pipe_fcn(producer, use_pipe)
ret = producer.get(topic, pipe)
if pipe:
producer.pipeline().get.assert_called_once_with(topic)
assert ret == redis.Redis().pipeline().get()
else:
producer.r.get.assert_called_once_with(topic)
assert ret == redis.Redis().get()
def use_pipe_fcn(producer, use_pipe):
if use_pipe:
return producer.pipeline()
return None
@pytest.mark.parametrize(
"topics, pattern",
[
["topics1", None],
[["topics1", "topics2"], None],
[None, "pattern1"],
[None, ["pattern1", "pattern2"]],
],
)
def test_redis_consumer_init(consumer, topics, pattern):
with mock.patch("bec_lib.redis_connector.redis.Redis"):
consumer = RedisConsumer(
"localhost", "1", topics, pattern, redis_cls=redis.Redis, cb=lambda *args, **kwargs: ...
)
if topics:
if isinstance(topics, list):
assert consumer.topics == topics
else:
assert consumer.topics == [topics]
if pattern:
if isinstance(pattern, list):
assert consumer.pattern == pattern
else:
assert consumer.pattern == [pattern]
assert consumer.r == redis.Redis()
assert consumer.pubsub == consumer.r.pubsub()
assert consumer.host == "localhost"
assert consumer.port == "1"
@pytest.mark.parametrize("pattern, topics", [["pattern", "topics1"], [None, "topics2"]])
def test_redis_consumer_initialize_connector(consumer, pattern, topics):
consumer.pattern = pattern
consumer.topics = topics
consumer.initialize_connector()
if consumer.pattern is not None:
consumer.pubsub.psubscribe.assert_called_once_with(consumer.pattern)
else:
consumer.pubsub.subscribe.assert_called_with(consumer.topics)
def test_redis_consumer_poll_messages(consumer):
cb_fcn_has_been_called = False
def cb_fcn(msg, **kwargs):
nonlocal cb_fcn_has_been_called
cb_fcn_has_been_called = True
print(msg)
consumer.cb = cb_fcn
test_msg = TestMessage("test")
consumer.pubsub.get_message.return_value = {
"channel": "",
"data": MsgpackSerialization.dumps(test_msg),
}
ret = consumer.poll_messages()
consumer.pubsub.get_message.assert_called_once_with(ignore_subscribe_messages=True)
assert cb_fcn_has_been_called
def test_redis_consumer_shutdown(consumer):
consumer.shutdown()
consumer.pubsub.close.assert_called_once()
def test_redis_consumer_additional_kwargs(connector):
cons = connector.consumer(topics="topic1", parent="here", cb=lambda *args, **kwargs: ...)
assert "parent" in cons.kwargs
@pytest.mark.parametrize(
"topics, pattern",
[
["topics1", None],
[["topics1", "topics2"], None],
[None, "pattern1"],
[None, ["pattern1", "pattern2"]],
],
)
def test_mixin_init_topics_and_pattern(mixin, topics, pattern):
ret_topics, ret_pattern = mixin._init_topics_and_pattern(mixin, topics, pattern)
if topics:
if isinstance(topics, list):
assert ret_topics == topics
else:
assert ret_topics == [topics]
if pattern:
if isinstance(pattern, list):
assert ret_pattern == pattern
else:
assert ret_pattern == [pattern]
def test_mixin_init_redis_cls(mixin, consumer):
mixin._init_redis_cls(consumer, None)
assert consumer.r == redis.Redis(host="localhost", port=1)
@pytest.mark.parametrize(
"topics, pattern",
[
["topics1", None],
[["topics1", "topics2"], None],
[None, "pattern1"],
[None, ["pattern1", "pattern2"]],
],
)
def test_redis_consumer_threaded_init(consumer_threaded, topics, pattern):
with mock.patch("bec_lib.redis_connector.redis.Redis"):
consumer_threaded = RedisConsumerThreaded(
"localhost", "1", topics, pattern, redis_cls=redis.Redis, cb=lambda *args, **kwargs: ...
)
if topics:
if isinstance(topics, list):
assert consumer_threaded.topics == topics
else:
assert consumer_threaded.topics == [topics]
if pattern:
if isinstance(pattern, list):
assert consumer_threaded.pattern == pattern
else:
assert consumer_threaded.pattern == [pattern]
assert consumer_threaded.r == redis.Redis()
assert consumer_threaded.pubsub == consumer_threaded.r.pubsub()
assert consumer_threaded.host == "localhost"
assert consumer_threaded.port == "1"
assert consumer_threaded.sleep_times == [0.005, 0.1]
assert consumer_threaded.last_received_msg == 0
assert consumer_threaded.idle_time == 30
def test_redis_connector_xadd(producer):
producer.xadd("topic1", {"key": "value"})
producer.r.xadd.assert_called_once_with("topic1", {"key": MsgpackSerialization.dumps("value")})
test_msg = TestMessage("test")
producer.xadd("topic1", {"data": test_msg})
producer.r.xadd.assert_called_with("topic1", {"data": MsgpackSerialization.dumps(test_msg)})
producer.r.xrevrange.return_value = [
(b"1707391599960-0", {b"data": MsgpackSerialization.dumps(test_msg)})
]
msg = producer.get_last("topic1")
assert msg == test_msg
def test_redis_connector_xadd_with_maxlen(producer):
producer.xadd("topic1", {"key": "value"}, max_size=100)
producer.r.xadd.assert_called_once_with(
"topic1", {"key": MsgpackSerialization.dumps("value")}, maxlen=100
)
def test_redis_connector_xadd_with_expire(producer):
producer.xadd("topic1", {"key": "value"}, expire=100)
producer.r.pipeline().xadd.assert_called_once_with(
"topic1", {"key": MsgpackSerialization.dumps("value")}
)
producer.r.pipeline().expire.assert_called_once_with("topic1", 100)
producer.r.pipeline().execute.assert_called_once()
def test_redis_connector_xread(producer):
producer.xread("topic1", "id")
producer.r.xread.assert_called_once_with({"topic1": "id"}, count=None, block=None)
def test_redis_connector_xread_without_id(producer):
producer.xread("topic1", from_start=True)
producer.r.xread.assert_called_once_with({"topic1": "0-0"}, count=None, block=None)
producer.r.xread.reset_mock()
producer.stream_keys["topic1"] = "id"
producer.xread("topic1")
producer.r.xread.assert_called_once_with({"topic1": "id"}, count=None, block=None)
def test_redis_connector_xread_from_end(producer):
producer.xread("topic1", from_start=False)
producer.r.xread.assert_called_once_with({"topic1": "$"}, count=None, block=None)
def test_redis_connector_xread_from_new_topic(producer):
producer.xread("topic1", from_start=False)
producer.r.xread.assert_called_once_with({"topic1": "$"}, count=None, block=None)
def test_redis_connector_get_last(producer):
producer.r.xrevrange.return_value = [
(b"1707391599960-0", {b"key": MsgpackSerialization.dumps("value")})
]
msg = producer.get_last("topic1")
producer.r.xrevrange.assert_called_once_with("topic1", "+", "-", count=1)
assert msg is None # no key given, default is b'data'
assert producer.get_last("topic1", "key") == "value"
assert producer.get_last("topic1", None) == {"key": "value"}
def test_redis_xrange(producer):
producer.xrange("topic1", "start", "end")
producer.r.xrange.assert_called_once_with("topic1", "start", "end", count=None)
def test_redis_xrange_topic_with_suffix(producer):
producer.xrange("topic1", "start", "end")
producer.r.xrange.assert_called_once_with("topic1", "start", "end", count=None)
def test_redis_consumer_threaded_no_cb_without_messages(consumer_threaded):
with mock.patch.object(consumer_threaded.pubsub, "get_message", return_value=None):
consumer_threaded.cb = mock.MagicMock()
consumer_threaded.poll_messages()
consumer_threaded.cb.assert_not_called()
def test_redis_consumer_threaded_cb_called_with_messages(consumer_threaded):
message = {"channel": b"topic1", "data": MsgpackSerialization.dumps(TestMessage("test"))}
with mock.patch.object(consumer_threaded.pubsub, "get_message", return_value=message):
consumer_threaded.cb = mock.MagicMock()
consumer_threaded.poll_messages()
msg_object = MessageObject("topic1", TestMessage("test"))
consumer_threaded.cb.assert_called_once_with(msg_object)
def test_redis_consumer_threaded_shutdown(consumer_threaded):
consumer_threaded.shutdown()
consumer_threaded.pubsub.close.assert_called_once()
def test_redis_stream_consumer_threaded_get_newest_message():
consumer = RedisStreamConsumerThreaded(
"localhost", "1", topics="topic1", cb=mock.MagicMock(), redis_cls=mock.MagicMock()
)
consumer.r.xrevrange.return_value = [(b"1691610882756-0", {b"data": b"msg"})]
msgs = []
consumer.get_newest_message(msgs)
assert "topic1" in consumer.stream_keys
assert consumer.stream_keys["topic1"] == b"1691610882756-0"
def test_redis_stream_consumer_threaded_get_newest_message_no_msg():
consumer = RedisStreamConsumerThreaded(
"localhost", "1", topics="topic1", cb=mock.MagicMock(), redis_cls=mock.MagicMock()
)
consumer.r.xrevrange.return_value = []
msgs = []
consumer.get_newest_message(msgs)
assert "topic1" in consumer.stream_keys
assert consumer.stream_keys["topic1"] == "0-0"
def test_redis_stream_consumer_threaded_get_id():
consumer = RedisStreamConsumerThreaded(
"localhost", "1", topics="topic1", cb=mock.MagicMock(), redis_cls=mock.MagicMock()
)
consumer.stream_keys["topic1"] = b"1691610882756-0"
assert consumer.get_id("topic1") == b"1691610882756-0"
assert consumer.get_id("doesnt_exist") == "0-0"
def test_redis_stream_consumer_threaded_poll_messages():
consumer = RedisStreamConsumerThreaded(
"localhost", "1", topics="topic1", cb=mock.MagicMock(), redis_cls=mock.MagicMock()
)
with mock.patch.object(
consumer, "get_newest_message", return_value=None
) as mock_get_newest_message:
consumer.poll_messages()
mock_get_newest_message.assert_called_once()
consumer.r.xread.assert_not_called()
def test_redis_stream_consumer_threaded_poll_messages_newest_only():
consumer = RedisStreamConsumerThreaded(
"localhost",
"1",
topics="topic1",
cb=mock.MagicMock(),
redis_cls=mock.MagicMock(),
newest_only=True,
)
consumer.r.xrevrange.return_value = [(b"1691610882756-0", {b"data": b"msg"})]
consumer.poll_messages()
consumer.r.xread.assert_not_called()
consumer.cb.assert_called_once_with(MessageObject(topic="topic1", value=b"msg"))
def test_redis_stream_consumer_threaded_poll_messages_read():
consumer = RedisStreamConsumerThreaded(
"localhost", "1", topics="topic1", cb=mock.MagicMock(), redis_cls=mock.MagicMock()
)
consumer.stream_keys["topic1"] = "0-0"
msg = [[b"topic1", [(b"1691610714612-0", {b"data": b"msg"})]]]
consumer.r.xread.return_value = msg
consumer.poll_messages()
consumer.r.xread.assert_called_once_with({"topic1": "0-0"}, count=1)
consumer.cb.assert_called_once_with(MessageObject(topic="topic1", value=b"msg"))
@pytest.mark.parametrize(
"topics,expected",
[
("topic1", ["topic1"]),
(["topic1"], ["topic1"]),
(["topic1", "topic2"], ["topic1", "topic2"]),
],
)
def test_redis_stream_consumer_threaded_init_topics(topics, expected):
consumer = RedisStreamConsumerThreaded(
"localhost", "1", topics=topics, cb=mock.MagicMock(), redis_cls=mock.MagicMock()
)
assert consumer.topics == expected