diff --git a/scripts/sp2xr_pipeline.py b/scripts/sp2xr_pipeline.py index 186820e..b03f66a 100644 --- a/scripts/sp2xr_pipeline.py +++ b/scripts/sp2xr_pipeline.py @@ -7,6 +7,7 @@ import time import dask.dataframe as dd import pandas as pd import numpy as np +import gc from sp2xr.helpers import ( parse_args, load_and_resolve_config, @@ -25,8 +26,7 @@ from sp2xr.resample_pbp_hk import ( ) from sp2xr.distribution import ( bin_lims_to_ctrs, - process_hist_and_dist_partition, - make_hist_meta, + process_histograms, ) from sp2xr.concentrations import add_concentrations from sp2xr.schema import ( @@ -89,6 +89,16 @@ def main(): ) timelag_bin_ctrs = bin_lims_to_ctrs(timelag_bins_lims) + # Scatter these arrays once to avoid repeated serialization + scattered_bins = { + "inc_mass_bin_lims": client.scatter(inc_mass_bin_lims, broadcast=True), + "inc_mass_bin_ctrs": client.scatter(inc_mass_bin_ctrs, broadcast=True), + "scatt_bin_lims": client.scatter(scatt_bin_lims, broadcast=True), + "scatt_bin_ctrs": client.scatter(scatt_bin_ctrs, broadcast=True), + "timelag_bins_lims": client.scatter(timelag_bins_lims, broadcast=True), + "timelag_bin_ctrs": client.scatter(timelag_bin_ctrs, broadcast=True), + } + for chunk_start, chunk_end in time_chunks: print(f"Processing: {chunk_start} to {chunk_end}") @@ -100,127 +110,362 @@ def main(): pbp_filters.append(("hour", ">=", chunk_start.hour)) pbp_filters.append(("hour", "<", chunk_end.hour)) - # 2. HK processing -------------------------------------- + client.restart() + + scattered_bins = { + "inc_mass_bin_lims": client.scatter(inc_mass_bin_lims, broadcast=True), + "inc_mass_bin_ctrs": client.scatter(inc_mass_bin_ctrs, broadcast=True), + "scatt_bin_lims": client.scatter(scatt_bin_lims, broadcast=True), + "scatt_bin_ctrs": client.scatter(scatt_bin_ctrs, broadcast=True), + "timelag_bins_lims": client.scatter(timelag_bins_lims, broadcast=True), + "timelag_bin_ctrs": client.scatter(timelag_bin_ctrs, broadcast=True), + } + + dask_objects = [] try: - ddf_hk = dd.read_parquet( - run_config["input_hk"], - engine="pyarrow", - filters=pbp_filters, - calculate_divisions=True, - ) - except (FileNotFoundError, OSError): - print(" → no HK files for this chunk; skipping.") - continue + # 2. HK processing -------------------------------------- + try: + ddf_hk = dd.read_parquet( + run_config["input_hk"], + engine="pyarrow", + filters=pbp_filters, + calculate_divisions=True, + ) + dask_objects.append(ddf_hk) + except (FileNotFoundError, OSError): + print(" → no HK files for this chunk; skipping.") + continue - if ddf_hk.npartitions == 0 or partition_rowcount(ddf_hk) == 0: - print(" → HK frame is empty; skipping.") - continue - ddf_hk = ddf_hk.map_partitions(lambda pdf: pdf.sort_index()) - if not ddf_hk.known_divisions: - ddf_hk = ddf_hk.reset_index().set_index( # 'calculated_time' becomes a column - "calculated_time", sorted=False, shuffle="tasks" - ) # Dask now infers divisions - ddf_hk = ddf_hk.repartition(freq="1h") - meta = pd.DataFrame( - { - "Sample Flow Controller Read (sccm)": pd.Series(dtype="float64"), - "Sample Flow Controller Read (vccm)": pd.Series(dtype="float64"), - "date": pd.Series(dtype="datetime64[ns]"), - "hour": pd.Series(dtype="int64"), - }, - index=pd.DatetimeIndex([]), - ) - ddf_hk_dt = ddf_hk.map_partitions( - resample_hk_partition, dt=f"{run_config['dt']}s", meta=meta - ) + if ddf_hk.npartitions == 0 or partition_rowcount(ddf_hk) == 0: + print(" → HK frame is empty; skipping.") + continue - flow_dt = ddf_hk_dt["Sample Flow Controller Read (vccm)"].compute() - - # 3. PBP processing -------------------------------------- - try: - ddf_raw = dd.read_parquet( - run_config["input_pbp"], - engine="pyarrow", - filters=pbp_filters, - calculate_divisions=True, - ) - except (FileNotFoundError, OSError): - print(" → no PbP files for this chunk; skipping.") - continue - - if ddf_raw.npartitions == 0 or partition_rowcount(ddf_raw) == 0: - print(" → PbP frame is empty; skipping.") - continue - - ddf_raw = ddf_raw.map_partitions(lambda pdf: pdf.sort_index()) - if not ddf_raw.known_divisions: - ddf_raw = ddf_raw.reset_index().set_index( # 'calculated_time' becomes a column - "calculated_time", sorted=False, shuffle="tasks" - ) # Dask now infers divisions - ddf_raw = ddf_raw.repartition(freq="1h") - - ddf_cal = calibrate_single_particle(ddf_raw, instr_config, run_config) - - ddf_pbp_with_flow = join_pbp_with_flow(ddf_cal, flow_dt, run_config) - - delete_partition_if_exists( - output_path=f"{run_config['output']}/pbp_calibrated", - partition_values={ - "date": chunk_start.strftime("%Y-%m-%d 00:00:00"), - "hour": chunk_start.hour, - }, - ) - ddf_pbp_with_flow = enforce_schema(ddf_pbp_with_flow) - - ddf_pbp_with_flow.to_parquet( - path=f"{run_config['output']}/pbp_calibrated", - partition_on=["date", "hour"], - engine="pyarrow", - write_index=True, - write_metadata_file=True, - append=True, - schema="infer", - ) - - # 4. Aggregate PBP --------------------------------------------- - ddf_pbp_dt = ddf_cal.map_partitions( - build_dt_summary, - dt_s=run_config["dt"], - meta=build_dt_summary(ddf_cal._meta), - ) - - ddf_pbp_hk_dt = aggregate_dt(ddf_pbp_dt, ddf_hk_dt, run_config) - - # 4. (optional) dt bulk conc -------------------------- - if run_config["do_conc"]: - meta_conc = add_concentrations(ddf_pbp_hk_dt._meta, dt=run_config["dt"]) - meta_conc = meta_conc.astype( + ddf_hk = ddf_hk.map_partitions(lambda pdf: pdf.sort_index()) + if not ddf_hk.known_divisions: + ddf_hk = ddf_hk.reset_index().set_index( # 'calculated_time' becomes a column + "calculated_time", sorted=False, shuffle="tasks" + ) # Dask now infers divisions + ddf_hk = ddf_hk.repartition(freq="1h") + meta = pd.DataFrame( { - c: CANONICAL_DTYPES.get(c, DEFAULT_FLOAT) - for c in meta_conc.columns + "Sample Flow Controller Read (sccm)": pd.Series( + dtype="float64" + ), + "Sample Flow Controller Read (vccm)": pd.Series( + dtype="float64" + ), + "date": pd.Series(dtype="datetime64[ns]"), + "hour": pd.Series(dtype="int64"), }, - copy=False, - ).convert_dtypes(dtype_backend="pyarrow") + index=pd.DatetimeIndex([]), + ) + ddf_hk_dt = ddf_hk.map_partitions( + resample_hk_partition, dt=f"{run_config['dt']}s", meta=meta + ) + dask_objects.append(ddf_hk_dt) - ddf_conc = ddf_pbp_hk_dt.map_partitions( - add_concentrations, dt=run_config["dt"], meta=meta_conc - ).map_partitions(cast_and_arrow, meta=meta_conc) + # flow_dt = ddf_hk_dt["Sample Flow Controller Read (vccm)"].compute() + flow_series = ddf_hk_dt["Sample Flow Controller Read (vccm)"] + flow_dt_future = client.compute(flow_series, sync=False) + flow_dt = flow_dt_future.result() + flow_dt_scattered = client.scatter(flow_dt, broadcast=True) + + # 3. PBP processing -------------------------------------- + try: + ddf_raw = dd.read_parquet( + run_config["input_pbp"], + engine="pyarrow", + filters=pbp_filters, + calculate_divisions=True, + ) + dask_objects.append(ddf_raw) + except (FileNotFoundError, OSError): + print(" → no PbP files for this chunk; skipping.") + continue + + if ddf_raw.npartitions == 0 or partition_rowcount(ddf_raw) == 0: + print(" → PbP frame is empty; skipping.") + continue + + ddf_raw = ddf_raw.map_partitions(lambda pdf: pdf.sort_index()) + if not ddf_raw.known_divisions: + ddf_raw = ddf_raw.reset_index().set_index( # 'calculated_time' becomes a column + "calculated_time", sorted=False, shuffle="tasks" + ) # Dask now infers divisions + ddf_raw = ddf_raw.repartition(freq="1h") + + ddf_cal = calibrate_single_particle(ddf_raw, instr_config, run_config) + dask_objects.append(ddf_cal) + + ddf_pbp_with_flow = join_pbp_with_flow(ddf_cal, flow_dt, run_config) + dask_objects.append(ddf_pbp_with_flow) + + delete_partition_if_exists( + output_path=f"{run_config['output']}/pbp_calibrated", + partition_values={ + "date": chunk_start.strftime("%Y-%m-%d 00:00:00"), + "hour": chunk_start.hour, + }, + ) + ddf_pbp_with_flow = enforce_schema(ddf_pbp_with_flow) + + write_future = ddf_pbp_with_flow.to_parquet( + path=f"{run_config['output']}/pbp_calibrated", + partition_on=["date", "hour"], + engine="pyarrow", + write_index=True, + write_metadata_file=True, + append=True, + schema="infer", + compute=False, + ) + write_future.compute() + # ddf_pbp_with_flow = ddf_pbp_with_flow.persist() + + # 4. Aggregate PBP --------------------------------------------- + ddf_pbp_dt = ddf_cal.map_partitions( + build_dt_summary, + dt_s=run_config["dt"], + meta=build_dt_summary(ddf_cal._meta), + ) + + ddf_pbp_hk_dt = aggregate_dt(ddf_pbp_dt, ddf_hk_dt, run_config) + dask_objects.append(ddf_pbp_hk_dt) + + # 4. (optional) dt bulk conc -------------------------- + if run_config["do_conc"]: + meta_conc = add_concentrations( + ddf_pbp_hk_dt._meta, dt=run_config["dt"] + ) + meta_conc = meta_conc.astype( + { + c: CANONICAL_DTYPES.get(c, DEFAULT_FLOAT) + for c in meta_conc.columns + }, + copy=False, + ).convert_dtypes(dtype_backend="pyarrow") + + ddf_conc = ddf_pbp_hk_dt.map_partitions( + add_concentrations, dt=run_config["dt"], meta=meta_conc + ).map_partitions(cast_and_arrow, meta=meta_conc) + dask_objects.append(ddf_conc) + + idx_target = "datetime64[ns]" + ddf_conc = ddf_conc.map_partitions( + lambda pdf: pdf.set_index( + pdf.index.astype(idx_target, copy=False) + ), + meta=ddf_conc._meta, + ) + + # 2) cast partition columns *before* Dask strips them off + ddf_conc["date"] = dd.to_datetime(ddf_conc["date"]).astype( + "datetime64[ns]" + ) + ddf_conc["hour"] = ddf_conc["hour"].astype("int64") + + conc_future = ddf_conc.to_parquet( + f"{run_config['output']}/conc_{run_config['dt']}s", + partition_on=["date", "hour"], + engine="pyarrow", + write_index=True, + write_metadata_file=True, + append=True, + schema="infer", + compute=False, + ) + conc_future.compute() + + # 5. (optional) dt histograms -------------------------- + if any( + [ + run_config["do_BC_hist"], + run_config["do_scatt_hist"], + run_config["do_timelag_hist"], + ] + ): + # Reread the saved data to avoid graph buildup + ddf_pbp_with_flow_fresh = dd.read_parquet( + f"{run_config['output']}/pbp_calibrated", + filters=pbp_filters, + engine="pyarrow", + ) + + process_histograms( + ddf_pbp_with_flow_fresh, + run_config, + inc_mass_bin_lims, + inc_mass_bin_ctrs, + scatt_bin_lims, + scatt_bin_ctrs, + timelag_bins_lims, + timelag_bin_ctrs, + chunk_start, + client, + ) + + """if run_config["do_BC_hist"]: + print("Computing BC distributions...") + # --- Mass histogram + BC_hist_configs = [ + {"flag_col": None, "flag_value": None}, + {"flag_col": "cnts_thin", "flag_value": 1}, + {"flag_col": "cnts_thin_noScatt", "flag_value": 1}, + {"flag_col": "cnts_thick", "flag_value": 1}, + {"flag_col": "cnts_thick_sat", "flag_value": 1}, + {"flag_col": "cnts_thin_sat", "flag_value": 1}, + {"flag_col": "cnts_ntl_sat", "flag_value": 1}, + {"flag_col": "cnts_ntl", "flag_value": 1}, + { + "flag_col": "cnts_extreme_positive_timelag", + "flag_value": 1, + }, + { + "flag_col": "cnts_thin_low_inc_scatt_ratio", + "flag_value": 1, + }, + {"flag_col": "cnts_thin_total", "flag_value": 1}, + {"flag_col": "cnts_thick_total", "flag_value": 1}, + {"flag_col": "cnts_unclassified", "flag_value": 1}, + ] + + results = [] + + for cfg_hist in BC_hist_configs[:2]: + meta_hist = ( + make_hist_meta( + bin_ctrs=inc_mass_bin_ctrs, + kind="mass", + flag_col=cfg_hist["flag_col"], + rho_eff=run_config["rho_eff"], + BC_type=run_config["BC_type"], + ) + .astype(DEFAULT_FLOAT, copy=False) + .convert_dtypes(dtype_backend="pyarrow") + ) + ddf_out = ddf_pbp_with_flow.map_partitions( + process_hist_and_dist_partition, + col="BC mass within range", + flag_col=cfg_hist["flag_col"], + flag_value=cfg_hist["flag_value"], + bin_lims=inc_mass_bin_lims, + bin_ctrs=inc_mass_bin_ctrs, + dt=run_config["dt"], + calculate_conc=True, + flow=None, + rho_eff=run_config["rho_eff"], + BC_type=run_config["BC_type"], + #t=1, + meta=meta_hist, + ).map_partitions(cast_and_arrow, meta=meta_hist) + results.append(ddf_out) + + # --- Scattering histogram + if run_config["do_scatt_hist"]: + print("Computing scattering distribution...") + meta_hist = ( + make_hist_meta( + bin_ctrs=scatt_bin_ctrs, + kind="scatt", + flag_col=None, + rho_eff=None, + BC_type=None, + ) + .astype(DEFAULT_FLOAT, copy=False) + .convert_dtypes(dtype_backend="pyarrow") + ) + ddf_scatt = ddf_pbp_with_flow.map_partitions( + process_hist_and_dist_partition, + col="Opt diam scatt only", + flag_col=None, + flag_value=None, + bin_lims=scatt_bin_lims, + bin_ctrs=scatt_bin_ctrs, + dt=run_config["dt"], + calculate_conc=True, + flow=None, + rho_eff=None, + BC_type=None, + #t=1, + meta=meta_hist, + ).map_partitions(cast_and_arrow, meta=meta_hist) + results.append(ddf_scatt) + + # --- Timelag histogram + if run_config["do_timelag_hist"]: + print("Computing time delay distribution...") + mass_bins = ( + ddf_pbp_with_flow[["BC mass bin"]] + .compute() + .astype("Int64") + .drop_duplicates() + .dropna() + ) + + for idx, mass_bin in enumerate(mass_bins[:1]): + ddf_bin = ddf_pbp_with_flow[ + ddf_pbp_with_flow["BC mass bin"] == mass_bin + ] + + name_prefix = f"dNdlogDmev_{inc_mass_bin_ctrs[idx]:.2f}_timelag" + + meta_hist = make_hist_meta( + bin_ctrs=timelag_bin_ctrs, + kind="timelag", + flag_col="cnts_particles_for_tl_dist", + name_prefix=name_prefix, + rho_eff=None, + BC_type=None, + ) + + tl_ddf = ddf_bin.map_partitions( + process_hist_and_dist_partition, + col="time_lag", + flag_col="cnts_particles_for_tl_dist", + flag_value=1, + bin_lims=timelag_bins_lims, + bin_ctrs=timelag_bin_ctrs, + dt=run_config["dt"], + calculate_conc=True, + flow=None, + rho_eff=None, + BC_type=None, + #t=1, + name_prefix=name_prefix, + meta=meta_hist, + ) + + # + tl_ddf = tl_ddf.map_partitions(cast_and_arrow, meta=meta_hist) + + results.append(tl_ddf) + # --- Merge all hists + + merged_ddf = dd.concat(results, axis=1, interleave_partitions=True) idx_target = "datetime64[ns]" - ddf_conc = ddf_conc.map_partitions( + merged_ddf = merged_ddf.map_partitions( lambda pdf: pdf.set_index(pdf.index.astype(idx_target, copy=False)), - meta=ddf_conc._meta, + meta=merged_ddf._meta, ) - # 2) cast partition columns *before* Dask strips them off - ddf_conc["date"] = dd.to_datetime(ddf_conc["date"]).astype( - "datetime64[ns]" + index_as_dt = dd.to_datetime(merged_ddf.index.to_series()) + merged_ddf["date"] = index_as_dt.map_partitions( + lambda s: s.dt.normalize(), meta=("date", "datetime64[ns]") ) - ddf_conc["hour"] = ddf_conc["hour"].astype("int64") - ddf_conc.to_parquet( - f"{run_config['output']}/conc_{run_config['dt']}s", - partition_on=["date", "hour"], + # --- Save hists to parquet + + delete_partition_if_exists( + output_path=f"{run_config['output']}/hists_{run_config['dt']}s", + partition_values={ + "date": chunk_start.strftime("%Y-%m-%d"), + "hour": chunk_start.hour, + }, + ) + merged_ddf.to_parquet( + f"{run_config['output']}/hists_{run_config['dt']}s", + partition_on=["date"], engine="pyarrow", write_index=True, write_metadata_file=True, @@ -228,173 +473,44 @@ def main(): schema="infer", ) - # 5. (optional) dt histograms -------------------------- + client.cancel([ddf_pbp_with_flow, ddf_hk, + ddf_hk_dt, ddf_pbp_dt, ddf_pbp_hk_dt]) + del ddf_pbp_with_flow + client.run(gc.collect) # workers + gc.collect() # client""" + finally: + # Comprehensive cleanup + try: + # Cancel all dask objects + if dask_objects: + client.cancel(dask_objects) - if run_config["do_BC_hist"]: - print("Computing BC distributions...") - # --- Mass histogram - BC_hist_configs = [ - {"flag_col": None, "flag_value": None}, - {"flag_col": "cnts_thin", "flag_value": 1}, - {"flag_col": "cnts_thin_noScatt", "flag_value": 1}, - {"flag_col": "cnts_thick", "flag_value": 1}, - {"flag_col": "cnts_thick_sat", "flag_value": 1}, - {"flag_col": "cnts_thin_sat", "flag_value": 1}, - {"flag_col": "cnts_ntl_sat", "flag_value": 1}, - {"flag_col": "cnts_ntl", "flag_value": 1}, - { - "flag_col": "cnts_extreme_positive_timelag", - "flag_value": 1, - }, - { - "flag_col": "cnts_thin_low_inc_scatt_ratio", - "flag_value": 1, - }, - {"flag_col": "cnts_thin_total", "flag_value": 1}, - {"flag_col": "cnts_thick_total", "flag_value": 1}, - {"flag_col": "cnts_unclassified", "flag_value": 1}, - ] + # Clean up scattered data + if "flow_dt_scattered" in locals(): + client.cancel(flow_dt_scattered) - results = [] + # Delete local references + for obj_name in [ + "ddf_hk", + "ddf_hk_dt", + "ddf_raw", + "ddf_cal", + "ddf_pbp_with_flow", + "ddf_pbp_dt", + "ddf_pbp_hk_dt", + "ddf_conc", + "flow_dt", + "flow_dt_scattered", + ]: + if obj_name in locals(): + del locals()[obj_name] - for cfg_hist in BC_hist_configs[:2]: - meta_hist = ( - make_hist_meta( - bin_ctrs=inc_mass_bin_ctrs, - kind="mass", - flag_col=cfg_hist["flag_col"], - rho_eff=run_config["rho_eff"], - BC_type=run_config["BC_type"], - ) - .astype(DEFAULT_FLOAT, copy=False) - .convert_dtypes(dtype_backend="pyarrow") - ) - ddf_out = ddf_pbp_with_flow.map_partitions( - process_hist_and_dist_partition, - col="BC mass within range", - flag_col=cfg_hist["flag_col"], - flag_value=cfg_hist["flag_value"], - bin_lims=inc_mass_bin_lims, - bin_ctrs=inc_mass_bin_ctrs, - dt=run_config["dt"], - calculate_conc=True, - flow=None, - rho_eff=run_config["rho_eff"], - BC_type=run_config["BC_type"], - t=1, - meta=meta_hist, - ).map_partitions(cast_and_arrow, meta=meta_hist) - results.append(ddf_out) + # Force garbage collection on workers and client + client.run(gc.collect) + gc.collect() - # --- Scattering histogram - if run_config["do_scatt_hist"]: - print("Computing scattering distribution...") - meta_hist = ( - make_hist_meta( - bin_ctrs=scatt_bin_ctrs, - kind="scatt", - flag_col=None, - rho_eff=None, - BC_type=None, - ) - .astype(DEFAULT_FLOAT, copy=False) - .convert_dtypes(dtype_backend="pyarrow") - ) - ddf_scatt = ddf_pbp_with_flow.map_partitions( - process_hist_and_dist_partition, - col="Opt diam scatt only", - flag_col=None, - flag_value=None, - bin_lims=scatt_bin_lims, - bin_ctrs=scatt_bin_ctrs, - dt=run_config["dt"], - calculate_conc=True, - flow=None, - rho_eff=None, - BC_type=None, - t=1, - meta=meta_hist, - ).map_partitions(cast_and_arrow, meta=meta_hist) - results.append(ddf_scatt) - - # --- Timelag histogram - if run_config["do_timelag_hist"]: - print("Computing time delay distribution...") - mass_bins = ( - ddf_pbp_with_flow[["BC mass bin"]] - .compute() - .astype("Int64") - .drop_duplicates() - .dropna() - ) - - for idx, mass_bin in enumerate(mass_bins[:1]): - ddf_bin = ddf_pbp_with_flow[ - ddf_pbp_with_flow["BC mass bin"] == mass_bin - ] - - name_prefix = f"dNdlogDmev_{inc_mass_bin_ctrs[idx]:.2f}_timelag" - - meta_hist = make_hist_meta( - bin_ctrs=timelag_bin_ctrs, - kind="timelag", - flag_col="cnts_particles_for_tl_dist", - name_prefix=name_prefix, - rho_eff=None, - BC_type=None, - ) - - tl_ddf = ddf_bin.map_partitions( - process_hist_and_dist_partition, - col="time_lag", - flag_col="cnts_particles_for_tl_dist", - flag_value=1, - bin_lims=timelag_bins_lims, - bin_ctrs=timelag_bin_ctrs, - dt=run_config["dt"], - calculate_conc=True, - flow=None, - rho_eff=None, - BC_type=None, - t=1, - name_prefix=name_prefix, - meta=meta_hist, - ) - - results.append(tl_ddf) - - # --- Merge all hists - merged_ddf = dd.concat(results, axis=1, interleave_partitions=True) - - idx_target = "datetime64[ns]" - merged_ddf = merged_ddf.map_partitions( - lambda pdf: pdf.set_index(pdf.index.astype(idx_target, copy=False)), - meta=merged_ddf._meta, - ) - - index_as_dt = dd.to_datetime(merged_ddf.index.to_series()) - merged_ddf["date"] = index_as_dt.map_partitions( - lambda s: s.dt.normalize(), meta=("date", "datetime64[ns]") - ) - - # --- Save hists to parquet - - delete_partition_if_exists( - output_path=f"{run_config['output']}/hists_{run_config['dt']}s", - partition_values={ - "date": chunk_start.strftime("%Y-%m-%d"), - "hour": chunk_start.hour, - }, - ) - merged_ddf.to_parquet( - f"{run_config['output']}/hists_{run_config['dt']}s", - partition_on=["date"], - engine="pyarrow", - write_index=True, - write_metadata_file=True, - append=True, - schema="infer", - ) + except Exception as cleanup_error: + print(f"Warning: Cleanup error: {cleanup_error}") finally: print("Final cleanup...", flush=True) client.close() diff --git a/src/sp2xr/distribution.py b/src/sp2xr/distribution.py index fc0f956..800ef3f 100644 --- a/src/sp2xr/distribution.py +++ b/src/sp2xr/distribution.py @@ -3,8 +3,12 @@ import numpy as np import pandas as pd from typing import Tuple from typing import Optional +import gc +import dask.dataframe as dd from dask.dataframe.utils import make_meta from .calibration import BC_mass_to_diam +from sp2xr.helpers import delete_partition_if_exists +from sp2xr.schema import cast_and_arrow, DEFAULT_FLOAT def make_bin_arrays( @@ -214,11 +218,12 @@ def process_hist_and_dist_partition( flow=None, # kept for API compatibility; not used rho_eff: Optional[float] = 1800, BC_type: Optional[str] = "", - t: float = 1.0, + # t: float = 1.0, name_prefix: Optional[str] = None, # <-- NEW: force exact output name prefix ): # normalize dt -> "Xs" dt_str = f"{dt}s" if isinstance(dt, (int, float)) else str(dt) + t_sec: float = pd.to_timedelta(dt_str).total_seconds() # filter by flag (when provided) if flag_col and flag_value is not None: @@ -281,7 +286,7 @@ def process_hist_and_dist_partition( # compute number concentration (per-log-bin) # last column is flow; left part are counts per bin - dNd = counts2numConc(inc_hist_flow.iloc[:, :-1], inc_hist_flow.iloc[:, -1], t=t) + dNd = counts2numConc(inc_hist_flow.iloc[:, :-1], inc_hist_flow.iloc[:, -1], t=t_sec) # choose the abscissa used for dlog (Dmev for BC, or bin_ctrs otherwise) if rho_eff is not None and BC_type is not None: @@ -370,3 +375,218 @@ def make_hist_meta( index=pd.DatetimeIndex([], name="calculated_time"), ) return make_meta(empty_df) + + +def process_histograms( + ddf_pbp_with_flow, + run_config, + inc_mass_bin_lims, + inc_mass_bin_ctrs, + scatt_bin_lims, + scatt_bin_ctrs, + timelag_bins_lims, + timelag_bin_ctrs, + chunk_start, + client, +): + """Separate function to process histograms and avoid graph buildup""" + + # results = [] + computed_results = [] + + if run_config["do_BC_hist"]: + print("Computing BC distributions...") + # --- Mass histogram + BC_hist_configs = [ + {"flag_col": None, "flag_value": None}, + {"flag_col": "cnts_thin", "flag_value": 1}, + {"flag_col": "cnts_thin_noScatt", "flag_value": 1}, + {"flag_col": "cnts_thick", "flag_value": 1}, + {"flag_col": "cnts_thick_sat", "flag_value": 1}, + {"flag_col": "cnts_thin_sat", "flag_value": 1}, + {"flag_col": "cnts_ntl_sat", "flag_value": 1}, + {"flag_col": "cnts_ntl", "flag_value": 1}, + { + "flag_col": "cnts_extreme_positive_timelag", + "flag_value": 1, + }, + { + "flag_col": "cnts_thin_low_inc_scatt_ratio", + "flag_value": 1, + }, + {"flag_col": "cnts_thin_total", "flag_value": 1}, + {"flag_col": "cnts_thick_total", "flag_value": 1}, + {"flag_col": "cnts_unclassified", "flag_value": 1}, + ] + + for cfg_hist in BC_hist_configs: + meta_hist = ( + make_hist_meta( + bin_ctrs=inc_mass_bin_ctrs, + kind="mass", + flag_col=cfg_hist["flag_col"], + rho_eff=run_config["rho_eff"], + BC_type=run_config["BC_type"], + ) + .astype(DEFAULT_FLOAT, copy=False) + .convert_dtypes(dtype_backend="pyarrow") + ) + ddf_out = ddf_pbp_with_flow.map_partitions( + process_hist_and_dist_partition, + col="BC mass within range", + flag_col=cfg_hist["flag_col"], + flag_value=cfg_hist["flag_value"], + bin_lims=inc_mass_bin_lims, + bin_ctrs=inc_mass_bin_ctrs, + dt=run_config["dt"], + calculate_conc=True, + flow=None, + rho_eff=run_config["rho_eff"], + BC_type=run_config["BC_type"], + meta=meta_hist, + ).map_partitions(cast_and_arrow, meta=meta_hist) + # results.append(ddf_out) + + computed_hist = ddf_out.compute() + computed_results.append(computed_hist) + del ddf_out + + # Process other histogram types... + if run_config["do_scatt_hist"]: + print("Computing scattering distribution...") + meta_hist = ( + make_hist_meta( + bin_ctrs=scatt_bin_ctrs, + kind="scatt", + flag_col=None, + rho_eff=None, + BC_type=None, + ) + .astype(DEFAULT_FLOAT, copy=False) + .convert_dtypes(dtype_backend="pyarrow") + ) + ddf_scatt = ddf_pbp_with_flow.map_partitions( + process_hist_and_dist_partition, + col="Opt diam scatt only", + flag_col=None, + flag_value=None, + bin_lims=scatt_bin_lims, + bin_ctrs=scatt_bin_ctrs, + dt=run_config["dt"], + calculate_conc=True, + flow=None, + rho_eff=None, + BC_type=None, + # t=1, + meta=meta_hist, + ).map_partitions(cast_and_arrow, meta=meta_hist) + # results.append(ddf_scatt) + computed_scatt = ddf_scatt.compute() + computed_results.append(computed_scatt) + + if run_config["do_timelag_hist"]: + print("Computing time delay distribution...") + mass_bins = ( + ddf_pbp_with_flow[["BC mass bin"]] + .compute() + .astype("Int64") + .drop_duplicates() + .dropna() + ) + + for idx, mass_bin in enumerate(mass_bins): + ddf_bin = ddf_pbp_with_flow[ddf_pbp_with_flow["BC mass bin"] == mass_bin] + + name_prefix = f"dNdlogDmev_{inc_mass_bin_ctrs[idx]:.2f}_timelag" + + meta_hist = make_hist_meta( + bin_ctrs=timelag_bin_ctrs, + kind="timelag", + flag_col="cnts_particles_for_tl_dist", + name_prefix=name_prefix, + rho_eff=None, + BC_type=None, + ) + + tl_ddf = ddf_bin.map_partitions( + process_hist_and_dist_partition, + col="time_lag", + flag_col="cnts_particles_for_tl_dist", + flag_value=1, + bin_lims=timelag_bins_lims, + bin_ctrs=timelag_bin_ctrs, + dt=run_config["dt"], + calculate_conc=True, + flow=None, + rho_eff=None, + BC_type=None, + # t=1, + name_prefix=name_prefix, + meta=meta_hist, + ) + + # + tl_ddf = tl_ddf.map_partitions(cast_and_arrow, meta=meta_hist) + + # results.append(tl_ddf) + computed_tl = tl_ddf.compute() + computed_results.append(computed_tl) + if computed_results: + for i, df in enumerate(computed_results): + # Ensure index is nanosecond precision + if hasattr(df.index, "dtype") and "datetime" in str(df.index.dtype): + df.index = df.index.astype("datetime64[ns]") + + # Fix any datetime columns + for col in df.columns: + if hasattr(df[col], "dtype") and "datetime" in str(df[col].dtype): + computed_results[i][col] = df[col].astype("datetime64[ns]") + merged_df = pd.concat(computed_results, axis=1) + # merged_ddf = dd.from_pandas(merged_df, npartitions=1) + # Double-check the merged result + if hasattr(merged_df.index, "dtype") and "datetime" in str( + merged_df.index.dtype + ): + merged_df.index = merged_df.index.astype("datetime64[ns]") + + merged_ddf = dd.from_pandas(merged_df, npartitions=1) + + """idx_target = "datetime64[ns]" + merged_ddf = merged_ddf.map_partitions( + lambda pdf: pdf.set_index(pdf.index.astype(idx_target, copy=False)), + meta=merged_ddf._meta, + ) + + index_as_dt = dd.to_datetime(merged_ddf.index.to_series()) + merged_ddf["date"] = index_as_dt.map_partitions( + lambda s: s.dt.normalize(), meta=("date", "datetime64[ns]") + )""" + merged_ddf["date"] = dd.to_datetime(merged_ddf.index.to_series()).dt.normalize() + + # --- Save hists to parquet + delete_partition_if_exists( + output_path=f"{run_config['output']}/hists_{run_config['dt']}s", + partition_values={ + "date": chunk_start.strftime("%Y-%m-%d"), + "hour": chunk_start.hour, + }, + ) + + # Compute immediately + hist_future = merged_ddf.to_parquet( + f"{run_config['output']}/hists_{run_config['dt']}s", + partition_on=["date"], + engine="pyarrow", + write_index=True, + write_metadata_file=True, + append=True, + schema="infer", + compute=False, + ) + hist_future.compute() + + # client.cancel([merged_ddf, *computed_results, hist_future]) + # del merged_ddf, computed_results, hist_future + del merged_df, merged_ddf, computed_results + gc.collect() + # results = []