Files
DeepLearning/src/models.py

219 lines
8.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import torch.nn as nn
import torch.nn.functional as F
import torch
import numpy as np
class singlePhotonNet_250909(nn.Module):
def weight_init(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def __init__(self):
super(singlePhotonNet_250909, self).__init__()
self.conv1 = nn.Conv2d(1, 5, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(5, 10, kernel_size=3, padding=1)
self.conv3 = nn.Conv2d(10, 20, kernel_size=3, padding=1)
self.fc = nn.Linear(20*5*5, 2)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
class singlePhotonNet_251020(nn.Module):
'''
Smaller input size (3x3)
'''
def weight_init(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def __init__(self):
super(singlePhotonNet_251020, self).__init__()
self.conv1 = nn.Conv2d(1, 5, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(5, 10, kernel_size=3, padding=1)
self.conv3 = nn.Conv2d(10, 20, kernel_size=3, padding=1)
self.fc = nn.Linear(20*3*3, 2)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
class doublePhotonNet_250909(nn.Module):
def __init__(self):
super(doublePhotonNet_250909, self).__init__()
self.conv1 = nn.Conv2d(1, 3, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(3, 5, kernel_size=3, padding=1)
self.conv3 = nn.Conv2d(5, 5, kernel_size=3, padding=1)
self.fc1 = nn.Linear(5*6*6, 4)
# self.fc2 = nn.Linear(50, 4)
# 初始化更稳一些
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
nn.init.zeros_(m.bias)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = x.view(x.size(0), -1)
# x = F.relu(self.fc1(x))
# x = self.fc2(x)
x = self.fc1(x)
return x
class doublePhotonNet_250910(nn.Module):
def __init__(self):
super(doublePhotonNet_250910, self).__init__()
### x shape: (B, 1, 6, 6)
self.conv1 = nn.Conv2d(1, 5, kernel_size=5, padding=2) # (B,5,6,6)
self.conv2 = nn.Conv2d(5, 10, kernel_size=5, padding=2) # (B,10,6,6)
self.conv3 = nn.Conv2d(10, 20, kernel_size=3, padding=0) # (B,20,4,4)
self.fc1 = nn.Linear(20*4*4, 4)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
nn.init.zeros_(m.bias)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = x.view(x.size(0), -1)
# x = F.relu(self.fc1(x))
# x = self.fc2(x)
x = self.fc1(x) * 6
return x
class doublePhotonNet_251001(nn.Module):
def __init__(self):
super().__init__()
# 保持空间分辨率:使用小卷积核 + 无池化
self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1) # 6x6 -> 6x6
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1) # 6x6 -> 6x6
self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1) # 6x6 -> 6x6
# 全局特征提取(替代全连接层)
self.global_avg_pool = nn.AdaptiveAvgPool2d((1,1)) # 64x1x1
self.global_max_pool = nn.AdaptiveMaxPool2d((1,1)) # 64x1x1
# 回归头输出4个坐标 (x1,y1,x2,y2)
self.fc = nn.Sequential(
nn.Linear(64, 128),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(128, 4), # 直接输出坐标
# nn.Sigmoid() # sigmoid leads to overfitting
)
# for m in self.modules():
# if isinstance(m, nn.Conv2d):
# nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
# if isinstance(m, nn.Linear):
# nn.init.xavier_uniform_(m.weight)
# nn.init.zeros_(m.bias)
def forward(self, x):
x = torch.relu(self.conv1(x))
x = torch.relu(self.conv2(x))
x = torch.relu(self.conv3(x))
# x = self.global_avg_pool(x).view(x.size(0), -1)
x = self.global_max_pool(x).view(x.size(0), -1)
coords = self.fc(x)*6
return coords # shape: [B, 4]
import torch
import torch.nn as nn
import torch.nn.functional as F
class doublePhotonNet_251001_2(nn.Module):
def __init__(self):
super().__init__()
# Backbone: deeper + residual-like blocks
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
# Spatial Attention Module (轻量但有效)
self.spatial_attn = nn.Sequential(
nn.Conv2d(128, 1, kernel_size=1),
nn.Sigmoid()
)
# Multi-scale feature fusion (optional but helpful)
self.reduce1 = nn.Conv2d(32, 32, kernel_size=1) # from conv1
self.reduce2 = nn.Conv2d(64, 32, kernel_size=1) # from conv2
self.fuse = nn.Conv2d(32*3, 128, kernel_size=1)
# Global context with both Max and Avg pooling (better than GAP alone)
self.global_max_pool = nn.AdaptiveMaxPool2d((1,1))
self.global_avg_pool = nn.AdaptiveAvgPool2d((1,1))
# Enhanced regression head
self.fc = nn.Sequential(
nn.Linear(128 * 2, 256), # concat max + avg
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, 128),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(128, 4),
nn.Sigmoid() # output in [0,1]
)
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
nn.init.zeros_(m.bias)
def forward(self, x):
# Feature extraction
c1 = F.relu(self.conv1(x)) # [B, 32, 6, 6]
c2 = F.relu(self.conv2(c1)) # [B, 64, 6, 6]
c3 = F.relu(self.conv3(c2)) # [B,128, 6, 6]
# Spatial attention: highlight photon peaks
attn = self.spatial_attn(c3) # [B, 1, 6, 6]
c3 = c3 * attn # reweight features
# (Optional) Multi-scale fusion — uncomment if needed
# r1 = F.interpolate(self.reduce1(c1), size=(6,6), mode='nearest')
# r2 = self.reduce2(c2)
# fused = torch.cat([r1, r2, c3], dim=1)
# c3 = self.fuse(fused)
# Global context: MaxPool better captures peaks, Avg for context
g_max = self.global_max_pool(c3).flatten(1) # [B, 128]
g_avg = self.global_avg_pool(c3).flatten(1) # [B, 128]
global_feat = torch.cat([g_max, g_avg], dim=1) # [B, 256]
# Regression
coords = self.fc(global_feat) * 6.0 # scale to [0,6)
return coords # [B, 4]