From d721ef05f5760f1a116bcffdc7a94527ccf0af91 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mose=20M=C3=BCller?= Date: Wed, 2 Aug 2023 12:06:21 +0200 Subject: [PATCH] refactoring the way DataService instance attributes are updated --- .../data_service/data_service.py | 198 +++++------------- src/pyDataInterface/server/web_server.py | 39 +++- src/pyDataInterface/utils/helpers.py | 191 +++++++++++++---- 3 files changed, 240 insertions(+), 188 deletions(-) diff --git a/src/pyDataInterface/data_service/data_service.py b/src/pyDataInterface/data_service/data_service.py index 441e7c7..94fca58 100644 --- a/src/pyDataInterface/data_service/data_service.py +++ b/src/pyDataInterface/data_service/data_service.py @@ -2,10 +2,9 @@ import asyncio import inspect import json import os -import re from collections.abc import Callable from enum import Enum -from typing import Any, Optional, TypedDict, cast, get_type_hints +from typing import Any, Optional, cast, get_type_hints import rpyc from loguru import logger @@ -16,8 +15,10 @@ from pyDataInterface.utils import ( ) from pyDataInterface.utils.helpers import ( convert_arguments_to_hinted_types, - generate_paths_and_values_from_serialized_DataService, - get_DataService_attr_from_path, + generate_paths_from_DataService_dict, + get_nested_value_by_path_and_key, + get_object_attr_from_path, + parse_list_attr_and_index, set_if_differs, ) @@ -25,100 +26,9 @@ from .data_service_list import DataServiceList from .task_manager import TaskManager -class UpdateDict(TypedDict): - """ - A TypedDict subclass representing a dictionary used for updating attributes in a - DataService. - - Attributes: - ---------- - name : str - The name of the attribute to be updated in the DataService instance. - If the attribute is part of a nested structure, this would be the name of the - attribute in the last nested object. For example, for an attribute access path - 'attr1.list_attr[0].attr2', 'attr2' would be the name. - - parent_path : str - The access path for the parent object of the attribute to be updated. This is - used to construct the full access path for the attribute. For example, for an - attribute access path 'attr1.list_attr[0].attr2', 'attr1.list_attr[0]' would be - the parent_path. - - value : Any - The new value to be assigned to the attribute. The type of this value should - match the type of the attribute to be updated. - """ - - name: str - parent_path: str - value: Any - - -def extract_path_list_and_name_and_index_from_UpdateDict( - data: UpdateDict, -) -> tuple[list[str], str, Optional[int]]: - path_list, attr_name = data["parent_path"].split("."), data["name"] - index: Optional[int] = None - index_search = re.search(r"\[(\d+)\]", attr_name) - if index_search: - attr_name = attr_name.split("[")[0] - index = int(index_search.group(1)) - return path_list, attr_name, index - - -def get_target_object_and_attribute( - service: "DataService", path_list: list[str], attr_name: str -) -> tuple[Any, Any]: - target_obj = get_DataService_attr_from_path(service, path_list) - attr = getattr(target_obj, attr_name, None) - if attr is None: - logger.error(f"Attribute {attr_name} not found.") - return target_obj, attr - - -def update_each_DataService_attribute( - service: "DataService", parent_path: str, data_value: dict[str, Any] -) -> None: - for key, value in data_value.items(): - update_DataService_by_path( - service, - { - "name": key, - "parent_path": parent_path, - "value": value, - }, - ) - - -def process_DataService_attribute( - service: "DataService", attr_name: str, data: UpdateDict -) -> None: - update_each_DataService_attribute( - service, - f"{data['parent_path']}.{attr_name}", - cast(dict[str, Any], data["value"]), - ) - - -def process_list_attribute( - service: "DataService", attr: list[Any], index: int, data: UpdateDict -) -> None: - if isinstance(attr[index], DataService): - update_each_DataService_attribute( - service, - f"{data['parent_path']}.{data['name']}", - cast(dict[str, Any], data["value"]), - ) - elif isinstance(attr[index], list): - logger.error("Nested lists are not supported yet.") - raise NotImplementedError - else: - set_if_differs(attr, index, data["value"]) - - -def process_callable_attribute(attr: Any, data: UpdateDict) -> Any: +def process_callable_attribute(attr: Any, args: dict[str, Any]) -> Any: converted_args_or_error_msg = convert_arguments_to_hinted_types( - data["value"]["args"], get_type_hints(attr) + args, get_type_hints(attr) ) return ( attr(**converted_args_or_error_msg) @@ -127,27 +37,6 @@ def process_callable_attribute(attr: Any, data: UpdateDict) -> Any: ) -def update_DataService_by_path(service: "DataService", data: UpdateDict) -> Any: - ( - path_list, - attr_name, - index, - ) = extract_path_list_and_name_and_index_from_UpdateDict(data) - target_obj, attr = get_target_object_and_attribute(service, path_list, attr_name) - if attr is None: - return - if isinstance(attr, DataService): - process_DataService_attribute(service, attr_name, data) - elif isinstance(attr, Enum): - set_if_differs(target_obj, attr_name, attr.__class__[data["value"]]) - elif callable(attr): - return process_callable_attribute(attr, data) - elif isinstance(attr, list) and index is not None: - process_list_attribute(service, attr, index, data) - else: - set_if_differs(target_obj, attr_name, data["value"]) - - class DataService(rpyc.Service, TaskManager): _list_mapping: dict[int, DataServiceList] = {} """ @@ -201,9 +90,7 @@ class DataService(rpyc.Service, TaskManager): with open(self._filename, "r") as f: # Load JSON data from file and update class attributes with these # values - self.set_attributes_from_serialized_representation( - cast(dict[str, Any], json.load(f)) - ) + self.load_DataService_from_JSON(cast(dict[str, Any], json.load(f))) def write_to_file(self) -> None: """ @@ -221,35 +108,19 @@ class DataService(rpyc.Service, TaskManager): 'Skipping "write_to_file"...' ) - def set_attributes_from_serialized_representation( - self, serialized_representation: dict[str, Any] - ) -> None: + def load_DataService_from_JSON(self, json_dict: dict[str, Any]) -> None: # Traverse the serialized representation and set the attributes of the class - for path, value in generate_paths_and_values_from_serialized_DataService( - serialized_representation - ).items(): - # Split the path into elements - parent_path, attr_name = f"DataService.{path}".rsplit(".", 1) + for path in generate_paths_from_DataService_dict(json_dict): + value = get_nested_value_by_path_and_key(json_dict, path=path) + value_type = get_nested_value_by_path_and_key( + json_dict, path=path, key="type" + ) - if isinstance(value, list): - for index, item in enumerate(value): - update_DataService_by_path( - self, - { - "name": f"{attr_name}[{index}]", - "parent_path": parent_path, - "value": item, - }, - ) - else: - update_DataService_by_path( - self, - { - "name": attr_name, - "parent_path": parent_path, - "value": value, - }, - ) + # Split the path into parts + parts = path.split(".") + attr_name = parts[-1] + + self.update_DataService_attribute(parts[:-1], attr_name, value, value_type) def __setattr__(self, __name: str, __value: Any) -> None: current_value = getattr(self, __name, None) @@ -765,3 +636,34 @@ class DataService(rpyc.Service, TaskManager): } return result + + def update_DataService_attribute( + self, + path_list: list[str], + attr_name: str, + value: Any, + attr_type: Optional[str] = None, + ) -> None: + # If attr_name corresponds to a list entry, extract the attr_name and the index + attr_name, index = parse_list_attr_and_index(attr_name) + + # Traverse the object according to the path parts + target_obj = get_object_attr_from_path(self, path_list) + + attr = get_object_attr_from_path(target_obj, [attr_name]) + + if attr is None: + return + + # Set the attribute at the terminal point of the path + if isinstance(attr, Enum): + set_if_differs(target_obj, attr_name, attr.__class__[value]) + elif isinstance(attr, list) and index is not None: + set_if_differs(attr, index, value) + elif isinstance(attr, DataService) and isinstance(value, dict): + for key, v in value.items(): + self.update_DataService_attribute([*path_list, attr_name], key, v) + elif callable(attr): + return process_callable_attribute(attr, value["args"]) + else: + set_if_differs(target_obj, attr_name, value) diff --git a/src/pyDataInterface/server/web_server.py b/src/pyDataInterface/server/web_server.py index 4d07397..e0fa71f 100644 --- a/src/pyDataInterface/server/web_server.py +++ b/src/pyDataInterface/server/web_server.py @@ -9,13 +9,38 @@ from loguru import logger from pyDataInterface import DataService from pyDataInterface.config import OperationMode -from pyDataInterface.data_service.data_service import ( - UpdateDict, - update_DataService_by_path, -) from pyDataInterface.version import __version__ +class UpdateDict(TypedDict): + """ + A TypedDict subclass representing a dictionary used for updating attributes in a + DataService. + + Attributes: + ---------- + name : str + The name of the attribute to be updated in the DataService instance. + If the attribute is part of a nested structure, this would be the name of the + attribute in the last nested object. For example, for an attribute access path + 'attr1.list_attr[0].attr2', 'attr2' would be the name. + + parent_path : str + The access path for the parent object of the attribute to be updated. This is + used to construct the full access path for the attribute. For example, for an + attribute access path 'attr1.list_attr[0].attr2', 'attr1.list_attr[0]' would be + the parent_path. + + value : Any + The new value to be assigned to the attribute. The type of this value should + match the type of the attribute to be updated. + """ + + name: str + parent_path: str + value: Any + + class WebAPI: __sio_app: socketio.ASGIApp __fastapi_app: FastAPI @@ -51,7 +76,11 @@ class WebAPI: @sio.event # type: ignore def frontend_update(sid: str, data: UpdateDict) -> Any: logger.debug(f"Received frontend update: {data}") - return update_DataService_by_path(self.service, data) + path_list, attr_name = data["parent_path"].split("."), data["name"] + path_list.remove("DataService") # always at the start, does not do anything + return self.service.update_DataService_attribute( + path_list=path_list, attr_name=attr_name, value=data["value"] + ) self.__sio = sio self.__sio_app = socketio.ASGIApp(self.__sio) diff --git a/src/pyDataInterface/utils/helpers.py b/src/pyDataInterface/utils/helpers.py index 04207a6..c7493d9 100644 --- a/src/pyDataInterface/utils/helpers.py +++ b/src/pyDataInterface/utils/helpers.py @@ -1,5 +1,6 @@ +import re from itertools import chain -from typing import Any +from typing import Any, Optional from loguru import logger @@ -19,7 +20,7 @@ def get_class_and_instance_attributes(obj: object) -> dict[str, Any]: return attrs -def get_DataService_attr_from_path(target_obj: Any, path: list[str]) -> Any: +def get_object_attr_from_path(target_obj: Any, path: list[str]) -> Any: """ Traverse the object tree according to the given path. @@ -36,10 +37,6 @@ def get_DataService_attr_from_path(target_obj: Any, path: list[str]) -> Any: ValueError: If a list index in the path is not a valid integer. """ for part in path: - # Skip the root object itself - if part == "DataService": - continue - try: # Try to split the part into attribute and index attr, index_str = part.split("[", maxsplit=1) @@ -56,55 +53,140 @@ def get_DataService_attr_from_path(target_obj: Any, path: list[str]) -> Any: return target_obj -def generate_paths_and_values_from_serialized_DataService( - data: dict, -) -> dict[str, Any]: +def generate_paths_from_DataService_dict( + data: dict, parent_path: str = "" +) -> list[str]: """ - Recursively generate paths from a dictionary and return a dictionary of paths and - their corresponding values. + Recursively generate paths from a dictionary representing a DataService object. - This function traverses through a nested dictionary (usually the result of a - serialization of a DataService) and generates a dictionary where the keys are the - paths to each terminal value in the original dictionary and the values are the - corresponding terminal values in the original dictionary. + This function traverses through a nested dictionary, which is typically obtained + from serializing a DataService object. The function generates a list where each + element is a string representing the path to each terminal value in the original + dictionary. - The paths are represented as string keys with dots connecting the levels and - brackets indicating list indices. + The paths are represented as strings, with dots ('.') denoting nesting levels and + square brackets ('[]') denoting list indices. Args: - data (dict): The input dictionary to generate paths and values from. - parent_path (Optional[str], optional): The current path up to the current level - of recursion. Defaults to None. + data (dict): The input dictionary to generate paths from. This is typically + obtained from serializing a DataService object. + parent_path (str, optional): The current path up to the current level of + recursion. Defaults to ''. Returns: - dict[str, Any]: A dictionary with paths as keys and corresponding values as - values. + list[str]: A list with paths as elements. + + Note: + The function ignores keys whose "type" is "method", as these represent methods of the + DataService object and not its state. + + Example: + ------- + + >>> { + ... "attr1": {"type": "int", "value": 10}, + ... "attr2": { + ... "type": "list", + ... "value": [{"type": "int", "value": 1}, {"type": "int", "value": 2}], + ... }, + ... "add": { + ... "type": "method", + ... "async": False, + ... "parameters": {"a": "float", "b": "int"}, + ... "doc": "Returns the sum of the numbers a and b.", + ... }, + ... } + >>> print(generate_paths_from_DataService_dict(nested_dict)) + [attr1, attr2[0], attr2[1]] """ - paths_and_values = {} + paths = [] for key, value in data.items(): if value["type"] == "method": # ignoring methods continue + new_path = f"{parent_path}.{key}" if parent_path else key if isinstance(value["value"], dict): - paths_and_values[ - key - ] = generate_paths_and_values_from_serialized_DataService(value["value"]) - + paths.extend(generate_paths_from_DataService_dict(value["value"], new_path)) # type: ignore elif isinstance(value["value"], list): for index, item in enumerate(value["value"]): - indexed_key_path = f"{key}[{index}]" + indexed_key_path = f"{new_path}[{index}]" if isinstance(item["value"], dict): - paths_and_values[ - indexed_key_path - ] = generate_paths_and_values_from_serialized_DataService( - item["value"] + paths.extend( # type: ignore + generate_paths_from_DataService_dict( + item["value"], indexed_key_path + ) ) else: - paths_and_values[indexed_key_path] = item["value"] # type: ignore + paths.append(indexed_key_path) # type: ignore else: - paths_and_values[key] = value["value"] # type: ignore - return paths_and_values + paths.append(new_path) # type: ignore + return paths + + +STANDARD_TYPES = ("int", "float", "bool", "str", "Enum") + + +def get_nested_value_by_path_and_key(data: dict, path: str, key: str = "value") -> Any: + """ + Get the value associated with a specific key from a dictionary given a path. + + This function traverses the dictionary according to the path provided and + returns the value associated with the specified key at that path. The path is + a string with dots connecting the levels and brackets indicating list indices. + + The function can handle complex dictionaries where data is nested within different + types of objects. It checks the type of each object it encounters and correctly + descends into the object if it is not a standard type (i.e., int, float, bool, str, + Enum). + + Args: + data (dict): The input dictionary to get the value from. + path (str): The path to the value in the dictionary. + key (str, optional): The key associated with the value to be returned. + Default is "value". + + Returns: + Any: The value associated with the specified key at the given path in the + dictionary. + + Examples: + Let's consider the following dictionary: + + >>> data = { + >>> "attr1": {"type": "int", "value": 10}, + >>> "attr2": {"type": "MyClass", "value": {"attr3": {"type": "float", "value": 20.5}}} + >>> } + + The function can be used to get the value of 'attr1' as follows: + >>> get_value_of_key_from_path(data, "attr1") + 10 + + It can also be used to get the value of 'attr3', which is nested within 'attr2', as follows: + >>> get_value_of_key_from_path(data, "attr2.attr3", "type") + float + """ + + # Split the path into parts + parts = re.split(r"\.|(?=\[\d+\])", path) # Split by '.' or '[' + + # Traverse the dictionary according to the path parts + for part in parts: + if part.startswith("["): + # List index + idx = int(part[1:-1]) # Strip the brackets and convert to integer + data = data[idx] + else: + # Dictionary key + data = data[part] + + # When the attribute is a class instance, the attributes are nested in the + # "value" key + if data["type"] not in STANDARD_TYPES: + data = data["value"] + + # Return the value at the terminal point of the path + return data[key] def convert_arguments_to_hinted_types( @@ -171,3 +253,42 @@ def set_if_differs(target: Any, attr_name: str | int, new_value: Any) -> None: setattr(target, attr_name, new_value) else: logger.error(f"Incompatible arguments: {target}, {attr_name}.") + + +def parse_list_attr_and_index(attr_string: str) -> tuple[str, Optional[int]]: + """ + Parses an attribute string and extracts a potential list attribute name and its + index. + + This function examines the provided attribute string. If the string contains square + brackets, it assumes that it's a list attribute and the string within brackets is + the index of an element. It then returns the attribute name and the index as an + integer. If no brackets are present, the function assumes it's a regular attribute + and returns the attribute name and None as the index. + + Parameters: + ----------- + attr_string: str + The attribute string to parse. Can be a regular attribute name (e.g. + 'attr_name') or a list attribute with an index (e.g. 'list_attr[2]'). + + Returns: + -------- + tuple: (str, Optional[int]) + A tuple containing the attribute name as a string and the index as an integer if + present, otherwise None. + + Example: + -------- + >>> parse_list_attr_and_index('list_attr[2]') + ('list_attr', 2) + >>> parse_list_attr_and_index('attr_name') + ('attr_name', None) + """ + + attr_name = attr_string + index = None + if "[" in attr_string and "]" in attr_string: + attr_name, idx = attr_string[:-1].split("[") + index = int(idx) + return attr_name, index