diff --git a/src/pydase/data_service/data_service.py b/src/pydase/data_service/data_service.py index 0735237..fd7e904 100644 --- a/src/pydase/data_service/data_service.py +++ b/src/pydase/data_service/data_service.py @@ -17,7 +17,7 @@ from pydase.utils.helpers import ( generate_paths_from_DataService_dict, get_class_and_instance_attributes, get_component_class_names, - get_nested_value_by_path_and_key, + get_nested_value_from_DataService_by_path_and_key, get_object_attr_from_path, parse_list_attr_and_index, update_value_if_changed, @@ -133,16 +133,20 @@ class DataService(rpyc.Service, AbstractDataService): # Traverse the serialized representation and set the attributes of the class serialized_class = self.serialize() for path in generate_paths_from_DataService_dict(json_dict): - value = get_nested_value_by_path_and_key(json_dict, path=path) - value_type = get_nested_value_by_path_and_key( + value = get_nested_value_from_DataService_by_path_and_key( + json_dict, path=path + ) + value_type = get_nested_value_from_DataService_by_path_and_key( json_dict, path=path, key="type" ) - class_value_type = get_nested_value_by_path_and_key( + class_value_type = get_nested_value_from_DataService_by_path_and_key( serialized_class, path=path, key="type" ) if class_value_type == value_type: - class_attr_is_read_only = get_nested_value_by_path_and_key( - serialized_class, path=path, key="readonly" + class_attr_is_read_only = ( + get_nested_value_from_DataService_by_path_and_key( + serialized_class, path=path, key="readonly" + ) ) if class_attr_is_read_only: logger.debug( diff --git a/src/pydase/utils/helpers.py b/src/pydase/utils/helpers.py index f48bcf0..c3b1d00 100644 --- a/src/pydase/utils/helpers.py +++ b/src/pydase/utils/helpers.py @@ -1,6 +1,6 @@ import re from itertools import chain -from typing import Any, Optional +from typing import Any, Optional, cast from loguru import logger @@ -126,7 +126,86 @@ def generate_paths_from_DataService_dict( return paths -def get_nested_value_by_path_and_key(data: dict, path: str, key: str = "value") -> Any: +def extract_dict_or_list_entry(data: dict[str, Any], key: str) -> dict[str, Any] | None: + """ + Extract a nested dictionary or list entry based on the provided key. + + Given a dictionary and a key, this function retrieves the corresponding nested + dictionary or list entry. If the key includes an index in the format "[]", + the function assumes that the corresponding entry in the dictionary is a list, and + it will attempt to retrieve the indexed item from that list. + + Args: + data (dict): The input dictionary containing nested dictionaries or lists. + key (str): The key specifying the desired entry within the dictionary. The key + can be a regular dictionary key or can include an index in the format + "[]" to retrieve an item from a nested list. + + Returns: + dict | None: The nested dictionary or list item found for the given key. If the + key is invalid, or if the specified index is out of bounds for a list, it + returns None. + + Example: + >>> data = { + ... "attr1": [ + ... {"type": "int", "value": 10}, {"type": "string", "value": "hello"} + ... ], + ... "attr2": { + ... "type": "MyClass", + ... "value": {"sub_attr": {"type": "float", "value": 20.5}} + ... } + ... } + + >>> extract_dict_or_list_entry(data, "attr1[1]") + {"type": "string", "value": "hello"} + + >>> extract_dict_or_list_entry(data, "attr2") + {"type": "MyClass", "value": {"sub_attr": {"type": "float", "value": 20.5}}} + """ + + attr_name = key + index: Optional[int] = None + + # Check if the key contains an index part like '[]' + if "[" in key and key.endswith("]"): + attr_name, index_part = key.split("[", 1) + index_part = index_part.rstrip("]") # remove the closing bracket + + # Convert the index part to an integer + if index_part.isdigit(): + index = int(index_part) + else: + logger.error(f"Invalid index format in key: {key}") + + current_data: dict[str, Any] | list[dict[str, Any]] | None = data.get( + attr_name, None + ) + if not isinstance(current_data, dict): + # key does not exist in dictionary, e.g. when class does not have this + # attribute + return None + + if isinstance(current_data["value"], list): + current_data = current_data["value"] + + if index is not None and 0 <= index < len(current_data): + current_data = current_data[index] + else: + return None + + # When the attribute is a class instance, the attributes are nested in the + # "value" key + if current_data["type"] not in STANDARD_TYPES: + current_data = cast(dict[str, Any], current_data.get("value", None)) + assert isinstance(current_data, dict) + + return current_data + + +def get_nested_value_from_DataService_by_path_and_key( + data: dict[str, Any], path: str, key: str = "value" +) -> Any: """ Get the value associated with a specific key from a dictionary given a path. @@ -161,35 +240,26 @@ def get_nested_value_by_path_and_key(data: dict, path: str, key: str = "value") >>> } The function can be used to get the value of 'attr1' as follows: - >>> get_value_of_key_from_path(data, "attr1") + >>> get_nested_value_by_path_and_key(data, "attr1") 10 It can also be used to get the value of 'attr3', which is nested within 'attr2', as follows: - >>> get_value_of_key_from_path(data, "attr2.attr3", "type") + >>> get_nested_value_by_path_and_key(data, "attr2.attr3", "type") float """ # Split the path into parts - parts = re.split(r"\.|(?=\[\d+\])", path) # Split by '.' or '[' + parts: list[str] = re.split(r"\.", path) # Split by '.' + current_data: dict[str, Any] | None = data - # Traverse the dictionary according to the path parts for part in parts: - if part.startswith("["): - # List index - idx = int(part[1:-1]) # Strip the brackets and convert to integer - data = data[idx] - else: - # Dictionary key - data = data[part] + if current_data is None: + return + current_data = extract_dict_or_list_entry(current_data, part) - # When the attribute is a class instance, the attributes are nested in the - # "value" key - if data["type"] not in STANDARD_TYPES: - data = data["value"] - - # Return the value at the terminal point of the path - return data[key] + if isinstance(current_data, dict): + return current_data.get(key, None) def convert_arguments_to_hinted_types(