diff --git a/src/datasets.py b/src/datasets.py index 215b7ee..117c058 100644 --- a/src/datasets.py +++ b/src/datasets.py @@ -7,6 +7,7 @@ class singlePhotonDataset(Dataset): self.sampleFileList = sampleList self.sampleRatio = sampleRatio self.datasetName = datasetName + self._init_coords() all_samples = [] all_labels = [] @@ -43,21 +44,34 @@ class singlePhotonDataset(Dataset): print(f'Adding Gaussian noise with sigma = {noiseKeV} keV to samples in {self.datasetName} dataset') noise = np.random.normal(loc=0.0, scale=noiseKeV, size=self.samples.shape) self.samples = self.samples + noise - self.labels = np.concatenate(all_labels, axis=0) + self.labels = torch.tensor(np.concatenate(all_labels, axis=0)) self.referencePoint = np.concatenate(all_ref_pts, axis=0) if all_ref_pts else None if self.samples.shape[1] == 5: ### if sample size is 5x5, remove border pixels to make it 3x3 self.samples = self.samples[:, 1:-1, 1:-1] ### remove border pixels - self.labels = self.labels - np.array([1, 1, 0, 0]) ### adjust labels accordingly + self.labels = self.labels - torch.tensor([1, 1, 0, 0]) ### adjust labels accordingly + self.samples = torch.tensor(self.samples).unsqueeze(1).float() + x_grids = self.x_grid.expand(self.samples.size(0), 1, -1, -1) + y_grids = self.y_grid.expand(self.samples.size(0), 1, -1, -1) + self.samples = torch.cat([self.samples, x_grids, y_grids], dim=1) ### concatenate coordinate channels + self.labels -= torch.tensor([self.samples.shape[1]/2., self.samples.shape[1]/2., 0, 0]) ### adjust labels to be centered at (0,0) + self.labels[:, 2] /= 650. ### normalize z coordinate (depth) to [0, 1] ### total number of samples self.length = int(self.samples.shape[0] * self.sampleRatio) print(f"[{self.datasetName} dataset] \t Total number of samples: {self.length}") + def _init_coords(self): + # Create a coordinate grid for 3x3 input + x = torch.linspace(-0.5, 0.5, 3) + y = torch.linspace(-0.5, 0.5, 3) + x_grid, y_grid = torch.meshgrid(x, y, indexing='ij') # (3,3), (3,3) + self.x_grid = x_grid.unsqueeze(0) # (1, 3, 3) + self.y_grid = y_grid.unsqueeze(0) # (1, 3, 3) + def __getitem__(self, index): sample = self.samples[index] - sample = np.expand_dims(sample, axis=0) label = self.labels[index] - return torch.tensor(sample, dtype=torch.float32), torch.tensor(label, dtype=torch.float32) + return sample, label def __len__(self): return self.length