diff --git a/icon_service_base/database/influxdb_connection.py b/icon_service_base/database/influxdb_connection.py index be301f3..4cd893e 100644 --- a/icon_service_base/database/influxdb_connection.py +++ b/icon_service_base/database/influxdb_connection.py @@ -2,6 +2,7 @@ from __future__ import annotations from pathlib import Path from types import TracebackType +from typing import Optional from confz import FileSource from influxdb_client import ( # type: ignore @@ -49,10 +50,25 @@ class InfluxDBSession: ``` """ - def __init__(self, config_folder: Path | str) -> None: - self._config = InfluxDBConfig( - config_sources=FileSource(Path(config_folder) / "influxdb_config.yaml") - ) + conf_folder: Path | str + + def __init__(self, config_folder: Optional[Path | str] = None) -> None: + config_folder = config_folder or getattr(self, "conf_folder", None) + if InfluxDBConfig.CONFIG_SOURCES is not None or config_folder is not None: + config_sources = None + if config_folder is not None: + config_sources = FileSource( + Path(config_folder) / "influxdb_config.yaml" + ) + self._config = InfluxDBConfig(config_sources=config_sources) + else: + logger.error( + "No config folder given. Please provide a config folder either by " + "passing it to the constructor or by setting the 'conf_folder' " + "attribute." + ) + return + self.url = self._config.url self.token = str(self._config.token) self.org = self._config.org @@ -60,7 +76,7 @@ class InfluxDBSession: self.write_api: WriteApi self.buckets_api: BucketsApi | None = None - def __enter__(self) -> InfluxDBConnection: + def __enter__(self) -> InfluxDBSession: self.client = InfluxDBClient(url=self.url, token=self.token, org=self.org) self.write_api = self.client.write_api(write_options=SYNCHRONOUS) # type: ignore return self diff --git a/icon_service_base/database/postgres_connection.py b/icon_service_base/database/postgres_connection.py index 01ce670..fca19bd 100644 --- a/icon_service_base/database/postgres_connection.py +++ b/icon_service_base/database/postgres_connection.py @@ -5,7 +5,7 @@ import json import re from pathlib import Path from types import TracebackType -from typing import Any +from typing import Any, Optional from confz import FileSource from dateutil.parser import ParserError, parse # type: ignore @@ -128,13 +128,25 @@ class PostgresDatabaseSession(Session): ``` """ - def __init__(self, config_folder: Path | str) -> None: + conf_folder: Path | str + + def __init__(self, config_folder: Optional[Path | str] = None) -> None: """Initializes a new session bound to the database engine.""" - self._config = PostgreSQLConfig( - config_sources=FileSource( - Path(config_folder) / f"postgres_{OperationMode().environment}.yaml" + config_folder = config_folder or getattr(self, "conf_folder", None) + if PostgreSQLConfig.CONFIG_SOURCES is not None or config_folder is not None: + config_sources = None + if config_folder is not None: + config_sources = FileSource( + Path(config_folder) / f"postgres_{OperationMode().environment}.yaml" + ) + self._config = PostgreSQLConfig(config_sources=config_sources) + else: + logger.error( + "No config folder given. Please provide a config folder either by " + "passing it to the constructor or by setting the 'conf_folder' " + "attribute." ) - ) + return super().__init__( bind=create_engine(