bec/bec_lib/tests/test_redis_connector.py

588 lines
20 KiB
Python

from dataclasses import dataclass, field
from unittest import mock
import pytest
import redis
import bec_lib.messages as bec_messages
from bec_lib.alarm_handler import Alarms
from bec_lib.connector import ConsumerConnectorError
from bec_lib.endpoints import MessageEndpoints
from bec_lib.messages import AlarmMessage, BECMessage, LogMessage
from bec_lib.redis_connector import (
MessageObject,
RedisConnector,
)
from bec_lib.serialization import MsgpackSerialization
@pytest.fixture
def connector():
with mock.patch("bec_lib.redis_connector.redis.Redis"):
connector = RedisConnector("localhost:1")
try:
yield connector
finally:
connector.shutdown()
@pytest.fixture
def connected_connector(redis_proc):
connector = RedisConnector(f"localhost:{redis_proc.port}")
try:
yield connector
finally:
connector.shutdown()
@pytest.mark.parametrize(
"topics, threaded",
[["topics", True], ["topics", False], [None, True], [None, False]],
)
def test_redis_connector_register(connected_connector, threaded, topics):
breakpoint()
connector = connected_connector
if topics is None:
with pytest.raises(TypeError):
ret = connector.register(
topics=topics, cb=lambda *args, **kwargs: ..., start_thread=threaded
)
else:
ret = connector.register(
topics=topics, cb=lambda *args, **kwargs: ..., start_thread=threaded
)
if threaded:
assert connector._events_listener_thread is not None
def test_redis_connector_log_warning(connector):
with mock.patch.object(connector, "send", return_value=None):
connector.log_warning("msg")
connector.send.assert_called_once_with(
MessageEndpoints.log(), LogMessage(log_type="warning", log_msg="msg")
)
def test_redis_connector_log_message(connector):
with mock.patch.object(connector, "send", return_value=None):
connector.log_message("msg")
connector.send.assert_called_once_with(
MessageEndpoints.log(), LogMessage(log_type="log", log_msg="msg")
)
def test_redis_connector_log_error(connector):
with mock.patch.object(connector, "send", return_value=None):
connector.log_error("msg")
connector.send.assert_called_once_with(
MessageEndpoints.log(), LogMessage(log_type="error", log_msg="msg")
)
@pytest.mark.parametrize(
"severity, alarm_type, source, msg, metadata",
[
[Alarms.MAJOR, "alarm", "source", "content1", {"metadata": "metadata1"}],
[Alarms.MINOR, "alarm", "source", "content1", {"metadata": "metadata1"}],
[Alarms.WARNING, "alarm", "source", "content1", {"metadata": "metadata1"}],
],
)
def test_redis_connector_raise_alarm(connector, severity, alarm_type, source, msg, metadata):
with mock.patch.object(connector, "set_and_publish", return_value=None):
connector.raise_alarm(severity, alarm_type, source, msg, metadata)
connector.set_and_publish.assert_called_once_with(
MessageEndpoints.alarm(),
AlarmMessage(
severity=severity,
alarm_type=alarm_type,
source=source,
msg=msg,
metadata=metadata,
),
)
@dataclass(eq=False)
class TestMessage(BECMessage):
__test__ = False # just for pytest to ignore this class
msg_type = "test_message"
msg: str
# have to add this field here,
# could be inherited but it requires Python 3.10
# and 'kw_only=True'
metadata: dict = field(default_factory=lambda: {})
# register at BEC messages module level, to be able to
# find it when using "loads()"
bec_messages.TestMessage = TestMessage
@pytest.mark.parametrize(
"topic , msg", [["topic1", TestMessage("msg1")], ["topic2", TestMessage("msg2")]]
)
def test_redis_connector_send(connector, topic, msg):
connector.send(topic, msg)
connector._redis_conn.publish.assert_called_once_with(topic, MsgpackSerialization.dumps(msg))
connector.send(topic, msg, pipe=connector.pipeline())
connector._redis_conn.pipeline().publish.assert_called_once_with(
topic, MsgpackSerialization.dumps(msg)
)
@pytest.mark.parametrize(
"topic, msgs, max_size, expire",
[
["topic1", "msgs", None, None],
["topic1", "msgs", 10, None],
["topic1", "msgs", None, 100],
],
)
def test_redis_connector_lpush(connector, topic, msgs, max_size, expire):
pipe = None
connector.lpush(topic, msgs, pipe, max_size, expire)
connector._redis_conn.pipeline().lpush.assert_called_once_with(topic, msgs)
if max_size:
connector._redis_conn.pipeline().ltrim.assert_called_once_with(topic, 0, max_size)
if expire:
connector._redis_conn.pipeline().expire.assert_called_once_with(topic, expire)
if not pipe:
connector._redis_conn.pipeline().execute.assert_called_once()
@pytest.mark.parametrize(
"topic, msgs, max_size, expire",
[
["topic1", TestMessage("msgs"), None, None],
["topic1", TestMessage("msgs"), 10, None],
["topic1", TestMessage("msgs"), None, 100],
],
)
def test_redis_connector_lpush_BECMessage(connector, topic, msgs, max_size, expire):
pipe = None
connector.lpush(topic, msgs, pipe, max_size, expire)
connector._redis_conn.pipeline().lpush.assert_called_once_with(
topic, MsgpackSerialization.dumps(msgs)
)
if max_size:
connector._redis_conn.pipeline().ltrim.assert_called_once_with(topic, 0, max_size)
if expire:
connector._redis_conn.pipeline().expire.assert_called_once_with(topic, expire)
if not pipe:
connector._redis_conn.pipeline().execute.assert_called_once()
@pytest.mark.parametrize(
"topic , index , msgs, use_pipe",
[["topic1", 1, "msg1", True], ["topic2", 4, "msg2", False]],
)
def test_redis_connector_lset(connector, topic, index, msgs, use_pipe):
pipe = use_pipe_fcn(connector, use_pipe)
ret = connector.lset(topic, index, msgs, pipe)
if pipe:
connector._redis_conn.pipeline().lset.assert_called_once_with(topic, index, msgs)
assert ret == redis.Redis().pipeline().lset()
else:
connector._redis_conn.lset.assert_called_once_with(topic, index, msgs)
assert ret == redis.Redis().lset()
@pytest.mark.parametrize(
"topic , index , msgs, use_pipe",
[
["topic1", 1, TestMessage("msg1"), True],
["topic2", 4, TestMessage("msg2"), False],
],
)
def test_redis_connector_lset_BECMessage(connector, topic, index, msgs, use_pipe):
pipe = use_pipe_fcn(connector, use_pipe)
ret = connector.lset(topic, index, msgs, pipe)
if pipe:
connector._redis_conn.pipeline().lset.assert_called_once_with(
topic, index, MsgpackSerialization.dumps(msgs)
)
assert ret == redis.Redis().pipeline().lset()
else:
connector._redis_conn.lset.assert_called_once_with(
topic, index, MsgpackSerialization.dumps(msgs)
)
assert ret == redis.Redis().lset()
@pytest.mark.parametrize(
"topic, msgs, use_pipe", [["topic1", "msg1", True], ["topic2", "msg2", False]]
)
def test_redis_connector_rpush(connector, topic, msgs, use_pipe):
pipe = use_pipe_fcn(connector, use_pipe)
ret = connector.rpush(topic, msgs, pipe)
if pipe:
connector._redis_conn.pipeline().rpush.assert_called_once_with(topic, msgs)
assert ret == redis.Redis().pipeline().rpush()
else:
connector._redis_conn.rpush.assert_called_once_with(topic, msgs)
assert ret == redis.Redis().rpush()
@pytest.mark.parametrize(
"topic, msgs, use_pipe",
[["topic1", TestMessage("msg1"), True], ["topic2", TestMessage("msg2"), False]],
)
def test_redis_connector_rpush_BECMessage(connector, topic, msgs, use_pipe):
pipe = use_pipe_fcn(connector, use_pipe)
ret = connector.rpush(topic, msgs, pipe)
if pipe:
connector._redis_conn.pipeline().rpush.assert_called_once_with(
topic, MsgpackSerialization.dumps(msgs)
)
assert ret == redis.Redis().pipeline().rpush()
else:
connector._redis_conn.rpush.assert_called_once_with(topic, MsgpackSerialization.dumps(msgs))
assert ret == redis.Redis().rpush()
@pytest.mark.parametrize(
"topic, start, end, use_pipe", [["topic1", 0, 4, True], ["topic2", 3, 7, False]]
)
def test_redis_connector_lrange(connector, topic, start, end, use_pipe):
pipe = use_pipe_fcn(connector, use_pipe)
ret = connector.lrange(topic, start, end, pipe)
if pipe:
connector._redis_conn.pipeline().lrange.assert_called_once_with(topic, start, end)
assert ret == redis.Redis().pipeline().lrange()
else:
connector._redis_conn.lrange.assert_called_once_with(topic, start, end)
assert ret == []
@pytest.mark.parametrize(
"topic, msg, pipe, expire",
[
["topic1", TestMessage("msg1"), None, 400],
["topic2", TestMessage("msg2"), None, None],
["topic3", "msg3", None, None],
],
)
def test_redis_connector_set_and_publish(connector, topic, msg, pipe, expire):
if not isinstance(msg, BECMessage):
with pytest.raises(TypeError):
connector.set_and_publish(topic, msg, pipe, expire)
else:
connector.set_and_publish(topic, msg, pipe, expire)
connector._redis_conn.pipeline().publish.assert_called_once_with(
topic, MsgpackSerialization.dumps(msg)
)
connector._redis_conn.pipeline().set.assert_called_once_with(
topic, MsgpackSerialization.dumps(msg), ex=expire
)
if not pipe:
connector._redis_conn.pipeline().execute.assert_called_once()
@pytest.mark.parametrize("topic, msg, expire", [["topic1", "msg1", None], ["topic2", "msg2", 400]])
def test_redis_connector_set(connector, topic, msg, expire):
pipe = None
connector.set(topic, msg, pipe, expire)
if pipe:
connector._redis_conn.pipeline().set.assert_called_once_with(topic, msg, ex=expire)
else:
connector._redis_conn.set.assert_called_once_with(topic, msg, ex=expire)
@pytest.mark.parametrize("pattern", ["samx", "samy"])
def test_redis_connector_keys(connector, pattern):
ret = connector.keys(pattern)
connector._redis_conn.keys.assert_called_once_with(pattern)
assert ret == redis.Redis().keys()
def test_redis_connector_pipeline(connector):
ret = connector.pipeline()
connector._redis_conn.pipeline.assert_called_once()
assert ret == redis.Redis().pipeline()
def use_pipe_fcn(connector, use_pipe):
if use_pipe:
return connector.pipeline()
return None
@pytest.mark.parametrize("topic,use_pipe", [["topic1", True], ["topic2", False]])
def test_redis_connector_delete(connector, topic, use_pipe):
pipe = use_pipe_fcn(connector, use_pipe)
connector.delete(topic, pipe)
if pipe:
connector.pipeline().delete.assert_called_once_with(topic)
else:
connector._redis_conn.delete.assert_called_once_with(topic)
@pytest.mark.parametrize("topic, use_pipe", [["topic1", True], ["topic2", False]])
def test_redis_connector_get(connector, topic, use_pipe):
pipe = use_pipe_fcn(connector, use_pipe)
ret = connector.get(topic, pipe)
if pipe:
connector.pipeline().get.assert_called_once_with(topic)
assert ret == redis.Redis().pipeline().get()
else:
connector._redis_conn.get.assert_called_once_with(topic)
assert ret == redis.Redis().get()
@pytest.mark.parametrize(
"subscribed_topics, subscribed_patterns, msgs",
[
["topics1", None, ["topics1"]],
[["topics1", "topics2"], None, ["topics1", "topics2"]],
[None, "pattern1", ["pattern1"]],
[None, ["patt*", "top*"], ["pattern1", "topics1"]],
],
)
def test_redis_connector_register(
redisdb, connected_connector, subscribed_topics, subscribed_patterns, msgs
):
connector = connected_connector
test_msg = TestMessage("test")
cb_mock = mock.Mock(spec=[]) # spec is here to remove all attributes
if subscribed_topics:
connector.register(
subscribed_topics, subscribed_patterns, cb=cb_mock, start_thread=False, a=1
)
for msg in msgs:
connector.send(msg, TestMessage(msg))
connector.poll_messages()
msg_object = MessageObject(msg, TestMessage(msg))
cb_mock.assert_called_with(msg_object, a=1)
if subscribed_patterns:
connector.register(
subscribed_topics, subscribed_patterns, cb=cb_mock, start_thread=False, a=1
)
for msg in msgs:
connector.send(msg, TestMessage(msg))
connector.poll_messages()
msg_object = MessageObject(msg, TestMessage(msg))
cb_mock.assert_called_with(msg_object, a=1)
def test_redis_register_poll_messages(redisdb, connected_connector):
connector = connected_connector
cb_fcn_has_been_called = False
def cb_fcn(msg, **kwargs):
nonlocal cb_fcn_has_been_called
cb_fcn_has_been_called = True
assert kwargs["a"] == 1
test_msg = TestMessage("test")
connector.register("test", cb=cb_fcn, a=1, start_thread=False)
redisdb.publish("test", MsgpackSerialization.dumps(test_msg))
connector.poll_messages(timeout=1)
assert cb_fcn_has_been_called
with pytest.raises(TimeoutError):
connector.poll_messages(timeout=0.1)
def test_redis_connector_xadd(connector):
connector.xadd("topic1", {"key": "value"})
connector._redis_conn.xadd.assert_called_once_with("topic1", {"key": "value"})
def test_redis_connector_xadd_with_maxlen(connector):
connector.xadd("topic1", {"key": "value"}, max_size=100)
connector._redis_conn.xadd.assert_called_once_with("topic1", {"key": "value"}, maxlen=100)
def test_redis_connector_xadd_with_expire(connector):
connector.xadd("topic1", {"key": "value"}, expire=100)
connector._redis_conn.pipeline().xadd.assert_called_once_with("topic1", {"key": "value"})
connector._redis_conn.pipeline().expire.assert_called_once_with("topic1", 100)
connector._redis_conn.pipeline().execute.assert_called_once()
def test_redis_connector_xread(connector):
connector.xread("topic1", "id")
connector._redis_conn.xread.assert_called_once_with({"topic1": "id"}, count=None, block=None)
def test_redis_connector_xadd(connector):
connector.xadd("topic1", {"key": "value"})
connector._redis_conn.xadd.assert_called_once_with(
"topic1", {"key": MsgpackSerialization.dumps("value")}
)
test_msg = TestMessage("test")
connector.xadd("topic1", {"data": test_msg})
connector._redis_conn.xadd.assert_called_with(
"topic1", {"data": MsgpackSerialization.dumps(test_msg)}
)
connector._redis_conn.xrevrange.return_value = [
(b"1707391599960-0", {b"data": MsgpackSerialization.dumps(test_msg)})
]
msg = connector.get_last("topic1")
assert msg == test_msg
def test_redis_connector_xadd_with_maxlen(connector):
connector.xadd("topic1", {"key": "value"}, max_size=100)
connector._redis_conn.xadd.assert_called_once_with(
"topic1", {"key": MsgpackSerialization.dumps("value")}, maxlen=100
)
def test_redis_connector_xadd_with_expire(connector):
connector.xadd("topic1", {"key": "value"}, expire=100)
connector._redis_conn.pipeline().xadd.assert_called_once_with(
"topic1", {"key": MsgpackSerialization.dumps("value")}
)
connector._redis_conn.pipeline().expire.assert_called_once_with("topic1", 100)
connector._redis_conn.pipeline().execute.assert_called_once()
def test_redis_connector_xread(connector):
connector.xread("topic1", "id")
connector._redis_conn.xread.assert_called_once_with({"topic1": "id"}, count=None, block=None)
def test_redis_connector_xread_without_id(connector):
connector.xread("topic1", from_start=True)
connector._redis_conn.xread.assert_called_once_with({"topic1": "0-0"}, count=None, block=None)
connector._redis_conn.xread.reset_mock()
connector.stream_keys["topic1"] = "id"
connector.xread("topic1")
connector._redis_conn.xread.assert_called_once_with({"topic1": "id"}, count=None, block=None)
def test_redis_connector_xread_from_end(connector):
connector.xread("topic1", from_start=False)
connector._redis_conn.xrevrange.assert_called_once_with("topic1", "+", "-", count=1)
def test_redis_connector_get_last(connector):
connector._redis_conn.xrevrange.return_value = [
(b"1707391599960-0", {b"key": MsgpackSerialization.dumps("value")})
]
msg = connector.get_last("topic1")
connector._redis_conn.xrevrange.assert_called_once_with("topic1", "+", "-", count=1)
assert msg is None # no key given, default is b'data'
assert connector.get_last("topic1", "key") == "value"
assert connector.get_last("topic1", None) == {"key": "value"}
def test_redis_connector_xread_without_id(connector):
connector.xread("topic1", from_start=True)
connector._redis_conn.xread.assert_called_once_with({"topic1": "0-0"}, count=None, block=None)
connector._redis_conn.xread.reset_mock()
connector.stream_keys["topic1"] = "id"
connector.xread("topic1")
connector._redis_conn.xread.assert_called_once_with({"topic1": "id"}, count=None, block=None)
def test_redis_xrange(connector):
connector.xrange("topic1", "start", "end")
connector._redis_conn.xrange.assert_called_once_with("topic1", "start", "end", count=None)
def test_redis_xrange_topic_with_suffix(connector):
connector.xrange("topic1", "start", "end")
connector._redis_conn.xrange.assert_called_once_with("topic1", "start", "end", count=None)
# def test_redis_stream_register_threaded_get_id():
# register = RedisStreamConsumerThreaded(
# "localhost", "1", topics="topic1", cb=mock.MagicMock(), redis_cls=mock.MagicMock()
# )
# register.stream_keys["topic1"] = b"1691610882756-0"
# assert register.get_id("topic1") == b"1691610882756-0"
# assert register.get_id("doesnt_exist") == "0-0"
# def test_redis_stream_register_threaded_poll_messages():
# register = RedisStreamConsumerThreaded(
# "localhost", "1", topics="topic1", cb=mock.MagicMock(), redis_cls=mock.MagicMock()
# )
# with mock.patch.object(
# register, "get_newest_message", return_value=None
# ) as mock_get_newest_message:
# register.poll_messages()
# mock_get_newest_message.assert_called_once()
# register._redis_conn.xread.assert_not_called()
# def test_redis_stream_register_threaded_poll_messages_newest_only():
# register = RedisStreamConsumerThreaded(
# "localhost",
# "1",
# topics="topic1",
# cb=mock.MagicMock(),
# redis_cls=mock.MagicMock(),
# newest_only=True,
# )
#
# register._redis_conn.xrevrange.return_value = [(b"1691610882756-0", {b"data": b"msg"})]
# register.poll_messages()
# register._redis_conn.xread.assert_not_called()
# register.cb.assert_called_once_with(MessageObject(topic="topic1", value=b"msg"))
# def test_redis_stream_register_threaded_poll_messages_read():
# register = RedisStreamConsumerThreaded(
# "localhost",
# "1",
# topics="topic1",
# cb=mock.MagicMock(),
# redis_cls=mock.MagicMock(),
# )
# register.stream_keys["topic1"] = "0-0"
#
# msg = [[b"topic1", [(b"1691610714612-0", {b"data": b"msg"})]]]
#
# register._redis_conn.xread.return_value = msg
# register.poll_messages()
# register._redis_conn.xread.assert_called_once_with({"topic1": "0-0"}, count=1)
# register.cb.assert_called_once_with(MessageObject(topic="topic1", value=b"msg"))
# @pytest.mark.parametrize(
# "topics,expected",
# [
# ("topic1", ["topic1"]),
# (["topic1"], ["topic1"]),
# (["topic1", "topic2"], ["topic1", "topic2"]),
# ],
# )
# def test_redis_stream_register_threaded_init_topics(topics, expected):
# register = RedisStreamConsumerThreaded(
# "localhost",
# "1",
# topics=topics,
# cb=mock.MagicMock(),
# redis_cls=mock.MagicMock(),
# )
# assert register.topics == expected