diff --git a/src/pydase/data_service/task_manager.py b/src/pydase/data_service/task_manager.py index 6bca414..af5d6b1 100644 --- a/src/pydase/data_service/task_manager.py +++ b/src/pydase/data_service/task_manager.py @@ -21,10 +21,6 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class TaskDefinitionError(Exception): - pass - - class TaskStatus(Enum): RUNNING = "running" @@ -107,12 +103,13 @@ class TaskManager: method = getattr(self.service, name) if inspect.iscoroutinefunction(method): if function_has_arguments(method): - raise TaskDefinitionError( - "Asynchronous functions (tasks) should be defined without " - f"arguments. The task '{method.__name__}' has at least one " - "argument. Please remove the argument(s) from this function to " - "use it." + logger.info( + "Async function %a is defined with at least one argument. If " + "you want to use it as a task, remove the argument(s) from the " + "function definition.", + method.__name__, ) + continue # create start and stop methods for each coroutine setattr( diff --git a/src/pydase/observer_pattern/observable/observable_object.py b/src/pydase/observer_pattern/observable/observable_object.py index d749ef5..6f9e16a 100644 --- a/src/pydase/observer_pattern/observable/observable_object.py +++ b/src/pydase/observer_pattern/observable/observable_object.py @@ -1,34 +1,37 @@ +from __future__ import annotations + import logging import weakref from abc import ABC, abstractmethod -from collections.abc import Iterable from typing import TYPE_CHECKING, Any, ClassVar, SupportsIndex from pydase.utils.helpers import parse_serialized_key if TYPE_CHECKING: + from collections.abc import Iterable + from pydase.observer_pattern.observer.observer import Observer logger = logging.getLogger(__name__) class ObservableObject(ABC): - _list_mapping: ClassVar[dict[int, weakref.ReferenceType["_ObservableList"]]] = {} - _dict_mapping: ClassVar[dict[int, weakref.ReferenceType["_ObservableDict"]]] = {} + _list_mapping: ClassVar[dict[int, weakref.ReferenceType[_ObservableList]]] = {} + _dict_mapping: ClassVar[dict[int, weakref.ReferenceType[_ObservableDict]]] = {} def __init__(self) -> None: if not hasattr(self, "_observers"): - self._observers: dict[str, list["ObservableObject | Observer"]] = {} + self._observers: dict[str, list[ObservableObject | Observer]] = {} def add_observer( - self, observer: "ObservableObject | Observer", attr_name: str = "" + self, observer: ObservableObject | Observer, attr_name: str = "" ) -> None: if attr_name not in self._observers: self._observers[attr_name] = [] if observer not in self._observers[attr_name]: self._observers[attr_name].append(observer) - def _remove_observer(self, observer: "ObservableObject", attribute: str) -> None: + def _remove_observer(self, observer: ObservableObject, attribute: str) -> None: if attribute in self._observers: self._observers[attribute].remove(observer) diff --git a/tests/data_service/test_data_service.py b/tests/data_service/test_data_service.py index 77fc65b..9ee9120 100644 --- a/tests/data_service/test_data_service.py +++ b/tests/data_service/test_data_service.py @@ -7,7 +7,6 @@ import pytest from pydase import DataService from pydase.data_service.data_service_observer import DataServiceObserver from pydase.data_service.state_manager import StateManager -from pydase.data_service.task_manager import TaskDefinitionError from pydase.utils.decorators import FunctionDefinitionError, frontend from pytest import LogCaptureFixture @@ -37,7 +36,8 @@ def test_unexpected_type_change_warning(caplog: LogCaptureFixture) -> None: def test_basic_inheritance_warning(caplog: LogCaptureFixture) -> None: - class SubService(DataService): ... + class SubService(DataService): + ... class SomeEnum(Enum): HI = 0 @@ -57,9 +57,11 @@ def test_basic_inheritance_warning(caplog: LogCaptureFixture) -> None: def name(self) -> str: return self._name - def some_method(self) -> None: ... + def some_method(self) -> None: + ... - async def some_task(self) -> None: ... + async def some_task(self) -> None: + ... ServiceClass() @@ -118,14 +120,7 @@ def test_protected_and_private_attribute_warning(caplog: LogCaptureFixture) -> N ) not in caplog.text -def test_exposing_methods() -> None: - class ClassWithTask(pydase.DataService): - async def some_task(self, sleep_time: int) -> None: - pass - - with pytest.raises(TaskDefinitionError): - ClassWithTask() - +def test_exposing_methods(caplog: LogCaptureFixture) -> None: with pytest.raises(FunctionDefinitionError): class ClassWithMethod(pydase.DataService): @@ -133,6 +128,18 @@ def test_exposing_methods() -> None: def some_method(self, *args: Any) -> str: return "some method" + class ClassWithTask(pydase.DataService): + async def some_task(self, sleep_time: int) -> None: + pass + + ClassWithTask() + + assert ( + "Async function 'some_task' is defined with at least one argument. If you want " + "to use it as a task, remove the argument(s) from the function definition." + in caplog.text + ) + def test_dynamically_added_attribute(caplog: LogCaptureFixture) -> None: class MyService(DataService): diff --git a/tests/observer_pattern/observer/test_property_observer.py b/tests/observer_pattern/observer/test_property_observer.py index 8901d81..199b9b7 100644 --- a/tests/observer_pattern/observer/test_property_observer.py +++ b/tests/observer_pattern/observer/test_property_observer.py @@ -16,6 +16,7 @@ def test_inherited_property_dependency_resolution() -> None: _name = "DerivedObservable" class MyObserver(PropertyObserver): - def on_change(self, full_access_path: str, value: Any) -> None: ... + def on_change(self, full_access_path: str, value: Any) -> None: + ... assert MyObserver(DerivedObservable()).property_deps_dict == {"_name": ["name"]} diff --git a/tests/utils/serialization/test_serializer.py b/tests/utils/serialization/test_serializer.py index 1be21c1..f8e6d0a 100644 --- a/tests/utils/serialization/test_serializer.py +++ b/tests/utils/serialization/test_serializer.py @@ -476,7 +476,8 @@ def test_derived_data_service_serialization() -> None: def name(self, value: str) -> None: self._name = value - class DerivedService(BaseService): ... + class DerivedService(BaseService): + ... base_service_serialization = dump(BaseService()) derived_service_serialization = dump(DerivedService())