From eef6a87f0603e72e893de380d359750b80784eb5 Mon Sep 17 00:00:00 2001 From: "xiangyu.xie" Date: Tue, 25 Nov 2025 15:25:07 +0100 Subject: [PATCH] add configurable cluster size for double phtoon sample --- src/datasets.py | 38 +++++++++++++++++++++----------------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/src/datasets.py b/src/datasets.py index 1781529..d77d51c 100644 --- a/src/datasets.py +++ b/src/datasets.py @@ -101,11 +101,12 @@ class singlePhotonDataset(Dataset): return self.effectiveLength class doublePhotonDataset(Dataset): - def __init__(self, sampleList, sampleRatio, datasetName, reuselFactor=1, noiseKeV=0): + def __init__(self, sampleList, sampleRatio, datasetName, reuselFactor=1, noiseKeV=0, nSize=6): self.sampleFileList = sampleList self.sampleRatio = sampleRatio self.datasetName = datasetName self.noiseKeV = noiseKeV + self.nSize = nSize self._init_coords() all_samples = [] @@ -118,6 +119,8 @@ class doublePhotonDataset(Dataset): if self.noiseKeV != 0: print(f'Adding Gaussian noise with sigma = {self.noiseKeV} keV to samples in {self.datasetName} dataset') noise = np.random.normal(loc=0.0, scale=self.noiseKeV, size=self.samples.shape) + #### add noise only to pixels that not zero + noise[self.samples == 0] = 0 self.samples = self.samples + noise self.labels = np.concatenate(all_labels, axis=0) @@ -127,14 +130,14 @@ class doublePhotonDataset(Dataset): def _init_coords(self): # Create a coordinate grid for 3x3 input - x = np.linspace(0, 5, 6) - y = np.linspace(0, 5, 6) - x_grid, y_grid = np.meshgrid(x, y, indexing='ij') # (6,6), (6,6) - self.x_grid = torch.tensor(np.expand_dims(x_grid, axis=0)).float().contiguous() # (1, 6, 6) - self.y_grid = torch.tensor(np.expand_dims(y_grid, axis=0)).float().contiguous() # (1, 6, 6) + x = np.linspace(-self.nSize/2. + 0.5, self.nSize/2. - 0.5, self.nSize) + y = np.linspace(-self.nSize/2. + 0.5, self.nSize/2. - 0.5, self.nSize) + x_grid, y_grid = np.meshgrid(x, y, indexing='ij') # (nSize,nSize), (nSize,nSize) + self.x_grid = torch.tensor(np.expand_dims(x_grid, axis=0)).float().contiguous() # (1, nSize, nSize) + self.y_grid = torch.tensor(np.expand_dims(y_grid, axis=0)).float().contiguous() # (1, nSize, nSize) def __getitem__(self, index): - sample = np.zeros((8, 8), dtype=np.float32) + sample = np.zeros((self.nSize+2, self.nSize+2), dtype=np.float32) idx1 = np.random.randint(0, self.samples.shape[0]) idx2 = np.random.randint(0, self.samples.shape[0]) photon1 = self.samples[idx1] @@ -148,13 +151,12 @@ class doublePhotonDataset(Dataset): pos_x2 = np.random.randint(1, 4) pos_y2 = np.random.randint(1, 4) sample[pos_y2:pos_y2+singlePhotonSize, pos_x2:pos_x2+singlePhotonSize] += photon2 - sample = sample[1:-1, 1:-1] ### sample size: 6x6 + sample = sample[1:-1, 1:-1] ### sample size: nSize x nSize sample = torch.tensor(sample, dtype=torch.float32).unsqueeze(0) sample = torch.cat((sample, self.x_grid, self.y_grid), dim=0) ### concatenate coordinate channels - doublePhotonSize = 6 - label1 = self.labels[idx1] + np.array([pos_x1-1-doublePhotonSize//2, pos_y1-1-doublePhotonSize//2, 0, 0]) - label2 = self.labels[idx2] + np.array([pos_x2-1-doublePhotonSize//2, pos_y2-1-doublePhotonSize//2, 0, 0]) + label1 = self.labels[idx1] + np.array([pos_x1-1-self.nSize/2., pos_y1-1-self.nSize/2., 0, 0]) + label2 = self.labels[idx2] + np.array([pos_x2-1-self.nSize/2., pos_y2-1-self.nSize/2., 0, 0]) label = np.concatenate((label1, label2), axis=0) return sample, torch.tensor(label, dtype=torch.float32) @@ -163,10 +165,11 @@ class doublePhotonDataset(Dataset): class doublePhotonInferenceDataset(Dataset): - def __init__(self, sampleList, sampleRatio, datasetName): + def __init__(self, sampleList, sampleRatio, datasetName, nSize=6): self.sampleFileList = sampleList self.sampleRatio = sampleRatio self.datasetName = datasetName + self.nSize = nSize self._init_coords() all_samples = [] @@ -187,15 +190,16 @@ class doublePhotonInferenceDataset(Dataset): self.referencePoint = np.concatenate(all_ref_pts, axis=0) if all_ref_pts else None ### total number of samples self.length = int(self.samples.shape[0] * self.sampleRatio) + self.referencePoint = self.referencePoint[:self.length] print(f"[{self.datasetName} dataset] \t Total number of samples: {self.length}") def _init_coords(self): # Create a coordinate grid for 3x3 input - x = np.linspace(0, 5, 6) - y = np.linspace(0, 5, 6) - x_grid, y_grid = np.meshgrid(x, y, indexing='ij') # (6,6), (6,6) - self.x_grid = torch.tensor(np.expand_dims(x_grid, axis=0)).float().contiguous() # (1, 6, 6) - self.y_grid = torch.tensor(np.expand_dims(y_grid, axis=0)).float().contiguous() # (1, 6, 6) + x = np.linspace(-self.nSize/2. + 0.5, self.nSize/2. - 0.5, self.nSize) + y = np.linspace(-self.nSize/2. + 0.5, self.nSize/2. - 0.5, self.nSize) + x_grid, y_grid = np.meshgrid(x, y, indexing='ij') # (nSize,nSize), (nSize,nSize) + self.x_grid = torch.tensor(np.expand_dims(x_grid, axis=0)).float().contiguous() # (1, nSize, nSize) + self.y_grid = torch.tensor(np.expand_dims(y_grid, axis=0)).float().contiguous() # (1, nSize, nSize) def __getitem__(self, index): sample = self.samples[index]