add h5 compatibility
This commit is contained in:
@@ -112,9 +112,17 @@ class doublePhotonDataset(Dataset):
|
|||||||
all_samples = []
|
all_samples = []
|
||||||
all_labels = []
|
all_labels = []
|
||||||
for idx, sampleFile in enumerate(self.sampleFileList):
|
for idx, sampleFile in enumerate(self.sampleFileList):
|
||||||
data = np.load(sampleFile)
|
if '.npz' in sampleFile:
|
||||||
all_samples.append(data['samples'])
|
data = np.load(sampleFile)
|
||||||
all_labels.append(data['labels'])
|
all_samples.append(data['samples'])
|
||||||
|
all_labels.append(data['labels'])
|
||||||
|
elif '.h5' in sampleFile:
|
||||||
|
import h5py
|
||||||
|
with h5py.File(sampleFile, 'r') as f:
|
||||||
|
samples = f['clusters'][:]
|
||||||
|
labels = f['labels'][:]
|
||||||
|
all_samples.append(samples)
|
||||||
|
all_labels.append(labels)
|
||||||
self.samples = np.concatenate(all_samples, axis=0)
|
self.samples = np.concatenate(all_samples, axis=0)
|
||||||
if self.noiseKeV != 0:
|
if self.noiseKeV != 0:
|
||||||
print(f'Adding Gaussian noise with sigma = {self.noiseKeV} keV to samples in {self.datasetName} dataset')
|
print(f'Adding Gaussian noise with sigma = {self.noiseKeV} keV to samples in {self.datasetName} dataset')
|
||||||
|
|||||||
Reference in New Issue
Block a user