diff --git a/src/datasets.py b/src/datasets.py index 5acc0b7..71a874f 100644 --- a/src/datasets.py +++ b/src/datasets.py @@ -141,8 +141,6 @@ class doublePhotonDataset(Dataset): if self.noiseKeV != 0: print(f'Adding Gaussian noise with sigma = {self.noiseKeV} keV to samples in {self.datasetName} dataset') noise = np.random.normal(loc=0.0, scale=self.noiseKeV, size=self.samples.shape) - #### add noise only to pixels that not zero - noise[self.samples == 0] = 0 self.samples = self.samples + noise if noiseThreshold != 0: print(f'[{self.datasetName} dataset] \t Setting values below noise threshold ({noiseThreshold} keV) to zero') @@ -167,7 +165,8 @@ class doublePhotonDataset(Dataset): self.y_grid = torch.tensor(np.expand_dims(y_grid, axis=0)).float().contiguous() # (1, nSize, nSize) def __getitem__(self, index): - sample = np.zeros((self.nSize+2, self.nSize+2), dtype=np.float32) + # sample = np.zeros((self.nSize+2, self.nSize+2), dtype=np.float32) + sample = np.random.normal(loc=0.0, scale=self.noiseKeV, size=(self.nSize+2, self.nSize+2)) ### add noise to the whole sample idx1 = np.random.randint(0, self.samples.shape[0]) idx2 = np.random.randint(0, self.samples.shape[0]) photon1 = self.samples[idx1]