Correct double photon sample (remove ordering)

This commit is contained in:
2025-11-04 16:56:42 +01:00
parent f4178ce50f
commit 88536af944
3 changed files with 149 additions and 70 deletions
+58 -5
View File
@@ -10,6 +10,13 @@ def get_model_class(version):
raise ValueError(f"Model class '{class_name}' not found.")
return cls
def get_double_photon_model_class(version):
class_name = f'doublePhotonNet_{version}'
cls = globals().get(class_name)
if cls is None:
raise ValueError(f"Model class '{class_name}' not found.")
return cls
class singlePhotonNet_250909(nn.Module):
def weight_init(self):
for m in self.modules():
@@ -181,15 +188,11 @@ class doublePhotonNet_251001(nn.Module):
coords = self.fc(x)*6
return coords # shape: [B, 4]
import torch
import torch.nn as nn
import torch.nn.functional as F
class doublePhotonNet_251001_2(nn.Module):
def __init__(self):
super().__init__()
# Backbone: deeper + residual-like blocks
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
@@ -253,4 +256,54 @@ class doublePhotonNet_251001_2(nn.Module):
# Regression
coords = self.fc(global_feat) * 6.0 # scale to [0,6)
return coords # [B, 4]
class doublePhotonNet_251104(nn.Module):
def __init__(self):
super().__init__()
# Backbone: deeper
self.conv1 = nn.Conv2d(3, 32, kernel_size=3) # 6x6 -> 4x4
self.conv2 = nn.Conv2d(32, 64, kernel_size=3) # 4x4 -> 2x2
self.conv3 = nn.Conv2d(64, 128, kernel_size=2) # 2x2 -> 1x1
# Spatial Attention Module (轻量但有效)
self.spatial_attn = nn.Sequential(
nn.Conv2d(128, 1, kernel_size=1),
nn.Sigmoid()
)
# Enhanced regression head
self.fc = nn.Sequential(
nn.Linear(128, 256), # concat max + avg
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, 128),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(128, 4),
nn.Sigmoid() # output in [0,1]
)
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
nn.init.zeros_(m.bias)
def forward(self, x):
# Feature extraction
c1 = F.relu(self.conv1(x)) # [B, 32, 6, 6]
c2 = F.relu(self.conv2(c1)) # [B, 64, 6, 6]
c3 = F.relu(self.conv3(c2)) # [B,128, 6, 6]
# Spatial attention: highlight photon peaks
attn = self.spatial_attn(c3) # [B, 1, 6, 6]
c3 = c3 * attn # reweight features
# Regression
coords = self.fc(c3.flatten(1)) * 6.0 # scale to [0,6)
return coords # [B, 4]