add model doublePhotonNet_251124 for 7x7 cluster size

This commit is contained in:
2025-11-25 15:24:10 +01:00
parent bee7a30a88
commit 1c4f03a308

View File

@@ -254,6 +254,76 @@ class doublePhotonNet_251001_2(nn.Module):
g_avg = self.global_avg_pool(c3).flatten(1) # [B, 128]
global_feat = torch.cat([g_max, g_avg], dim=1) # [B, 256]
# Regression
coords = self.fc(global_feat)
return coords # [B, 4]
class doublePhotonNet_251124(nn.Module): ### adapted for 7x7 input from 251001_2
def __init__(self):
super().__init__()
# Backbone: deeper + residual-like blocks
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1) ### 7x7
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) ### 7x7
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) ### 7x7
# Spatial Attention Module (轻量但有效)
self.spatial_attn = nn.Sequential(
nn.Conv2d(128, 1, kernel_size=1),
nn.Sigmoid()
)
# Multi-scale feature fusion (optional but helpful)
self.reduce1 = nn.Conv2d(32, 32, kernel_size=1) # from conv1
self.reduce2 = nn.Conv2d(64, 32, kernel_size=1) # from conv2
self.fuse = nn.Conv2d(32*3, 128, kernel_size=1)
# Global context with both Max and Avg pooling (better than GAP alone)
self.global_max_pool = nn.AdaptiveMaxPool2d((1,1))
self.global_avg_pool = nn.AdaptiveAvgPool2d((1,1))
# Enhanced regression head
self.fc = nn.Sequential(
nn.Linear(128 * 2, 256), # concat max + avg
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, 128),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(128, 4),
# nn.Sigmoid() # output in [0,1]
)
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
nn.init.zeros_(m.bias)
def forward(self, x):
# Feature extraction
c1 = F.relu(self.conv1(x)) # [B, 32, 7, 7]
c2 = F.relu(self.conv2(c1)) # [B, 64, 7, 7]
c3 = F.relu(self.conv3(c2)) # [B,128, 7, 7]
# Spatial attention: highlight photon peaks
attn = self.spatial_attn(c3) # [B, 1, 7, 7]
c3 = c3 * attn # reweight features
# (Optional) Multi-scale fusion — uncomment if needed
# r1 = F.interpolate(self.reduce1(c1), size=(7,7), mode='nearest')
# r2 = self.reduce2(c2)
# fused = torch.cat([r1, r2, c3], dim=1)
# c3 = self.fuse(fused)
# Global context: MaxPool better captures peaks, Avg for context
g_max = self.global_max_pool(c3).flatten(1) # [B, 128]
g_avg = self.global_avg_pool(c3).flatten(1) # [B, 128]
global_feat = torch.cat([g_max, g_avg], dim=1) # [B, 256]
# Regression
coords = self.fc(global_feat)
return coords # [B, 4]