From 206a831473ff2e74bf3549a3452b1b6dbd4f5309 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mose=20M=C3=BCller?= Date: Wed, 2 Aug 2023 12:06:19 +0200 Subject: [PATCH] feat: added property callbacks, added warnings --- .../data_service/data_service.py | 194 +++++++++-- .../data_service/data_service_list.py | 7 + src/pyDataInterface/utils/__init__.py | 3 + src/pyDataInterface/utils/warnings.py | 14 + tests/__init__.py | 12 + tests/test_properties.py | 326 +++++++++++++++++- tests/test_warnings.py | 34 ++ 7 files changed, 549 insertions(+), 41 deletions(-) create mode 100644 src/pyDataInterface/utils/__init__.py create mode 100644 src/pyDataInterface/utils/warnings.py create mode 100644 tests/test_warnings.py diff --git a/src/pyDataInterface/data_service/data_service.py b/src/pyDataInterface/data_service/data_service.py index afba824..df32b0f 100644 --- a/src/pyDataInterface/data_service/data_service.py +++ b/src/pyDataInterface/data_service/data_service.py @@ -3,11 +3,16 @@ import inspect import threading from collections.abc import Callable from concurrent.futures import Future +from itertools import chain from typing import Any import rpyc from loguru import logger +from pyDataInterface.utils import ( + warn_if_instance_class_does_not_inherit_from_DataService, +) + from .data_service_list import DataServiceList @@ -39,18 +44,127 @@ class DataService(rpyc.Service): self._start_async_loop_in_thread() self._start_autostart_tasks() - self._register_callbacks(self, f"{self.__class__.__name__}") - self._turn_lists_into_notify_lists(self, f"{self.__class__.__name__}") - self._do_something_with_properties() + self._register_DataService_callbacks(self, f"{self.__class__.__name__}") + self._register_list_change_callbacks(self, f"{self.__class__.__name__}") + self._register_property_callbacks(self, f"{self.__class__.__name__}") + self._check_instance_classes() self._initialised = True - def _do_something_with_properties(self) -> None: - for attr_name in dir(self.__class__): - attr_value = getattr(self.__class__, attr_name) - if isinstance(attr_value, property): # If attribute is a property - logger.debug(attr_value.fget.__code__.co_names) + def _check_instance_classes(self) -> None: + for attr_name, attr_value in self.__get_class_and_instance_attributes().items(): + # every class defined by the user should inherit from DataService + if not attr_name.startswith("_DataService__"): + warn_if_instance_class_does_not_inherit_from_DataService(attr_value) - def _turn_lists_into_notify_lists( + def __register_recursive_parameter_callback( + self, + obj: "DataService | DataServiceList", + callback: Callable[[str | int, Any], None], + ) -> None: + """ + Register callback to the DataService instance and all its nested instances. + + This method recursively traverses all attributes of the DataService `obj` and + adds the callback to each instance's `_callbacks` set when an attribute is a + DataService instance. This ensures any modification of attributes within + nested instances will trigger the provided callback. + """ + + if isinstance(obj, DataServiceList): + # emits callback when item in list gets reassigned + obj.add_callback(callback=callback) + obj_list: DataServiceList | list[DataService] = obj + else: + obj_list = [obj] + + # this enables notifications when a class instance was changed (-> item is + # changed, not reassigned) + for item in obj_list: + if isinstance(item, DataService): + item._callbacks.add(callback) + for attr_name in set(dir(item)) - set(dir(object)) - {"_root"}: + attr_value = getattr(item, attr_name) + if isinstance(attr_value, (DataService, DataServiceList)): + self.__register_recursive_parameter_callback( + attr_value, callback + ) + + def _register_property_callbacks( + self, + obj: "DataService", + parent_path: str, + ) -> None: + """ + Register callbacks to emit notifications when attributes used in a property + getter are changed. + + This method iterates over all attributes of the class. For each attribute that + is a property, it gets the names of the attributes used inside the property's + getter method. It then creates a callback for each of these dependent + attributes. + + If the dependent attribute is a DataServiceList, the callback is added to the + list. So, if any element in the list is changed, the callback will be triggered + and a notification will be emitted. + + If the dependent attribute is an instance of DataService, the callback is + registered to all nested DataService instances of this attribute using + `_register_recursive_callback`. + + For all other types of attributes, the callback is simply added to the + `_callbacks` set of the instance. + """ + + attrs = obj.__get_class_and_instance_attributes() + + for attr_name, attr_value in attrs.items(): + if isinstance(attr_value, DataService): + self._register_property_callbacks( + attr_value, parent_path=f"{parent_path}.{attr_name}" + ) + elif isinstance(attr_value, DataServiceList): + for i, item in enumerate(attr_value): + if isinstance(item, DataService): + self._register_property_callbacks( + item, parent_path=f"{parent_path}.{attr_name}[{i}]" + ) + if isinstance(attr_value, property): + dependencies = attr_value.fget.__code__.co_names # type: ignore + + for dependency in dependencies: + # use `obj` instead of `type(obj)` to get DataServiceList + # instead of list + dependency_value = getattr(obj, dependency) + + if isinstance(dependency_value, (DataServiceList, DataService)): + callback = ( + lambda name, value, dependent_attr=attr_name: obj._emit_notification( + parent_path=parent_path, + name=dependent_attr, + value=getattr(obj, dependent_attr), + ) + if self == obj._root + else None + ) + + self.__register_recursive_parameter_callback( + dependency_value, + callback=callback, + ) + else: + callback = ( + lambda name, value, dependent_attr=attr_name, dep=dependency: obj._emit_notification( + parent_path=parent_path, + name=dependent_attr, + value=getattr(obj, dependent_attr), + ) + if name == dep and self == obj._root + else None + ) + # Add to _callbacks + obj._callbacks.add(callback) + + def _register_list_change_callbacks( self, obj: "DataService", parent_path: str ) -> None: """ @@ -82,12 +196,12 @@ class DataService(rpyc.Service): """ # Convert all list attributes (both class and instance) to DataServiceList - for attr_name in set(dir(obj)) - set(dir(object)) - {"_root"}: - attr_value = getattr(obj, attr_name) + attrs = obj.__get_class_and_instance_attributes() + for attr_name, attr_value in attrs.items(): if isinstance(attr_value, DataService): new_path = f"{parent_path}.{attr_name}" - self._turn_lists_into_notify_lists(attr_value, new_path) + self._register_list_change_callbacks(attr_value, new_path) elif isinstance(attr_value, list): # Create callback for current attr_name # Default arguments solve the late binding problem by capturing the @@ -107,7 +221,7 @@ class DataService(rpyc.Service): if isinstance(attr_value, DataServiceList): attr_value.add_callback(callback) continue - elif id(attr_value) in self._list_mapping: + if id(attr_value) in self._list_mapping: notifying_list = self._list_mapping[id(attr_value)] notifying_list.add_callback(callback) else: @@ -120,7 +234,7 @@ class DataService(rpyc.Service): for i, item in enumerate(attr_value): if isinstance(item, DataService): new_path = f"{parent_path}.{attr_name}[{i}]" - self._turn_lists_into_notify_lists(item, new_path) + self._register_list_change_callbacks(item, new_path) def _start_autostart_tasks(self) -> None: if self._autostart_tasks is not None: @@ -166,7 +280,9 @@ class DataService(rpyc.Service): setattr(self, f"start_{name}", start_task) setattr(self, f"stop_{name}", stop_task) - def _register_callbacks(self, obj: "DataService", parent_path: str) -> None: + def _register_DataService_callbacks( + self, obj: "DataService", parent_path: str + ) -> None: """ This function is a key part of the observer pattern implemented by the DataService class. @@ -200,16 +316,20 @@ class DataService(rpyc.Service): lambda name, value: obj._emit_notification( parent_path=parent_path, name=name, value=value ) - if self == self._root + if self == obj._root + and not name.startswith("_") # we are only interested in public attributes + and not isinstance( + getattr(type(obj), name, None), property + ) # exlude proerty notifications -> those are handled in separate callbacks else None ) obj._callbacks.add(callback) # Recursively register callbacks for all nested attributes of the object - attribute_set = set(dir(obj)) - set(dir(object)) - {"_root"} - for nested_attr_name in attribute_set: - nested_attr = getattr(obj, nested_attr_name) + attrs = obj.__get_class_and_instance_attributes() + + for nested_attr_name, nested_attr in attrs.items(): if isinstance(nested_attr, list): self._register_list_callbacks( nested_attr, parent_path, nested_attr_name @@ -238,7 +358,7 @@ class DataService(rpyc.Service): nested_attr._root = self._root new_path = f"{parent_path}.{attr_name}" - self._register_callbacks(nested_attr, new_path) + self._register_DataService_callbacks(nested_attr, new_path) def _start_loop(self) -> None: asyncio.set_event_loop(self.__loop) @@ -256,8 +376,12 @@ class DataService(rpyc.Service): if self.__dict__.get("_initialised") and not __name == "_initialised": for callback in self._callbacks: callback(__name, __value) - # TODO: add emits for properties -> can use co_names, which is a tuple - # containing the names used by the bytecode + elif __name.startswith(f"_{self.__class__.__name__}__"): + logger.warning( + f"Warning: You should not set private but rather protected attributes! " + f"Use {__name.replace(f'_{self.__class__.__name__}__', '_')} instead " + f"of {__name.replace(f'_{self.__class__.__name__}__', '__')}." + ) def _emit_notification(self, parent_path: str, name: str, value: Any) -> None: logger.debug(f"{parent_path}.{name} changed to {value}!") @@ -269,7 +393,7 @@ class DataService(rpyc.Service): # allow all other attributes return getattr(self, name) - def _rpyc_setattr(self, name: str, value: Any): + def _rpyc_setattr(self, name: str, value: Any) -> None: if name.startswith("_"): # disallow special and private attributes raise AttributeError("cannot access private/special names") @@ -282,6 +406,20 @@ class DataService(rpyc.Service): # allow all other attributes setattr(self, name, value) + def __get_class_and_instance_attributes(self) -> dict[str, Any]: + """Dictionary containing all attributes (both instance and class level) of a + given object. + + If an attribute exists at both the instance and class level,the value from the + instance attribute takes precedence. + The _root object is removed as this will lead to endless recursion in the for + loops. + """ + + attrs = dict(chain(type(self).__dict__.items(), self.__dict__.items())) + attrs.pop("_root") + return attrs + def serialize(self, prefix: str = "") -> dict[str, dict[str, Any]]: """ Serializes the instance into a dictionary, preserving the structure of the @@ -314,18 +452,18 @@ class DataService(rpyc.Service): result: dict[str, dict[str, Any]] = {} # Get the dictionary of the base class - base_dict = set(super().__class__.__dict__) + base_set = set(type(super()).__dict__) # Get the dictionary of the derived class - derived_dict = set(self.__class__.__dict__) + derived_set = set(type(self).__dict__) # Get the difference between the two dictionaries - derived_only_dict = derived_dict - base_dict + derived_only_set = derived_set - base_set instance_dict = set(self.__dict__) # Merge the class and instance dictionaries - merged_dict = derived_only_dict | instance_dict + merged_set = derived_only_set | instance_dict # Iterate over attributes, properties, class attributes, and methods - for key in merged_dict: + for key in merged_set: if key.startswith("_"): continue # Skip attributes that start with underscore diff --git a/src/pyDataInterface/data_service/data_service_list.py b/src/pyDataInterface/data_service/data_service_list.py index 2813768..7166f6e 100644 --- a/src/pyDataInterface/data_service/data_service_list.py +++ b/src/pyDataInterface/data_service/data_service_list.py @@ -1,6 +1,10 @@ from collections.abc import Callable from typing import Any +from pyDataInterface.utils import ( + warn_if_instance_class_does_not_inherit_from_DataService, +) + class DataServiceList(list): """ @@ -36,6 +40,9 @@ class DataServiceList(list): if isinstance(callback, list): self.callbacks = callback + for item in args[0]: + warn_if_instance_class_does_not_inherit_from_DataService(item) + # prevent gc to delete the passed list by keeping a reference self._original_list = args[0] diff --git a/src/pyDataInterface/utils/__init__.py b/src/pyDataInterface/utils/__init__.py new file mode 100644 index 0000000..c65b108 --- /dev/null +++ b/src/pyDataInterface/utils/__init__.py @@ -0,0 +1,3 @@ +from .warnings import warn_if_instance_class_does_not_inherit_from_DataService + +__all__ = ["warn_if_instance_class_does_not_inherit_from_DataService"] diff --git a/src/pyDataInterface/utils/warnings.py b/src/pyDataInterface/utils/warnings.py new file mode 100644 index 0000000..05fbeb2 --- /dev/null +++ b/src/pyDataInterface/utils/warnings.py @@ -0,0 +1,14 @@ +from loguru import logger + + +def warn_if_instance_class_does_not_inherit_from_DataService(__value: object) -> None: + base_class_name = __value.__class__.__base__.__name__ + module_name = __value.__class__.__module__ + + if module_name not in ["builtins", "__builtin__"] and base_class_name not in [ + "DataService", + "list", + ]: + logger.warning( + f"Warning: Class {type(__value).__name__} does not inherit from DataService." + ) diff --git a/tests/__init__.py b/tests/__init__.py index 8aea465..f913d2c 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,8 +1,20 @@ +from collections.abc import Generator from typing import Any +import pytest +from loguru import logger +from pytest import LogCaptureFixture + from pyDataInterface import DataService +@pytest.fixture +def caplog(caplog: LogCaptureFixture) -> Generator[LogCaptureFixture, Any, None]: + handler_id = logger.add(caplog.handler, format="{message}") + yield caplog + logger.remove(handler_id) + + def emit(self: Any, parent_path: str, name: str, value: Any) -> None: if isinstance(value, DataService): value = value.serialize() diff --git a/tests/test_properties.py b/tests/test_properties.py index da593fe..27b0ccc 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -5,30 +5,330 @@ from pyDataInterface import DataService def test_properties(capsys: CaptureFixture) -> None: class ServiceClass(DataService): - _power = True + _voltage = 10.0 + _current = 1.0 @property - def power(self) -> bool: - return self._power - - @power.setter - def power(self, value: bool) -> None: - self._power = value + def power(self) -> float: + return self._voltage * self.current @property - def power_two(self) -> bool: - return self._power + def voltage(self) -> float: + return self._voltage + + @voltage.setter + def voltage(self, value: float) -> None: + self._voltage = value + + @property + def current(self) -> float: + return self._current + + @current.setter + def current(self, value: float) -> None: + self._current = value test_service = ServiceClass() - test_service.power = False + test_service.voltage = 1 captured = capsys.readouterr() expected_output = sorted( [ - "ServiceClass.power = False", - "ServiceClass.power_two = False", - "ServiceClass._power = False", + "ServiceClass.power = 1.0", + "ServiceClass.voltage = 1", ] ) actual_output = sorted(captured.out.strip().split("\n")) assert actual_output == expected_output + + test_service.current = 12.0 + + captured = capsys.readouterr() + expected_output = sorted( + [ + "ServiceClass.power = 12.0", + "ServiceClass.current = 12.0", + ] + ) + actual_output = sorted(captured.out.strip().split("\n")) + assert actual_output == expected_output + + +def test_nested_properties(capsys: CaptureFixture) -> None: + class SubSubClass(DataService): + name = "Hello" + + class SubClass(DataService): + name = "Hello" + class_attr = SubSubClass() + + class ServiceClass(DataService): + class_attr = SubClass() + name = "World" + + @property + def subsub_name(self) -> str: + return f"{self.class_attr.class_attr.name} {self.name}" + + @property + def sub_name(self) -> str: + return f"{self.class_attr.name} {self.name}" + + test_service = ServiceClass() + test_service.name = "Peepz" + + captured = capsys.readouterr() + expected_output = sorted( + [ + "ServiceClass.name = Peepz", + "ServiceClass.sub_name = Hello Peepz", + "ServiceClass.subsub_name = Hello Peepz", + ] + ) + actual_output = sorted(captured.out.strip().split("\n")) + assert actual_output == expected_output + + test_service.class_attr.name = "Hi" + + captured = capsys.readouterr() + expected_output = sorted( + [ + "ServiceClass.sub_name = Hi Peepz", + "ServiceClass.subsub_name = Hello Peepz", # registers subclass changes + "ServiceClass.class_attr.name = Hi", + ] + ) + actual_output = sorted(captured.out.strip().split("\n")) + assert actual_output == expected_output + + test_service.class_attr.class_attr.name = "Ciao" + + captured = capsys.readouterr() + expected_output = sorted( + [ + "ServiceClass.sub_name = Hi Peepz", # registers subclass changes + "ServiceClass.subsub_name = Ciao Peepz", + "ServiceClass.class_attr.class_attr.name = Ciao", + ] + ) + actual_output = sorted(captured.out.strip().split("\n")) + assert actual_output == expected_output + + +def test_simple_list_properties(capsys: CaptureFixture) -> None: + class ServiceClass(DataService): + list = ["Hello", "Ciao"] + name = "World" + + @property + def total_name(self) -> str: + return f"{self.list[0]} {self.name}" + + test_service = ServiceClass() + test_service.name = "Peepz" + + captured = capsys.readouterr() + expected_output = sorted( + [ + "ServiceClass.name = Peepz", + "ServiceClass.total_name = Hello Peepz", + ] + ) + actual_output = sorted(captured.out.strip().split("\n")) + assert actual_output == expected_output + + test_service.list[0] = "Hi" + + captured = capsys.readouterr() + expected_output = sorted( + [ + "ServiceClass.total_name = Hi Peepz", + "ServiceClass.list[0] = Hi", + ] + ) + actual_output = sorted(captured.out.strip().split("\n")) + assert actual_output == expected_output + + +def test_class_list_properties(capsys: CaptureFixture) -> None: + class SubClass(DataService): + name = "Hello" + + class ServiceClass(DataService): + list = [SubClass()] + name = "World" + + @property + def total_name(self) -> str: + return f"{self.list[0].name} {self.name}" + + test_service = ServiceClass() + test_service.name = "Peepz" + + captured = capsys.readouterr() + expected_output = sorted( + [ + "ServiceClass.name = Peepz", + "ServiceClass.total_name = Hello Peepz", + ] + ) + actual_output = sorted(captured.out.strip().split("\n")) + assert actual_output == expected_output + + test_service.list[0].name = "Hi" + + captured = capsys.readouterr() + expected_output = sorted( + [ + "ServiceClass.total_name = Hi Peepz", + "ServiceClass.list[0].name = Hi", + ] + ) + actual_output = sorted(captured.out.strip().split("\n")) + assert actual_output == expected_output + + +def test_subclass_properties(capsys: CaptureFixture) -> None: + class SubClass(DataService): + name = "Hello" + _voltage = 10.0 + _current = 1.0 + + @property + def power(self) -> float: + return self._voltage * self.current + + @property + def voltage(self) -> float: + return self._voltage + + @voltage.setter + def voltage(self, value: float) -> None: + self._voltage = value + + @property + def current(self) -> float: + return self._current + + @current.setter + def current(self, value: float) -> None: + self._current = value + + class ServiceClass(DataService): + class_attr = SubClass() + + test_service = ServiceClass() + test_service.class_attr.voltage = 10.0 + + captured = capsys.readouterr() + expected_output = sorted( + [ + "ServiceClass.class_attr.voltage = 10.0", + "ServiceClass.class_attr.power = 10.0", + ] + ) + actual_output = sorted(captured.out.strip().split("\n")) + assert actual_output == expected_output + + +def test_subclass_properties(capsys: CaptureFixture) -> None: + class SubClass(DataService): + name = "Hello" + _voltage = 10.0 + _current = 1.0 + + @property + def power(self) -> float: + return self._voltage * self.current + + @property + def voltage(self) -> float: + return self._voltage + + @voltage.setter + def voltage(self, value: float) -> None: + self._voltage = value + + @property + def current(self) -> float: + return self._current + + @current.setter + def current(self, value: float) -> None: + self._current = value + + class ServiceClass(DataService): + class_attr = SubClass() + + @property + def voltage(self) -> float: + return self.class_attr.voltage + + test_service = ServiceClass() + test_service.class_attr.voltage = 10.0 + + captured = capsys.readouterr() + expected_output = sorted( + { + "ServiceClass.class_attr.voltage = 10.0", + "ServiceClass.class_attr.power = 10.0", + "ServiceClass.voltage = 10.0", + } + ) + # using a set here as "ServiceClass.voltage = 10.0" is emitted twice. Once for + # changing voltage, and once for changing power. + actual_output = sorted(set(captured.out.strip().split("\n"))) + assert actual_output == expected_output + + +def test_subclass_properties_2(capsys: CaptureFixture) -> None: + class SubClass(DataService): + name = "Hello" + _voltage = 10.0 + _current = 1.0 + + @property + def power(self) -> float: + return self._voltage * self.current + + @property + def voltage(self) -> float: + return self._voltage + + @voltage.setter + def voltage(self, value: float) -> None: + self._voltage = value + + @property + def current(self) -> float: + return self._current + + @current.setter + def current(self, value: float) -> None: + self._current = value + + class ServiceClass(DataService): + class_attr = [SubClass() for i in range(2)] + + @property + def voltage(self) -> float: + return self.class_attr[0].voltage + + test_service = ServiceClass() + test_service.class_attr[1].current = 10.0 + + captured = capsys.readouterr() + expected_output = sorted( + { + "ServiceClass.class_attr[1].current = 10.0", + "ServiceClass.class_attr[1].power = 100.0", + "ServiceClass.voltage = 10.0", + } + ) + # using a set here as "ServiceClass.voltage = 10.0" is emitted twice. Once for + # changing current, and once for changing power. Note that the voltage property is + # only dependent on class_attr[0] but still emits an update notification. This is + # because every time any item in the list `test_service.class_attr` is changed, + # a notification will be emitted. + actual_output = sorted(set(captured.out.strip().split("\n"))) + assert actual_output == expected_output diff --git a/tests/test_warnings.py b/tests/test_warnings.py new file mode 100644 index 0000000..4ec215d --- /dev/null +++ b/tests/test_warnings.py @@ -0,0 +1,34 @@ +from pytest import LogCaptureFixture + +from pyDataInterface import DataService + +from . import caplog # noqa + + +def test_setattr_warnings(caplog: LogCaptureFixture) -> None: # noqa + # def test_setattr_warnings(capsys: CaptureFixture) -> None: + class SubClass: + name = "Hello" + + class ServiceClass(DataService): + def __init__(self) -> None: + self.attr_1 = SubClass() + super().__init__() + + ServiceClass() + + assert "Warning: Class SubClass does not inherit from DataService." in caplog.text + + +def test_private_attribute_warning(caplog: LogCaptureFixture) -> None: # noqa + class ServiceClass(DataService): + def __init__(self) -> None: + self.__something = "" + super().__init__() + + ServiceClass() + + assert ( + " Warning: You should not set private but rather protected attributes! Use " + "_something instead of __something." in caplog.text + )