diff --git a/pyproject.toml b/pyproject.toml index 163ce53..6304dcc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "pydase" -version = "0.8.2" +version = "0.8.3" description = "A flexible and robust Python library for creating, managing, and interacting with data services, with built-in support for web and RPC servers, and customizable features for diverse use cases." authors = ["Mose Mueller "] readme = "README.md" diff --git a/src/pydase/observer_pattern/observable/observable_object.py b/src/pydase/observer_pattern/observable/observable_object.py index 82f0e70..aa39413 100644 --- a/src/pydase/observer_pattern/observable/observable_object.py +++ b/src/pydase/observer_pattern/observable/observable_object.py @@ -1,4 +1,5 @@ import logging +import weakref from abc import ABC, abstractmethod from collections.abc import Iterable from typing import TYPE_CHECKING, Any, ClassVar, SupportsIndex @@ -12,8 +13,8 @@ logger = logging.getLogger(__name__) class ObservableObject(ABC): - _list_mapping: ClassVar[dict[int, "_ObservableList"]] = {} - _dict_mapping: ClassVar[dict[int, "_ObservableDict"]] = {} + _list_mapping: ClassVar[dict[int, weakref.ReferenceType["_ObservableList"]]] = {} + _dict_mapping: ClassVar[dict[int, weakref.ReferenceType["_ObservableDict"]]] = {} def __init__(self) -> None: if not hasattr(self, "_observers"): @@ -88,19 +89,23 @@ class ObservableObject(ABC): if isinstance(value, list): if id(value) in self._list_mapping: # If the list `value` was already referenced somewhere else - new_value = self._list_mapping[id(value)] + new_value = self._list_mapping[id(value)]() else: # convert the builtin list into a ObservableList new_value = _ObservableList(original_list=value) - self._list_mapping[id(value)] = new_value + + # Use weakref to allow the GC to collect unused objects + self._list_mapping[id(value)] = weakref.ref(new_value) elif isinstance(value, dict): if id(value) in self._dict_mapping: # If the dict `value` was already referenced somewhere else - new_value = self._dict_mapping[id(value)] + new_value = self._dict_mapping[id(value)]() else: - # convert the builtin list into a ObservableList + # convert the builtin dict into a ObservableDict new_value = _ObservableDict(original_dict=value) - self._dict_mapping[id(value)] = new_value + + # Use weakref to allow the GC to collect unused objects + self._dict_mapping[id(value)] = weakref.ref(new_value) if isinstance(new_value, ObservableObject): new_value.add_observer(self, attr_name_or_key) return new_value @@ -139,6 +144,9 @@ class _ObservableList(ObservableObject, list[Any]): for i, item in enumerate(self._original_list): super().__setitem__(i, self._initialise_new_objects(f"[{i}]", item)) + def __del__(self) -> None: + self._list_mapping.pop(id(self._original_list)) + def __setitem__(self, key: int, value: Any) -> None: # type: ignore[override] if hasattr(self, "_observers"): self._remove_observer_if_observable(f"[{key}]") @@ -237,6 +245,9 @@ class _ObservableDict(ObservableObject, dict[str, Any]): for key, value in self._original_dict.items(): self.__setitem__(key, self._initialise_new_objects(f'["{key}"]', value)) + def __del__(self) -> None: + self._dict_mapping.pop(id(self._original_dict)) + def __setitem__(self, key: str, value: Any) -> None: if not isinstance(key, str): raise ValueError( diff --git a/tests/observer_pattern/observable/test_observable_object.py b/tests/observer_pattern/observable/test_observable_object.py index 6c9eb29..74c8a47 100644 --- a/tests/observer_pattern/observable/test_observable_object.py +++ b/tests/observer_pattern/observable/test_observable_object.py @@ -311,3 +311,51 @@ def test_list_remove(caplog: pytest.LogCaptureFixture) -> None: # checks if observer key was updated correctly (was index 1) other_observable_instance_2.greeting = "Ciao" assert "'my_list[0].greeting' changed to 'Ciao'" in caplog.text + + +def test_list_garbage_collection() -> None: + """Makes sure that the GC collects lists that are not referenced anymore.""" + + import gc + import json + + list_json = """ + [1] + """ + + class MyObservable(Observable): + def __init__(self) -> None: + super().__init__() + self.list_attr = json.loads(list_json) + + observable = MyObservable() + list_mapping_length = len(observable._list_mapping) + observable.list_attr = json.loads(list_json) + + gc.collect() + assert len(observable._list_mapping) <= list_mapping_length + + +def test_dict_garbage_collection() -> None: + """Makes sure that the GC collects dicts that are not referenced anymore.""" + + import gc + import json + + dict_json = """ + { + "foo": "bar" + } + """ + + class MyObservable(Observable): + def __init__(self) -> None: + super().__init__() + self.dict_attr = json.loads(dict_json) + + observable = MyObservable() + dict_mapping_length = len(observable._dict_mapping) + observable.dict_attr = json.loads(dict_json) + + gc.collect() + assert len(observable._dict_mapping) <= dict_mapping_length