Shift label to sample center; remove the sigmoid in FC

This commit is contained in:
2025-11-04 17:37:39 +01:00
parent 88536af944
commit b0a396b0d8
3 changed files with 10 additions and 58 deletions

View File

@@ -12,7 +12,7 @@ from torchinfo import summary
torch.manual_seed(0)
np.random.seed(0)
modelVersion = '251104' # '250910' or '251001'
modelVersion = '251001_2' # '250910' or '251001'
model = models.get_double_photon_model_class(modelVersion)().cuda()
Energy = '15.3keV'
TrainLosses, ValLosses = [], []

View File

@@ -139,22 +139,24 @@ class doublePhotonDataset(Dataset):
idx2 = np.random.randint(0, self.samples.shape[0])
photon1 = self.samples[idx1]
photon2 = self.samples[idx2]
singlePhotonSize = photon1.shape[0]
### random position for photons in
pos_x1 = np.random.randint(1, 4)
# pos_y1 = np.random.randint(1, 4)
pos_y1 = pos_x1
sample[pos_y1:pos_y1+5, pos_x1:pos_x1+5] += photon1
sample[pos_y1:pos_y1+singlePhotonSize, pos_x1:pos_x1+singlePhotonSize] += photon1
pos_x2 = np.random.randint(1, 4)
# pos_y2 = np.random.randint(1, 4)
pos_y2 = pos_x2
sample[pos_y2:pos_y2+5, pos_x2:pos_x2+5] += photon2
sample[pos_y2:pos_y2+singlePhotonSize, pos_x2:pos_x2+singlePhotonSize] += photon2
sample = sample[1:-1, 1:-1] ### sample size: 6x6
sample = torch.tensor(sample, dtype=torch.float32).unsqueeze(0)
sample = torch.cat((sample, self.x_grid, self.y_grid), dim=0) ### concatenate coordinate channels
doublePhotonSize = 6
label1 = self.labels[idx1] + np.array([pos_x1-1, pos_y1-1, 0, 0])
label2 = self.labels[idx2] + np.array([pos_x2-1, pos_y2-1, 0, 0])
label1 = self.labels[idx1] + np.array([pos_x1-1-doublePhotonSize//2, pos_y1-1-doublePhotonSize//2, 0, 0])
label2 = self.labels[idx2] + np.array([pos_x2-1-doublePhotonSize//2, pos_y2-1-doublePhotonSize//2, 0, 0])
label = np.concatenate((label1, label2), axis=0)
return sample, torch.tensor(label, dtype=torch.float32)

View File

@@ -185,7 +185,7 @@ class doublePhotonNet_251001(nn.Module):
x = torch.relu(self.conv3(x))
# x = self.global_avg_pool(x).view(x.size(0), -1)
x = self.global_max_pool(x).view(x.size(0), -1)
coords = self.fc(x)*6
coords = self.fc(x)
return coords # shape: [B, 4]
class doublePhotonNet_251001_2(nn.Module):
@@ -220,7 +220,7 @@ class doublePhotonNet_251001_2(nn.Module):
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(128, 4),
nn.Sigmoid() # output in [0,1]
# nn.Sigmoid() # output in [0,1]
)
self._init_weights()
@@ -255,55 +255,5 @@ class doublePhotonNet_251001_2(nn.Module):
global_feat = torch.cat([g_max, g_avg], dim=1) # [B, 256]
# Regression
coords = self.fc(global_feat) * 6.0 # scale to [0,6)
return coords # [B, 4]
class doublePhotonNet_251104(nn.Module):
def __init__(self):
super().__init__()
# Backbone: deeper
self.conv1 = nn.Conv2d(3, 32, kernel_size=3) # 6x6 -> 4x4
self.conv2 = nn.Conv2d(32, 64, kernel_size=3) # 4x4 -> 2x2
self.conv3 = nn.Conv2d(64, 128, kernel_size=2) # 2x2 -> 1x1
# Spatial Attention Module (轻量但有效)
self.spatial_attn = nn.Sequential(
nn.Conv2d(128, 1, kernel_size=1),
nn.Sigmoid()
)
# Enhanced regression head
self.fc = nn.Sequential(
nn.Linear(128, 256), # concat max + avg
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, 128),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(128, 4),
nn.Sigmoid() # output in [0,1]
)
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
nn.init.zeros_(m.bias)
def forward(self, x):
# Feature extraction
c1 = F.relu(self.conv1(x)) # [B, 32, 6, 6]
c2 = F.relu(self.conv2(c1)) # [B, 64, 6, 6]
c3 = F.relu(self.conv3(c2)) # [B,128, 6, 6]
# Spatial attention: highlight photon peaks
attn = self.spatial_attn(c3) # [B, 1, 6, 6]
c3 = c3 * attn # reweight features
# Regression
coords = self.fc(c3.flatten(1)) * 6.0 # scale to [0,6)
coords = self.fc(global_feat)
return coords # [B, 4]