329 lines
12 KiB
Python
329 lines
12 KiB
Python
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch
|
|
import numpy as np
|
|
|
|
def get_model_class(version):
|
|
class_name = f'singlePhotonNet_{version}'
|
|
cls = globals().get(class_name)
|
|
if cls is None:
|
|
raise ValueError(f"Model class '{class_name}' not found.")
|
|
return cls
|
|
|
|
def get_double_photon_model_class(version):
|
|
class_name = f'doublePhotonNet_{version}'
|
|
cls = globals().get(class_name)
|
|
if cls is None:
|
|
raise ValueError(f"Model class '{class_name}' not found.")
|
|
return cls
|
|
|
|
class singlePhotonNet_250909(nn.Module):
|
|
def weight_init(self):
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
|
if m.bias is not None:
|
|
nn.init.constant_(m.bias, 0)
|
|
elif isinstance(m, nn.Linear):
|
|
nn.init.normal_(m.weight, 0, 0.01)
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
def __init__(self):
|
|
super(singlePhotonNet_250909, self).__init__()
|
|
self.conv1 = nn.Conv2d(1, 5, kernel_size=3, padding=1)
|
|
self.conv2 = nn.Conv2d(5, 10, kernel_size=3, padding=1)
|
|
self.conv3 = nn.Conv2d(10, 20, kernel_size=3, padding=1)
|
|
self.fc = nn.Linear(20*5*5, 2)
|
|
|
|
def forward(self, x):
|
|
x = F.relu(self.conv1(x))
|
|
x = F.relu(self.conv2(x))
|
|
x = F.relu(self.conv3(x))
|
|
x = x.view(x.size(0), -1)
|
|
x = self.fc(x)
|
|
return x
|
|
|
|
class singlePhotonNet_251020(nn.Module):
|
|
'''
|
|
Smaller input size (3x3)
|
|
'''
|
|
def weight_init(self):
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
|
if m.bias is not None:
|
|
nn.init.constant_(m.bias, 0)
|
|
elif isinstance(m, nn.Linear):
|
|
nn.init.normal_(m.weight, 0, 0.01)
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
def __init__(self):
|
|
super(singlePhotonNet_251020, self).__init__()
|
|
self.conv1 = nn.Conv2d(1, 5, kernel_size=3, padding=1)
|
|
self.conv2 = nn.Conv2d(5, 10, kernel_size=3, padding=1)
|
|
self.conv3 = nn.Conv2d(10, 20, kernel_size=3, padding=1)
|
|
self.fc = nn.Linear(20*3*3, 2)
|
|
|
|
def forward(self, x):
|
|
x = F.relu(self.conv1(x))
|
|
x = F.relu(self.conv2(x))
|
|
x = F.relu(self.conv3(x))
|
|
x = x.view(x.size(0), -1)
|
|
x = self.fc(x)
|
|
return x
|
|
|
|
class singlePhotonNet_251022(nn.Module):
|
|
'''
|
|
Smaller input size (3x3)
|
|
'''
|
|
def weight_init(self):
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
|
if m.bias is not None:
|
|
nn.init.constant_(m.bias, 0)
|
|
elif isinstance(m, nn.Linear):
|
|
nn.init.normal_(m.weight, 0, 0.01)
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
def __init__(self):
|
|
super(singlePhotonNet_251022, self).__init__()
|
|
self.conv1 = nn.Conv2d(3, 5, kernel_size=3, padding=1)
|
|
self.conv2 = nn.Conv2d(5, 10, kernel_size=3, padding=1)
|
|
self.conv3 = nn.Conv2d(10, 20, kernel_size=3)
|
|
self.fc = nn.Linear(20, 3)
|
|
self.weight_init()
|
|
|
|
def forward(self, x):
|
|
x = F.relu(self.conv1(x))
|
|
x = F.relu(self.conv2(x))
|
|
x = F.relu(self.conv3(x))
|
|
x = x.view(x.size(0), -1)
|
|
x = self.fc(x)
|
|
return x
|
|
|
|
class doublePhotonNet_250909(nn.Module):
|
|
def __init__(self):
|
|
super(doublePhotonNet_250909, self).__init__()
|
|
self.conv1 = nn.Conv2d(1, 3, kernel_size=3, padding=1)
|
|
self.conv2 = nn.Conv2d(3, 5, kernel_size=3, padding=1)
|
|
self.conv3 = nn.Conv2d(5, 5, kernel_size=3, padding=1)
|
|
self.fc1 = nn.Linear(5*6*6, 4)
|
|
# self.fc2 = nn.Linear(50, 4)
|
|
# 初始化更稳一些
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
|
if isinstance(m, nn.Linear):
|
|
nn.init.xavier_uniform_(m.weight)
|
|
nn.init.zeros_(m.bias)
|
|
|
|
def forward(self, x):
|
|
x = F.relu(self.conv1(x))
|
|
x = F.relu(self.conv2(x))
|
|
x = F.relu(self.conv3(x))
|
|
x = x.view(x.size(0), -1)
|
|
# x = F.relu(self.fc1(x))
|
|
# x = self.fc2(x)
|
|
x = self.fc1(x)
|
|
return x
|
|
class doublePhotonNet_250910(nn.Module):
|
|
def __init__(self):
|
|
super(doublePhotonNet_250910, self).__init__()
|
|
### x shape: (B, 1, 6, 6)
|
|
self.conv1 = nn.Conv2d(1, 5, kernel_size=5, padding=2) # (B,5,6,6)
|
|
self.conv2 = nn.Conv2d(5, 10, kernel_size=5, padding=2) # (B,10,6,6)
|
|
self.conv3 = nn.Conv2d(10, 20, kernel_size=3, padding=0) # (B,20,4,4)
|
|
self.fc1 = nn.Linear(20*4*4, 4)
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
|
if isinstance(m, nn.Linear):
|
|
nn.init.xavier_uniform_(m.weight)
|
|
nn.init.zeros_(m.bias)
|
|
|
|
def forward(self, x):
|
|
x = F.relu(self.conv1(x))
|
|
x = F.relu(self.conv2(x))
|
|
x = F.relu(self.conv3(x))
|
|
x = x.view(x.size(0), -1)
|
|
# x = F.relu(self.fc1(x))
|
|
# x = self.fc2(x)
|
|
x = self.fc1(x) * 6
|
|
return x
|
|
|
|
class doublePhotonNet_251001(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
# 保持空间分辨率:使用小卷积核 + 无池化
|
|
self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1) # 6x6 -> 6x6
|
|
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1) # 6x6 -> 6x6
|
|
self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1) # 6x6 -> 6x6
|
|
|
|
# 全局特征提取(替代全连接层)
|
|
self.global_avg_pool = nn.AdaptiveAvgPool2d((1,1)) # 64x1x1
|
|
self.global_max_pool = nn.AdaptiveMaxPool2d((1,1)) # 64x1x1
|
|
|
|
# 回归头:输出4个坐标 (x1,y1,x2,y2)
|
|
self.fc = nn.Sequential(
|
|
nn.Linear(64, 128),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.3),
|
|
nn.Linear(128, 4), # 直接输出坐标
|
|
# nn.Sigmoid() # sigmoid leads to overfitting
|
|
)
|
|
# for m in self.modules():
|
|
# if isinstance(m, nn.Conv2d):
|
|
# nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
|
# if isinstance(m, nn.Linear):
|
|
# nn.init.xavier_uniform_(m.weight)
|
|
# nn.init.zeros_(m.bias)
|
|
|
|
def forward(self, x):
|
|
x = torch.relu(self.conv1(x))
|
|
x = torch.relu(self.conv2(x))
|
|
x = torch.relu(self.conv3(x))
|
|
# x = self.global_avg_pool(x).view(x.size(0), -1)
|
|
x = self.global_max_pool(x).view(x.size(0), -1)
|
|
coords = self.fc(x)
|
|
return coords # shape: [B, 4]
|
|
|
|
class doublePhotonNet_251001_2(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
# Backbone: deeper + residual-like blocks
|
|
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
|
|
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
|
|
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
|
|
|
|
# Spatial Attention Module (轻量但有效)
|
|
self.spatial_attn = nn.Sequential(
|
|
nn.Conv2d(128, 1, kernel_size=1),
|
|
nn.Sigmoid()
|
|
)
|
|
|
|
# Multi-scale feature fusion (optional but helpful)
|
|
self.reduce1 = nn.Conv2d(32, 32, kernel_size=1) # from conv1
|
|
self.reduce2 = nn.Conv2d(64, 32, kernel_size=1) # from conv2
|
|
self.fuse = nn.Conv2d(32*3, 128, kernel_size=1)
|
|
|
|
# Global context with both Max and Avg pooling (better than GAP alone)
|
|
self.global_max_pool = nn.AdaptiveMaxPool2d((1,1))
|
|
self.global_avg_pool = nn.AdaptiveAvgPool2d((1,1))
|
|
|
|
# Enhanced regression head
|
|
self.fc = nn.Sequential(
|
|
nn.Linear(128 * 2, 256), # concat max + avg
|
|
nn.ReLU(),
|
|
nn.Dropout(0.3),
|
|
nn.Linear(256, 128),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.3),
|
|
nn.Linear(128, 4),
|
|
# nn.Sigmoid() # output in [0,1]
|
|
)
|
|
|
|
self._init_weights()
|
|
|
|
def _init_weights(self):
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
|
elif isinstance(m, nn.Linear):
|
|
nn.init.xavier_uniform_(m.weight)
|
|
nn.init.zeros_(m.bias)
|
|
|
|
def forward(self, x):
|
|
# Feature extraction
|
|
c1 = F.relu(self.conv1(x)) # [B, 32, 6, 6]
|
|
c2 = F.relu(self.conv2(c1)) # [B, 64, 6, 6]
|
|
c3 = F.relu(self.conv3(c2)) # [B,128, 6, 6]
|
|
|
|
# Spatial attention: highlight photon peaks
|
|
attn = self.spatial_attn(c3) # [B, 1, 6, 6]
|
|
c3 = c3 * attn # reweight features
|
|
|
|
# (Optional) Multi-scale fusion — uncomment if needed
|
|
# r1 = F.interpolate(self.reduce1(c1), size=(6,6), mode='nearest')
|
|
# r2 = self.reduce2(c2)
|
|
# fused = torch.cat([r1, r2, c3], dim=1)
|
|
# c3 = self.fuse(fused)
|
|
|
|
# Global context: MaxPool better captures peaks, Avg for context
|
|
g_max = self.global_max_pool(c3).flatten(1) # [B, 128]
|
|
g_avg = self.global_avg_pool(c3).flatten(1) # [B, 128]
|
|
global_feat = torch.cat([g_max, g_avg], dim=1) # [B, 256]
|
|
|
|
# Regression
|
|
coords = self.fc(global_feat)
|
|
return coords # [B, 4]
|
|
|
|
class doublePhotonNet_251124(nn.Module): ### adapted for 7x7 input from 251001_2
|
|
def __init__(self):
|
|
super().__init__()
|
|
# Backbone: deeper + residual-like blocks
|
|
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1) ### 7x7
|
|
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) ### 7x7
|
|
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) ### 7x7
|
|
|
|
# Spatial Attention Module (轻量但有效)
|
|
self.spatial_attn = nn.Sequential(
|
|
nn.Conv2d(128, 1, kernel_size=1),
|
|
nn.Sigmoid()
|
|
)
|
|
|
|
# Multi-scale feature fusion (optional but helpful)
|
|
self.reduce1 = nn.Conv2d(32, 32, kernel_size=1) # from conv1
|
|
self.reduce2 = nn.Conv2d(64, 32, kernel_size=1) # from conv2
|
|
self.fuse = nn.Conv2d(32*3, 128, kernel_size=1)
|
|
|
|
# Global context with both Max and Avg pooling (better than GAP alone)
|
|
self.global_max_pool = nn.AdaptiveMaxPool2d((1,1))
|
|
self.global_avg_pool = nn.AdaptiveAvgPool2d((1,1))
|
|
|
|
# Enhanced regression head
|
|
self.fc = nn.Sequential(
|
|
nn.Linear(128 * 2, 256), # concat max + avg
|
|
nn.ReLU(),
|
|
nn.Dropout(0.3),
|
|
nn.Linear(256, 128),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.3),
|
|
nn.Linear(128, 4),
|
|
# nn.Sigmoid() # output in [0,1]
|
|
)
|
|
|
|
self._init_weights()
|
|
|
|
def _init_weights(self):
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
|
elif isinstance(m, nn.Linear):
|
|
nn.init.xavier_uniform_(m.weight)
|
|
nn.init.zeros_(m.bias)
|
|
|
|
def forward(self, x):
|
|
# Feature extraction
|
|
c1 = F.relu(self.conv1(x)) # [B, 32, 7, 7]
|
|
c2 = F.relu(self.conv2(c1)) # [B, 64, 7, 7]
|
|
c3 = F.relu(self.conv3(c2)) # [B,128, 7, 7]
|
|
|
|
# Spatial attention: highlight photon peaks
|
|
attn = self.spatial_attn(c3) # [B, 1, 7, 7]
|
|
c3 = c3 * attn # reweight features
|
|
|
|
# (Optional) Multi-scale fusion — uncomment if needed
|
|
# r1 = F.interpolate(self.reduce1(c1), size=(7,7), mode='nearest')
|
|
# r2 = self.reduce2(c2)
|
|
# fused = torch.cat([r1, r2, c3], dim=1)
|
|
# c3 = self.fuse(fused)
|
|
|
|
# Global context: MaxPool better captures peaks, Avg for context
|
|
g_max = self.global_max_pool(c3).flatten(1) # [B, 128]
|
|
g_avg = self.global_avg_pool(c3).flatten(1) # [B, 128]
|
|
global_feat = torch.cat([g_max, g_avg], dim=1) # [B, 256]
|
|
|
|
# Regression
|
|
coords = self.fc(global_feat)
|
|
return coords # [B, 4] |