diff --git a/pipelines/steps/adjust_uncertainty_column_in_nas_file.py b/pipelines/steps/adjust_uncertainty_column_in_nas_file.py index 93181f9..5eb161f 100644 --- a/pipelines/steps/adjust_uncertainty_column_in_nas_file.py +++ b/pipelines/steps/adjust_uncertainty_column_in_nas_file.py @@ -21,7 +21,29 @@ import pandas as pd from dima.instruments.readers.nasa_ames_reader import read_nasa_ames_as_dict from pipelines.steps.utils import compute_uncertainty_estimate -def main(path_to_data_file, base_column_name): +def main(path_to_data_file, base_column_names : list): + + """Adjust the error or uncertainty columns of data table, where data table is available + in input nas file and by specifying a list of columns to adjust. + + Parameters: + ------------ + path_to_data_file (str) : Path to nas file + base_column_names (list) : list of column names + + Raises + ------ + RuntimeError + _description_ + ValueError + _description_ + ValueError + _description_ + ValueError + _description_ + RuntimeError + _description_ + """ if not path_to_data_file.endswith('.nas'): @@ -42,67 +64,79 @@ def main(path_to_data_file, base_column_name): data_table = dataset['data'] # structured numpy array df = pd.DataFrame(data_table) - if base_column_name not in df.columns: - raise ValueError(f"Base column '{base_column_name}' not found in dataset.") + if any(col not in df.columns for col in base_column_names): + raise ValueError(f"Base column '{col}' not found in dataset.") + + # filter out columns with name starting in 'err_' + base_column_names_cleaned = [col for col in base_column_names if not col.startswith('err_')] - err_column = f"err_{base_column_name}" - if err_column not in df.columns: - raise ValueError(f"Column '{err_column}' not found in dataset.") - - # Apply callback to base column - - err_index = data_table.dtype.names.index(err_column) # Read original lines from file with open(path_to_data_file, 'rb') as file: raw_lines = file.readlines() header_length = header_metadata_dict['header_length'] - data_table_lines = [] + + - # Iterate through data table lines - cnt = 0 - for line_idx in range(len(raw_lines)): - if line_idx >= header_length - 1: - line = raw_lines[line_idx] - fields = list(re.finditer(rb'\S+', line)) + for col in base_column_names_cleaned: + + data_table_lines = [] + base_column_name = col - if err_index < len(fields): - match = fields[err_index] - original_bytes = match.group() - original_str = original_bytes.decode('utf-8') + err_column = f"err_{base_column_name}" + if err_column not in df.columns: + raise ValueError(f"Column '{err_column}' not found in dataset.") - # Skip column header or fill values - clean_original_str = original_str.strip().replace('.', '') - if err_column in original_str: - data_table_lines.append(line) - continue + # Apply callback to base column - # Decimal precision - decimals = len(original_str.split('.')[1]) if '.' in original_str else 0 + err_index = data_table.dtype.names.index(err_column) + # Iterate through data table lines + cnt = 0 + for line_idx in range(len(raw_lines)): + if line_idx >= header_length - 1: + line = raw_lines[line_idx] + fields = list(re.finditer(rb'\S+', line)) - try: - original_err = float(original_str) - if not (clean_original_str and all(c == '9' for c in clean_original_str)): - additional_term = df.loc[cnt, base_column_name] - updated_value = compute_uncertainty_estimate(additional_term, original_err) - else: # if original value is missing, then keep the same - updated_value = original_err - except Exception as e: - raise RuntimeError(f"Error calculating updated value on line {line_idx}: {e}") + if err_index < len(fields): + match = fields[err_index] + original_bytes = match.group() + original_str = original_bytes.decode('utf-8') - # Preserve width and precision - start, end = match.span() - width = end - start - formatted_str = f"{updated_value:.{decimals}f}" + # Skip column header or fill values + clean_original_str = original_str.strip().replace('.', '') + if err_column in original_str: + data_table_lines.append(line) + continue - if len(formatted_str) > width: - print(f"Warning: formatted value '{formatted_str}' too wide for field of width {width} at line {line_idx}. Value may be truncated.") + # Decimal precision + decimals = len(original_str.split('.')[1]) if '.' in original_str else 0 - formatted_bytes = formatted_str.rjust(width).encode('utf-8') - new_line = line[:start] + formatted_bytes + line[end:] - data_table_lines.append(new_line) - cnt += 1 + try: + original_err = float(original_str) + if not (clean_original_str and all(c == '9' for c in clean_original_str)): + additional_term = df.loc[cnt, base_column_name] + updated_value = compute_uncertainty_estimate(additional_term, original_err) + else: # if original value is missing, then keep the same + updated_value = original_err + except Exception as e: + raise RuntimeError(f"Error calculating updated value on line {line_idx}: {e}") + + # Preserve width and precision + start, end = match.span() + width = end - start + formatted_str = f"{updated_value:.{decimals}f}" + + if len(formatted_str) > width: + print(f"Warning: formatted value '{formatted_str}' too wide for field of width {width} at line {line_idx}. Value may be truncated.") + + formatted_bytes = formatted_str.rjust(width).encode('utf-8') + new_line = line[:start] + formatted_bytes + line[end:] + data_table_lines.append(new_line) + cnt += 1 + # update raw lines + for line_idx in range(header_length - 1, len(raw_lines)): + raw_lines[line_idx] = data_table_lines[line_idx - header_length + 1] # Reconstruct the file processed_lines = ( diff --git a/pipelines/steps/utils.py b/pipelines/steps/utils.py index 197c7cd..95c8f78 100644 --- a/pipelines/steps/utils.py +++ b/pipelines/steps/utils.py @@ -93,8 +93,28 @@ def generate_missing_value_code(max_val, num_decimals): return missing_code -def compute_uncertainty_estimate(x,x_err): - return ((0.5*x_err)**2+(0.5*x)**2)**0.5 +import math +import numpy as np + +def compute_uncertainty_estimate(x, x_err): + """ + Computes uncertainty estimate: sqrt((0.5 * x_err)^2 + (0.5 * x)^2) + for scalar inputs. Prints errors if inputs are invalid. + """ + try: + x = float(x) + x_err = float(x_err) + + if math.isnan(x) or math.isnan(x_err): + print(f"Warning: One or both inputs are NaN -> x: {x}, x_err: {x_err}") + return np.nan + + return math.sqrt((0.5 * x_err)**2 + (0.5 * x)**2) + + except (ValueError, TypeError) as e: + print(f"Error computing uncertainty for x: {x}, x_err: {x_err} -> {e}") + return np.nan + def generate_error_dataframe(df: pd.DataFrame, datetime_var):