Merge pull request #127 from tiqi-group/feature/ignore_coroutine

Skip coroutines with arguments instead of raising an exception
This commit is contained in:
Mose Müller 2024-05-27 15:10:46 +02:00 committed by GitHub
commit 9fa8f06280
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 38 additions and 29 deletions

View File

@ -21,10 +21,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class TaskDefinitionError(Exception):
pass
class TaskStatus(Enum): class TaskStatus(Enum):
RUNNING = "running" RUNNING = "running"
@ -107,12 +103,13 @@ class TaskManager:
method = getattr(self.service, name) method = getattr(self.service, name)
if inspect.iscoroutinefunction(method): if inspect.iscoroutinefunction(method):
if function_has_arguments(method): if function_has_arguments(method):
raise TaskDefinitionError( logger.info(
"Asynchronous functions (tasks) should be defined without " "Async function %a is defined with at least one argument. If "
f"arguments. The task '{method.__name__}' has at least one " "you want to use it as a task, remove the argument(s) from the "
"argument. Please remove the argument(s) from this function to " "function definition.",
"use it." method.__name__,
) )
continue
# create start and stop methods for each coroutine # create start and stop methods for each coroutine
setattr( setattr(

View File

@ -1,34 +1,37 @@
from __future__ import annotations
import logging import logging
import weakref import weakref
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, ClassVar, SupportsIndex from typing import TYPE_CHECKING, Any, ClassVar, SupportsIndex
from pydase.utils.helpers import parse_serialized_key from pydase.utils.helpers import parse_serialized_key
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Iterable
from pydase.observer_pattern.observer.observer import Observer from pydase.observer_pattern.observer.observer import Observer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ObservableObject(ABC): class ObservableObject(ABC):
_list_mapping: ClassVar[dict[int, weakref.ReferenceType["_ObservableList"]]] = {} _list_mapping: ClassVar[dict[int, weakref.ReferenceType[_ObservableList]]] = {}
_dict_mapping: ClassVar[dict[int, weakref.ReferenceType["_ObservableDict"]]] = {} _dict_mapping: ClassVar[dict[int, weakref.ReferenceType[_ObservableDict]]] = {}
def __init__(self) -> None: def __init__(self) -> None:
if not hasattr(self, "_observers"): if not hasattr(self, "_observers"):
self._observers: dict[str, list["ObservableObject | Observer"]] = {} self._observers: dict[str, list[ObservableObject | Observer]] = {}
def add_observer( def add_observer(
self, observer: "ObservableObject | Observer", attr_name: str = "" self, observer: ObservableObject | Observer, attr_name: str = ""
) -> None: ) -> None:
if attr_name not in self._observers: if attr_name not in self._observers:
self._observers[attr_name] = [] self._observers[attr_name] = []
if observer not in self._observers[attr_name]: if observer not in self._observers[attr_name]:
self._observers[attr_name].append(observer) self._observers[attr_name].append(observer)
def _remove_observer(self, observer: "ObservableObject", attribute: str) -> None: def _remove_observer(self, observer: ObservableObject, attribute: str) -> None:
if attribute in self._observers: if attribute in self._observers:
self._observers[attribute].remove(observer) self._observers[attribute].remove(observer)

View File

@ -7,7 +7,6 @@ import pytest
from pydase import DataService from pydase import DataService
from pydase.data_service.data_service_observer import DataServiceObserver from pydase.data_service.data_service_observer import DataServiceObserver
from pydase.data_service.state_manager import StateManager from pydase.data_service.state_manager import StateManager
from pydase.data_service.task_manager import TaskDefinitionError
from pydase.utils.decorators import FunctionDefinitionError, frontend from pydase.utils.decorators import FunctionDefinitionError, frontend
from pytest import LogCaptureFixture from pytest import LogCaptureFixture
@ -37,7 +36,8 @@ def test_unexpected_type_change_warning(caplog: LogCaptureFixture) -> None:
def test_basic_inheritance_warning(caplog: LogCaptureFixture) -> None: def test_basic_inheritance_warning(caplog: LogCaptureFixture) -> None:
class SubService(DataService): ... class SubService(DataService):
...
class SomeEnum(Enum): class SomeEnum(Enum):
HI = 0 HI = 0
@ -57,9 +57,11 @@ def test_basic_inheritance_warning(caplog: LogCaptureFixture) -> None:
def name(self) -> str: def name(self) -> str:
return self._name return self._name
def some_method(self) -> None: ... def some_method(self) -> None:
...
async def some_task(self) -> None: ... async def some_task(self) -> None:
...
ServiceClass() ServiceClass()
@ -118,14 +120,7 @@ def test_protected_and_private_attribute_warning(caplog: LogCaptureFixture) -> N
) not in caplog.text ) not in caplog.text
def test_exposing_methods() -> None: def test_exposing_methods(caplog: LogCaptureFixture) -> None:
class ClassWithTask(pydase.DataService):
async def some_task(self, sleep_time: int) -> None:
pass
with pytest.raises(TaskDefinitionError):
ClassWithTask()
with pytest.raises(FunctionDefinitionError): with pytest.raises(FunctionDefinitionError):
class ClassWithMethod(pydase.DataService): class ClassWithMethod(pydase.DataService):
@ -133,6 +128,18 @@ def test_exposing_methods() -> None:
def some_method(self, *args: Any) -> str: def some_method(self, *args: Any) -> str:
return "some method" return "some method"
class ClassWithTask(pydase.DataService):
async def some_task(self, sleep_time: int) -> None:
pass
ClassWithTask()
assert (
"Async function 'some_task' is defined with at least one argument. If you want "
"to use it as a task, remove the argument(s) from the function definition."
in caplog.text
)
def test_dynamically_added_attribute(caplog: LogCaptureFixture) -> None: def test_dynamically_added_attribute(caplog: LogCaptureFixture) -> None:
class MyService(DataService): class MyService(DataService):

View File

@ -16,6 +16,7 @@ def test_inherited_property_dependency_resolution() -> None:
_name = "DerivedObservable" _name = "DerivedObservable"
class MyObserver(PropertyObserver): class MyObserver(PropertyObserver):
def on_change(self, full_access_path: str, value: Any) -> None: ... def on_change(self, full_access_path: str, value: Any) -> None:
...
assert MyObserver(DerivedObservable()).property_deps_dict == {"_name": ["name"]} assert MyObserver(DerivedObservable()).property_deps_dict == {"_name": ["name"]}

View File

@ -476,7 +476,8 @@ def test_derived_data_service_serialization() -> None:
def name(self, value: str) -> None: def name(self, value: str) -> None:
self._name = value self._name = value
class DerivedService(BaseService): ... class DerivedService(BaseService):
...
base_service_serialization = dump(BaseService()) base_service_serialization = dump(BaseService())
derived_service_serialization = dump(DerivedService()) derived_service_serialization = dump(DerivedService())