Refine function. Abstract out load parameters as function to simplify code.

This commit is contained in:
2025-02-28 17:18:13 +01:00
parent 6890377134
commit a9533ae3e8

View File

@ -34,8 +34,39 @@ def compute_cpc_flags():
return 0
def load_parameters(flag_type):
# Implicit input
if flag_type == 'diagnostics':
flag_type_file = os.path.normpath(os.path.join(projectPath,'pipelines/params/validity_thresholds.yaml'))
error_message = f"Error accessing validation thresholds at: {flag_type_file}"
elif flag_type == 'species':
flag_type_file = os.path.normpath(os.path.join(projectPath,'pipelines/params/calibration_params.yaml'))
error_message = f"Error accessing calibration parameters at: {flag_type_file}"
output_dict = {}
try:
with open(flag_type_file, 'r') as stream:
output_dict = yaml.load(stream, Loader=yaml.FullLoader)
except Exception as e:
print(error_message)
return {}
# Get name of the specifies to flag based on diagnostics and manual flags
#path_to_calib_params = os.path.normpath(os.path.join(projectPath,'pipelines/params/calibration_params.yaml'))
#if not os.path.exists(path_to_calib_params):
# raise FileNotFoundError(f'Calibration params file:{path_to_calib_params}')
#with open(path_to_calib_params,'r') as stream:
# calib_param_dict = yaml.safe_load(stream)
return output_dict
#def compute_diagnostic_variable_flags(data_table, validity_thresholds_dict):
def generate_diagnostic_flags(data_table):
def generate_diagnostic_flags(data_table, validity_thresholds_dict):
"""
Create indicator variables that check whether a particular diagnostic variable is within
pre-specified/acceptable limits, which are defined by `variable_limits`.
@ -55,17 +86,7 @@ def generate_diagnostic_flags(data_table):
and additional indicator variables, representing flags.
"""
# Implicit input
validity_thersholds_file = 'pipelines/params/validity_thresholds.yaml'
validity_thresholds_dict = {}
try:
with open(validity_thersholds_file, 'r') as stream:
validity_thresholds_dict = yaml.load(stream, Loader=yaml.FullLoader)
except Exception as e:
print(f"Error accessing validation thresholds at: {validity_thersholds_file}")
return 1
# Define binary to ebas flag code map
# Specify labeling function to create numbered EBAS flags. It maps a column indicator,
@ -111,7 +132,7 @@ def generate_diagnostic_flags(data_table):
return new_data_table
# TODO: abstract some of the code in the command line main
def generate_species_flags(data_table : pd.DataFrame):
def generate_species_flags(data_table : pd.DataFrame, calib_param_dict : dict):
"""Generate flags for columns in data_table based on flags_table
@ -121,16 +142,12 @@ def generate_species_flags(data_table : pd.DataFrame):
_description_
"""
# Get name of the specifies to flag based on diagnostics and manual flags
path_to_calib_params = os.path.normpath(os.path.join(projectPath,'pipelines/params/calibration_params.yaml'))
if not os.path.exists(path_to_calib_params):
raise FileNotFoundError(f'Calibration params file:{path_to_calib_params}')
with open(path_to_calib_params,'r') as stream:
calib_param_dict = yaml.safe_load(stream)
predefined_species = calib_param_dict['variables']['species']
predefined_species = calib_param_dict.get('variables',{}).get('species',[])
if not predefined_species:
raise RuntimeError("Undefined species. Input argument 'calib_param_dict' must contain a 'variables' : {'species' : ['example1',...,'examplen']} ")
print('Predefined_species:', predefined_species)
@ -164,7 +181,7 @@ def generate_species_flags(data_table : pd.DataFrame):
if (not datetime_var == var) and (var in predefined_species):
renaming_map[var] = f'numflag_{var}'
print(f'numflag_{var}')
data_table[var] = pd.Series(flags_table['numflag_any_diagnostic_flag'].values)
data_table[var] = pd.Series(flags_table['numflag_any_diagnostic_flag'].values,dtype=np.int64)
print(renaming_map)
data_table.rename(columns=renaming_map, inplace=True)
else:
@ -223,7 +240,7 @@ def generate_species_flags(data_table : pd.DataFrame):
return data_table.loc[:,numflag_columns]
return data_table.loc[:,[datetime_var] + numflag_columns]
@ -251,7 +268,7 @@ def reconcile_flags(data_table, flag_code, t1_idx, t2_idx, numflag_columns):
new_values = np.where(current_ranks < flag_code_rank, flag_code, sub_table.values)
# Update the dataframe with the new values
data_table.loc[t1_idx:t2_idx, numflag_columns] = new_values
data_table.loc[t1_idx:t2_idx, numflag_columns] = new_values.astype(np.int64)
return data_table
@ -372,15 +389,18 @@ if __name__ == '__main__':
# Compute diagnostic flags based on validity thresholds defined in configuration_file_dict
if flag_type == 'diagnostics':
flags_table = generate_diagnostic_flags(data_table)
validity_thresholds_dict = load_parameters(flag_type)
flags_table = generate_diagnostic_flags(data_table, validity_thresholds_dict)
if flag_type == 'species':
flags_table = generate_species_flags(data_table)
if flag_type == 'species':
calib_param_dict = load_parameters(flag_type)
flags_table = generate_species_flags(data_table,calib_param_dict)
metadata = {'actris_level' : 1,
'processing_script': processingScriptRelPath.replace(os.sep,'/'),
'processing_date' : utils.created_at(),
'flag_type' : flag_type
'flag_type' : flag_type,
'datetime_var': datetime_var
}
# Save output tables to csv file and save/or update data lineage record