Merge pull request #126 from tiqi-group/fix/memory_leak

Fix memory leak in ObservableObject
This commit is contained in:
Mose Müller 2024-05-21 13:51:03 +02:00 committed by GitHub
commit f783d0b25c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 67 additions and 8 deletions

View File

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "pydase" name = "pydase"
version = "0.8.2" version = "0.8.3"
description = "A flexible and robust Python library for creating, managing, and interacting with data services, with built-in support for web and RPC servers, and customizable features for diverse use cases." description = "A flexible and robust Python library for creating, managing, and interacting with data services, with built-in support for web and RPC servers, and customizable features for diverse use cases."
authors = ["Mose Mueller <mosmuell@ethz.ch>"] authors = ["Mose Mueller <mosmuell@ethz.ch>"]
readme = "README.md" readme = "README.md"

View File

@ -1,4 +1,5 @@
import logging import logging
import weakref
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Iterable from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, ClassVar, SupportsIndex from typing import TYPE_CHECKING, Any, ClassVar, SupportsIndex
@ -12,8 +13,8 @@ logger = logging.getLogger(__name__)
class ObservableObject(ABC): class ObservableObject(ABC):
_list_mapping: ClassVar[dict[int, "_ObservableList"]] = {} _list_mapping: ClassVar[dict[int, weakref.ReferenceType["_ObservableList"]]] = {}
_dict_mapping: ClassVar[dict[int, "_ObservableDict"]] = {} _dict_mapping: ClassVar[dict[int, weakref.ReferenceType["_ObservableDict"]]] = {}
def __init__(self) -> None: def __init__(self) -> None:
if not hasattr(self, "_observers"): if not hasattr(self, "_observers"):
@ -88,19 +89,23 @@ class ObservableObject(ABC):
if isinstance(value, list): if isinstance(value, list):
if id(value) in self._list_mapping: if id(value) in self._list_mapping:
# If the list `value` was already referenced somewhere else # If the list `value` was already referenced somewhere else
new_value = self._list_mapping[id(value)] new_value = self._list_mapping[id(value)]()
else: else:
# convert the builtin list into a ObservableList # convert the builtin list into a ObservableList
new_value = _ObservableList(original_list=value) new_value = _ObservableList(original_list=value)
self._list_mapping[id(value)] = new_value
# Use weakref to allow the GC to collect unused objects
self._list_mapping[id(value)] = weakref.ref(new_value)
elif isinstance(value, dict): elif isinstance(value, dict):
if id(value) in self._dict_mapping: if id(value) in self._dict_mapping:
# If the dict `value` was already referenced somewhere else # If the dict `value` was already referenced somewhere else
new_value = self._dict_mapping[id(value)] new_value = self._dict_mapping[id(value)]()
else: else:
# convert the builtin list into a ObservableList # convert the builtin dict into a ObservableDict
new_value = _ObservableDict(original_dict=value) new_value = _ObservableDict(original_dict=value)
self._dict_mapping[id(value)] = new_value
# Use weakref to allow the GC to collect unused objects
self._dict_mapping[id(value)] = weakref.ref(new_value)
if isinstance(new_value, ObservableObject): if isinstance(new_value, ObservableObject):
new_value.add_observer(self, attr_name_or_key) new_value.add_observer(self, attr_name_or_key)
return new_value return new_value
@ -139,6 +144,9 @@ class _ObservableList(ObservableObject, list[Any]):
for i, item in enumerate(self._original_list): for i, item in enumerate(self._original_list):
super().__setitem__(i, self._initialise_new_objects(f"[{i}]", item)) super().__setitem__(i, self._initialise_new_objects(f"[{i}]", item))
def __del__(self) -> None:
self._list_mapping.pop(id(self._original_list))
def __setitem__(self, key: int, value: Any) -> None: # type: ignore[override] def __setitem__(self, key: int, value: Any) -> None: # type: ignore[override]
if hasattr(self, "_observers"): if hasattr(self, "_observers"):
self._remove_observer_if_observable(f"[{key}]") self._remove_observer_if_observable(f"[{key}]")
@ -237,6 +245,9 @@ class _ObservableDict(ObservableObject, dict[str, Any]):
for key, value in self._original_dict.items(): for key, value in self._original_dict.items():
self.__setitem__(key, self._initialise_new_objects(f'["{key}"]', value)) self.__setitem__(key, self._initialise_new_objects(f'["{key}"]', value))
def __del__(self) -> None:
self._dict_mapping.pop(id(self._original_dict))
def __setitem__(self, key: str, value: Any) -> None: def __setitem__(self, key: str, value: Any) -> None:
if not isinstance(key, str): if not isinstance(key, str):
raise ValueError( raise ValueError(

View File

@ -311,3 +311,51 @@ def test_list_remove(caplog: pytest.LogCaptureFixture) -> None:
# checks if observer key was updated correctly (was index 1) # checks if observer key was updated correctly (was index 1)
other_observable_instance_2.greeting = "Ciao" other_observable_instance_2.greeting = "Ciao"
assert "'my_list[0].greeting' changed to 'Ciao'" in caplog.text assert "'my_list[0].greeting' changed to 'Ciao'" in caplog.text
def test_list_garbage_collection() -> None:
"""Makes sure that the GC collects lists that are not referenced anymore."""
import gc
import json
list_json = """
[1]
"""
class MyObservable(Observable):
def __init__(self) -> None:
super().__init__()
self.list_attr = json.loads(list_json)
observable = MyObservable()
list_mapping_length = len(observable._list_mapping)
observable.list_attr = json.loads(list_json)
gc.collect()
assert len(observable._list_mapping) <= list_mapping_length
def test_dict_garbage_collection() -> None:
"""Makes sure that the GC collects dicts that are not referenced anymore."""
import gc
import json
dict_json = """
{
"foo": "bar"
}
"""
class MyObservable(Observable):
def __init__(self) -> None:
super().__init__()
self.dict_attr = json.loads(dict_json)
observable = MyObservable()
dict_mapping_length = len(observable._dict_mapping)
observable.dict_attr = json.loads(dict_json)
gc.collect()
assert len(observable._dict_mapping) <= dict_mapping_length