Merge pull request #127 from tiqi-group/feature/ignore_coroutine

Skip coroutines with arguments instead of raising an exception
This commit is contained in:
Mose Müller 2024-05-27 15:10:46 +02:00 committed by GitHub
commit 9fa8f06280
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 38 additions and 29 deletions

View File

@ -21,10 +21,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class TaskDefinitionError(Exception):
pass
class TaskStatus(Enum):
RUNNING = "running"
@ -107,12 +103,13 @@ class TaskManager:
method = getattr(self.service, name)
if inspect.iscoroutinefunction(method):
if function_has_arguments(method):
raise TaskDefinitionError(
"Asynchronous functions (tasks) should be defined without "
f"arguments. The task '{method.__name__}' has at least one "
"argument. Please remove the argument(s) from this function to "
"use it."
logger.info(
"Async function %a is defined with at least one argument. If "
"you want to use it as a task, remove the argument(s) from the "
"function definition.",
method.__name__,
)
continue
# create start and stop methods for each coroutine
setattr(

View File

@ -1,34 +1,37 @@
from __future__ import annotations
import logging
import weakref
from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, ClassVar, SupportsIndex
from pydase.utils.helpers import parse_serialized_key
if TYPE_CHECKING:
from collections.abc import Iterable
from pydase.observer_pattern.observer.observer import Observer
logger = logging.getLogger(__name__)
class ObservableObject(ABC):
_list_mapping: ClassVar[dict[int, weakref.ReferenceType["_ObservableList"]]] = {}
_dict_mapping: ClassVar[dict[int, weakref.ReferenceType["_ObservableDict"]]] = {}
_list_mapping: ClassVar[dict[int, weakref.ReferenceType[_ObservableList]]] = {}
_dict_mapping: ClassVar[dict[int, weakref.ReferenceType[_ObservableDict]]] = {}
def __init__(self) -> None:
if not hasattr(self, "_observers"):
self._observers: dict[str, list["ObservableObject | Observer"]] = {}
self._observers: dict[str, list[ObservableObject | Observer]] = {}
def add_observer(
self, observer: "ObservableObject | Observer", attr_name: str = ""
self, observer: ObservableObject | Observer, attr_name: str = ""
) -> None:
if attr_name not in self._observers:
self._observers[attr_name] = []
if observer not in self._observers[attr_name]:
self._observers[attr_name].append(observer)
def _remove_observer(self, observer: "ObservableObject", attribute: str) -> None:
def _remove_observer(self, observer: ObservableObject, attribute: str) -> None:
if attribute in self._observers:
self._observers[attribute].remove(observer)

View File

@ -7,7 +7,6 @@ import pytest
from pydase import DataService
from pydase.data_service.data_service_observer import DataServiceObserver
from pydase.data_service.state_manager import StateManager
from pydase.data_service.task_manager import TaskDefinitionError
from pydase.utils.decorators import FunctionDefinitionError, frontend
from pytest import LogCaptureFixture
@ -37,7 +36,8 @@ def test_unexpected_type_change_warning(caplog: LogCaptureFixture) -> None:
def test_basic_inheritance_warning(caplog: LogCaptureFixture) -> None:
class SubService(DataService): ...
class SubService(DataService):
...
class SomeEnum(Enum):
HI = 0
@ -57,9 +57,11 @@ def test_basic_inheritance_warning(caplog: LogCaptureFixture) -> None:
def name(self) -> str:
return self._name
def some_method(self) -> None: ...
def some_method(self) -> None:
...
async def some_task(self) -> None: ...
async def some_task(self) -> None:
...
ServiceClass()
@ -118,14 +120,7 @@ def test_protected_and_private_attribute_warning(caplog: LogCaptureFixture) -> N
) not in caplog.text
def test_exposing_methods() -> None:
class ClassWithTask(pydase.DataService):
async def some_task(self, sleep_time: int) -> None:
pass
with pytest.raises(TaskDefinitionError):
ClassWithTask()
def test_exposing_methods(caplog: LogCaptureFixture) -> None:
with pytest.raises(FunctionDefinitionError):
class ClassWithMethod(pydase.DataService):
@ -133,6 +128,18 @@ def test_exposing_methods() -> None:
def some_method(self, *args: Any) -> str:
return "some method"
class ClassWithTask(pydase.DataService):
async def some_task(self, sleep_time: int) -> None:
pass
ClassWithTask()
assert (
"Async function 'some_task' is defined with at least one argument. If you want "
"to use it as a task, remove the argument(s) from the function definition."
in caplog.text
)
def test_dynamically_added_attribute(caplog: LogCaptureFixture) -> None:
class MyService(DataService):

View File

@ -16,6 +16,7 @@ def test_inherited_property_dependency_resolution() -> None:
_name = "DerivedObservable"
class MyObserver(PropertyObserver):
def on_change(self, full_access_path: str, value: Any) -> None: ...
def on_change(self, full_access_path: str, value: Any) -> None:
...
assert MyObserver(DerivedObservable()).property_deps_dict == {"_name": ["name"]}

View File

@ -476,7 +476,8 @@ def test_derived_data_service_serialization() -> None:
def name(self, value: str) -> None:
self._name = value
class DerivedService(BaseService): ...
class DerivedService(BaseService):
...
base_service_serialization = dump(BaseService())
derived_service_serialization = dump(DerivedService())