From 86a02a95b5dedcc7f9f4e479cd9f343b9fca425e Mon Sep 17 00:00:00 2001 From: "xiangyu.xie" Date: Wed, 26 Nov 2025 11:42:16 +0100 Subject: [PATCH] add h5 compatibility --- src/datasets.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/datasets.py b/src/datasets.py index d77d51c..860373b 100644 --- a/src/datasets.py +++ b/src/datasets.py @@ -112,9 +112,17 @@ class doublePhotonDataset(Dataset): all_samples = [] all_labels = [] for idx, sampleFile in enumerate(self.sampleFileList): - data = np.load(sampleFile) - all_samples.append(data['samples']) - all_labels.append(data['labels']) + if '.npz' in sampleFile: + data = np.load(sampleFile) + all_samples.append(data['samples']) + all_labels.append(data['labels']) + elif '.h5' in sampleFile: + import h5py + with h5py.File(sampleFile, 'r') as f: + samples = f['clusters'][:] + labels = f['labels'][:] + all_samples.append(samples) + all_labels.append(labels) self.samples = np.concatenate(all_samples, axis=0) if self.noiseKeV != 0: print(f'Adding Gaussian noise with sigma = {self.noiseKeV} keV to samples in {self.datasetName} dataset')