diff --git a/src/pyDataInterface/data_service/data_service.py b/src/pyDataInterface/data_service/data_service.py index 438aa56..2784ee4 100644 --- a/src/pyDataInterface/data_service/data_service.py +++ b/src/pyDataInterface/data_service/data_service.py @@ -1,6 +1,11 @@ +import asyncio import inspect +import json +import os +import re from collections.abc import Callable -from typing import Any +from enum import Enum +from typing import Any, Optional, TypedDict, cast, get_type_hints import rpyc from loguru import logger @@ -9,13 +14,140 @@ from pyDataInterface.utils import ( get_class_and_instance_attributes, warn_if_instance_class_does_not_inherit_from_DataService, ) +from pyDataInterface.utils.helpers import ( + convert_arguments_to_hinted_types, + generate_paths_and_values_from_serialized_DataService, + get_DataService_attr_from_path, + set_if_differs, +) from .data_service_list import DataServiceList from .data_service_serializer import DataServiceSerializer from .task_manager import TaskManager -class DataService(rpyc.Service, TaskManager, DataServiceSerializer): +class UpdateDict(TypedDict): + """ + A TypedDict subclass representing a dictionary used for updating attributes in a + DataService. + + Attributes: + ---------- + name : str + The name of the attribute to be updated in the DataService instance. + If the attribute is part of a nested structure, this would be the name of the + attribute in the last nested object. For example, for an attribute access path + 'attr1.list_attr[0].attr2', 'attr2' would be the name. + + parent_path : str + The access path for the parent object of the attribute to be updated. This is + used to construct the full access path for the attribute. For example, for an + attribute access path 'attr1.list_attr[0].attr2', 'attr1.list_attr[0]' would be + the parent_path. + + value : Any + The new value to be assigned to the attribute. The type of this value should + match the type of the attribute to be updated. + """ + + name: str + parent_path: str + value: Any + + +def extract_path_list_and_name_and_index_from_UpdateDict( + data: UpdateDict, +) -> tuple[list[str], str, Optional[int]]: + path_list, attr_name = data["parent_path"].split("."), data["name"] + index: Optional[int] = None + index_search = re.search(r"\[(\d+)\]", attr_name) + if index_search: + attr_name = attr_name.split("[")[0] + index = int(index_search.group(1)) + return path_list, attr_name, index + + +def get_target_object_and_attribute( + service: "DataService", path_list: list[str], attr_name: str +) -> tuple[Any, Any]: + target_obj = get_DataService_attr_from_path(service, path_list) + attr = getattr(target_obj, attr_name, None) + if attr is None: + logger.error(f"Attribute {attr_name} not found.") + return target_obj, attr + + +def update_each_DataService_attribute( + service: "DataService", parent_path: str, data_value: dict[str, Any] +) -> None: + for key, value in data_value.items(): + update_DataService_by_path( + service, + { + "name": key, + "parent_path": parent_path, + "value": value, + }, + ) + + +def process_DataService_attribute( + service: "DataService", attr_name: str, data: UpdateDict +) -> None: + update_each_DataService_attribute( + service, + f"{data['parent_path']}.{attr_name}", + cast(dict[str, Any], data["value"]), + ) + + +def process_list_attribute( + service: "DataService", attr: list[Any], index: int, data: UpdateDict +) -> None: + if isinstance(attr[index], DataService): + update_each_DataService_attribute( + service, + f"{data['parent_path']}.{data['name']}", + cast(dict[str, Any], data["value"]), + ) + elif isinstance(attr[index], list): + logger.error("Nested lists are not supported yet.") + raise NotImplementedError + else: + set_if_differs(attr, index, data["value"]) + + +def process_callable_attribute(attr: Any, data: UpdateDict) -> Any: + converted_args_or_error_msg = convert_arguments_to_hinted_types( + data["value"]["args"], get_type_hints(attr) + ) + return ( + attr(**converted_args_or_error_msg) + if not isinstance(converted_args_or_error_msg, str) + else converted_args_or_error_msg + ) + + +def update_DataService_by_path(service: "DataService", data: UpdateDict) -> Any: + ( + path_list, + attr_name, + index, + ) = extract_path_list_and_name_and_index_from_UpdateDict(data) + target_obj, attr = get_target_object_and_attribute(service, path_list, attr_name) + if attr is None: + return + if isinstance(attr, DataService): + process_DataService_attribute(service, attr_name, data) + elif isinstance(attr, Enum): + set_if_differs(target_obj, attr_name, attr.__class__[data["value"]]) + elif callable(attr): + return process_callable_attribute(attr, data) + elif isinstance(attr, list) and index is not None: + process_list_attribute(service, attr, index, data) + else: + set_if_differs(target_obj, attr_name, data["value"]) + _list_mapping: dict[int, DataServiceList] = {} """ A dictionary mapping the id of the original lists to the corresponding diff --git a/src/pyDataInterface/server/web_server.py b/src/pyDataInterface/server/web_server.py index 2e75e0a..4d07397 100644 --- a/src/pyDataInterface/server/web_server.py +++ b/src/pyDataInterface/server/web_server.py @@ -9,18 +9,13 @@ from loguru import logger from pyDataInterface import DataService from pyDataInterface.config import OperationMode -from pyDataInterface.utils.apply_update_to_data_service import ( - apply_updates_to_data_service, +from pyDataInterface.data_service.data_service import ( + UpdateDict, + update_DataService_by_path, ) from pyDataInterface.version import __version__ -class FrontendUpdate(TypedDict): - name: str - parent_path: str - value: Any - - class WebAPI: __sio_app: socketio.ASGIApp __fastapi_app: FastAPI @@ -54,9 +49,9 @@ class WebAPI: sio = socketio.AsyncServer(async_mode="asgi") @sio.event # type: ignore - def frontend_update(sid: str, data: FrontendUpdate) -> Any: + def frontend_update(sid: str, data: UpdateDict) -> Any: logger.debug(f"Received frontend update: {data}") - return apply_updates_to_data_service(self.service, data) + return update_DataService_by_path(self.service, data) self.__sio = sio self.__sio_app = socketio.ASGIApp(self.__sio) diff --git a/src/pyDataInterface/utils/apply_update_to_data_service.py b/src/pyDataInterface/utils/apply_update_to_data_service.py deleted file mode 100644 index ea9bbe5..0000000 --- a/src/pyDataInterface/utils/apply_update_to_data_service.py +++ /dev/null @@ -1,67 +0,0 @@ -import re -from enum import Enum -from typing import Any, Optional, TypedDict, get_type_hints - -from loguru import logger - -from pyDataInterface.data_service.data_service import DataService - -from .helpers import get_attr_from_path - - -class UpdateDictionary(TypedDict): - name: str - """Name of the attribute.""" - parent_path: str - """Full access path of the attribute.""" - value: Any - """New value of the attribute.""" - - -def apply_updates_to_data_service(service: Any, data: UpdateDictionary) -> Any: - parent_path = data["parent_path"].split(".") - attr_name = data["name"] - - # Traverse the object tree according to parent_path - target_obj = get_attr_from_path(service, parent_path) - - # Check if attr_name contains an index for a list item - index: Optional[int] = None - if re.search(r"\[.*\]", attr_name): - attr_name, index_str = attr_name.split("[") - try: - index = int( - index_str.replace("]", "") - ) # Remove closing bracket and convert to int - except ValueError: - logger.error(f"Invalid list index: {index_str}") - return - - attr = getattr(target_obj, attr_name) - - if isinstance(attr, DataService): - attr.apply_updates(data["value"]) - elif isinstance(attr, Enum): - setattr(service, data["name"], attr.__class__[data["value"]["value"]]) - elif callable(attr): - args: dict[str, Any] = data["value"]["args"] - type_hints = get_type_hints(attr) - - # Convert arguments to their hinted types - for arg_name, arg_value in args.items(): - if arg_name in type_hints: - arg_type = type_hints[arg_name] - if isinstance(arg_type, type): - # Attempt to convert the argument to its hinted type - try: - args[arg_name] = arg_type(arg_value) - except ValueError: - msg = f"Failed to convert argument '{arg_name}' to type {arg_type.__name__}" - logger.error(msg) - return msg - - return attr(**args) - elif isinstance(attr, list): - attr[index] = data["value"] - else: - setattr(target_obj, attr_name, data["value"]) diff --git a/src/pyDataInterface/utils/helpers.py b/src/pyDataInterface/utils/helpers.py index a75a81e..04207a6 100644 --- a/src/pyDataInterface/utils/helpers.py +++ b/src/pyDataInterface/utils/helpers.py @@ -1,4 +1,3 @@ -import re from itertools import chain from typing import Any @@ -20,7 +19,7 @@ def get_class_and_instance_attributes(obj: object) -> dict[str, Any]: return attrs -def get_attr_from_path(target_obj: Any, path: list[str]) -> Any: +def get_DataService_attr_from_path(target_obj: Any, path: list[str]) -> Any: """ Traverse the object tree according to the given path. @@ -30,24 +29,145 @@ def get_attr_from_path(target_obj: Any, path: list[str]) -> Any: Returns: The attribute at the end of the path. If the path includes a list index, - the function returns the specific item at that index. + the function returns the specific item at that index. If an attribute in + the path does not exist, the function logs a debug message and returns None. Raises: ValueError: If a list index in the path is not a valid integer. """ for part in path: - if part != "DataService": # Skip the root object itself - # Check if part contains an index for a list item - if re.search(r"\[.*\]", part): - attr, index_str = part.split("[") - try: - index = int( - index_str.replace("]", "") - ) # Remove closing bracket and convert to int - except ValueError: - logger.error(f"Invalid list index: {index_str}") - raise ValueError(f"Invalid list index: {index_str}") - target_obj = getattr(target_obj, attr)[index] - else: - target_obj = getattr(target_obj, part) + # Skip the root object itself + if part == "DataService": + continue + + try: + # Try to split the part into attribute and index + attr, index_str = part.split("[", maxsplit=1) + index_str = index_str.replace("]", "") + index = int(index_str) + target_obj = getattr(target_obj, attr)[index] + except ValueError: + # No index, so just get the attribute + target_obj = getattr(target_obj, part) + except AttributeError: + # The attribute doesn't exist + logger.debug(f"Attribute {part} does not exist in the object.") + return None return target_obj + + +def generate_paths_and_values_from_serialized_DataService( + data: dict, +) -> dict[str, Any]: + """ + Recursively generate paths from a dictionary and return a dictionary of paths and + their corresponding values. + + This function traverses through a nested dictionary (usually the result of a + serialization of a DataService) and generates a dictionary where the keys are the + paths to each terminal value in the original dictionary and the values are the + corresponding terminal values in the original dictionary. + + The paths are represented as string keys with dots connecting the levels and + brackets indicating list indices. + + Args: + data (dict): The input dictionary to generate paths and values from. + parent_path (Optional[str], optional): The current path up to the current level + of recursion. Defaults to None. + + Returns: + dict[str, Any]: A dictionary with paths as keys and corresponding values as + values. + """ + + paths_and_values = {} + for key, value in data.items(): + if value["type"] == "method": + # ignoring methods + continue + if isinstance(value["value"], dict): + paths_and_values[ + key + ] = generate_paths_and_values_from_serialized_DataService(value["value"]) + + elif isinstance(value["value"], list): + for index, item in enumerate(value["value"]): + indexed_key_path = f"{key}[{index}]" + if isinstance(item["value"], dict): + paths_and_values[ + indexed_key_path + ] = generate_paths_and_values_from_serialized_DataService( + item["value"] + ) + else: + paths_and_values[indexed_key_path] = item["value"] # type: ignore + else: + paths_and_values[key] = value["value"] # type: ignore + return paths_and_values + + +def convert_arguments_to_hinted_types( + args: dict[str, Any], type_hints: dict[str, Any] +) -> dict[str, Any] | str: + """ + Convert the given arguments to their types hinted in the type_hints dictionary. + + This function attempts to convert each argument in the args dictionary to the type + specified for the argument in the type_hints dictionary. If the conversion is + successful, the function replaces the original argument in the args dictionary with + the converted argument. + + If a ValueError is raised during the conversion of an argument, the function logs + an error message and returns the error message as a string. + + Args: + args: A dictionary of arguments to be converted. The keys are argument names + and the values are the arguments themselves. + type_hints: A dictionary of type hints for the arguments. The keys are + argument names and the values are the hinted types. + + Returns: + A dictionary of the converted arguments if all conversions are successful, + or an error message string if a ValueError is raised during a conversion. + """ + + # Convert arguments to their hinted types + for arg_name, arg_value in args.items(): + if arg_name in type_hints: + arg_type = type_hints[arg_name] + if isinstance(arg_type, type): + # Attempt to convert the argument to its hinted type + try: + args[arg_name] = arg_type(arg_value) + except ValueError: + msg = ( + f"Failed to convert argument '{arg_name}' to type " + f"{arg_type.__name__}" + ) + logger.error(msg) + return msg + return args + + +def set_if_differs(target: Any, attr_name: str | int, new_value: Any) -> None: + """ + Set the value of an attribute or a list element on a target object to a new value, + but only if the current value of the attribute or the list element differs from the + new value. + + Args: + target: The object that has the attribute or the list. + attr_name: The name of the attribute or the index of the list element. + new_value: The new value for the attribute or the list element. + """ + if isinstance(target, list) and isinstance(attr_name, int): + # Case for a list + if target[attr_name] != new_value: + target[attr_name] = new_value + elif isinstance(attr_name, str): + # Case for an attribute + if getattr(target, attr_name) != new_value: + setattr(target, attr_name, new_value) + else: + logger.error(f"Incompatible arguments: {target}, {attr_name}.")