From 88536af944a63cab1b2dd6d82961e23add92c0c5 Mon Sep 17 00:00:00 2001 From: "xiangyu.xie" Date: Tue, 4 Nov 2025 16:56:42 +0100 Subject: [PATCH] Correct double photon sample (remove ordering) --- Train_DoublePhoton.py | 121 +++++++++++++++++++++++------------------- src/datasets.py | 35 ++++++++---- src/models.py | 63 ++++++++++++++++++++-- 3 files changed, 149 insertions(+), 70 deletions(-) diff --git a/Train_DoublePhoton.py b/Train_DoublePhoton.py index 50c9ef6..9bb752d 100644 --- a/Train_DoublePhoton.py +++ b/Train_DoublePhoton.py @@ -2,7 +2,7 @@ import sys sys.path.append('./src') import torch import numpy as np -from models import * +import models from datasets import * import torch.optim as optim from tqdm import tqdm @@ -12,9 +12,13 @@ from torchinfo import summary torch.manual_seed(0) np.random.seed(0) -modelVersion = '251001_2' # '250910' or '251001' +modelVersion = '251104' # '250910' or '251001' +model = models.get_double_photon_model_class(modelVersion)().cuda() +Energy = '15.3keV' TrainLosses, ValLosses = [], [] TestLoss = -1 +LearningRate = 1e-3 +Noise = 0.13 # in keV def two_point_set_loss_l2(pred_xy, gt_xy): def pair_cost_l2sq(p, q): # p,q: (...,2) @@ -25,32 +29,8 @@ def two_point_set_loss_l2(pred_xy, gt_xy): c_b = pair_cost_l2sq(p1,g2) + pair_cost_l2sq(p2,g1) return torch.minimum(c_a, c_b).mean() -def min_matching_loss(pred, target): - """ - pred: [B, 4] -> (x1,y1,x2,y2) - target: [B, 4] -> (x1,y1,x2,y2) - """ - pred = pred.view(-1, 2, 2) # [B, 2, 2] - target = target.view(-1, 2, 2) # [B, 2, 2] - - # 计算所有匹配的MSE - loss1 = torch.mean((pred[:,0] - target[:,0])**2 + (pred[:,1] - target[:,1])**2) - loss2 = torch.mean((pred[:,0] - target[:,1])**2 + (pred[:,1] - target[:,0])**2) - - return torch.min(loss1, loss2) - -# switch modelVersion: -if modelVersion == '250910': - loss_fn = two_point_set_loss_l2 - model = doublePhotonNet_250910().cuda() -elif modelVersion == '251001': - loss_fn = min_matching_loss - model = doublePhotonNet_251001().cuda() -elif modelVersion == '251001_2': - loss_fn = min_matching_loss - model = doublePhotonNet_251001_2().cuda() - # summary(model, input_size=(128, 1, 6, 6)) ### print model summary +loss_fn = two_point_set_loss_l2 def train(model, trainLoader, optimizer): model.train() @@ -67,33 +47,33 @@ def train(model, trainLoader, optimizer): loss.backward() optimizer.step() batchLoss += loss.item() * sample.shape[0] - avgLoss = batchLoss / len(trainLoader.dataset) /2 ### divide by 2 to get the average loss per photon + avgLoss = batchLoss / len(trainLoader.dataset) / 4 ### divide by 4 to get the average loss per photon per axis print(f"[Train]\t Average Loss: {avgLoss:.6f} (RMS = {np.sqrt(avgLoss):.6f})") TrainLosses.append(avgLoss) def test(model, testLoader): model.eval() batchLoss = 0 - residuals_x, residuals_y = np.array([]), np.array([]) + gt_xy, out_xy = [], [] with torch.no_grad(): for batch_idx, (sample, label) in enumerate(testLoader): sample, label = sample.cuda(), label.cuda() x1, y1, z1, e1 = label[:,0], label[:,1], label[:,2], label[:,3] x2, y2, z2, e2 = label[:,4], label[:,5], label[:,6], label[:,7] - gt_xy = torch.stack((torch.stack((x1, y1), axis=1), torch.stack((x2, y2), axis=1)), axis=1) + _gt_xy = torch.stack((torch.stack((x1, y1), axis=1), torch.stack((x2, y2), axis=1)), axis=1) output = model(sample) - pred_xy = torch.stack((output[:,0:2], output[:,2:4]), axis=1) - loss = loss_fn(pred_xy, gt_xy) + _pred_xy = torch.stack((output[:,0:2], output[:,2:4]), axis=1) + loss = loss_fn(_pred_xy, _gt_xy) batchLoss += loss.item() * sample.shape[0] - ### Collect residuals for analysis - residuals_x = np.concatenate((residuals_x, (pred_xy[:,0,0] - gt_xy[:,0,0]).cpu().numpy(), (pred_xy[:,1,0] - gt_xy[:,1,0]).cpu().numpy())) - residuals_y = np.concatenate((residuals_y, (pred_xy[:,0,1] - gt_xy[:,0,1]).cpu().numpy(), (pred_xy[:,1,1] - gt_xy[:,1,1]).cpu().numpy())) - avgLoss = batchLoss / len(testLoader.dataset) + gt_xy.append(_gt_xy.cpu()) + out_xy.append(_pred_xy.cpu()) + gt_xy = torch.cat(gt_xy, dim=0) + out_xy = torch.cat(out_xy, dim=0) + avgLoss = batchLoss / len(testLoader.dataset) / 4 ### divide by 4 to get the average loss per photon per axis datasetName = testLoader.dataset.datasetName print(f"[{datasetName}]\t Average Loss: {avgLoss:.6f} (RMS = {np.sqrt(avgLoss):.6f})") - print(f" Residuals X: mean={np.mean(residuals_x):.4f}, std={np.std(residuals_x):.4f}") - print(f" Residuals Y: mean={np.mean(residuals_y):.4f}, std={np.std(residuals_y):.4f}") + calculate_residuals(gt_xy, out_xy) if datasetName == 'Val': ValLosses.append(avgLoss) else: @@ -101,26 +81,57 @@ def test(model, testLoader): TestLoss = avgLoss return avgLoss +def calculate_residuals(gt_xy, out_xy): + """ + gt_xy: (N, 2, 2) — [ [x1, y1], [x2, y2] ] + out_xy: (N, 2, 2) — [ [x1', y1'], [x2', y2'] ] + """ + # Option A: match (p1->g1, p2->g2) + cost_a = (out_xy - gt_xy).pow(2).sum(dim=-1).sum(dim=-1) # (N,) + + # Option B: match (p1->g2, p2->g1) → swap out_xy + out_swapped = out_xy[:, [1, 0], :] # swap the two points: (N, 2, 2) + cost_b = (out_swapped - gt_xy).pow(2).sum(dim=-1).sum(dim=-1) # (N,) + + # Choose best assignment per sample + swap_mask = cost_b < cost_a # (N,) + + # Apply swapping to get optimally matched predictions + out_matched = out_xy.clone() + out_matched[swap_mask] = out_xy[swap_mask][:, [1, 0], :] + + # Compute residuals + residuals = out_matched - gt_xy # (N, 2, 2) + + # Flatten to get all residuals (2N points) + residuals_x = residuals[:, :, 0].flatten().cpu().numpy() + residuals_y = residuals[:, :, 1].flatten().cpu().numpy() + + # Print statistics + print(f"\t\tResiduals X: mean={np.mean(residuals_x):.4f}, std={np.std(residuals_x):.4f}") + print(f"\t\tResiduals Y: mean={np.mean(residuals_y):.4f}, std={np.std(residuals_y):.4f}") + +sampleFolder = '/mnt/sls_det_storage/moench_data/MLXID/Samples/Simulation/Moench040' trainDataset = doublePhotonDataset( - [f'/home/xie_x1/MLXID/McGeneration/Samples/15keV_Moench040_150V_{i}_samples.npy' for i in range(13)], - [f'/home/xie_x1/MLXID/McGeneration/Samples/15keV_Moench040_150V_{i}_labels.npy' for i in range(13)], + [f'{sampleFolder}/{Energy}_Moench040_150V_{i}.npz' for i in range(13)], sampleRatio=1.0, datasetName='Train', reuselFactor=1, + noiseKeV = Noise, ) valDataset = doublePhotonDataset( - [f'/home/xie_x1/MLXID/McGeneration/Samples/15keV_Moench040_150V_{i}_samples.npy' for i in range(13,14)], - [f'/home/xie_x1/MLXID/McGeneration/Samples/15keV_Moench040_150V_{i}_labels.npy' for i in range(13,14)], + [f'{sampleFolder}/{Energy}_Moench040_150V_{i}.npz' for i in range(13,14)], sampleRatio=1.0, datasetName='Val', reuselFactor=1, + noiseKeV = Noise, ) testDataset = doublePhotonDataset( - [f'/home/xie_x1/MLXID/McGeneration/Samples/15keV_Moench040_150V_{i}_samples.npy' for i in range(15,16)], - [f'/home/xie_x1/MLXID/McGeneration/Samples/15keV_Moench040_150V_{i}_labels.npy' for i in range(15,16)], + [f'{sampleFolder}/{Energy}_Moench040_150V_{i}.npz' for i in range(15,16)], sampleRatio=1.0, datasetName='Test', reuselFactor=1, + noiseKeV = Noise, ) trainLoader = torch.utils.data.DataLoader( trainDataset, @@ -141,17 +152,19 @@ testLoader = torch.utils.data.DataLoader( shuffle=False, num_workers=16 ) -optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4) +optimizer = torch.optim.Adam(model.parameters(), lr=LearningRate, weight_decay=1e-4) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.7, patience = 5) -# if __name__ == "__main__": -# for epoch in tqdm(range(1, 301)): -# train(model, trainLoader, optimizer) -# test(model, valLoader) -# scheduler.step(TrainLosses[-1]) +if __name__ == "__main__": + for epoch in tqdm(range(1, 301)): + train(model, trainLoader, optimizer) + test(model, valLoader) + scheduler.step(ValLosses[-1]) + print(f"Learning Rate: {optimizer.param_groups[0]['lr']}") + if epoch in [20, 50, 100, 200, 300, 500, 750, 1000]: + torch.save(model.state_dict(), f'Models/doublePhoton{modelVersion}_{Energy}_Noise{Noise}keV_E{epoch}.pth') -model.load_state_dict(torch.load(f'doublePhotonNet_{modelVersion}.pth', weights_only=True)) test(model, testLoader) -torch.save(model.state_dict(), f'doublePhotonNet_{modelVersion}.pth') +torch.save(model.state_dict(), f'Models/doublePhotonNet_{modelVersion}.pth') def plot_loss_curve(TrainLosses, ValLosses, TestLoss, modelVersion): import matplotlib.pyplot as plt @@ -165,6 +178,6 @@ def plot_loss_curve(TrainLosses, ValLosses, TestLoss, modelVersion): plt.ylabel('MSE Loss') plt.legend() plt.grid() - plt.savefig(f'loss_curve_doublePhoton_{modelVersion}.png', dpi=300) + plt.savefig(f'Results/loss_curve_doublePhoton_{modelVersion}.png', dpi=300) -# plot_loss_curve(TrainLosses, ValLosses, TestLoss, modelVersion=modelVersion) \ No newline at end of file +plot_loss_curve(TrainLosses, ValLosses, TestLoss, modelVersion=modelVersion) \ No newline at end of file diff --git a/src/datasets.py b/src/datasets.py index f1a1464..10cffa8 100644 --- a/src/datasets.py +++ b/src/datasets.py @@ -101,10 +101,12 @@ class singlePhotonDataset(Dataset): return self.effectiveLength class doublePhotonDataset(Dataset): - def __init__(self, sampleList, sampleRatio, datasetName, reuselFactor=1): + def __init__(self, sampleList, sampleRatio, datasetName, reuselFactor=1, noiseKeV=0): self.sampleFileList = sampleList self.sampleRatio = sampleRatio self.datasetName = datasetName + self.noiseKeV = noiseKeV + self._init_coords() all_samples = [] all_labels = [] @@ -113,11 +115,23 @@ class doublePhotonDataset(Dataset): all_samples.append(data['samples']) all_labels.append(data['labels']) self.samples = np.concatenate(all_samples, axis=0) + if self.noiseKeV != 0: + print(f'Adding Gaussian noise with sigma = {self.noiseKeV} keV to samples in {self.datasetName} dataset') + noise = np.random.normal(loc=0.0, scale=self.noiseKeV, size=self.samples.shape) + self.samples = self.samples + noise self.labels = np.concatenate(all_labels, axis=0) ### total number of samples self.length = int(self.samples.shape[0] * self.sampleRatio) // 2 * reuselFactor print(f"[{self.datasetName} dataset] \t Total number of samples: {self.length}") + + def _init_coords(self): + # Create a coordinate grid for 3x3 input + x = np.linspace(0, 5, 6) + y = np.linspace(0, 5, 6) + x_grid, y_grid = np.meshgrid(x, y, indexing='ij') # (6,6), (6,6) + self.x_grid = torch.tensor(np.expand_dims(x_grid, axis=0)).float().contiguous() # (1, 6, 6) + self.y_grid = torch.tensor(np.expand_dims(y_grid, axis=0)).float().contiguous() # (1, 6, 6) def __getitem__(self, index): sample = np.zeros((8, 8), dtype=np.float32) @@ -128,21 +142,20 @@ class doublePhotonDataset(Dataset): ### random position for photons in pos_x1 = np.random.randint(1, 4) - pos_y1 = np.random.randint(1, 4) + # pos_y1 = np.random.randint(1, 4) + pos_y1 = pos_x1 sample[pos_y1:pos_y1+5, pos_x1:pos_x1+5] += photon1 pos_x2 = np.random.randint(1, 4) - pos_y2 = np.random.randint(1, 4) + # pos_y2 = np.random.randint(1, 4) + pos_y2 = pos_x2 sample[pos_y2:pos_y2+5, pos_x2:pos_x2+5] += photon2 sample = sample[1:-1, 1:-1] ### sample size: 6x6 - sample = torch.tensor(sample, dtype=torch.float32).unsqueeze(0) / 1000. ### to keV + sample = torch.tensor(sample, dtype=torch.float32).unsqueeze(0) + sample = torch.cat((sample, self.x_grid, self.y_grid), dim=0) ### concatenate coordinate channels - label1 = self.labels[idx1] + np.array([pos_x1, pos_y1, 0, 0]) - 1 - label2 = self.labels[idx2] + np.array([pos_x2, pos_y2, 0, 0]) - 1 - # label = np.concatenate((label1, label2), axis=0) - if label1[0] < label2[0]: - label = np.concatenate((label1, label2), axis=0) - else: - label = np.concatenate((label2, label1), axis=0) + label1 = self.labels[idx1] + np.array([pos_x1-1, pos_y1-1, 0, 0]) + label2 = self.labels[idx2] + np.array([pos_x2-1, pos_y2-1, 0, 0]) + label = np.concatenate((label1, label2), axis=0) return sample, torch.tensor(label, dtype=torch.float32) def __len__(self): diff --git a/src/models.py b/src/models.py index 9086b25..8425948 100644 --- a/src/models.py +++ b/src/models.py @@ -10,6 +10,13 @@ def get_model_class(version): raise ValueError(f"Model class '{class_name}' not found.") return cls +def get_double_photon_model_class(version): + class_name = f'doublePhotonNet_{version}' + cls = globals().get(class_name) + if cls is None: + raise ValueError(f"Model class '{class_name}' not found.") + return cls + class singlePhotonNet_250909(nn.Module): def weight_init(self): for m in self.modules(): @@ -181,15 +188,11 @@ class doublePhotonNet_251001(nn.Module): coords = self.fc(x)*6 return coords # shape: [B, 4] -import torch -import torch.nn as nn -import torch.nn.functional as F - class doublePhotonNet_251001_2(nn.Module): def __init__(self): super().__init__() # Backbone: deeper + residual-like blocks - self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1) + self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1) self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) @@ -253,4 +256,54 @@ class doublePhotonNet_251001_2(nn.Module): # Regression coords = self.fc(global_feat) * 6.0 # scale to [0,6) + return coords # [B, 4] + +class doublePhotonNet_251104(nn.Module): + def __init__(self): + super().__init__() + # Backbone: deeper + self.conv1 = nn.Conv2d(3, 32, kernel_size=3) # 6x6 -> 4x4 + self.conv2 = nn.Conv2d(32, 64, kernel_size=3) # 4x4 -> 2x2 + self.conv3 = nn.Conv2d(64, 128, kernel_size=2) # 2x2 -> 1x1 + + # Spatial Attention Module (轻量但有效) + self.spatial_attn = nn.Sequential( + nn.Conv2d(128, 1, kernel_size=1), + nn.Sigmoid() + ) + + # Enhanced regression head + self.fc = nn.Sequential( + nn.Linear(128, 256), # concat max + avg + nn.ReLU(), + nn.Dropout(0.3), + nn.Linear(256, 128), + nn.ReLU(), + nn.Dropout(0.3), + nn.Linear(128, 4), + nn.Sigmoid() # output in [0,1] + ) + + self._init_weights() + + def _init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + nn.init.zeros_(m.bias) + + def forward(self, x): + # Feature extraction + c1 = F.relu(self.conv1(x)) # [B, 32, 6, 6] + c2 = F.relu(self.conv2(c1)) # [B, 64, 6, 6] + c3 = F.relu(self.conv3(c2)) # [B,128, 6, 6] + + # Spatial attention: highlight photon peaks + attn = self.spatial_attn(c3) # [B, 1, 6, 6] + c3 = c3 * attn # reweight features + + # Regression + coords = self.fc(c3.flatten(1)) * 6.0 # scale to [0,6) return coords # [B, 4] \ No newline at end of file