add configurable cluster size for double phtoon sample
This commit is contained in:
+21
-17
@@ -101,11 +101,12 @@ class singlePhotonDataset(Dataset):
|
|||||||
return self.effectiveLength
|
return self.effectiveLength
|
||||||
|
|
||||||
class doublePhotonDataset(Dataset):
|
class doublePhotonDataset(Dataset):
|
||||||
def __init__(self, sampleList, sampleRatio, datasetName, reuselFactor=1, noiseKeV=0):
|
def __init__(self, sampleList, sampleRatio, datasetName, reuselFactor=1, noiseKeV=0, nSize=6):
|
||||||
self.sampleFileList = sampleList
|
self.sampleFileList = sampleList
|
||||||
self.sampleRatio = sampleRatio
|
self.sampleRatio = sampleRatio
|
||||||
self.datasetName = datasetName
|
self.datasetName = datasetName
|
||||||
self.noiseKeV = noiseKeV
|
self.noiseKeV = noiseKeV
|
||||||
|
self.nSize = nSize
|
||||||
self._init_coords()
|
self._init_coords()
|
||||||
|
|
||||||
all_samples = []
|
all_samples = []
|
||||||
@@ -118,6 +119,8 @@ class doublePhotonDataset(Dataset):
|
|||||||
if self.noiseKeV != 0:
|
if self.noiseKeV != 0:
|
||||||
print(f'Adding Gaussian noise with sigma = {self.noiseKeV} keV to samples in {self.datasetName} dataset')
|
print(f'Adding Gaussian noise with sigma = {self.noiseKeV} keV to samples in {self.datasetName} dataset')
|
||||||
noise = np.random.normal(loc=0.0, scale=self.noiseKeV, size=self.samples.shape)
|
noise = np.random.normal(loc=0.0, scale=self.noiseKeV, size=self.samples.shape)
|
||||||
|
#### add noise only to pixels that not zero
|
||||||
|
noise[self.samples == 0] = 0
|
||||||
self.samples = self.samples + noise
|
self.samples = self.samples + noise
|
||||||
self.labels = np.concatenate(all_labels, axis=0)
|
self.labels = np.concatenate(all_labels, axis=0)
|
||||||
|
|
||||||
@@ -127,14 +130,14 @@ class doublePhotonDataset(Dataset):
|
|||||||
|
|
||||||
def _init_coords(self):
|
def _init_coords(self):
|
||||||
# Create a coordinate grid for 3x3 input
|
# Create a coordinate grid for 3x3 input
|
||||||
x = np.linspace(0, 5, 6)
|
x = np.linspace(-self.nSize/2. + 0.5, self.nSize/2. - 0.5, self.nSize)
|
||||||
y = np.linspace(0, 5, 6)
|
y = np.linspace(-self.nSize/2. + 0.5, self.nSize/2. - 0.5, self.nSize)
|
||||||
x_grid, y_grid = np.meshgrid(x, y, indexing='ij') # (6,6), (6,6)
|
x_grid, y_grid = np.meshgrid(x, y, indexing='ij') # (nSize,nSize), (nSize,nSize)
|
||||||
self.x_grid = torch.tensor(np.expand_dims(x_grid, axis=0)).float().contiguous() # (1, 6, 6)
|
self.x_grid = torch.tensor(np.expand_dims(x_grid, axis=0)).float().contiguous() # (1, nSize, nSize)
|
||||||
self.y_grid = torch.tensor(np.expand_dims(y_grid, axis=0)).float().contiguous() # (1, 6, 6)
|
self.y_grid = torch.tensor(np.expand_dims(y_grid, axis=0)).float().contiguous() # (1, nSize, nSize)
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
sample = np.zeros((8, 8), dtype=np.float32)
|
sample = np.zeros((self.nSize+2, self.nSize+2), dtype=np.float32)
|
||||||
idx1 = np.random.randint(0, self.samples.shape[0])
|
idx1 = np.random.randint(0, self.samples.shape[0])
|
||||||
idx2 = np.random.randint(0, self.samples.shape[0])
|
idx2 = np.random.randint(0, self.samples.shape[0])
|
||||||
photon1 = self.samples[idx1]
|
photon1 = self.samples[idx1]
|
||||||
@@ -148,13 +151,12 @@ class doublePhotonDataset(Dataset):
|
|||||||
pos_x2 = np.random.randint(1, 4)
|
pos_x2 = np.random.randint(1, 4)
|
||||||
pos_y2 = np.random.randint(1, 4)
|
pos_y2 = np.random.randint(1, 4)
|
||||||
sample[pos_y2:pos_y2+singlePhotonSize, pos_x2:pos_x2+singlePhotonSize] += photon2
|
sample[pos_y2:pos_y2+singlePhotonSize, pos_x2:pos_x2+singlePhotonSize] += photon2
|
||||||
sample = sample[1:-1, 1:-1] ### sample size: 6x6
|
sample = sample[1:-1, 1:-1] ### sample size: nSize x nSize
|
||||||
sample = torch.tensor(sample, dtype=torch.float32).unsqueeze(0)
|
sample = torch.tensor(sample, dtype=torch.float32).unsqueeze(0)
|
||||||
sample = torch.cat((sample, self.x_grid, self.y_grid), dim=0) ### concatenate coordinate channels
|
sample = torch.cat((sample, self.x_grid, self.y_grid), dim=0) ### concatenate coordinate channels
|
||||||
doublePhotonSize = 6
|
|
||||||
|
|
||||||
label1 = self.labels[idx1] + np.array([pos_x1-1-doublePhotonSize//2, pos_y1-1-doublePhotonSize//2, 0, 0])
|
label1 = self.labels[idx1] + np.array([pos_x1-1-self.nSize/2., pos_y1-1-self.nSize/2., 0, 0])
|
||||||
label2 = self.labels[idx2] + np.array([pos_x2-1-doublePhotonSize//2, pos_y2-1-doublePhotonSize//2, 0, 0])
|
label2 = self.labels[idx2] + np.array([pos_x2-1-self.nSize/2., pos_y2-1-self.nSize/2., 0, 0])
|
||||||
label = np.concatenate((label1, label2), axis=0)
|
label = np.concatenate((label1, label2), axis=0)
|
||||||
return sample, torch.tensor(label, dtype=torch.float32)
|
return sample, torch.tensor(label, dtype=torch.float32)
|
||||||
|
|
||||||
@@ -163,10 +165,11 @@ class doublePhotonDataset(Dataset):
|
|||||||
|
|
||||||
|
|
||||||
class doublePhotonInferenceDataset(Dataset):
|
class doublePhotonInferenceDataset(Dataset):
|
||||||
def __init__(self, sampleList, sampleRatio, datasetName):
|
def __init__(self, sampleList, sampleRatio, datasetName, nSize=6):
|
||||||
self.sampleFileList = sampleList
|
self.sampleFileList = sampleList
|
||||||
self.sampleRatio = sampleRatio
|
self.sampleRatio = sampleRatio
|
||||||
self.datasetName = datasetName
|
self.datasetName = datasetName
|
||||||
|
self.nSize = nSize
|
||||||
self._init_coords()
|
self._init_coords()
|
||||||
|
|
||||||
all_samples = []
|
all_samples = []
|
||||||
@@ -187,15 +190,16 @@ class doublePhotonInferenceDataset(Dataset):
|
|||||||
self.referencePoint = np.concatenate(all_ref_pts, axis=0) if all_ref_pts else None
|
self.referencePoint = np.concatenate(all_ref_pts, axis=0) if all_ref_pts else None
|
||||||
### total number of samples
|
### total number of samples
|
||||||
self.length = int(self.samples.shape[0] * self.sampleRatio)
|
self.length = int(self.samples.shape[0] * self.sampleRatio)
|
||||||
|
self.referencePoint = self.referencePoint[:self.length]
|
||||||
print(f"[{self.datasetName} dataset] \t Total number of samples: {self.length}")
|
print(f"[{self.datasetName} dataset] \t Total number of samples: {self.length}")
|
||||||
|
|
||||||
def _init_coords(self):
|
def _init_coords(self):
|
||||||
# Create a coordinate grid for 3x3 input
|
# Create a coordinate grid for 3x3 input
|
||||||
x = np.linspace(0, 5, 6)
|
x = np.linspace(-self.nSize/2. + 0.5, self.nSize/2. - 0.5, self.nSize)
|
||||||
y = np.linspace(0, 5, 6)
|
y = np.linspace(-self.nSize/2. + 0.5, self.nSize/2. - 0.5, self.nSize)
|
||||||
x_grid, y_grid = np.meshgrid(x, y, indexing='ij') # (6,6), (6,6)
|
x_grid, y_grid = np.meshgrid(x, y, indexing='ij') # (nSize,nSize), (nSize,nSize)
|
||||||
self.x_grid = torch.tensor(np.expand_dims(x_grid, axis=0)).float().contiguous() # (1, 6, 6)
|
self.x_grid = torch.tensor(np.expand_dims(x_grid, axis=0)).float().contiguous() # (1, nSize, nSize)
|
||||||
self.y_grid = torch.tensor(np.expand_dims(y_grid, axis=0)).float().contiguous() # (1, 6, 6)
|
self.y_grid = torch.tensor(np.expand_dims(y_grid, axis=0)).float().contiguous() # (1, nSize, nSize)
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
sample = self.samples[index]
|
sample = self.samples[index]
|
||||||
|
|||||||
Reference in New Issue
Block a user