Fix the 2photon dataset

This commit is contained in:
2025-11-05 21:38:20 +01:00
parent b0a396b0d8
commit c4a86e32e3

View File

@@ -143,12 +143,10 @@ class doublePhotonDataset(Dataset):
### random position for photons in
pos_x1 = np.random.randint(1, 4)
# pos_y1 = np.random.randint(1, 4)
pos_y1 = pos_x1
pos_y1 = np.random.randint(1, 4)
sample[pos_y1:pos_y1+singlePhotonSize, pos_x1:pos_x1+singlePhotonSize] += photon1
pos_x2 = np.random.randint(1, 4)
# pos_y2 = np.random.randint(1, 4)
pos_y2 = pos_x2
pos_y2 = np.random.randint(1, 4)
sample[pos_y2:pos_y2+singlePhotonSize, pos_x2:pos_x2+singlePhotonSize] += photon2
sample = sample[1:-1, 1:-1] ### sample size: 6x6
sample = torch.tensor(sample, dtype=torch.float32).unsqueeze(0)
@@ -160,5 +158,52 @@ class doublePhotonDataset(Dataset):
label = np.concatenate((label1, label2), axis=0)
return sample, torch.tensor(label, dtype=torch.float32)
def __len__(self):
return self.length
class doublePhotonInferenceDataset(Dataset):
def __init__(self, sampleList, sampleRatio, datasetName):
self.sampleFileList = sampleList
self.sampleRatio = sampleRatio
self.datasetName = datasetName
self._init_coords()
all_samples = []
all_ref_pts = []
for idx, sampleFile in enumerate(self.sampleFileList):
if '.npz' in sampleFile:
data = np.load(sampleFile)
all_samples.append(data['samples'])
all_ref_pts.append(data['referencePoint'])
elif '.h5' in sampleFile:
import h5py
with h5py.File(sampleFile, 'r') as f:
samples = f['clusters'][:]
ref_pts = f['referencePoint'][:]
all_samples.append(samples)
all_ref_pts.append(ref_pts)
self.samples = np.concatenate(all_samples, axis=0) if all_samples else None
self.referencePoint = np.concatenate(all_ref_pts, axis=0) if all_ref_pts else None
### total number of samples
self.length = int(self.samples.shape[0] * self.sampleRatio)
print(f"[{self.datasetName} dataset] \t Total number of samples: {self.length}")
def _init_coords(self):
# Create a coordinate grid for 3x3 input
x = np.linspace(0, 5, 6)
y = np.linspace(0, 5, 6)
x_grid, y_grid = np.meshgrid(x, y, indexing='ij') # (6,6), (6,6)
self.x_grid = torch.tensor(np.expand_dims(x_grid, axis=0)).float().contiguous() # (1, 6, 6)
self.y_grid = torch.tensor(np.expand_dims(y_grid, axis=0)).float().contiguous() # (1, 6, 6)
def __getitem__(self, index):
sample = self.samples[index]
# sample[sample == 0] += np.random.normal(loc=0.0, scale=0.13, size=sample[sample == 0].shape) ### add noise to zero pixels
sample = torch.tensor(sample, dtype=torch.float32).unsqueeze(0)
sample = torch.cat((sample, self.x_grid, self.y_grid), dim=0) ### concatenate coordinate channels
dummy_label = np.zeros((8,), dtype=np.float32) ### dummy label
return sample, torch.tensor(dummy_label, dtype=torch.float32)
def __len__(self):
return self.length