refactoring the way DataService instance attributes are updated

This commit is contained in:
Mose Müller
2023-08-02 12:06:21 +02:00
parent df8ea404ae
commit d721ef05f5
3 changed files with 240 additions and 188 deletions

View File

@ -2,10 +2,9 @@ import asyncio
import inspect
import json
import os
import re
from collections.abc import Callable
from enum import Enum
from typing import Any, Optional, TypedDict, cast, get_type_hints
from typing import Any, Optional, cast, get_type_hints
import rpyc
from loguru import logger
@ -16,8 +15,10 @@ from pyDataInterface.utils import (
)
from pyDataInterface.utils.helpers import (
convert_arguments_to_hinted_types,
generate_paths_and_values_from_serialized_DataService,
get_DataService_attr_from_path,
generate_paths_from_DataService_dict,
get_nested_value_by_path_and_key,
get_object_attr_from_path,
parse_list_attr_and_index,
set_if_differs,
)
@ -25,100 +26,9 @@ from .data_service_list import DataServiceList
from .task_manager import TaskManager
class UpdateDict(TypedDict):
"""
A TypedDict subclass representing a dictionary used for updating attributes in a
DataService.
Attributes:
----------
name : str
The name of the attribute to be updated in the DataService instance.
If the attribute is part of a nested structure, this would be the name of the
attribute in the last nested object. For example, for an attribute access path
'attr1.list_attr[0].attr2', 'attr2' would be the name.
parent_path : str
The access path for the parent object of the attribute to be updated. This is
used to construct the full access path for the attribute. For example, for an
attribute access path 'attr1.list_attr[0].attr2', 'attr1.list_attr[0]' would be
the parent_path.
value : Any
The new value to be assigned to the attribute. The type of this value should
match the type of the attribute to be updated.
"""
name: str
parent_path: str
value: Any
def extract_path_list_and_name_and_index_from_UpdateDict(
data: UpdateDict,
) -> tuple[list[str], str, Optional[int]]:
path_list, attr_name = data["parent_path"].split("."), data["name"]
index: Optional[int] = None
index_search = re.search(r"\[(\d+)\]", attr_name)
if index_search:
attr_name = attr_name.split("[")[0]
index = int(index_search.group(1))
return path_list, attr_name, index
def get_target_object_and_attribute(
service: "DataService", path_list: list[str], attr_name: str
) -> tuple[Any, Any]:
target_obj = get_DataService_attr_from_path(service, path_list)
attr = getattr(target_obj, attr_name, None)
if attr is None:
logger.error(f"Attribute {attr_name} not found.")
return target_obj, attr
def update_each_DataService_attribute(
service: "DataService", parent_path: str, data_value: dict[str, Any]
) -> None:
for key, value in data_value.items():
update_DataService_by_path(
service,
{
"name": key,
"parent_path": parent_path,
"value": value,
},
)
def process_DataService_attribute(
service: "DataService", attr_name: str, data: UpdateDict
) -> None:
update_each_DataService_attribute(
service,
f"{data['parent_path']}.{attr_name}",
cast(dict[str, Any], data["value"]),
)
def process_list_attribute(
service: "DataService", attr: list[Any], index: int, data: UpdateDict
) -> None:
if isinstance(attr[index], DataService):
update_each_DataService_attribute(
service,
f"{data['parent_path']}.{data['name']}",
cast(dict[str, Any], data["value"]),
)
elif isinstance(attr[index], list):
logger.error("Nested lists are not supported yet.")
raise NotImplementedError
else:
set_if_differs(attr, index, data["value"])
def process_callable_attribute(attr: Any, data: UpdateDict) -> Any:
def process_callable_attribute(attr: Any, args: dict[str, Any]) -> Any:
converted_args_or_error_msg = convert_arguments_to_hinted_types(
data["value"]["args"], get_type_hints(attr)
args, get_type_hints(attr)
)
return (
attr(**converted_args_or_error_msg)
@ -127,27 +37,6 @@ def process_callable_attribute(attr: Any, data: UpdateDict) -> Any:
)
def update_DataService_by_path(service: "DataService", data: UpdateDict) -> Any:
(
path_list,
attr_name,
index,
) = extract_path_list_and_name_and_index_from_UpdateDict(data)
target_obj, attr = get_target_object_and_attribute(service, path_list, attr_name)
if attr is None:
return
if isinstance(attr, DataService):
process_DataService_attribute(service, attr_name, data)
elif isinstance(attr, Enum):
set_if_differs(target_obj, attr_name, attr.__class__[data["value"]])
elif callable(attr):
return process_callable_attribute(attr, data)
elif isinstance(attr, list) and index is not None:
process_list_attribute(service, attr, index, data)
else:
set_if_differs(target_obj, attr_name, data["value"])
class DataService(rpyc.Service, TaskManager):
_list_mapping: dict[int, DataServiceList] = {}
"""
@ -201,9 +90,7 @@ class DataService(rpyc.Service, TaskManager):
with open(self._filename, "r") as f:
# Load JSON data from file and update class attributes with these
# values
self.set_attributes_from_serialized_representation(
cast(dict[str, Any], json.load(f))
)
self.load_DataService_from_JSON(cast(dict[str, Any], json.load(f)))
def write_to_file(self) -> None:
"""
@ -221,35 +108,19 @@ class DataService(rpyc.Service, TaskManager):
'Skipping "write_to_file"...'
)
def set_attributes_from_serialized_representation(
self, serialized_representation: dict[str, Any]
) -> None:
def load_DataService_from_JSON(self, json_dict: dict[str, Any]) -> None:
# Traverse the serialized representation and set the attributes of the class
for path, value in generate_paths_and_values_from_serialized_DataService(
serialized_representation
).items():
# Split the path into elements
parent_path, attr_name = f"DataService.{path}".rsplit(".", 1)
for path in generate_paths_from_DataService_dict(json_dict):
value = get_nested_value_by_path_and_key(json_dict, path=path)
value_type = get_nested_value_by_path_and_key(
json_dict, path=path, key="type"
)
if isinstance(value, list):
for index, item in enumerate(value):
update_DataService_by_path(
self,
{
"name": f"{attr_name}[{index}]",
"parent_path": parent_path,
"value": item,
},
)
else:
update_DataService_by_path(
self,
{
"name": attr_name,
"parent_path": parent_path,
"value": value,
},
)
# Split the path into parts
parts = path.split(".")
attr_name = parts[-1]
self.update_DataService_attribute(parts[:-1], attr_name, value, value_type)
def __setattr__(self, __name: str, __value: Any) -> None:
current_value = getattr(self, __name, None)
@ -765,3 +636,34 @@ class DataService(rpyc.Service, TaskManager):
}
return result
def update_DataService_attribute(
self,
path_list: list[str],
attr_name: str,
value: Any,
attr_type: Optional[str] = None,
) -> None:
# If attr_name corresponds to a list entry, extract the attr_name and the index
attr_name, index = parse_list_attr_and_index(attr_name)
# Traverse the object according to the path parts
target_obj = get_object_attr_from_path(self, path_list)
attr = get_object_attr_from_path(target_obj, [attr_name])
if attr is None:
return
# Set the attribute at the terminal point of the path
if isinstance(attr, Enum):
set_if_differs(target_obj, attr_name, attr.__class__[value])
elif isinstance(attr, list) and index is not None:
set_if_differs(attr, index, value)
elif isinstance(attr, DataService) and isinstance(value, dict):
for key, v in value.items():
self.update_DataService_attribute([*path_list, attr_name], key, v)
elif callable(attr):
return process_callable_attribute(attr, value["args"])
else:
set_if_differs(target_obj, attr_name, value)