diff --git a/src/datasets.py b/src/datasets.py index 12f79f0..1781529 100644 --- a/src/datasets.py +++ b/src/datasets.py @@ -143,12 +143,10 @@ class doublePhotonDataset(Dataset): ### random position for photons in pos_x1 = np.random.randint(1, 4) - # pos_y1 = np.random.randint(1, 4) - pos_y1 = pos_x1 + pos_y1 = np.random.randint(1, 4) sample[pos_y1:pos_y1+singlePhotonSize, pos_x1:pos_x1+singlePhotonSize] += photon1 pos_x2 = np.random.randint(1, 4) - # pos_y2 = np.random.randint(1, 4) - pos_y2 = pos_x2 + pos_y2 = np.random.randint(1, 4) sample[pos_y2:pos_y2+singlePhotonSize, pos_x2:pos_x2+singlePhotonSize] += photon2 sample = sample[1:-1, 1:-1] ### sample size: 6x6 sample = torch.tensor(sample, dtype=torch.float32).unsqueeze(0) @@ -160,5 +158,52 @@ class doublePhotonDataset(Dataset): label = np.concatenate((label1, label2), axis=0) return sample, torch.tensor(label, dtype=torch.float32) + def __len__(self): + return self.length + + +class doublePhotonInferenceDataset(Dataset): + def __init__(self, sampleList, sampleRatio, datasetName): + self.sampleFileList = sampleList + self.sampleRatio = sampleRatio + self.datasetName = datasetName + self._init_coords() + + all_samples = [] + all_ref_pts = [] + for idx, sampleFile in enumerate(self.sampleFileList): + if '.npz' in sampleFile: + data = np.load(sampleFile) + all_samples.append(data['samples']) + all_ref_pts.append(data['referencePoint']) + elif '.h5' in sampleFile: + import h5py + with h5py.File(sampleFile, 'r') as f: + samples = f['clusters'][:] + ref_pts = f['referencePoint'][:] + all_samples.append(samples) + all_ref_pts.append(ref_pts) + self.samples = np.concatenate(all_samples, axis=0) if all_samples else None + self.referencePoint = np.concatenate(all_ref_pts, axis=0) if all_ref_pts else None + ### total number of samples + self.length = int(self.samples.shape[0] * self.sampleRatio) + print(f"[{self.datasetName} dataset] \t Total number of samples: {self.length}") + + def _init_coords(self): + # Create a coordinate grid for 3x3 input + x = np.linspace(0, 5, 6) + y = np.linspace(0, 5, 6) + x_grid, y_grid = np.meshgrid(x, y, indexing='ij') # (6,6), (6,6) + self.x_grid = torch.tensor(np.expand_dims(x_grid, axis=0)).float().contiguous() # (1, 6, 6) + self.y_grid = torch.tensor(np.expand_dims(y_grid, axis=0)).float().contiguous() # (1, 6, 6) + + def __getitem__(self, index): + sample = self.samples[index] + # sample[sample == 0] += np.random.normal(loc=0.0, scale=0.13, size=sample[sample == 0].shape) ### add noise to zero pixels + sample = torch.tensor(sample, dtype=torch.float32).unsqueeze(0) + sample = torch.cat((sample, self.x_grid, self.y_grid), dim=0) ### concatenate coordinate channels + dummy_label = np.zeros((8,), dtype=np.float32) ### dummy label + return sample, torch.tensor(dummy_label, dtype=torch.float32) + def __len__(self): return self.length \ No newline at end of file