Files
DataProcess/GeneratePileupSample.py
2026-05-29 16:10:34 +02:00

94 lines
4.5 KiB
Python

import sys
sys.path.append("/home/xie_x1/software/aareBuild")
from omegaconf import OmegaConf ### for yaml config parsing
import numpy as np
import aare
from multiprocessing import Pool
conf = OmegaConf.load("Configs/generate_sample.yaml")
def generate_sample(input_tuple):
worker_idx, file_list = input_tuple
### generate samples with given number of photons, energy, sample size, and noise level
'''
input file contains single photon clusters (3x3)
assemble them into multi-photon pile-up clusters (e.g. 7x7 for 2 photons, 9x9 for 3 photons)
'''
output_path = conf.output_folder
all_samples, all_labels = None, None
for file in file_list:
data = np.load(file)
if all_samples is None:
all_samples = data['samples']
all_labels = data['labels']
else:
all_samples = np.concatenate((all_samples, data['samples']), axis=0)
all_labels = np.concatenate((all_labels, data['labels']), axis=0)
if all_samples.shape[-1] == 5:
all_samples = all_samples[:, 1:4, 1:4] ### extract the central 3x3 part of the 5x5 cluster as the single-photon cluster
all_labels[:, :2] -= 1 ### adjust the label accordingly
CF = aare.VarClusterFinder((conf.sample_size*3, conf.sample_size*3), 5) ### arg1: frame size, arg2: threshold
CF.set_peripheralThresholdFactor(3)
noise_map = np.ones((conf.sample_size*3, conf.sample_size*3)) * conf.noise_keV
CF.set_noiseMap(noise_map)
CF.set_numberOfNeighbours(4)
CF.set_empty_surroundingPixels(False)
count_of_event = 0
pileup_samples = []
pileup_labels = []
while count_of_event < conf.num_of_events:
sample = np.random.normal(loc=0, scale=conf.noise_keV, size=(conf.sample_size*3, conf.sample_size*3)) ### start with a noise sample
### randomly select conf.num_of_photons single-photon clusters and their reference points
labels = []
for _ in range(conf.num_of_photons):
_idx = np.random.randint(0, all_samples.shape[0])
_sample = all_samples[_idx]
_label = all_labels[_idx].copy() ### copy the label to avoid modifying the original label in all_labels
_ref_x, _ref_y = np.random.randint(conf.sample_size, conf.sample_size*2-2), np.random.randint(conf.sample_size, conf.sample_size*2-2) ### random reference point of the lower-left photon in the sample, make sure the 3x3 cluster can fit in the sample
sample[_ref_y:_ref_y+3, _ref_x:_ref_x+3] += _sample ### add the 3x3 cluster to the sample
_label[0] += _ref_x
_label[1] += _ref_y
labels.append(_label)
### test if photons form into one cluster with the cluster finder
CF.find_clusters_X(sample)
clusters = CF.hits()
if len(clusters) != 1:
continue ### if not exactly one cluster is found, discard this sample and generate a new one
### if exactly one cluster is found, save the sample and the label (photon positions)
### cut into conf.sample_size x conf.sample_size, centering around the weighted center of the cluster
clusterSize = clusters['size'][0]
enes = clusters['enes'][0][:clusterSize]
xs = clusters['cols'][0][:clusterSize]
ys = clusters['rows'][0][:clusterSize]
for i in range(clusterSize):
sample[ys[i], xs[i]] = enes[i] ###
x_center = np.sum(xs * enes) / np.sum(enes)
y_center = np.sum(ys * enes) / np.sum(enes)
ref_x = int(x_center - conf.sample_size / 2) + 1
ref_y = int(y_center - conf.sample_size / 2) + 1
sample = sample[ref_y:ref_y+conf.sample_size, ref_x:ref_x+conf.sample_size]
for label in labels:
label[0] -= ref_x
label[1] -= ref_y
labels = np.array(labels)
pileup_samples.append(sample)
pileup_labels.append(labels)
count_of_event += 1
if count_of_event % 10000 == 0:
print(f"Worker {worker_idx}: {count_of_event} events generated")
pileup_samples = np.array(pileup_samples)
pileup_labels = np.array(pileup_labels)
np.savez(f'{output_path}/pileupOf{conf.num_of_photons}phs_sample_{worker_idx}.npz', samples=pileup_samples, labels=pileup_labels)
print(f"Worker {worker_idx}: finished generating {count_of_event} events, saved to {output_path}/pileupOf{conf.num_of_photons}phs_sample_{worker_idx}.npz")
if __name__ == "__main__":
with Pool(processes=16) as pool:
pool.map(generate_sample, [(i, [f'{conf.sample_folder}/12keV_Moench040_150V_{i}.npz']) for i in range(16)])