fix(bec_lib): fixed access to global vars

This commit is contained in:
wakonig_k 2024-06-12 19:06:22 +02:00
parent db991bcf06
commit f621ef280e
4 changed files with 42 additions and 18 deletions

View File

@ -250,13 +250,13 @@ class BECService:
""" """
self.connector.delete(MessageEndpoints.global_vars(name)) self.connector.delete(MessageEndpoints.global_vars(name))
def global_vars(self) -> str: def show_global_vars(self) -> str:
"""Get all available global variables""" """Get all available global variables"""
# sadly, this cannot be a property as it causes side effects with IPython's tab completion # sadly, this cannot be a property as it causes side effects with IPython's tab completion
available_keys = self.connector.keys(MessageEndpoints.global_vars("*")) available_keys = self.connector.keys(MessageEndpoints.global_vars("*"))
def get_endpoint_from_topic(topic: str) -> str: def get_endpoint_from_topic(topic: str) -> str:
return topic.decode().split(MessageEndpoints.global_vars(""))[-1] return topic.decode().split(MessageEndpoints.global_vars("").endpoint)[-1]
endpoints = [get_endpoint_from_topic(k) for k in available_keys] endpoints = [get_endpoint_from_topic(k) for k in available_keys]

View File

@ -1,6 +1,8 @@
import fakeredis
import pytest import pytest
from bec_lib import bec_logger from bec_lib import bec_logger
from bec_lib.redis_connector import RedisConnector
# overwrite threads_check fixture from bec_lib, # overwrite threads_check fixture from bec_lib,
# to have it in autouse # to have it in autouse
@ -10,3 +12,18 @@ from bec_lib import bec_logger
def threads_check(threads_check): def threads_check(threads_check):
yield yield
bec_logger.logger.remove() bec_logger.logger.remove()
def fake_redis_server(host, port):
redis = fakeredis.FakeRedis()
return redis
@pytest.fixture
def connected_connector():
connector = RedisConnector("localhost:1", redis_cls=fake_redis_server)
connector._redis_conn.flushall()
try:
yield connector
finally:
connector.shutdown()

View File

@ -152,3 +152,26 @@ def test_bec_service_default_config():
os.path.abspath(service._service_config.service_config["file_writer"]["base_path"]) os.path.abspath(service._service_config.service_config["file_writer"]["base_path"])
== bec_lib_path == bec_lib_path
) )
def test_bec_service_show_global_vars(capsys):
config = ServiceConfig(redis={"host": "localhost", "port": 6379})
with bec_service(config=config, unique_service=True) as service:
ep = MessageEndpoints.global_vars("test").endpoint.encode()
with mock.patch.object(service.connector, "keys", return_value=[ep]):
with mock.patch.object(service, "get_global_var", return_value="test_value"):
service.show_global_vars()
captured = capsys.readouterr()
assert "test" in captured.out
assert "test_value" in captured.out
def test_bec_service_globals(connected_connector):
config = ServiceConfig(redis={"host": "localhost", "port": 1})
with bec_service(config=config, unique_service=True) as service:
service.connector = connected_connector
service.set_global_var("test", "test_value")
assert service.get_global_var("test") == "test_value"
service.delete_global_var("test")
assert service.get_global_var("test") is None

View File

@ -2,7 +2,6 @@ import threading
import time import time
from unittest import mock from unittest import mock
import fakeredis
import pytest import pytest
import redis import redis
from redis.client import Pipeline from redis.client import Pipeline
@ -20,21 +19,6 @@ from bec_lib.serialization import MsgpackSerialization
# pylint: disable=unused-argument # pylint: disable=unused-argument
def fake_redis_server(host, port):
redis = fakeredis.FakeRedis()
return redis
@pytest.fixture
def connected_connector():
connector = RedisConnector("localhost:1", redis_cls=fake_redis_server)
connector._redis_conn.flushall()
try:
yield connector
finally:
connector.shutdown()
TestStreamEndpoint = EndpointInfo("test", TestMessage(), MessageOp.STREAM) TestStreamEndpoint = EndpointInfo("test", TestMessage(), MessageOp.STREAM)
TestStreamEndpoint2 = EndpointInfo("test2", TestMessage(), MessageOp.STREAM) TestStreamEndpoint2 = EndpointInfo("test2", TestMessage(), MessageOp.STREAM)