import torch.nn as nn import torch.nn.functional as F import torch import numpy as np class singlePhotonNet_250909(nn.Module): def weight_init(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) nn.init.constant_(m.bias, 0) def __init__(self): super(singlePhotonNet_250909, self).__init__() self.conv1 = nn.Conv2d(1, 5, kernel_size=3, padding=1) self.conv2 = nn.Conv2d(5, 10, kernel_size=3, padding=1) self.conv3 = nn.Conv2d(10, 20, kernel_size=3, padding=1) self.fc = nn.Linear(20*5*5, 2) def forward(self, x): x = F.relu(self.conv1(x)) x = F.relu(self.conv2(x)) x = F.relu(self.conv3(x)) x = x.view(x.size(0), -1) x = self.fc(x) return x class singlePhotonNet_251020(nn.Module): ''' Smaller input size (3x3) ''' def weight_init(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) nn.init.constant_(m.bias, 0) def __init__(self): super(singlePhotonNet_251020, self).__init__() self.conv1 = nn.Conv2d(1, 5, kernel_size=3, padding=1) self.conv2 = nn.Conv2d(5, 10, kernel_size=3, padding=1) self.conv3 = nn.Conv2d(10, 20, kernel_size=3, padding=1) self.fc = nn.Linear(20*3*3, 2) def forward(self, x): x = F.relu(self.conv1(x)) x = F.relu(self.conv2(x)) x = F.relu(self.conv3(x)) x = x.view(x.size(0), -1) x = self.fc(x) return x class doublePhotonNet_250909(nn.Module): def __init__(self): super(doublePhotonNet_250909, self).__init__() self.conv1 = nn.Conv2d(1, 3, kernel_size=3, padding=1) self.conv2 = nn.Conv2d(3, 5, kernel_size=3, padding=1) self.conv3 = nn.Conv2d(5, 5, kernel_size=3, padding=1) self.fc1 = nn.Linear(5*6*6, 4) # self.fc2 = nn.Linear(50, 4) # 初始化更稳一些 for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, nonlinearity="relu") if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) nn.init.zeros_(m.bias) def forward(self, x): x = F.relu(self.conv1(x)) x = F.relu(self.conv2(x)) x = F.relu(self.conv3(x)) x = x.view(x.size(0), -1) # x = F.relu(self.fc1(x)) # x = self.fc2(x) x = self.fc1(x) return x class doublePhotonNet_250910(nn.Module): def __init__(self): super(doublePhotonNet_250910, self).__init__() ### x shape: (B, 1, 6, 6) self.conv1 = nn.Conv2d(1, 5, kernel_size=5, padding=2) # (B,5,6,6) self.conv2 = nn.Conv2d(5, 10, kernel_size=5, padding=2) # (B,10,6,6) self.conv3 = nn.Conv2d(10, 20, kernel_size=3, padding=0) # (B,20,4,4) self.fc1 = nn.Linear(20*4*4, 4) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, nonlinearity="relu") if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) nn.init.zeros_(m.bias) def forward(self, x): x = F.relu(self.conv1(x)) x = F.relu(self.conv2(x)) x = F.relu(self.conv3(x)) x = x.view(x.size(0), -1) # x = F.relu(self.fc1(x)) # x = self.fc2(x) x = self.fc1(x) * 6 return x class doublePhotonNet_251001(nn.Module): def __init__(self): super().__init__() # 保持空间分辨率:使用小卷积核 + 无池化 self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1) # 6x6 -> 6x6 self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1) # 6x6 -> 6x6 self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1) # 6x6 -> 6x6 # 全局特征提取(替代全连接层) self.global_avg_pool = nn.AdaptiveAvgPool2d((1,1)) # 64x1x1 self.global_max_pool = nn.AdaptiveMaxPool2d((1,1)) # 64x1x1 # 回归头:输出4个坐标 (x1,y1,x2,y2) self.fc = nn.Sequential( nn.Linear(64, 128), nn.ReLU(), nn.Dropout(0.3), nn.Linear(128, 4), # 直接输出坐标 # nn.Sigmoid() # sigmoid leads to overfitting ) # for m in self.modules(): # if isinstance(m, nn.Conv2d): # nn.init.kaiming_normal_(m.weight, nonlinearity="relu") # if isinstance(m, nn.Linear): # nn.init.xavier_uniform_(m.weight) # nn.init.zeros_(m.bias) def forward(self, x): x = torch.relu(self.conv1(x)) x = torch.relu(self.conv2(x)) x = torch.relu(self.conv3(x)) # x = self.global_avg_pool(x).view(x.size(0), -1) x = self.global_max_pool(x).view(x.size(0), -1) coords = self.fc(x)*6 return coords # shape: [B, 4] import torch import torch.nn as nn import torch.nn.functional as F class doublePhotonNet_251001_2(nn.Module): def __init__(self): super().__init__() # Backbone: deeper + residual-like blocks self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1) self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) # Spatial Attention Module (轻量但有效) self.spatial_attn = nn.Sequential( nn.Conv2d(128, 1, kernel_size=1), nn.Sigmoid() ) # Multi-scale feature fusion (optional but helpful) self.reduce1 = nn.Conv2d(32, 32, kernel_size=1) # from conv1 self.reduce2 = nn.Conv2d(64, 32, kernel_size=1) # from conv2 self.fuse = nn.Conv2d(32*3, 128, kernel_size=1) # Global context with both Max and Avg pooling (better than GAP alone) self.global_max_pool = nn.AdaptiveMaxPool2d((1,1)) self.global_avg_pool = nn.AdaptiveAvgPool2d((1,1)) # Enhanced regression head self.fc = nn.Sequential( nn.Linear(128 * 2, 256), # concat max + avg nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, 128), nn.ReLU(), nn.Dropout(0.3), nn.Linear(128, 4), nn.Sigmoid() # output in [0,1] ) self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) nn.init.zeros_(m.bias) def forward(self, x): # Feature extraction c1 = F.relu(self.conv1(x)) # [B, 32, 6, 6] c2 = F.relu(self.conv2(c1)) # [B, 64, 6, 6] c3 = F.relu(self.conv3(c2)) # [B,128, 6, 6] # Spatial attention: highlight photon peaks attn = self.spatial_attn(c3) # [B, 1, 6, 6] c3 = c3 * attn # reweight features # (Optional) Multi-scale fusion — uncomment if needed # r1 = F.interpolate(self.reduce1(c1), size=(6,6), mode='nearest') # r2 = self.reduce2(c2) # fused = torch.cat([r1, r2, c3], dim=1) # c3 = self.fuse(fused) # Global context: MaxPool better captures peaks, Avg for context g_max = self.global_max_pool(c3).flatten(1) # [B, 128] g_avg = self.global_avg_pool(c3).flatten(1) # [B, 128] global_feat = torch.cat([g_max, g_avg], dim=1) # [B, 256] # Regression coords = self.fc(global_feat) * 6.0 # scale to [0,6) return coords # [B, 4]