Correct double photon sample (remove ordering)

This commit is contained in:
2025-11-04 16:56:42 +01:00
parent f4178ce50f
commit 88536af944
3 changed files with 149 additions and 70 deletions

View File

@@ -2,7 +2,7 @@ import sys
sys.path.append('./src')
import torch
import numpy as np
from models import *
import models
from datasets import *
import torch.optim as optim
from tqdm import tqdm
@@ -12,9 +12,13 @@ from torchinfo import summary
torch.manual_seed(0)
np.random.seed(0)
modelVersion = '251001_2' # '250910' or '251001'
modelVersion = '251104' # '250910' or '251001'
model = models.get_double_photon_model_class(modelVersion)().cuda()
Energy = '15.3keV'
TrainLosses, ValLosses = [], []
TestLoss = -1
LearningRate = 1e-3
Noise = 0.13 # in keV
def two_point_set_loss_l2(pred_xy, gt_xy):
def pair_cost_l2sq(p, q): # p,q: (...,2)
@@ -25,32 +29,8 @@ def two_point_set_loss_l2(pred_xy, gt_xy):
c_b = pair_cost_l2sq(p1,g2) + pair_cost_l2sq(p2,g1)
return torch.minimum(c_a, c_b).mean()
def min_matching_loss(pred, target):
"""
pred: [B, 4] -> (x1,y1,x2,y2)
target: [B, 4] -> (x1,y1,x2,y2)
"""
pred = pred.view(-1, 2, 2) # [B, 2, 2]
target = target.view(-1, 2, 2) # [B, 2, 2]
# 计算所有匹配的MSE
loss1 = torch.mean((pred[:,0] - target[:,0])**2 + (pred[:,1] - target[:,1])**2)
loss2 = torch.mean((pred[:,0] - target[:,1])**2 + (pred[:,1] - target[:,0])**2)
return torch.min(loss1, loss2)
# switch modelVersion:
if modelVersion == '250910':
loss_fn = two_point_set_loss_l2
model = doublePhotonNet_250910().cuda()
elif modelVersion == '251001':
loss_fn = min_matching_loss
model = doublePhotonNet_251001().cuda()
elif modelVersion == '251001_2':
loss_fn = min_matching_loss
model = doublePhotonNet_251001_2().cuda()
# summary(model, input_size=(128, 1, 6, 6)) ### print model summary
loss_fn = two_point_set_loss_l2
def train(model, trainLoader, optimizer):
model.train()
@@ -67,33 +47,33 @@ def train(model, trainLoader, optimizer):
loss.backward()
optimizer.step()
batchLoss += loss.item() * sample.shape[0]
avgLoss = batchLoss / len(trainLoader.dataset) /2 ### divide by 2 to get the average loss per photon
avgLoss = batchLoss / len(trainLoader.dataset) / 4 ### divide by 4 to get the average loss per photon per axis
print(f"[Train]\t Average Loss: {avgLoss:.6f} (RMS = {np.sqrt(avgLoss):.6f})")
TrainLosses.append(avgLoss)
def test(model, testLoader):
model.eval()
batchLoss = 0
residuals_x, residuals_y = np.array([]), np.array([])
gt_xy, out_xy = [], []
with torch.no_grad():
for batch_idx, (sample, label) in enumerate(testLoader):
sample, label = sample.cuda(), label.cuda()
x1, y1, z1, e1 = label[:,0], label[:,1], label[:,2], label[:,3]
x2, y2, z2, e2 = label[:,4], label[:,5], label[:,6], label[:,7]
gt_xy = torch.stack((torch.stack((x1, y1), axis=1), torch.stack((x2, y2), axis=1)), axis=1)
_gt_xy = torch.stack((torch.stack((x1, y1), axis=1), torch.stack((x2, y2), axis=1)), axis=1)
output = model(sample)
pred_xy = torch.stack((output[:,0:2], output[:,2:4]), axis=1)
loss = loss_fn(pred_xy, gt_xy)
_pred_xy = torch.stack((output[:,0:2], output[:,2:4]), axis=1)
loss = loss_fn(_pred_xy, _gt_xy)
batchLoss += loss.item() * sample.shape[0]
### Collect residuals for analysis
residuals_x = np.concatenate((residuals_x, (pred_xy[:,0,0] - gt_xy[:,0,0]).cpu().numpy(), (pred_xy[:,1,0] - gt_xy[:,1,0]).cpu().numpy()))
residuals_y = np.concatenate((residuals_y, (pred_xy[:,0,1] - gt_xy[:,0,1]).cpu().numpy(), (pred_xy[:,1,1] - gt_xy[:,1,1]).cpu().numpy()))
avgLoss = batchLoss / len(testLoader.dataset)
gt_xy.append(_gt_xy.cpu())
out_xy.append(_pred_xy.cpu())
gt_xy = torch.cat(gt_xy, dim=0)
out_xy = torch.cat(out_xy, dim=0)
avgLoss = batchLoss / len(testLoader.dataset) / 4 ### divide by 4 to get the average loss per photon per axis
datasetName = testLoader.dataset.datasetName
print(f"[{datasetName}]\t Average Loss: {avgLoss:.6f} (RMS = {np.sqrt(avgLoss):.6f})")
print(f" Residuals X: mean={np.mean(residuals_x):.4f}, std={np.std(residuals_x):.4f}")
print(f" Residuals Y: mean={np.mean(residuals_y):.4f}, std={np.std(residuals_y):.4f}")
calculate_residuals(gt_xy, out_xy)
if datasetName == 'Val':
ValLosses.append(avgLoss)
else:
@@ -101,26 +81,57 @@ def test(model, testLoader):
TestLoss = avgLoss
return avgLoss
def calculate_residuals(gt_xy, out_xy):
"""
gt_xy: (N, 2, 2) — [ [x1, y1], [x2, y2] ]
out_xy: (N, 2, 2) — [ [x1', y1'], [x2', y2'] ]
"""
# Option A: match (p1->g1, p2->g2)
cost_a = (out_xy - gt_xy).pow(2).sum(dim=-1).sum(dim=-1) # (N,)
# Option B: match (p1->g2, p2->g1) → swap out_xy
out_swapped = out_xy[:, [1, 0], :] # swap the two points: (N, 2, 2)
cost_b = (out_swapped - gt_xy).pow(2).sum(dim=-1).sum(dim=-1) # (N,)
# Choose best assignment per sample
swap_mask = cost_b < cost_a # (N,)
# Apply swapping to get optimally matched predictions
out_matched = out_xy.clone()
out_matched[swap_mask] = out_xy[swap_mask][:, [1, 0], :]
# Compute residuals
residuals = out_matched - gt_xy # (N, 2, 2)
# Flatten to get all residuals (2N points)
residuals_x = residuals[:, :, 0].flatten().cpu().numpy()
residuals_y = residuals[:, :, 1].flatten().cpu().numpy()
# Print statistics
print(f"\t\tResiduals X: mean={np.mean(residuals_x):.4f}, std={np.std(residuals_x):.4f}")
print(f"\t\tResiduals Y: mean={np.mean(residuals_y):.4f}, std={np.std(residuals_y):.4f}")
sampleFolder = '/mnt/sls_det_storage/moench_data/MLXID/Samples/Simulation/Moench040'
trainDataset = doublePhotonDataset(
[f'/home/xie_x1/MLXID/McGeneration/Samples/15keV_Moench040_150V_{i}_samples.npy' for i in range(13)],
[f'/home/xie_x1/MLXID/McGeneration/Samples/15keV_Moench040_150V_{i}_labels.npy' for i in range(13)],
[f'{sampleFolder}/{Energy}_Moench040_150V_{i}.npz' for i in range(13)],
sampleRatio=1.0,
datasetName='Train',
reuselFactor=1,
noiseKeV = Noise,
)
valDataset = doublePhotonDataset(
[f'/home/xie_x1/MLXID/McGeneration/Samples/15keV_Moench040_150V_{i}_samples.npy' for i in range(13,14)],
[f'/home/xie_x1/MLXID/McGeneration/Samples/15keV_Moench040_150V_{i}_labels.npy' for i in range(13,14)],
[f'{sampleFolder}/{Energy}_Moench040_150V_{i}.npz' for i in range(13,14)],
sampleRatio=1.0,
datasetName='Val',
reuselFactor=1,
noiseKeV = Noise,
)
testDataset = doublePhotonDataset(
[f'/home/xie_x1/MLXID/McGeneration/Samples/15keV_Moench040_150V_{i}_samples.npy' for i in range(15,16)],
[f'/home/xie_x1/MLXID/McGeneration/Samples/15keV_Moench040_150V_{i}_labels.npy' for i in range(15,16)],
[f'{sampleFolder}/{Energy}_Moench040_150V_{i}.npz' for i in range(15,16)],
sampleRatio=1.0,
datasetName='Test',
reuselFactor=1,
noiseKeV = Noise,
)
trainLoader = torch.utils.data.DataLoader(
trainDataset,
@@ -141,17 +152,19 @@ testLoader = torch.utils.data.DataLoader(
shuffle=False,
num_workers=16
)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
optimizer = torch.optim.Adam(model.parameters(), lr=LearningRate, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.7, patience = 5)
# if __name__ == "__main__":
# for epoch in tqdm(range(1, 301)):
# train(model, trainLoader, optimizer)
# test(model, valLoader)
# scheduler.step(TrainLosses[-1])
if __name__ == "__main__":
for epoch in tqdm(range(1, 301)):
train(model, trainLoader, optimizer)
test(model, valLoader)
scheduler.step(ValLosses[-1])
print(f"Learning Rate: {optimizer.param_groups[0]['lr']}")
if epoch in [20, 50, 100, 200, 300, 500, 750, 1000]:
torch.save(model.state_dict(), f'Models/doublePhoton{modelVersion}_{Energy}_Noise{Noise}keV_E{epoch}.pth')
model.load_state_dict(torch.load(f'doublePhotonNet_{modelVersion}.pth', weights_only=True))
test(model, testLoader)
torch.save(model.state_dict(), f'doublePhotonNet_{modelVersion}.pth')
torch.save(model.state_dict(), f'Models/doublePhotonNet_{modelVersion}.pth')
def plot_loss_curve(TrainLosses, ValLosses, TestLoss, modelVersion):
import matplotlib.pyplot as plt
@@ -165,6 +178,6 @@ def plot_loss_curve(TrainLosses, ValLosses, TestLoss, modelVersion):
plt.ylabel('MSE Loss')
plt.legend()
plt.grid()
plt.savefig(f'loss_curve_doublePhoton_{modelVersion}.png', dpi=300)
plt.savefig(f'Results/loss_curve_doublePhoton_{modelVersion}.png', dpi=300)
# plot_loss_curve(TrainLosses, ValLosses, TestLoss, modelVersion=modelVersion)
plot_loss_curve(TrainLosses, ValLosses, TestLoss, modelVersion=modelVersion)

View File

@@ -101,10 +101,12 @@ class singlePhotonDataset(Dataset):
return self.effectiveLength
class doublePhotonDataset(Dataset):
def __init__(self, sampleList, sampleRatio, datasetName, reuselFactor=1):
def __init__(self, sampleList, sampleRatio, datasetName, reuselFactor=1, noiseKeV=0):
self.sampleFileList = sampleList
self.sampleRatio = sampleRatio
self.datasetName = datasetName
self.noiseKeV = noiseKeV
self._init_coords()
all_samples = []
all_labels = []
@@ -113,11 +115,23 @@ class doublePhotonDataset(Dataset):
all_samples.append(data['samples'])
all_labels.append(data['labels'])
self.samples = np.concatenate(all_samples, axis=0)
if self.noiseKeV != 0:
print(f'Adding Gaussian noise with sigma = {self.noiseKeV} keV to samples in {self.datasetName} dataset')
noise = np.random.normal(loc=0.0, scale=self.noiseKeV, size=self.samples.shape)
self.samples = self.samples + noise
self.labels = np.concatenate(all_labels, axis=0)
### total number of samples
self.length = int(self.samples.shape[0] * self.sampleRatio) // 2 * reuselFactor
print(f"[{self.datasetName} dataset] \t Total number of samples: {self.length}")
def _init_coords(self):
# Create a coordinate grid for 3x3 input
x = np.linspace(0, 5, 6)
y = np.linspace(0, 5, 6)
x_grid, y_grid = np.meshgrid(x, y, indexing='ij') # (6,6), (6,6)
self.x_grid = torch.tensor(np.expand_dims(x_grid, axis=0)).float().contiguous() # (1, 6, 6)
self.y_grid = torch.tensor(np.expand_dims(y_grid, axis=0)).float().contiguous() # (1, 6, 6)
def __getitem__(self, index):
sample = np.zeros((8, 8), dtype=np.float32)
@@ -128,21 +142,20 @@ class doublePhotonDataset(Dataset):
### random position for photons in
pos_x1 = np.random.randint(1, 4)
pos_y1 = np.random.randint(1, 4)
# pos_y1 = np.random.randint(1, 4)
pos_y1 = pos_x1
sample[pos_y1:pos_y1+5, pos_x1:pos_x1+5] += photon1
pos_x2 = np.random.randint(1, 4)
pos_y2 = np.random.randint(1, 4)
# pos_y2 = np.random.randint(1, 4)
pos_y2 = pos_x2
sample[pos_y2:pos_y2+5, pos_x2:pos_x2+5] += photon2
sample = sample[1:-1, 1:-1] ### sample size: 6x6
sample = torch.tensor(sample, dtype=torch.float32).unsqueeze(0) / 1000. ### to keV
sample = torch.tensor(sample, dtype=torch.float32).unsqueeze(0)
sample = torch.cat((sample, self.x_grid, self.y_grid), dim=0) ### concatenate coordinate channels
label1 = self.labels[idx1] + np.array([pos_x1, pos_y1, 0, 0]) - 1
label2 = self.labels[idx2] + np.array([pos_x2, pos_y2, 0, 0]) - 1
# label = np.concatenate((label1, label2), axis=0)
if label1[0] < label2[0]:
label = np.concatenate((label1, label2), axis=0)
else:
label = np.concatenate((label2, label1), axis=0)
label1 = self.labels[idx1] + np.array([pos_x1-1, pos_y1-1, 0, 0])
label2 = self.labels[idx2] + np.array([pos_x2-1, pos_y2-1, 0, 0])
label = np.concatenate((label1, label2), axis=0)
return sample, torch.tensor(label, dtype=torch.float32)
def __len__(self):

View File

@@ -10,6 +10,13 @@ def get_model_class(version):
raise ValueError(f"Model class '{class_name}' not found.")
return cls
def get_double_photon_model_class(version):
class_name = f'doublePhotonNet_{version}'
cls = globals().get(class_name)
if cls is None:
raise ValueError(f"Model class '{class_name}' not found.")
return cls
class singlePhotonNet_250909(nn.Module):
def weight_init(self):
for m in self.modules():
@@ -181,15 +188,11 @@ class doublePhotonNet_251001(nn.Module):
coords = self.fc(x)*6
return coords # shape: [B, 4]
import torch
import torch.nn as nn
import torch.nn.functional as F
class doublePhotonNet_251001_2(nn.Module):
def __init__(self):
super().__init__()
# Backbone: deeper + residual-like blocks
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
@@ -253,4 +256,54 @@ class doublePhotonNet_251001_2(nn.Module):
# Regression
coords = self.fc(global_feat) * 6.0 # scale to [0,6)
return coords # [B, 4]
class doublePhotonNet_251104(nn.Module):
def __init__(self):
super().__init__()
# Backbone: deeper
self.conv1 = nn.Conv2d(3, 32, kernel_size=3) # 6x6 -> 4x4
self.conv2 = nn.Conv2d(32, 64, kernel_size=3) # 4x4 -> 2x2
self.conv3 = nn.Conv2d(64, 128, kernel_size=2) # 2x2 -> 1x1
# Spatial Attention Module (轻量但有效)
self.spatial_attn = nn.Sequential(
nn.Conv2d(128, 1, kernel_size=1),
nn.Sigmoid()
)
# Enhanced regression head
self.fc = nn.Sequential(
nn.Linear(128, 256), # concat max + avg
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, 128),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(128, 4),
nn.Sigmoid() # output in [0,1]
)
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
nn.init.zeros_(m.bias)
def forward(self, x):
# Feature extraction
c1 = F.relu(self.conv1(x)) # [B, 32, 6, 6]
c2 = F.relu(self.conv2(c1)) # [B, 64, 6, 6]
c3 = F.relu(self.conv3(c2)) # [B,128, 6, 6]
# Spatial attention: highlight photon peaks
attn = self.spatial_attn(c3) # [B, 1, 6, 6]
c3 = c3 * attn # reweight features
# Regression
coords = self.fc(c3.flatten(1)) * 6.0 # scale to [0,6)
return coords # [B, 4]