Extend pipelines/steps/adjust_uncertainty_column_in_nas_file.py to handle list of variables.

This commit is contained in:
2025-05-27 09:53:12 +02:00
parent 38fe2b8774
commit f3f830487e
2 changed files with 103 additions and 49 deletions

View File

@ -21,7 +21,29 @@ import pandas as pd
from dima.instruments.readers.nasa_ames_reader import read_nasa_ames_as_dict
from pipelines.steps.utils import compute_uncertainty_estimate
def main(path_to_data_file, base_column_name):
def main(path_to_data_file, base_column_names : list):
"""Adjust the error or uncertainty columns of data table, where data table is available
in input nas file and by specifying a list of columns to adjust.
Parameters:
------------
path_to_data_file (str) : Path to nas file
base_column_names (list) : list of column names
Raises
------
RuntimeError
_description_
ValueError
_description_
ValueError
_description_
ValueError
_description_
RuntimeError
_description_
"""
if not path_to_data_file.endswith('.nas'):
@ -42,67 +64,79 @@ def main(path_to_data_file, base_column_name):
data_table = dataset['data'] # structured numpy array
df = pd.DataFrame(data_table)
if base_column_name not in df.columns:
raise ValueError(f"Base column '{base_column_name}' not found in dataset.")
if any(col not in df.columns for col in base_column_names):
raise ValueError(f"Base column '{col}' not found in dataset.")
# filter out columns with name starting in 'err_'
base_column_names_cleaned = [col for col in base_column_names if not col.startswith('err_')]
err_column = f"err_{base_column_name}"
if err_column not in df.columns:
raise ValueError(f"Column '{err_column}' not found in dataset.")
# Apply callback to base column
err_index = data_table.dtype.names.index(err_column)
# Read original lines from file
with open(path_to_data_file, 'rb') as file:
raw_lines = file.readlines()
header_length = header_metadata_dict['header_length']
data_table_lines = []
# Iterate through data table lines
cnt = 0
for line_idx in range(len(raw_lines)):
if line_idx >= header_length - 1:
line = raw_lines[line_idx]
fields = list(re.finditer(rb'\S+', line))
for col in base_column_names_cleaned:
data_table_lines = []
base_column_name = col
if err_index < len(fields):
match = fields[err_index]
original_bytes = match.group()
original_str = original_bytes.decode('utf-8')
err_column = f"err_{base_column_name}"
if err_column not in df.columns:
raise ValueError(f"Column '{err_column}' not found in dataset.")
# Skip column header or fill values
clean_original_str = original_str.strip().replace('.', '')
if err_column in original_str:
data_table_lines.append(line)
continue
# Apply callback to base column
# Decimal precision
decimals = len(original_str.split('.')[1]) if '.' in original_str else 0
err_index = data_table.dtype.names.index(err_column)
# Iterate through data table lines
cnt = 0
for line_idx in range(len(raw_lines)):
if line_idx >= header_length - 1:
line = raw_lines[line_idx]
fields = list(re.finditer(rb'\S+', line))
try:
original_err = float(original_str)
if not (clean_original_str and all(c == '9' for c in clean_original_str)):
additional_term = df.loc[cnt, base_column_name]
updated_value = compute_uncertainty_estimate(additional_term, original_err)
else: # if original value is missing, then keep the same
updated_value = original_err
except Exception as e:
raise RuntimeError(f"Error calculating updated value on line {line_idx}: {e}")
if err_index < len(fields):
match = fields[err_index]
original_bytes = match.group()
original_str = original_bytes.decode('utf-8')
# Preserve width and precision
start, end = match.span()
width = end - start
formatted_str = f"{updated_value:.{decimals}f}"
# Skip column header or fill values
clean_original_str = original_str.strip().replace('.', '')
if err_column in original_str:
data_table_lines.append(line)
continue
if len(formatted_str) > width:
print(f"Warning: formatted value '{formatted_str}' too wide for field of width {width} at line {line_idx}. Value may be truncated.")
# Decimal precision
decimals = len(original_str.split('.')[1]) if '.' in original_str else 0
formatted_bytes = formatted_str.rjust(width).encode('utf-8')
new_line = line[:start] + formatted_bytes + line[end:]
data_table_lines.append(new_line)
cnt += 1
try:
original_err = float(original_str)
if not (clean_original_str and all(c == '9' for c in clean_original_str)):
additional_term = df.loc[cnt, base_column_name]
updated_value = compute_uncertainty_estimate(additional_term, original_err)
else: # if original value is missing, then keep the same
updated_value = original_err
except Exception as e:
raise RuntimeError(f"Error calculating updated value on line {line_idx}: {e}")
# Preserve width and precision
start, end = match.span()
width = end - start
formatted_str = f"{updated_value:.{decimals}f}"
if len(formatted_str) > width:
print(f"Warning: formatted value '{formatted_str}' too wide for field of width {width} at line {line_idx}. Value may be truncated.")
formatted_bytes = formatted_str.rjust(width).encode('utf-8')
new_line = line[:start] + formatted_bytes + line[end:]
data_table_lines.append(new_line)
cnt += 1
# update raw lines
for line_idx in range(header_length - 1, len(raw_lines)):
raw_lines[line_idx] = data_table_lines[line_idx - header_length + 1]
# Reconstruct the file
processed_lines = (

View File

@ -93,8 +93,28 @@ def generate_missing_value_code(max_val, num_decimals):
return missing_code
def compute_uncertainty_estimate(x,x_err):
return ((0.5*x_err)**2+(0.5*x)**2)**0.5
import math
import numpy as np
def compute_uncertainty_estimate(x, x_err):
"""
Computes uncertainty estimate: sqrt((0.5 * x_err)^2 + (0.5 * x)^2)
for scalar inputs. Prints errors if inputs are invalid.
"""
try:
x = float(x)
x_err = float(x_err)
if math.isnan(x) or math.isnan(x_err):
print(f"Warning: One or both inputs are NaN -> x: {x}, x_err: {x_err}")
return np.nan
return math.sqrt((0.5 * x_err)**2 + (0.5 * x)**2)
except (ValueError, TypeError) as e:
print(f"Error computing uncertainty for x: {x}, x_err: {x_err} -> {e}")
return np.nan
def generate_error_dataframe(df: pd.DataFrame, datetime_var):