Files
SP2XR/scripts/sp2xr_pipeline.py

500 lines
20 KiB
Python

from __future__ import annotations
import yaml
import signal
import sys
import time
import dask.dataframe as dd
import pandas as pd
import numpy as np
import gc
from sp2xr.helpers import (
parse_args,
load_and_resolve_config,
initialize_cluster,
extract_partitioned_datetimes,
get_time_chunks_from_range,
delete_partition_if_exists,
partition_rowcount,
validate_config_compatibility,
)
from sp2xr.calibration import calibrate_single_particle
from sp2xr.resample_pbp_hk import (
build_dt_summary,
resample_hk_partition,
join_pbp_with_flow,
aggregate_dt,
)
from sp2xr.distribution import (
bin_lims_to_ctrs,
process_histograms,
)
from sp2xr.concentrations import add_concentrations
from sp2xr.schema import (
cast_and_arrow,
CANONICAL_DTYPES,
DEFAULT_FLOAT,
enforce_schema,
)
def main():
args = parse_args()
run_config = load_and_resolve_config(args)
client, cluster = initialize_cluster(run_config)
def handle_sigterm(signum, frame):
print(
f"\nSIGTERM received (signal {signum}), shutting down Dask...", flush=True
)
try:
client.close()
cluster.close()
except Exception as e:
print(f"Error during cleanup: {e}", flush=True)
sys.exit(0)
signal.signal(signal.SIGTERM, handle_sigterm)
try:
"""# -1. chunking
pbp_times = extract_partitioned_datetimes(run_config["input_pbp"])
hk_times = extract_partitioned_datetimes(run_config["input_hk"])
global_start = min(min(pbp_times), min(hk_times))
global_end = max(max(pbp_times), max(hk_times))
chunk_freq = run_config["chunking"]["freq"] # e.g. "6h", "3d"
time_chunks = get_time_chunks_from_range(global_start, global_end, chunk_freq)
"""
# -1. Validate config compatibility
validate_config_compatibility(run_config)
# -2. chunking
pbp_times = extract_partitioned_datetimes(run_config["input_pbp"])
hk_times = extract_partitioned_datetimes(run_config["input_hk"])
# Use config date range if specified, otherwise use data extent
if run_config["chunking"]["start_date"]:
global_start = pd.to_datetime(run_config["chunking"]["start_date"])
print(f"Using config start_date: {global_start}")
else:
global_start = min(min(pbp_times), min(hk_times))
print(f"Using data extent start: {global_start}")
if run_config["chunking"]["end_date"]:
global_end = pd.to_datetime(run_config["chunking"]["end_date"])
print(f"Using config end_date: {global_end}")
else:
global_end = max(max(pbp_times), max(hk_times))
print(f"Using data extent end: {global_end}")
chunk_freq = run_config["chunking"]["freq"] # e.g. "6h", "3d"
time_chunks = get_time_chunks_from_range(global_start, global_end, chunk_freq)
print(
f"Processing {len(time_chunks)} time chunks from {global_start} to {global_end}"
)
# 0. calibration stage --------------------------------------------
instr_config = yaml.safe_load(open(run_config["instr_cfg"]))
# 1. Bins
inc_mass_bin_lims = np.logspace(
np.log10(run_config["histo"]["inc"]["min_mass"]),
np.log10(run_config["histo"]["inc"]["max_mass"]),
run_config["histo"]["inc"]["n_bins"],
)
inc_mass_bin_ctrs = bin_lims_to_ctrs(inc_mass_bin_lims)
scatt_bin_lims = np.logspace(
np.log10(run_config["histo"]["scatt"]["min_D"]),
np.log10(run_config["histo"]["scatt"]["max_D"]),
run_config["histo"]["scatt"]["n_bins"],
)
scatt_bin_ctrs = bin_lims_to_ctrs(scatt_bin_lims)
timelag_bins_lims = np.linspace(
run_config["histo"]["timelag"]["min"],
run_config["histo"]["timelag"]["max"],
run_config["histo"]["timelag"]["n_bins"],
)
timelag_bin_ctrs = bin_lims_to_ctrs(timelag_bins_lims)
for chunk_start, chunk_end in time_chunks:
print(f"Processing: {chunk_start} to {chunk_end}")
pbp_filters = [
("date", ">=", chunk_start.date().strftime("%Y-%m-%d")),
("date", "<", chunk_end.date().strftime("%Y-%m-%d")),
]
if "hour" in run_config["chunking"]["freq"]: # optionally filter by hour
pbp_filters.append(("hour", ">=", chunk_start.hour))
pbp_filters.append(("hour", "<", chunk_end.hour))
# client.restart()
scattered_bins = {
"inc_mass_bin_lims": client.scatter(inc_mass_bin_lims, broadcast=True),
"inc_mass_bin_ctrs": client.scatter(inc_mass_bin_ctrs, broadcast=True),
"scatt_bin_lims": client.scatter(scatt_bin_lims, broadcast=True),
"scatt_bin_ctrs": client.scatter(scatt_bin_ctrs, broadcast=True),
"timelag_bins_lims": client.scatter(timelag_bins_lims, broadcast=True),
"timelag_bin_ctrs": client.scatter(timelag_bin_ctrs, broadcast=True),
}
dask_objects = []
try:
# 2. HK processing --------------------------------------
try:
ddf_hk = dd.read_parquet(
run_config["input_hk"],
engine="pyarrow",
filters=pbp_filters,
calculate_divisions=True,
)
dask_objects.append(ddf_hk)
except (FileNotFoundError, OSError):
print(" → no HK files for this chunk; skipping.")
continue
if ddf_hk.npartitions == 0 or partition_rowcount(ddf_hk) == 0:
print(" → HK frame is empty; skipping.")
continue
ddf_hk = ddf_hk.map_partitions(lambda pdf: pdf.sort_index())
if not ddf_hk.known_divisions:
ddf_hk = ddf_hk.reset_index().set_index( # 'calculated_time' becomes a column
"calculated_time", sorted=False, shuffle="tasks"
) # Dask now infers divisions
# Repartition with consistent scheme for both HK and PBP
target_freq = run_config["repartition"]
target_partition_size = run_config.get("max_partition_size", "200MB")
fallback_partition_size = "100MB"
# Try to determine which method works for both datasets
hk_method_used = None
pbp_method_used = None
# First, try preferred method for HK
if target_freq:
try:
ddf_hk = ddf_hk.repartition(freq=target_freq)
hk_method_used = "freq"
print(f"HK repartitioned successfully using freq={target_freq}")
except Exception as e:
print(f"HK freq repartitioning failed: {e}")
try:
ddf_hk = ddf_hk.repartition(
partition_size=target_partition_size
)
hk_method_used = "size"
print(
f"HK repartitioned with fallback size={target_partition_size}"
)
except Exception as e2:
print(f"HK size repartitioning also failed: {e2}")
hk_method_used = None
else:
try:
ddf_hk = ddf_hk.repartition(
partition_size=target_partition_size
)
hk_method_used = "size"
print(
f"HK repartitioned successfully using size={target_partition_size}"
)
except Exception as e:
print(f"HK repartitioning failed: {e}")
hk_method_used = None
meta = pd.DataFrame(
{
"Sample Flow Controller Read (sccm)": pd.Series(
dtype="float64"
),
"Sample Flow Controller Read (vccm)": pd.Series(
dtype="float64"
),
"date": pd.Series(dtype="str"),
"hour": pd.Series(dtype="int64"),
},
index=pd.DatetimeIndex([]),
)
ddf_hk_dt = ddf_hk.map_partitions(
resample_hk_partition, dt=f"{run_config['dt']}s", meta=meta
)
dask_objects.append(ddf_hk_dt)
# flow_dt = ddf_hk_dt["Sample Flow Controller Read (vccm)"].compute()
flow_series = ddf_hk_dt["Sample Flow Controller Read (vccm)"]
# Add retry logic for computation
max_retries = 3
for attempt in range(max_retries):
try:
flow_dt_future = client.compute(
flow_series, sync=False, retries=2
)
flow_dt = flow_dt_future.result()
flow_dt_scattered = client.scatter(flow_dt, broadcast=True)
break
except Exception as e:
print(f"Attempt {attempt+1} failed for flow computation: {e}")
if attempt == max_retries - 1:
raise
# Wait before retry
import time
time.sleep(5)
# 3. PBP processing --------------------------------------
try:
ddf_raw = dd.read_parquet(
run_config["input_pbp"],
engine="pyarrow",
filters=pbp_filters,
calculate_divisions=True,
)
dask_objects.append(ddf_raw)
except (FileNotFoundError, OSError):
print(" → no PbP files for this chunk; skipping.")
continue
if ddf_raw.npartitions == 0 or partition_rowcount(ddf_raw) == 0:
print(" → PbP frame is empty; skipping.")
continue
ddf_raw = ddf_raw.map_partitions(lambda pdf: pdf.sort_index())
if not ddf_raw.known_divisions:
ddf_raw = ddf_raw.reset_index().set_index( # 'calculated_time' becomes a column
"calculated_time", sorted=False, shuffle="tasks"
) # Dask now infers divisions
# Now try the same for PBP
if target_freq:
try:
ddf_raw = ddf_raw.repartition(freq=target_freq)
pbp_method_used = "freq"
print(
f"PBP repartitioned successfully using freq={target_freq}"
)
except Exception as e:
print(f"PBP freq repartitioning failed: {e}")
try:
ddf_raw = ddf_raw.repartition(
partition_size=target_partition_size
)
pbp_method_used = "size"
print(
f"PBP repartitioned with fallback size={target_partition_size}"
)
except Exception as e2:
print(f"PBP size repartitioning also failed: {e2}")
pbp_method_used = None
else:
try:
ddf_raw = ddf_raw.repartition(
partition_size=target_partition_size
)
pbp_method_used = "size"
print(
f"PBP repartitioned successfully using size={target_partition_size}"
)
except Exception as e:
print(f"PBP repartitioning failed: {e}")
pbp_method_used = None
# Check if both datasets used the SAME method
if (
hk_method_used != pbp_method_used
or hk_method_used is None
or pbp_method_used is None
):
print(
f"Partition methods differ (HK: {hk_method_used}, PBP: {pbp_method_used})"
)
print(
f"Applying consistent fallback partitioning ({fallback_partition_size}) to both datasets"
)
try:
ddf_hk = ddf_hk.repartition(
partition_size=fallback_partition_size
)
ddf_raw = ddf_raw.repartition(
partition_size=fallback_partition_size
)
print(
"Both datasets successfully repartitioned with consistent fallback method"
)
except Exception as e:
print(f"Fallback repartitioning also failed: {e}")
print("Using original partitions - joins may be inefficient")
else:
print(
f"Both datasets successfully using consistent method: {hk_method_used}"
)
ddf_cal = calibrate_single_particle(ddf_raw, instr_config, run_config)
dask_objects.append(ddf_cal)
ddf_pbp_with_flow = join_pbp_with_flow(ddf_cal, flow_dt, run_config)
dask_objects.append(ddf_pbp_with_flow)
delete_partition_if_exists(
output_path=f"{run_config['output']}/pbp_calibrated",
partition_values={
"date": chunk_start.strftime("%Y-%m-%d"),
"hour": chunk_start.hour,
},
)
ddf_pbp_with_flow = enforce_schema(ddf_pbp_with_flow)
# Add retry logic for parquet writes
max_retries = 3
for attempt in range(max_retries):
try:
write_future = ddf_pbp_with_flow.to_parquet(
path=f"{run_config['output']}/pbp_calibrated",
partition_on=run_config["saving_schema"],
engine="pyarrow",
write_index=True,
write_metadata_file=True,
append=True,
schema="infer",
compute=False,
)
write_future.compute(retries=2)
break
except Exception as e:
print(f"Attempt {attempt+1} failed for parquet write: {e}")
if attempt == max_retries - 1:
raise
# Wait before retry and force garbage collection
import time
time.sleep(10)
client.run(gc.collect)
gc.collect()
# ddf_pbp_with_flow = ddf_pbp_with_flow.persist()
# 4. Aggregate PBP ---------------------------------------------
ddf_pbp_dt = ddf_cal.map_partitions(
build_dt_summary,
dt_s=run_config["dt"],
meta=build_dt_summary(ddf_cal._meta),
)
ddf_pbp_hk_dt = aggregate_dt(ddf_pbp_dt, ddf_hk_dt, run_config)
dask_objects.append(ddf_pbp_hk_dt)
# 4. (optional) dt bulk conc --------------------------
if run_config["do_conc"]:
meta_conc = add_concentrations(
ddf_pbp_hk_dt._meta, dt=run_config["dt"]
)
meta_conc = meta_conc.astype(
{
c: CANONICAL_DTYPES.get(c, DEFAULT_FLOAT)
for c in meta_conc.columns
},
copy=False,
).convert_dtypes(dtype_backend="pyarrow")
ddf_conc = ddf_pbp_hk_dt.map_partitions(
add_concentrations, dt=run_config["dt"], meta=meta_conc
).map_partitions(cast_and_arrow, meta=meta_conc)
dask_objects.append(ddf_conc)
idx_target = "datetime64[ns]"
ddf_conc = ddf_conc.map_partitions(
lambda pdf: pdf.set_index(
pdf.index.astype(idx_target, copy=False)
),
meta=ddf_conc._meta,
)
# 2) cast partition columns *before* Dask strips them off
# Keep date as string to avoid Windows path issues with datetime partitions
ddf_conc["date"] = ddf_conc["date"].astype("str")
ddf_conc["hour"] = ddf_conc["hour"].astype("int64")
conc_future = ddf_conc.to_parquet(
f"{run_config['output']}/conc_{run_config['dt']}s",
partition_on=run_config["saving_schema"],
engine="pyarrow",
write_index=True,
write_metadata_file=True,
append=True,
schema="infer",
compute=False,
)
conc_future.compute()
# 5. (optional) dt histograms --------------------------
if any(
[
run_config["do_BC_hist"],
run_config["do_scatt_hist"],
run_config["do_timelag_hist"],
]
):
# Reread the saved data to avoid graph buildup
ddf_pbp_with_flow_fresh = dd.read_parquet(
f"{run_config['output']}/pbp_calibrated",
filters=pbp_filters,
engine="pyarrow",
)
process_histograms(
ddf_pbp_with_flow_fresh,
run_config,
inc_mass_bin_lims,
inc_mass_bin_ctrs,
scatt_bin_lims,
scatt_bin_ctrs,
timelag_bins_lims,
timelag_bin_ctrs,
chunk_start,
client,
)
finally:
# Comprehensive cleanup
try:
# Cancel all dask objects
if dask_objects:
client.cancel(dask_objects)
# Clean up scattered data
if "flow_dt_scattered" in locals():
client.cancel(flow_dt_scattered)
# Delete local references
for obj_name in [
"ddf_hk",
"ddf_hk_dt",
"ddf_raw",
"ddf_cal",
"ddf_pbp_with_flow",
"ddf_pbp_dt",
"ddf_pbp_hk_dt",
"ddf_conc",
"flow_dt",
"flow_dt_scattered",
]:
if obj_name in locals():
del locals()[obj_name]
# Force garbage collection on workers and client
client.run(gc.collect)
gc.collect()
except Exception as cleanup_error:
print(f"Warning: Cleanup error: {cleanup_error}")
finally:
print("Final cleanup...", flush=True)
client.close()
cluster.close()
if __name__ == "__main__":
start_time = time.time()
main()