Shift label to sample center; remove the sigmoid in FC
This commit is contained in:
+6
-4
@@ -139,22 +139,24 @@ class doublePhotonDataset(Dataset):
|
||||
idx2 = np.random.randint(0, self.samples.shape[0])
|
||||
photon1 = self.samples[idx1]
|
||||
photon2 = self.samples[idx2]
|
||||
singlePhotonSize = photon1.shape[0]
|
||||
|
||||
### random position for photons in
|
||||
pos_x1 = np.random.randint(1, 4)
|
||||
# pos_y1 = np.random.randint(1, 4)
|
||||
pos_y1 = pos_x1
|
||||
sample[pos_y1:pos_y1+5, pos_x1:pos_x1+5] += photon1
|
||||
sample[pos_y1:pos_y1+singlePhotonSize, pos_x1:pos_x1+singlePhotonSize] += photon1
|
||||
pos_x2 = np.random.randint(1, 4)
|
||||
# pos_y2 = np.random.randint(1, 4)
|
||||
pos_y2 = pos_x2
|
||||
sample[pos_y2:pos_y2+5, pos_x2:pos_x2+5] += photon2
|
||||
sample[pos_y2:pos_y2+singlePhotonSize, pos_x2:pos_x2+singlePhotonSize] += photon2
|
||||
sample = sample[1:-1, 1:-1] ### sample size: 6x6
|
||||
sample = torch.tensor(sample, dtype=torch.float32).unsqueeze(0)
|
||||
sample = torch.cat((sample, self.x_grid, self.y_grid), dim=0) ### concatenate coordinate channels
|
||||
doublePhotonSize = 6
|
||||
|
||||
label1 = self.labels[idx1] + np.array([pos_x1-1, pos_y1-1, 0, 0])
|
||||
label2 = self.labels[idx2] + np.array([pos_x2-1, pos_y2-1, 0, 0])
|
||||
label1 = self.labels[idx1] + np.array([pos_x1-1-doublePhotonSize//2, pos_y1-1-doublePhotonSize//2, 0, 0])
|
||||
label2 = self.labels[idx2] + np.array([pos_x2-1-doublePhotonSize//2, pos_y2-1-doublePhotonSize//2, 0, 0])
|
||||
label = np.concatenate((label1, label2), axis=0)
|
||||
return sample, torch.tensor(label, dtype=torch.float32)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user