From b0a396b0d835479dbe43b70b258e62ca0ebd6e48 Mon Sep 17 00:00:00 2001 From: "xiangyu.xie" Date: Tue, 4 Nov 2025 17:37:39 +0100 Subject: [PATCH] Shift label to sample center; remove the sigmoid in FC --- Train_DoublePhoton.py | 2 +- src/datasets.py | 10 ++++---- src/models.py | 56 +++---------------------------------------- 3 files changed, 10 insertions(+), 58 deletions(-) diff --git a/Train_DoublePhoton.py b/Train_DoublePhoton.py index 9bb752d..7a662a1 100644 --- a/Train_DoublePhoton.py +++ b/Train_DoublePhoton.py @@ -12,7 +12,7 @@ from torchinfo import summary torch.manual_seed(0) np.random.seed(0) -modelVersion = '251104' # '250910' or '251001' +modelVersion = '251001_2' # '250910' or '251001' model = models.get_double_photon_model_class(modelVersion)().cuda() Energy = '15.3keV' TrainLosses, ValLosses = [], [] diff --git a/src/datasets.py b/src/datasets.py index 10cffa8..12f79f0 100644 --- a/src/datasets.py +++ b/src/datasets.py @@ -139,22 +139,24 @@ class doublePhotonDataset(Dataset): idx2 = np.random.randint(0, self.samples.shape[0]) photon1 = self.samples[idx1] photon2 = self.samples[idx2] + singlePhotonSize = photon1.shape[0] ### random position for photons in pos_x1 = np.random.randint(1, 4) # pos_y1 = np.random.randint(1, 4) pos_y1 = pos_x1 - sample[pos_y1:pos_y1+5, pos_x1:pos_x1+5] += photon1 + sample[pos_y1:pos_y1+singlePhotonSize, pos_x1:pos_x1+singlePhotonSize] += photon1 pos_x2 = np.random.randint(1, 4) # pos_y2 = np.random.randint(1, 4) pos_y2 = pos_x2 - sample[pos_y2:pos_y2+5, pos_x2:pos_x2+5] += photon2 + sample[pos_y2:pos_y2+singlePhotonSize, pos_x2:pos_x2+singlePhotonSize] += photon2 sample = sample[1:-1, 1:-1] ### sample size: 6x6 sample = torch.tensor(sample, dtype=torch.float32).unsqueeze(0) sample = torch.cat((sample, self.x_grid, self.y_grid), dim=0) ### concatenate coordinate channels + doublePhotonSize = 6 - label1 = self.labels[idx1] + np.array([pos_x1-1, pos_y1-1, 0, 0]) - label2 = self.labels[idx2] + np.array([pos_x2-1, pos_y2-1, 0, 0]) + label1 = self.labels[idx1] + np.array([pos_x1-1-doublePhotonSize//2, pos_y1-1-doublePhotonSize//2, 0, 0]) + label2 = self.labels[idx2] + np.array([pos_x2-1-doublePhotonSize//2, pos_y2-1-doublePhotonSize//2, 0, 0]) label = np.concatenate((label1, label2), axis=0) return sample, torch.tensor(label, dtype=torch.float32) diff --git a/src/models.py b/src/models.py index 8425948..c1575a7 100644 --- a/src/models.py +++ b/src/models.py @@ -185,7 +185,7 @@ class doublePhotonNet_251001(nn.Module): x = torch.relu(self.conv3(x)) # x = self.global_avg_pool(x).view(x.size(0), -1) x = self.global_max_pool(x).view(x.size(0), -1) - coords = self.fc(x)*6 + coords = self.fc(x) return coords # shape: [B, 4] class doublePhotonNet_251001_2(nn.Module): @@ -220,7 +220,7 @@ class doublePhotonNet_251001_2(nn.Module): nn.ReLU(), nn.Dropout(0.3), nn.Linear(128, 4), - nn.Sigmoid() # output in [0,1] + # nn.Sigmoid() # output in [0,1] ) self._init_weights() @@ -255,55 +255,5 @@ class doublePhotonNet_251001_2(nn.Module): global_feat = torch.cat([g_max, g_avg], dim=1) # [B, 256] # Regression - coords = self.fc(global_feat) * 6.0 # scale to [0,6) - return coords # [B, 4] - -class doublePhotonNet_251104(nn.Module): - def __init__(self): - super().__init__() - # Backbone: deeper - self.conv1 = nn.Conv2d(3, 32, kernel_size=3) # 6x6 -> 4x4 - self.conv2 = nn.Conv2d(32, 64, kernel_size=3) # 4x4 -> 2x2 - self.conv3 = nn.Conv2d(64, 128, kernel_size=2) # 2x2 -> 1x1 - - # Spatial Attention Module (轻量但有效) - self.spatial_attn = nn.Sequential( - nn.Conv2d(128, 1, kernel_size=1), - nn.Sigmoid() - ) - - # Enhanced regression head - self.fc = nn.Sequential( - nn.Linear(128, 256), # concat max + avg - nn.ReLU(), - nn.Dropout(0.3), - nn.Linear(256, 128), - nn.ReLU(), - nn.Dropout(0.3), - nn.Linear(128, 4), - nn.Sigmoid() # output in [0,1] - ) - - self._init_weights() - - def _init_weights(self): - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') - elif isinstance(m, nn.Linear): - nn.init.xavier_uniform_(m.weight) - nn.init.zeros_(m.bias) - - def forward(self, x): - # Feature extraction - c1 = F.relu(self.conv1(x)) # [B, 32, 6, 6] - c2 = F.relu(self.conv2(c1)) # [B, 64, 6, 6] - c3 = F.relu(self.conv3(c2)) # [B,128, 6, 6] - - # Spatial attention: highlight photon peaks - attn = self.spatial_attn(c3) # [B, 1, 6, 6] - c3 = c3 * attn # reweight features - - # Regression - coords = self.fc(c3.flatten(1)) * 6.0 # scale to [0,6) + coords = self.fc(global_feat) return coords # [B, 4] \ No newline at end of file