from __future__ import annotations

import datetime
import json
import re
from pathlib import Path
from types import TracebackType
from typing import Any, Optional

from confz import FileSource
from dateutil.parser import ParserError, parse  # type: ignore
from loguru import logger
from sqlmodel import Session, SQLModel, create_engine

from icon_service_base.database.config import OperationMode, PostgreSQLConfig
from icon_service_base.database.create_config import create_config


def json_loads_or_return_input(input_string: str) -> dict[str, Any] | Any:
    """
    Try to parse a string as JSON, if it fails return the original string.
    """
    try:
        return json.loads(input_string)
    except (TypeError, json.JSONDecodeError):
        return input_string


def parse_datetime_or_return_str(input_string: str) -> datetime.datetime | str:
    try:
        # Attempts to parse the string as a datetime object
        return parse(input_string)
    except ParserError:
        # If parsing fails, return the original input string
        return input_string


def is_datetime_format(input_string: str) -> bool:
    """
    Check if a string is in datetime format.
    """
    try:
        parse(input_string)
        return True
    except ParserError:
        return False


def json_dumps(data: Any) -> str | list:
    """
    Serialize a Python object into a JSON-formatted string, with custom handling for
    datetime and list objects.
    """
    # 'Infinity' is an unallowed token in JSON, thus make it a string
    # https://stackoverflow.com/questions/48356938/store-infinity-in-postgres-json-via-django
    pattern = r"(-?Infinity)"
    result: str | list

    if isinstance(data, str):
        if is_datetime_format(data):
            result = json.dumps(data)
        else:
            result = data
    elif isinstance(data, datetime.datetime):
        result = json.dumps(str(data))
    elif isinstance(data, list):
        result = [json_dumps(element) for element in data]
    else:
        if isinstance(data, SQLModel):
            serialized_data = data.json()
        else:
            serialized_data = json.dumps(data)
        result = re.sub(pattern, r'"\1"', serialized_data)
    return result


def deserialize_json_dict(json_string: str) -> Any:
    """
    Deserialize a JSON string into a Python dictionary.
    """
    # 'Infinity' is an unallowed token in JSON, thus we made it a string. Now, convert
    # it back
    pattern = r'"(-?Infinity)"'
    json_string = re.sub(pattern, r"\1", json_string)
    result: Any

    val = json.loads(json_string)
    json_dict_or_val = json_loads_or_return_input(val)
    if isinstance(json_dict_or_val, str):
        result = parse_datetime_or_return_str(json_dict_or_val)
    else:
        result = json_dict_or_val
    return result


class PostgresDatabaseSession(Session):
    """A class to represent a session with the PostgreSQL database.

    This class inherits from SQLModel's Session class and implements Python's context
    manager protocol. This class helps to ensure that sessions are properly opened
    and closed, even in cases of error.

    The main goal of this class is to provide a way to manage persistence operations
    for ORM-mapped objects.

    Attributes:
        bind: Represents the database engine to which this session is bound.

    Example:
        This class is designed to be used with a context manager (the 'with' keyword).
        Here's how you can use it to query data from a table represented by a SQLModel
        class named 'YourModel':

        ```python
        from your_module.models import YourModel  # replace with your model
        from sqlmodel import select

        with PostgresDatabaseSession() as session:
            row = session.exec(select(YourModel).limit(1)).all()
        ```

        You can also use it to add new data to the table:

        ```python
        with PostgresDatabaseSession() as session:
            new_row = YourModel(...)  # replace ... with your data
            session.add(new_row)
            session.commit()
        ```
    """

    conf_folder: Path | str

    def __init__(self, config_folder: Optional[Path | str] = None) -> None:
        """Initializes a new session bound to the database engine."""
        self._config = create_config(
            PostgreSQLConfig,
            config_folder=config_folder,
            config_file=f"postgres_{OperationMode().environment}.yaml",
        )

        super().__init__(
            bind=create_engine(
                f"postgresql://{self._config.user}:{self._config.password}@"
                f"{self._config.host.host}:{self._config.port}/"
                f"{self._config.database}",
                echo=False,
                json_serializer=json_dumps,
                json_deserializer=deserialize_json_dict,
            )
        )

    def __enter__(self) -> PostgresDatabaseSession:
        """Begins the runtime context related to the database session."""

        super().__enter__()
        return self

    def __exit__(
        self,
        exc_type: type[BaseException] | None,
        exc_value: BaseException | None,
        exc_traceback: TracebackType | None,
    ) -> None:
        """Ends the runtime context related to the database session.

        Ensures that the session is properly closed, even in the case of an error.
        """

        super().__exit__(exc_type, exc_value, exc_traceback)