cleaning up type hinting (using TYPE_CHECKING)

This commit is contained in:
Mose Müller 2023-08-03 14:44:14 +02:00
parent cea831f72c
commit 85a171c33e
3 changed files with 83 additions and 103 deletions

View File

@ -1,78 +1,16 @@
from __future__ import annotations from __future__ import annotations
import asyncio from abc import ABC
from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any
from collections.abc import Callable
from typing import Any, TypedDict
from pydase.data_service.data_service_list import DataServiceList if TYPE_CHECKING:
from .callback_manager import CallbackManager
from .data_service import DataService
from .task_manager import TaskManager
class AbstractDataService(ABC): class AbstractDataService(ABC):
__root__: AbstractDataService __root__: DataService
_task_manager: AbstractTaskManager _task_manager: TaskManager
_callback_manager: AbstractCallbackManager _callback_manager: CallbackManager
"""
This is a CallbackManager. Cannot type this here as this would lead to a recursive
loop.
"""
_autostart_tasks: dict[str, tuple[Any]] _autostart_tasks: dict[str, tuple[Any]]
class TaskDict(TypedDict):
task: asyncio.Task[None]
kwargs: dict[str, Any]
class AbstractTaskManager(ABC):
task_status_change_callbacks: list[Callable[[str, dict[str, Any] | None], Any]]
"""A list of callback functions to be invoked when the status of a task (start or
stop) changes."""
tasks: dict[str, TaskDict]
"""A dictionary to keep track of running tasks. The keys are the names of the
tasks and the values are TaskDict instances which include the task itself and
its kwargs.
"""
@abstractmethod
def _set_start_and_stop_for_async_methods(self) -> None:
...
@abstractmethod
def start_autostart_tasks(self) -> None:
...
class AbstractCallbackManager(ABC):
service: AbstractDataService
callbacks: set[Callable[[str, Any], None]]
_list_mapping: dict[int, DataServiceList]
"""
A dictionary mapping the id of the original lists to the corresponding
DataServiceList instances.
This is used to ensure that all references to the same list within the DataService
object point to the same DataServiceList, so that any modifications to that list can
be tracked consistently. The keys of the dictionary are the ids of the original
lists, and the values are the DataServiceList instances that wrap these lists.
"""
_notification_callbacks: list[Callable[[str, str, Any], Any]] = []
"""
A list of callback functions that are executed when a change occurs in the
DataService instance. These functions are intended to handle or respond to these
changes in some way, such as emitting a socket.io message to the frontend.
Each function in this list should be a callable that accepts three parameters:
- parent_path (str): The path to the parent of the attribute that was changed.
- name (str): The name of the attribute that was changed.
- value (Any): The new value of the attribute.
A callback function can be added to this list using the add_notification_callback
method. Whenever a change in the DataService instance occurs (or in its nested
DataService or DataServiceList instances), the emit_notification method is invoked,
which in turn calls all the callback functions in _notification_callbacks with the
appropriate arguments.
This implementation follows the observer pattern, with the DataService instance as
the "subject" and the callback functions as the "observers".
"""

View File

@ -1,20 +1,53 @@
from __future__ import annotations
import inspect import inspect
from collections.abc import Callable from collections.abc import Callable
from typing import Any, cast from typing import TYPE_CHECKING, Any
from loguru import logger from loguru import logger
from pydase.data_service.abstract_service_classes import AbstractDataService
from pydase.utils.helpers import get_class_and_instance_attributes from pydase.utils.helpers import get_class_and_instance_attributes
from .abstract_service_classes import AbstractCallbackManager, AbstractDataService
from .data_service_list import DataServiceList from .data_service_list import DataServiceList
if TYPE_CHECKING:
from .data_service import DataService
class CallbackManager(AbstractCallbackManager):
_notification_callbacks = []
_list_mapping = {}
def __init__(self, service: "AbstractDataService") -> None: class CallbackManager:
_notification_callbacks: list[Callable[[str, str, Any], Any]] = []
"""
A list of callback functions that are executed when a change occurs in the
DataService instance. These functions are intended to handle or respond to these
changes in some way, such as emitting a socket.io message to the frontend.
Each function in this list should be a callable that accepts three parameters:
- parent_path (str): The path to the parent of the attribute that was changed.
- name (str): The name of the attribute that was changed.
- value (Any): The new value of the attribute.
A callback function can be added to this list using the add_notification_callback
method. Whenever a change in the DataService instance occurs (or in its nested
DataService or DataServiceList instances), the emit_notification method is invoked,
which in turn calls all the callback functions in _notification_callbacks with the
appropriate arguments.
This implementation follows the observer pattern, with the DataService instance as
the "subject" and the callback functions as the "observers".
"""
_list_mapping: dict[int, DataServiceList] = {}
"""
A dictionary mapping the id of the original lists to the corresponding
DataServiceList instances.
This is used to ensure that all references to the same list within the DataService
object point to the same DataServiceList, so that any modifications to that list can
be tracked consistently. The keys of the dictionary are the ids of the original
lists, and the values are the DataServiceList instances that wrap these lists.
"""
def __init__(self, service: DataService) -> None:
self.callbacks: set[Callable[[str, Any], None]] = set() self.callbacks: set[Callable[[str, Any], None]] = set()
self.service = service self.service = service
@ -62,9 +95,7 @@ class CallbackManager(AbstractCallbackManager):
# value at the time the lambda is defined, not when it is called. This # value at the time the lambda is defined, not when it is called. This
# prevents attr_name from being overwritten in the next loop iteration. # prevents attr_name from being overwritten in the next loop iteration.
callback = ( callback = (
lambda index, value, attr_name=attr_name: cast( lambda index, value, attr_name=attr_name: self.service._callback_manager.emit_notification(
CallbackManager, self.service._callback_manager
).emit_notification(
parent_path=parent_path, parent_path=parent_path,
name=f"{attr_name}[{index}]", name=f"{attr_name}[{index}]",
value=value, value=value,
@ -125,9 +156,9 @@ class CallbackManager(AbstractCallbackManager):
# Create and register a callback for the object # Create and register a callback for the object
# only emit the notification when the call was registered by the root object # only emit the notification when the call was registered by the root object
callback: Callable[[str, Any], None] = ( callback: Callable[[str, Any], None] = (
lambda name, value: cast( lambda name, value: obj._callback_manager.emit_notification(
CallbackManager, obj._callback_manager parent_path=parent_path, name=name, value=value
).emit_notification(parent_path=parent_path, name=name, value=value) )
if self.service == obj.__root__ if self.service == obj.__root__
and not name.startswith("_") # we are only interested in public attributes and not name.startswith("_") # we are only interested in public attributes
and not isinstance( and not isinstance(
@ -136,7 +167,7 @@ class CallbackManager(AbstractCallbackManager):
else None else None
) )
cast(CallbackManager, obj._callback_manager).callbacks.add(callback) obj._callback_manager.callbacks.add(callback)
# Recursively register callbacks for all nested attributes of the object # Recursively register callbacks for all nested attributes of the object
attrs = get_class_and_instance_attributes(obj) attrs = get_class_and_instance_attributes(obj)
@ -199,7 +230,7 @@ class CallbackManager(AbstractCallbackManager):
# changed, not reassigned) # changed, not reassigned)
for item in obj_list: for item in obj_list:
if isinstance(item, AbstractDataService): if isinstance(item, AbstractDataService):
cast(CallbackManager, item._callback_manager).callbacks.add(callback) item._callback_manager.callbacks.add(callback)
for attr_name in set(dir(item)) - set(dir(object)) - {"__root__"}: for attr_name in set(dir(item)) - set(dir(object)) - {"__root__"}:
attr_value = getattr(item, attr_name) attr_value = getattr(item, attr_name)
if isinstance(attr_value, (AbstractDataService, DataServiceList)): if isinstance(attr_value, (AbstractDataService, DataServiceList)):
@ -261,9 +292,7 @@ class CallbackManager(AbstractCallbackManager):
dependency_value, (DataServiceList, AbstractDataService) dependency_value, (DataServiceList, AbstractDataService)
): ):
callback = ( callback = (
lambda name, value, dependent_attr=attr_name: cast( lambda name, value, dependent_attr=attr_name: obj._callback_manager.emit_notification(
CallbackManager, obj._callback_manager
).emit_notification(
parent_path=parent_path, parent_path=parent_path,
name=dependent_attr, name=dependent_attr,
value=getattr(obj, dependent_attr), value=getattr(obj, dependent_attr),
@ -278,9 +307,7 @@ class CallbackManager(AbstractCallbackManager):
) )
else: else:
callback = ( callback = (
lambda name, _, dep_attr=attr_name, dep=dependency: cast( # type: ignore lambda name, _, dep_attr=attr_name, dep=dependency: obj._callback_manager.emit_notification( # type: ignore
CallbackManager, obj._callback_manager
).emit_notification(
parent_path=parent_path, parent_path=parent_path,
name=dep_attr, name=dep_attr,
value=getattr(obj, dep_attr), value=getattr(obj, dep_attr),
@ -289,9 +316,7 @@ class CallbackManager(AbstractCallbackManager):
else None else None
) )
# Add to callbacks # Add to callbacks
cast(CallbackManager, obj._callback_manager).callbacks.add( obj._callback_manager.callbacks.add(callback)
callback
)
def _register_start_stop_task_callbacks( def _register_start_stop_task_callbacks(
self, obj: "AbstractDataService", parent_path: str self, obj: "AbstractDataService", parent_path: str
@ -313,9 +338,9 @@ class CallbackManager(AbstractCallbackManager):
# Create and register a callback for the object # Create and register a callback for the object
# only emit the notification when the call was registered by the root object # only emit the notification when the call was registered by the root object
callback: Callable[[str, dict[str, Any] | None], None] = ( callback: Callable[[str, dict[str, Any] | None], None] = (
lambda name, status: cast( lambda name, status: obj._callback_manager.emit_notification(
CallbackManager, obj._callback_manager parent_path=parent_path, name=name, value=status
).emit_notification(parent_path=parent_path, name=name, value=status) )
if self.service == obj.__root__ if self.service == obj.__root__
and not name.startswith("_") # we are only interested in public attributes and not name.startswith("_") # we are only interested in public attributes
else None else None

View File

@ -1,14 +1,23 @@
from __future__ import annotations
import asyncio import asyncio
import inspect import inspect
from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import Any from typing import TYPE_CHECKING, Any, TypedDict
from loguru import logger from loguru import logger
from .abstract_service_classes import AbstractDataService, AbstractTaskManager if TYPE_CHECKING:
from .data_service import DataService
class TaskManager(AbstractTaskManager): class TaskDict(TypedDict):
task: asyncio.Task[None]
kwargs: dict[str, Any]
class TaskManager:
""" """
The TaskManager class is a utility designed to manage asynchronous tasks. It The TaskManager class is a utility designed to manage asynchronous tasks. It
provides functionality for starting, stopping, and tracking these tasks. The class provides functionality for starting, stopping, and tracking these tasks. The class
@ -62,13 +71,21 @@ class TaskManager(AbstractTaskManager):
interfaces, but can also be used to write logs, etc. interfaces, but can also be used to write logs, etc.
""" """
def __init__(self, service: AbstractDataService) -> None: def __init__(self, service: DataService) -> None:
self.service = service self.service = service
self._loop = asyncio.get_event_loop() self._loop = asyncio.get_event_loop()
self.tasks = {} self.tasks: dict[str, TaskDict] = {}
"""A dictionary to keep track of running tasks. The keys are the names of the
tasks and the values are TaskDict instances which include the task itself and
its kwargs.
"""
self.task_status_change_callbacks = [] self.task_status_change_callbacks: list[
Callable[[str, dict[str, Any] | None], Any]
] = []
"""A list of callback functions to be invoked when the status of a task (start
or stop) changes."""
self._set_start_and_stop_for_async_methods() self._set_start_and_stop_for_async_methods()