diff --git a/src/models.py b/src/models.py index c1575a7..6cddaa2 100644 --- a/src/models.py +++ b/src/models.py @@ -254,6 +254,76 @@ class doublePhotonNet_251001_2(nn.Module): g_avg = self.global_avg_pool(c3).flatten(1) # [B, 128] global_feat = torch.cat([g_max, g_avg], dim=1) # [B, 256] + # Regression + coords = self.fc(global_feat) + return coords # [B, 4] + +class doublePhotonNet_251124(nn.Module): ### adapted for 7x7 input from 251001_2 + def __init__(self): + super().__init__() + # Backbone: deeper + residual-like blocks + self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1) ### 7x7 + self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) ### 7x7 + self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) ### 7x7 + + # Spatial Attention Module (轻量但有效) + self.spatial_attn = nn.Sequential( + nn.Conv2d(128, 1, kernel_size=1), + nn.Sigmoid() + ) + + # Multi-scale feature fusion (optional but helpful) + self.reduce1 = nn.Conv2d(32, 32, kernel_size=1) # from conv1 + self.reduce2 = nn.Conv2d(64, 32, kernel_size=1) # from conv2 + self.fuse = nn.Conv2d(32*3, 128, kernel_size=1) + + # Global context with both Max and Avg pooling (better than GAP alone) + self.global_max_pool = nn.AdaptiveMaxPool2d((1,1)) + self.global_avg_pool = nn.AdaptiveAvgPool2d((1,1)) + + # Enhanced regression head + self.fc = nn.Sequential( + nn.Linear(128 * 2, 256), # concat max + avg + nn.ReLU(), + nn.Dropout(0.3), + nn.Linear(256, 128), + nn.ReLU(), + nn.Dropout(0.3), + nn.Linear(128, 4), + # nn.Sigmoid() # output in [0,1] + ) + + self._init_weights() + + def _init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + nn.init.zeros_(m.bias) + + def forward(self, x): + # Feature extraction + c1 = F.relu(self.conv1(x)) # [B, 32, 7, 7] + c2 = F.relu(self.conv2(c1)) # [B, 64, 7, 7] + c3 = F.relu(self.conv3(c2)) # [B,128, 7, 7] + + # Spatial attention: highlight photon peaks + attn = self.spatial_attn(c3) # [B, 1, 7, 7] + c3 = c3 * attn # reweight features + + # (Optional) Multi-scale fusion — uncomment if needed + # r1 = F.interpolate(self.reduce1(c1), size=(7,7), mode='nearest') + # r2 = self.reduce2(c2) + # fused = torch.cat([r1, r2, c3], dim=1) + # c3 = self.fuse(fused) + + # Global context: MaxPool better captures peaks, Avg for context + g_max = self.global_max_pool(c3).flatten(1) # [B, 128] + g_avg = self.global_avg_pool(c3).flatten(1) # [B, 128] + global_feat = torch.cat([g_max, g_avg], dim=1) # [B, 256] + # Regression coords = self.fc(global_feat) return coords # [B, 4] \ No newline at end of file