From e04e783fb68bb992bcec396e2a2631aee1998c81 Mon Sep 17 00:00:00 2001 From: "xiangyu.xie" Date: Thu, 7 May 2026 13:49:10 +0200 Subject: [PATCH] New model and loss_fn to improve uniformity Co-authored-by: Copilot --- Train_2Photon.py | 17 ++++++++++++++-- src/models.py | 53 +++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 67 insertions(+), 3 deletions(-) diff --git a/Train_2Photon.py b/Train_2Photon.py index 152196a..bce77c1 100644 --- a/Train_2Photon.py +++ b/Train_2Photon.py @@ -49,6 +49,21 @@ def get_loss_function(conf): c_b = pair_cost_l2sq(p1,g2) + pair_cost_l2sq(p2,g1) return torch.minimum(c_a, c_b).mean() return two_point_set_loss_l2 + elif conf.loss.type == "two_point_set_loss_smooth_l1": + def two_point_set_loss_smooth_l1(pred_xy, gt_xy): + # Smooth L1 对异常值不那么敏感,但对于细小的亚像素误差能提供恒定的梯度 + loss_fn = torch.nn.SmoothL1Loss(reduction='none') + + p1, p2 = pred_xy[:,0], pred_xy[:,1] + g1, g2 = gt_xy[:,0], gt_xy[:,1] + + # 计算两种排列组合的损失 + c_a = loss_fn(p1, g1).sum(dim=-1) + loss_fn(p2, g2).sum(dim=-1) + c_b = loss_fn(p1, g2).sum(dim=-1) + loss_fn(p2, g1).sum(dim=-1) + + return torch.minimum(c_a, c_b).mean() + + return two_point_set_loss_smooth_l1 def train(model, trainLoader, optimizer, loss_fn): model.train() @@ -67,7 +82,6 @@ def train(model, trainLoader, optimizer, loss_fn): batchLoss += loss.item() * sample.shape[0] avgLoss = batchLoss / len(trainLoader.dataset) / 4 ### divide by 4 to get the average loss per photon per axis print(f"[Train]\t Average Loss: {avgLoss:.6f} (RMS = {np.sqrt(avgLoss):.6f})") - TrainLosses.append(avgLoss) return avgLoss def evaluate(model, valLoader, loss_fn): @@ -85,7 +99,6 @@ def evaluate(model, valLoader, loss_fn): batchLoss += loss.item() * sample.shape[0] avgLoss = batchLoss / len(valLoader.dataset) / 4 ### divide by 4 to get the average loss per photon per axis print(f"[Val]\t Average Loss: {avgLoss:.6f} (RMS = {np.sqrt(avgLoss):.6f})") - ValLosses.append(avgLoss) return avgLoss def get_dataloaders(conf): diff --git a/src/models.py b/src/models.py index 6cddaa2..62a4b48 100644 --- a/src/models.py +++ b/src/models.py @@ -326,4 +326,55 @@ class doublePhotonNet_251124(nn.Module): ### adapted for 7x7 input from 251001_2 # Regression coords = self.fc(global_feat) - return coords # [B, 4] \ No newline at end of file + return coords # [B, 4] + +class doublePhotonNet_260507(nn.Module): ## adapted from 251124, removed max pooling, added more capacity in FC layers + def __init__(self): + super().__init__() + # Backbone 保持不变 + self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1) + self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) + self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) + + # 空间注意力模块 (Spatial Attention Module) + self.spatial_attn = nn.Sequential( + nn.Conv2d(128, 1, kernel_size=1), + nn.Sigmoid() + ) + + # 我们移除了全局池化层 (Global Pooling)。 + # 7x7 的空间特征图有 128 个通道,展平后是 128 * 7 * 7 = 6272 个特征。 + # 这使得全连接层能够直接“看到”精确的空间分布规律。 + self.fc = nn.Sequential( + nn.Linear(128 * 7 * 7, 512), # 增加了中间层的维度以适配展平后的输入 + nn.ReLU(), + nn.Dropout(0.3), + nn.Linear(512, 128), + nn.ReLU(), + nn.Dropout(0.3), + nn.Linear(128, 4) + ) + + self._init_weights() + + def _init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + nn.init.zeros_(m.bias) + + def forward(self, x): + c1 = F.relu(self.conv1(x)) # [B, 32, 7, 7] + c2 = F.relu(self.conv2(c1)) # [B, 64, 7, 7] + c3 = F.relu(self.conv3(c2)) # [B, 128, 7, 7] + + attn = self.spatial_attn(c3) # [B, 1, 7, 7] + c3 = c3 * attn # [B, 128, 7, 7] + + # 直接展平 (Flatten) 而不是池化 + flat_feat = c3.view(c3.size(0), -1) # [B, 6272] + + coords = self.fc(flat_feat) # [B, 4] + return coords \ No newline at end of file