Add h5 support; add refpoints
This commit is contained in:
@@ -3,22 +3,55 @@ import torch
|
||||
import numpy as np
|
||||
|
||||
class singlePhotonDataset(Dataset):
|
||||
def __init__(self, sampleList, labelList, sampleRatio):
|
||||
def __init__(self, sampleList, sampleRatio, datasetName, noiseKeV=0):
|
||||
self.sampleFileList = sampleList
|
||||
self.labelFileList = labelList
|
||||
self.sampleRatio = sampleRatio
|
||||
self.datasetName = datasetName
|
||||
|
||||
all_samples = []
|
||||
all_labels = []
|
||||
all_ref_pts = []
|
||||
for sampleFile in self.sampleFileList:
|
||||
if '.npz' in sampleFile:
|
||||
data = np.load(sampleFile)
|
||||
all_samples.append(data['samples'])
|
||||
if 'referencePoint' in data:
|
||||
all_ref_pts.append(data['referencePoint'])
|
||||
else:
|
||||
all_ref_pts.append(np.zeros((data['samples'].shape[0], 2), dtype=np.float32)) ### dummy reference points
|
||||
if 'labels' in data:
|
||||
all_labels.append(data['labels'])
|
||||
else:
|
||||
all_labels.append(np.zeros((data['samples'].shape[0], 4), dtype=np.float32)) ### dummy labels
|
||||
elif '.h5' in sampleFile:
|
||||
import h5py
|
||||
with h5py.File(sampleFile, 'r') as f:
|
||||
samples = f['clusters'][:]
|
||||
if 'referencePoint' in f:
|
||||
ref_pts = f['referencePoint'][:]
|
||||
all_ref_pts.append(ref_pts)
|
||||
else:
|
||||
all_ref_pts.append(np.zeros((samples.shape[0], 2), dtype=np.float32)) ### dummy reference points
|
||||
if 'labels' in f:
|
||||
labels = f['labels'][:]
|
||||
else:
|
||||
labels = np.zeros((samples.shape[0], 4), dtype=np.float32) ### dummy labels
|
||||
all_samples.append(samples)
|
||||
all_labels.append(labels)
|
||||
self.samples = np.concatenate(all_samples, axis=0)
|
||||
if noiseKeV != 0:
|
||||
print(f'Adding Gaussian noise with sigma = {noiseKeV} keV to samples in {self.datasetName} dataset')
|
||||
noise = np.random.normal(loc=0.0, scale=noiseKeV, size=self.samples.shape)
|
||||
self.samples = self.samples + noise
|
||||
self.labels = np.concatenate(all_labels, axis=0)
|
||||
self.referencePoint = np.concatenate(all_ref_pts, axis=0) if all_ref_pts else None
|
||||
|
||||
for idx, sampleFile in enumerate(self.sampleFileList):
|
||||
if idx == 0:
|
||||
self.samples = np.load(sampleFile)
|
||||
self.labels = np.load(self.labelFileList[idx])
|
||||
else:
|
||||
self.samples = np.concatenate((self.samples, np.load(sampleFile)), axis=0)
|
||||
self.labels = np.concatenate((self.labels, np.load(self.labelFileList[idx])), axis=0)
|
||||
|
||||
if self.samples.shape[1] == 5: ### if sample size is 5x5, remove border pixels to make it 3x3
|
||||
self.samples = self.samples[:, 1:-1, 1:-1] ### remove border pixels
|
||||
self.labels = self.labels - np.array([1, 1, 0, 0]) ### adjust labels accordingly
|
||||
### total number of samples
|
||||
self.length = int(self.samples.shape[0] * self.sampleRatio)
|
||||
print(f"Total number of samples: {self.length}")
|
||||
print(f"[{self.datasetName} dataset] \t Total number of samples: {self.length}")
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.samples[index]
|
||||
@@ -30,20 +63,20 @@ class singlePhotonDataset(Dataset):
|
||||
return self.length
|
||||
|
||||
class doublePhotonDataset(Dataset):
|
||||
def __init__(self, sampleList, labelList, sampleRatio, datasetName, reuselFactor=1):
|
||||
def __init__(self, sampleList, sampleRatio, datasetName, reuselFactor=1):
|
||||
self.sampleFileList = sampleList
|
||||
self.labelFileList = labelList
|
||||
self.sampleRatio = sampleRatio
|
||||
self.datasetName = datasetName
|
||||
|
||||
all_samples = []
|
||||
all_labels = []
|
||||
for idx, sampleFile in enumerate(self.sampleFileList):
|
||||
if idx == 0:
|
||||
self.samples = np.load(sampleFile)
|
||||
self.labels = np.load(self.labelFileList[idx])
|
||||
else:
|
||||
self.samples = np.concatenate((self.samples, np.load(sampleFile)), axis=0)
|
||||
self.labels = np.concatenate((self.labels, np.load(self.labelFileList[idx])), axis=0)
|
||||
|
||||
data = np.load(sampleFile)
|
||||
all_samples.append(data['samples'])
|
||||
all_labels.append(data['labels'])
|
||||
self.samples = np.concatenate(all_samples, axis=0)
|
||||
self.labels = np.concatenate(all_labels, axis=0)
|
||||
|
||||
### total number of samples
|
||||
self.length = int(self.samples.shape[0] * self.sampleRatio) // 2 * reuselFactor
|
||||
print(f"[{self.datasetName} dataset] \t Total number of samples: {self.length}")
|
||||
|
||||
Reference in New Issue
Block a user