From f621ef280e5121a44277d1b51de586d8eae82be5 Mon Sep 17 00:00:00 2001 From: wakonig_k Date: Wed, 12 Jun 2024 19:06:22 +0200 Subject: [PATCH] fix(bec_lib): fixed access to global vars --- bec_lib/bec_lib/bec_service.py | 4 ++-- bec_lib/tests/conftest.py | 17 ++++++++++++++ bec_lib/tests/test_bec_service.py | 23 +++++++++++++++++++ .../tests/test_redis_connector_fakeredis.py | 16 ------------- 4 files changed, 42 insertions(+), 18 deletions(-) diff --git a/bec_lib/bec_lib/bec_service.py b/bec_lib/bec_lib/bec_service.py index 1b12dccd..b4f5e005 100644 --- a/bec_lib/bec_lib/bec_service.py +++ b/bec_lib/bec_lib/bec_service.py @@ -250,13 +250,13 @@ class BECService: """ self.connector.delete(MessageEndpoints.global_vars(name)) - def global_vars(self) -> str: + def show_global_vars(self) -> str: """Get all available global variables""" # sadly, this cannot be a property as it causes side effects with IPython's tab completion available_keys = self.connector.keys(MessageEndpoints.global_vars("*")) def get_endpoint_from_topic(topic: str) -> str: - return topic.decode().split(MessageEndpoints.global_vars(""))[-1] + return topic.decode().split(MessageEndpoints.global_vars("").endpoint)[-1] endpoints = [get_endpoint_from_topic(k) for k in available_keys] diff --git a/bec_lib/tests/conftest.py b/bec_lib/tests/conftest.py index 663d9f61..2c1e58eb 100644 --- a/bec_lib/tests/conftest.py +++ b/bec_lib/tests/conftest.py @@ -1,6 +1,8 @@ +import fakeredis import pytest from bec_lib import bec_logger +from bec_lib.redis_connector import RedisConnector # overwrite threads_check fixture from bec_lib, # to have it in autouse @@ -10,3 +12,18 @@ from bec_lib import bec_logger def threads_check(threads_check): yield bec_logger.logger.remove() + + +def fake_redis_server(host, port): + redis = fakeredis.FakeRedis() + return redis + + +@pytest.fixture +def connected_connector(): + connector = RedisConnector("localhost:1", redis_cls=fake_redis_server) + connector._redis_conn.flushall() + try: + yield connector + finally: + connector.shutdown() diff --git a/bec_lib/tests/test_bec_service.py b/bec_lib/tests/test_bec_service.py index d7a5b32a..3ed21fa6 100644 --- a/bec_lib/tests/test_bec_service.py +++ b/bec_lib/tests/test_bec_service.py @@ -152,3 +152,26 @@ def test_bec_service_default_config(): os.path.abspath(service._service_config.service_config["file_writer"]["base_path"]) == bec_lib_path ) + + +def test_bec_service_show_global_vars(capsys): + config = ServiceConfig(redis={"host": "localhost", "port": 6379}) + with bec_service(config=config, unique_service=True) as service: + ep = MessageEndpoints.global_vars("test").endpoint.encode() + with mock.patch.object(service.connector, "keys", return_value=[ep]): + with mock.patch.object(service, "get_global_var", return_value="test_value"): + service.show_global_vars() + captured = capsys.readouterr() + assert "test" in captured.out + assert "test_value" in captured.out + + +def test_bec_service_globals(connected_connector): + config = ServiceConfig(redis={"host": "localhost", "port": 1}) + with bec_service(config=config, unique_service=True) as service: + service.connector = connected_connector + service.set_global_var("test", "test_value") + assert service.get_global_var("test") == "test_value" + + service.delete_global_var("test") + assert service.get_global_var("test") is None diff --git a/bec_lib/tests/test_redis_connector_fakeredis.py b/bec_lib/tests/test_redis_connector_fakeredis.py index 0426d61a..8a9cccfa 100644 --- a/bec_lib/tests/test_redis_connector_fakeredis.py +++ b/bec_lib/tests/test_redis_connector_fakeredis.py @@ -2,7 +2,6 @@ import threading import time from unittest import mock -import fakeredis import pytest import redis from redis.client import Pipeline @@ -20,21 +19,6 @@ from bec_lib.serialization import MsgpackSerialization # pylint: disable=unused-argument -def fake_redis_server(host, port): - redis = fakeredis.FakeRedis() - return redis - - -@pytest.fixture -def connected_connector(): - connector = RedisConnector("localhost:1", redis_cls=fake_redis_server) - connector._redis_conn.flushall() - try: - yield connector - finally: - connector.shutdown() - - TestStreamEndpoint = EndpointInfo("test", TestMessage(), MessageOp.STREAM) TestStreamEndpoint2 = EndpointInfo("test2", TestMessage(), MessageOp.STREAM)