94 lines
4.5 KiB
Python
94 lines
4.5 KiB
Python
import sys
|
|
sys.path.append("/home/xie_x1/software/aareBuild")
|
|
from omegaconf import OmegaConf ### for yaml config parsing
|
|
import numpy as np
|
|
import aare
|
|
from multiprocessing import Pool
|
|
|
|
conf = OmegaConf.load("Configs/generate_sample.yaml")
|
|
|
|
def generate_sample(input_tuple):
|
|
worker_idx, file_list = input_tuple
|
|
### generate samples with given number of photons, energy, sample size, and noise level
|
|
'''
|
|
input file contains single photon clusters (3x3)
|
|
assemble them into multi-photon pile-up clusters (e.g. 7x7 for 2 photons, 9x9 for 3 photons)
|
|
'''
|
|
output_path = conf.output_folder
|
|
all_samples, all_labels = None, None
|
|
for file in file_list:
|
|
data = np.load(file)
|
|
if all_samples is None:
|
|
all_samples = data['samples']
|
|
all_labels = data['labels']
|
|
else:
|
|
all_samples = np.concatenate((all_samples, data['samples']), axis=0)
|
|
all_labels = np.concatenate((all_labels, data['labels']), axis=0)
|
|
if all_samples.shape[-1] == 5:
|
|
all_samples = all_samples[:, 1:4, 1:4] ### extract the central 3x3 part of the 5x5 cluster as the single-photon cluster
|
|
all_labels[:, :2] -= 1 ### adjust the label accordingly
|
|
|
|
|
|
CF = aare.VarClusterFinder((conf.sample_size*3, conf.sample_size*3), 5) ### arg1: frame size, arg2: threshold
|
|
CF.set_peripheralThresholdFactor(3)
|
|
noise_map = np.ones((conf.sample_size*3, conf.sample_size*3)) * conf.noise_keV
|
|
CF.set_noiseMap(noise_map)
|
|
CF.set_numberOfNeighbours(4)
|
|
CF.set_empty_surroundingPixels(False)
|
|
|
|
count_of_event = 0
|
|
pileup_samples = []
|
|
pileup_labels = []
|
|
while count_of_event < conf.num_of_events:
|
|
|
|
sample = np.random.normal(loc=0, scale=conf.noise_keV, size=(conf.sample_size*3, conf.sample_size*3)) ### start with a noise sample
|
|
### randomly select conf.num_of_photons single-photon clusters and their reference points
|
|
labels = []
|
|
for _ in range(conf.num_of_photons):
|
|
_idx = np.random.randint(0, all_samples.shape[0])
|
|
_sample = all_samples[_idx]
|
|
_label = all_labels[_idx].copy() ### copy the label to avoid modifying the original label in all_labels
|
|
|
|
_ref_x, _ref_y = np.random.randint(conf.sample_size, conf.sample_size*2-2), np.random.randint(conf.sample_size, conf.sample_size*2-2) ### random reference point of the lower-left photon in the sample, make sure the 3x3 cluster can fit in the sample
|
|
sample[_ref_y:_ref_y+3, _ref_x:_ref_x+3] += _sample ### add the 3x3 cluster to the sample
|
|
_label[0] += _ref_x
|
|
_label[1] += _ref_y
|
|
labels.append(_label)
|
|
### test if photons form into one cluster with the cluster finder
|
|
CF.find_clusters_X(sample)
|
|
clusters = CF.hits()
|
|
if len(clusters) != 1:
|
|
continue ### if not exactly one cluster is found, discard this sample and generate a new one
|
|
### if exactly one cluster is found, save the sample and the label (photon positions)
|
|
### cut into conf.sample_size x conf.sample_size, centering around the weighted center of the cluster
|
|
clusterSize = clusters['size'][0]
|
|
enes = clusters['enes'][0][:clusterSize]
|
|
xs = clusters['cols'][0][:clusterSize]
|
|
ys = clusters['rows'][0][:clusterSize]
|
|
for i in range(clusterSize):
|
|
sample[ys[i], xs[i]] = enes[i] ###
|
|
x_center = np.sum(xs * enes) / np.sum(enes)
|
|
y_center = np.sum(ys * enes) / np.sum(enes)
|
|
|
|
ref_x = int(x_center - conf.sample_size / 2) + 1
|
|
ref_y = int(y_center - conf.sample_size / 2) + 1
|
|
sample = sample[ref_y:ref_y+conf.sample_size, ref_x:ref_x+conf.sample_size]
|
|
for label in labels:
|
|
label[0] -= ref_x
|
|
label[1] -= ref_y
|
|
labels = np.array(labels)
|
|
|
|
pileup_samples.append(sample)
|
|
pileup_labels.append(labels)
|
|
count_of_event += 1
|
|
if count_of_event % 10000 == 0:
|
|
print(f"Worker {worker_idx}: {count_of_event} events generated")
|
|
pileup_samples = np.array(pileup_samples)
|
|
pileup_labels = np.array(pileup_labels)
|
|
|
|
np.savez(f'{output_path}/pileupOf{conf.num_of_photons}phs_sample_{worker_idx}.npz', samples=pileup_samples, labels=pileup_labels)
|
|
print(f"Worker {worker_idx}: finished generating {count_of_event} events, saved to {output_path}/pileupOf{conf.num_of_photons}phs_sample_{worker_idx}.npz")
|
|
|
|
if __name__ == "__main__":
|
|
with Pool(processes=16) as pool:
|
|
pool.map(generate_sample, [(i, [f'{conf.sample_folder}/12keV_Moench040_150V_{i}.npz']) for i in range(16)]) |