234 lines
12 KiB
Python
234 lines
12 KiB
Python
from torch.utils.data import Dataset
|
|
import torch
|
|
import numpy as np
|
|
|
|
class singlePhotonDataset(Dataset):
|
|
def __init__(self, sampleList, sampleRatio, datasetName, noiseKeV=0, numberOfAugOps=1, normalize=False, noiseThreshold=0):
|
|
self.sampleFileList = sampleList
|
|
self.sampleRatio = sampleRatio
|
|
self.datasetName = datasetName
|
|
self.numberOfAugOps = numberOfAugOps
|
|
self.normalize = normalize
|
|
self.noiseThreshold = noiseThreshold
|
|
self._init_coords()
|
|
|
|
all_samples = []
|
|
all_labels = []
|
|
all_ref_pts = []
|
|
for sampleFile in self.sampleFileList:
|
|
if '.npz' in sampleFile:
|
|
data = np.load(sampleFile)
|
|
all_samples.append(data['samples'])
|
|
if 'referencePoint' in data:
|
|
all_ref_pts.append(data['referencePoint'])
|
|
else:
|
|
all_ref_pts.append(np.zeros((data['samples'].shape[0], 2), dtype=np.float32)) ### dummy reference points
|
|
if 'labels' in data:
|
|
all_labels.append(data['labels'])
|
|
else:
|
|
all_labels.append(np.zeros((data['samples'].shape[0], 4), dtype=np.float32)) ### dummy labels
|
|
elif '.h5' in sampleFile:
|
|
import h5py
|
|
with h5py.File(sampleFile, 'r') as f:
|
|
samples = f['clusters'][:]
|
|
if 'referencePoint' in f:
|
|
ref_pts = f['referencePoint'][:]
|
|
all_ref_pts.append(ref_pts)
|
|
else:
|
|
all_ref_pts.append(np.zeros((samples.shape[0], 2), dtype=np.float32)) ### dummy reference points
|
|
if 'labels' in f:
|
|
labels = f['labels'][:]
|
|
else:
|
|
labels = np.zeros((samples.shape[0], 4), dtype=np.float32) ### dummy labels
|
|
all_samples.append(samples)
|
|
all_labels.append(labels)
|
|
self.samples = np.concatenate(all_samples, axis=0)
|
|
if noiseKeV != 0:
|
|
print(f'Adding Gaussian noise with sigma = {noiseKeV} keV to samples in {self.datasetName} dataset')
|
|
noise = np.random.normal(loc=0.0, scale=noiseKeV, size=self.samples.shape)
|
|
self.samples = self.samples + noise
|
|
if self.noiseThreshold != 0 and noiseKeV != 0:
|
|
print(f'[{self.datasetName} dataset] \t Setting values below noise threshold ({self.noiseThreshold} keV) to zero')
|
|
self.samples[self.samples < self.noiseThreshold] = 0 ### set values below threshold to zero
|
|
self.labels = np.concatenate(all_labels, axis=0)
|
|
self.referencePoint = np.concatenate(all_ref_pts, axis=0) if all_ref_pts else None
|
|
|
|
if self.normalize:
|
|
print(f'Normalizing samples in {self.datasetName} dataset by total charge')
|
|
total_charge = np.sum(self.samples, axis=(1,2), keepdims=True) # (B, 1, 1)
|
|
total_charge[total_charge == 0] = 1 # avoid division by zero
|
|
self.samples = self.samples / total_charge * 15. # normalize each sample by its total charge
|
|
|
|
if self.samples.shape[1] == 5: ### if sample size is 5x5, remove border pixels to make it 3x3
|
|
self.samples = self.samples[:, 1:-1, 1:-1] ### remove border pixels
|
|
self.labels = self.labels - np.array([1, 1, 0, 0]) ### adjust labels accordingly
|
|
if self.referencePoint is not None:
|
|
self.referencePoint = self.referencePoint + np.array([1, 1]) ### adjust reference points accordingly
|
|
self.samples = np.expand_dims(self.samples, axis=1)
|
|
self.labels -= np.array([self.samples.shape[-1]/2., self.samples.shape[-1]/2., 0, 0]) ### B,D,3,3 adjust labels to be centered at (0,0)
|
|
self.labels[:, 2] /= 650. ### normalize z coordinate (depth) to [0, 1]
|
|
### total number of samples
|
|
self.nSamples = int(self.samples.shape[0] * self.sampleRatio)
|
|
self.effectiveLength = self.nSamples * self.numberOfAugOps
|
|
print(f"[{self.datasetName} dataset] \t Total number of samples: {self.nSamples} \t Effective length (with augmentation): {self.effectiveLength}")
|
|
|
|
def _init_coords(self):
|
|
# Create a coordinate grid for 3x3 input
|
|
x = np.linspace(-0.5, 0.5, 3)
|
|
y = np.linspace(-0.5, 0.5, 3)
|
|
x_grid, y_grid = np.meshgrid(x, y, indexing='ij') # (3,3), (3,3)
|
|
self.x_grid = torch.tensor(np.expand_dims(x_grid, axis=0)).float().contiguous() # (1, 3, 3)
|
|
self.y_grid = torch.tensor(np.expand_dims(y_grid, axis=0)).float().contiguous() # (1, 3, 3)
|
|
|
|
def __getitem__(self, index):
|
|
sampleIdx, operationIdx = index // self.numberOfAugOps, index % self.numberOfAugOps
|
|
sample = self.samples[sampleIdx]
|
|
label = self.labels[sampleIdx]
|
|
|
|
###( flipAxes, swap, label_transform)
|
|
### sample axes: 0 - y axis, 1 - x axis
|
|
### label: (x, y, ...)
|
|
TRANSFORMS = {
|
|
0: (None, False, lambda l: l),
|
|
1: ([1], False, lambda l: np.array([-l[0], l[1], l[2], l[3]])),
|
|
2: ([0], False, lambda l: np.array([l[0], -l[1], l[2], l[3]])),
|
|
3: ([0, 1], False, lambda l: -l),
|
|
4: (None, True, lambda l: np.array([l[1], l[0], l[2], l[3]])),
|
|
5: ([1], True, lambda l: np.array([-l[1], l[0], l[2], l[3]])),
|
|
6: ([0], True, lambda l: np.array([l[1], -l[0], l[2], l[3]])),
|
|
7: ([0, 1], True, lambda l: -np.array([l[1], l[0], l[2], l[3]])),
|
|
}
|
|
flipAxes, doSwap, labelTransform = TRANSFORMS[operationIdx]
|
|
if doSwap:
|
|
sample = np.swapaxes(sample, -1, -2)
|
|
if flipAxes is not None:
|
|
sample = np.flip(sample, axis=[ax+1 for ax in flipAxes])
|
|
label = labelTransform(label)
|
|
|
|
sample = torch.from_numpy(np.ascontiguousarray(sample)).float()
|
|
sample = torch.cat((sample, self.x_grid, self.y_grid), dim=0) ### concatenate coordinate channels
|
|
label = torch.from_numpy(label).float()
|
|
return sample, label
|
|
|
|
def __len__(self):
|
|
return self.effectiveLength
|
|
|
|
class doublePhotonDataset(Dataset):
|
|
def __init__(self, sampleList, sampleRatio, datasetName, reuselFactor=1, noiseKeV=0, nSize=6):
|
|
self.sampleFileList = sampleList
|
|
self.sampleRatio = sampleRatio
|
|
self.datasetName = datasetName
|
|
self.noiseKeV = noiseKeV
|
|
self.nSize = nSize
|
|
self._init_coords()
|
|
|
|
all_samples = []
|
|
all_labels = []
|
|
for idx, sampleFile in enumerate(self.sampleFileList):
|
|
if '.npz' in sampleFile:
|
|
data = np.load(sampleFile)
|
|
all_samples.append(data['samples'])
|
|
all_labels.append(data['labels'])
|
|
elif '.h5' in sampleFile:
|
|
import h5py
|
|
with h5py.File(sampleFile, 'r') as f:
|
|
samples = f['clusters'][:]
|
|
labels = f['labels'][:]
|
|
all_samples.append(samples)
|
|
all_labels.append(labels)
|
|
self.samples = np.concatenate(all_samples, axis=0)
|
|
if self.noiseKeV != 0:
|
|
print(f'Adding Gaussian noise with sigma = {self.noiseKeV} keV to samples in {self.datasetName} dataset')
|
|
noise = np.random.normal(loc=0.0, scale=self.noiseKeV, size=self.samples.shape)
|
|
#### add noise only to pixels that not zero
|
|
noise[self.samples == 0] = 0
|
|
self.samples = self.samples + noise
|
|
self.labels = np.concatenate(all_labels, axis=0)
|
|
|
|
### total number of samples
|
|
self.length = int(self.samples.shape[0] * self.sampleRatio) // 2 * reuselFactor
|
|
print(f"[{self.datasetName} dataset] \t Total number of samples: {self.length}")
|
|
|
|
def _init_coords(self):
|
|
# Create a coordinate grid for 3x3 input
|
|
x = np.linspace(-self.nSize/2. + 0.5, self.nSize/2. - 0.5, self.nSize)
|
|
y = np.linspace(-self.nSize/2. + 0.5, self.nSize/2. - 0.5, self.nSize)
|
|
x_grid, y_grid = np.meshgrid(x, y, indexing='ij') # (nSize,nSize), (nSize,nSize)
|
|
self.x_grid = torch.tensor(np.expand_dims(x_grid, axis=0)).float().contiguous() # (1, nSize, nSize)
|
|
self.y_grid = torch.tensor(np.expand_dims(y_grid, axis=0)).float().contiguous() # (1, nSize, nSize)
|
|
|
|
def __getitem__(self, index):
|
|
sample = np.zeros((self.nSize+2, self.nSize+2), dtype=np.float32)
|
|
idx1 = np.random.randint(0, self.samples.shape[0])
|
|
idx2 = np.random.randint(0, self.samples.shape[0])
|
|
photon1 = self.samples[idx1]
|
|
photon2 = self.samples[idx2]
|
|
singlePhotonSize = photon1.shape[0]
|
|
|
|
### random position for photons in
|
|
pos_x1 = np.random.randint(1, 4)
|
|
pos_y1 = np.random.randint(1, 4)
|
|
sample[pos_y1:pos_y1+singlePhotonSize, pos_x1:pos_x1+singlePhotonSize] += photon1
|
|
pos_x2 = np.random.randint(1, 4)
|
|
pos_y2 = np.random.randint(1, 4)
|
|
sample[pos_y2:pos_y2+singlePhotonSize, pos_x2:pos_x2+singlePhotonSize] += photon2
|
|
sample = sample[1:-1, 1:-1] ### sample size: nSize x nSize
|
|
sample = torch.tensor(sample, dtype=torch.float32).unsqueeze(0)
|
|
sample = torch.cat((sample, self.x_grid, self.y_grid), dim=0) ### concatenate coordinate channels
|
|
|
|
label1 = self.labels[idx1] + np.array([pos_x1-1-self.nSize/2., pos_y1-1-self.nSize/2., 0, 0])
|
|
label2 = self.labels[idx2] + np.array([pos_x2-1-self.nSize/2., pos_y2-1-self.nSize/2., 0, 0])
|
|
label = np.concatenate((label1, label2), axis=0)
|
|
return sample, torch.tensor(label, dtype=torch.float32)
|
|
|
|
def __len__(self):
|
|
return self.length
|
|
|
|
|
|
class doublePhotonInferenceDataset(Dataset):
|
|
def __init__(self, sampleList, sampleRatio, datasetName, nSize=6):
|
|
self.sampleFileList = sampleList
|
|
self.sampleRatio = sampleRatio
|
|
self.datasetName = datasetName
|
|
self.nSize = nSize
|
|
self._init_coords()
|
|
|
|
all_samples = []
|
|
all_ref_pts = []
|
|
for idx, sampleFile in enumerate(self.sampleFileList):
|
|
if '.npz' in sampleFile:
|
|
data = np.load(sampleFile)
|
|
all_samples.append(data['samples'])
|
|
all_ref_pts.append(data['referencePoint'])
|
|
elif '.h5' in sampleFile:
|
|
import h5py
|
|
with h5py.File(sampleFile, 'r') as f:
|
|
samples = f['clusters'][:]
|
|
ref_pts = f['referencePoint'][:]
|
|
all_samples.append(samples)
|
|
all_ref_pts.append(ref_pts)
|
|
self.samples = np.concatenate(all_samples, axis=0) if all_samples else None
|
|
self.referencePoint = np.concatenate(all_ref_pts, axis=0) if all_ref_pts else None
|
|
### total number of samples
|
|
self.length = int(self.samples.shape[0] * self.sampleRatio)
|
|
self.referencePoint = self.referencePoint[:self.length]
|
|
print(f"[{self.datasetName} dataset] \t Total number of samples: {self.length}")
|
|
|
|
def _init_coords(self):
|
|
# Create a coordinate grid for 3x3 input
|
|
x = np.linspace(-self.nSize/2. + 0.5, self.nSize/2. - 0.5, self.nSize)
|
|
y = np.linspace(-self.nSize/2. + 0.5, self.nSize/2. - 0.5, self.nSize)
|
|
x_grid, y_grid = np.meshgrid(x, y, indexing='ij') # (nSize,nSize), (nSize,nSize)
|
|
self.x_grid = torch.tensor(np.expand_dims(x_grid, axis=0)).float().contiguous() # (1, nSize, nSize)
|
|
self.y_grid = torch.tensor(np.expand_dims(y_grid, axis=0)).float().contiguous() # (1, nSize, nSize)
|
|
|
|
def __getitem__(self, index):
|
|
sample = self.samples[index]
|
|
# sample[sample == 0] += np.random.normal(loc=0.0, scale=0.13, size=sample[sample == 0].shape) ### add noise to zero pixels
|
|
sample = torch.tensor(sample, dtype=torch.float32).unsqueeze(0)
|
|
sample = torch.cat((sample, self.x_grid, self.y_grid), dim=0) ### concatenate coordinate channels
|
|
dummy_label = np.zeros((8,), dtype=np.float32) ### dummy label
|
|
return sample, torch.tensor(dummy_label, dtype=torch.float32)
|
|
|
|
def __len__(self):
|
|
return self.length |