Add data augmentation
This commit is contained in:
@@ -3,10 +3,11 @@ import torch
|
||||
import numpy as np
|
||||
|
||||
class singlePhotonDataset(Dataset):
|
||||
def __init__(self, sampleList, sampleRatio, datasetName, noiseKeV=0):
|
||||
def __init__(self, sampleList, sampleRatio, datasetName, noiseKeV=0, numberOfAugOps=1):
|
||||
self.sampleFileList = sampleList
|
||||
self.sampleRatio = sampleRatio
|
||||
self.datasetName = datasetName
|
||||
self.numberOfAugOps = numberOfAugOps
|
||||
self._init_coords()
|
||||
|
||||
all_samples = []
|
||||
@@ -44,37 +45,60 @@ class singlePhotonDataset(Dataset):
|
||||
print(f'Adding Gaussian noise with sigma = {noiseKeV} keV to samples in {self.datasetName} dataset')
|
||||
noise = np.random.normal(loc=0.0, scale=noiseKeV, size=self.samples.shape)
|
||||
self.samples = self.samples + noise
|
||||
self.labels = torch.tensor(np.concatenate(all_labels, axis=0))
|
||||
self.labels = np.concatenate(all_labels, axis=0)
|
||||
self.referencePoint = np.concatenate(all_ref_pts, axis=0) if all_ref_pts else None
|
||||
|
||||
if self.samples.shape[1] == 5: ### if sample size is 5x5, remove border pixels to make it 3x3
|
||||
self.samples = self.samples[:, 1:-1, 1:-1] ### remove border pixels
|
||||
self.labels = self.labels - torch.tensor([1, 1, 0, 0]) ### adjust labels accordingly
|
||||
self.samples = torch.tensor(self.samples).unsqueeze(1).float()
|
||||
x_grids = self.x_grid.expand(self.samples.size(0), 1, -1, -1)
|
||||
y_grids = self.y_grid.expand(self.samples.size(0), 1, -1, -1)
|
||||
self.samples = torch.cat([self.samples, x_grids, y_grids], dim=1) ### concatenate coordinate channels
|
||||
self.labels -= torch.tensor([self.samples.shape[1]/2., self.samples.shape[1]/2., 0, 0]) ### adjust labels to be centered at (0,0)
|
||||
self.labels = self.labels - np.array([1, 1, 0, 0]) ### adjust labels accordingly
|
||||
self.samples = np.expand_dims(self.samples, axis=1)
|
||||
self.labels -= np.array([self.samples.shape[-1]/2., self.samples.shape[-1]/2., 0, 0]) ### B,D,3,3 adjust labels to be centered at (0,0)
|
||||
self.labels[:, 2] /= 650. ### normalize z coordinate (depth) to [0, 1]
|
||||
### total number of samples
|
||||
self.length = int(self.samples.shape[0] * self.sampleRatio)
|
||||
print(f"[{self.datasetName} dataset] \t Total number of samples: {self.length}")
|
||||
self.nSamples = int(self.samples.shape[0] * self.sampleRatio)
|
||||
self.effectiveLength = self.nSamples * self.numberOfAugOps
|
||||
print(f"[{self.datasetName} dataset] \t Total number of samples: {self.nSamples} \t Effective length (with augmentation): {self.effectiveLength}")
|
||||
|
||||
def _init_coords(self):
|
||||
# Create a coordinate grid for 3x3 input
|
||||
x = torch.linspace(-0.5, 0.5, 3)
|
||||
y = torch.linspace(-0.5, 0.5, 3)
|
||||
x_grid, y_grid = torch.meshgrid(x, y, indexing='ij') # (3,3), (3,3)
|
||||
self.x_grid = x_grid.unsqueeze(0) # (1, 3, 3)
|
||||
self.y_grid = y_grid.unsqueeze(0) # (1, 3, 3)
|
||||
x = np.linspace(-0.5, 0.5, 3)
|
||||
y = np.linspace(-0.5, 0.5, 3)
|
||||
x_grid, y_grid = np.meshgrid(x, y, indexing='ij') # (3,3), (3,3)
|
||||
self.x_grid = torch.tensor(np.expand_dims(x_grid, axis=0)).float().contiguous() # (1, 3, 3)
|
||||
self.y_grid = torch.tensor(np.expand_dims(y_grid, axis=0)).float().contiguous() # (1, 3, 3)
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.samples[index]
|
||||
label = self.labels[index]
|
||||
sampleIdx, operationIdx = index // self.numberOfAugOps, index % self.numberOfAugOps
|
||||
sample = self.samples[sampleIdx]
|
||||
label = self.labels[sampleIdx]
|
||||
|
||||
###( flipAxes, swap, label_transform)
|
||||
### sample axes: 0 - y axis, 1 - x axis
|
||||
### label: (x, y, ...)
|
||||
TRANSFORMS = {
|
||||
0: (None, False, lambda l: l),
|
||||
1: ([1], False, lambda l: np.array([-l[0], l[1], l[2], l[3]])),
|
||||
2: ([0], False, lambda l: np.array([l[0], -l[1], l[2], l[3]])),
|
||||
3: ([0, 1], False, lambda l: -l),
|
||||
4: (None, True, lambda l: np.array([l[1], l[0], l[2], l[3]])),
|
||||
5: ([1], True, lambda l: np.array([-l[1], l[0], l[2], l[3]])),
|
||||
6: ([0], True, lambda l: np.array([l[1], -l[0], l[2], l[3]])),
|
||||
7: ([0, 1], True, lambda l: -np.array([l[1], l[0], l[2], l[3]])),
|
||||
}
|
||||
flipAxes, doSwap, labelTransform = TRANSFORMS[operationIdx]
|
||||
if doSwap:
|
||||
sample = np.swapaxes(sample, -1, -2)
|
||||
if flipAxes is not None:
|
||||
sample = np.flip(sample, axis=[ax+1 for ax in flipAxes])
|
||||
label = labelTransform(label)
|
||||
|
||||
sample = torch.from_numpy(np.ascontiguousarray(sample)).float()
|
||||
sample = torch.cat((sample, self.x_grid, self.y_grid), dim=0) ### concatenate coordinate channels
|
||||
label = torch.from_numpy(label).float()
|
||||
return sample, label
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
return self.effectiveLength
|
||||
|
||||
class doublePhotonDataset(Dataset):
|
||||
def __init__(self, sampleList, sampleRatio, datasetName, reuselFactor=1):
|
||||
|
||||
Reference in New Issue
Block a user