import sys sys.path.append('./src') import torch import numpy as np import models from datasets import * import torch.optim as optim from tqdm import tqdm from torchinfo import summary ### random seed for reproducibility torch.manual_seed(0) np.random.seed(0) modelVersion = '251001_2' # '250910' or '251001' model = models.get_double_photon_model_class(modelVersion)().cuda() Energy = '15.3keV' TrainLosses, ValLosses = [], [] TestLoss = -1 LearningRate = 1e-3 Noise = 0.13 # in keV def two_point_set_loss_l2(pred_xy, gt_xy): def pair_cost_l2sq(p, q): # p,q: (...,2) return ((p - q)**2).sum(dim=-1) # squared L2 p1, p2 = pred_xy[:,0], pred_xy[:,1] g1, g2 = gt_xy[:,0], gt_xy[:,1] c_a = pair_cost_l2sq(p1,g1) + pair_cost_l2sq(p2,g2) c_b = pair_cost_l2sq(p1,g2) + pair_cost_l2sq(p2,g1) return torch.minimum(c_a, c_b).mean() # summary(model, input_size=(128, 1, 6, 6)) ### print model summary loss_fn = two_point_set_loss_l2 def train(model, trainLoader, optimizer): model.train() batchLoss = 0 for batch_idx, (sample, label) in enumerate(trainLoader): sample, label = sample.cuda(), label.cuda() x1, y1, z1, e1 = label[:,0], label[:,1], label[:,2], label[:,3] x2, y2, z2, e2 = label[:,4], label[:,5], label[:,6], label[:,7] gt_xy = torch.stack((torch.stack((x1, y1), axis=1), torch.stack((x2, y2), axis=1)), axis=1) optimizer.zero_grad() output = model(sample) pred_xy = torch.stack((output[:,0:2], output[:,2:4]), axis=1) loss = loss_fn(pred_xy, gt_xy) loss.backward() optimizer.step() batchLoss += loss.item() * sample.shape[0] avgLoss = batchLoss / len(trainLoader.dataset) / 4 ### divide by 4 to get the average loss per photon per axis print(f"[Train]\t Average Loss: {avgLoss:.6f} (RMS = {np.sqrt(avgLoss):.6f})") TrainLosses.append(avgLoss) def test(model, testLoader): model.eval() batchLoss = 0 gt_xy, out_xy = [], [] with torch.no_grad(): for batch_idx, (sample, label) in enumerate(testLoader): sample, label = sample.cuda(), label.cuda() x1, y1, z1, e1 = label[:,0], label[:,1], label[:,2], label[:,3] x2, y2, z2, e2 = label[:,4], label[:,5], label[:,6], label[:,7] _gt_xy = torch.stack((torch.stack((x1, y1), axis=1), torch.stack((x2, y2), axis=1)), axis=1) output = model(sample) _pred_xy = torch.stack((output[:,0:2], output[:,2:4]), axis=1) loss = loss_fn(_pred_xy, _gt_xy) batchLoss += loss.item() * sample.shape[0] gt_xy.append(_gt_xy.cpu()) out_xy.append(_pred_xy.cpu()) gt_xy = torch.cat(gt_xy, dim=0) out_xy = torch.cat(out_xy, dim=0) avgLoss = batchLoss / len(testLoader.dataset) / 4 ### divide by 4 to get the average loss per photon per axis datasetName = testLoader.dataset.datasetName print(f"[{datasetName}]\t Average Loss: {avgLoss:.6f} (RMS = {np.sqrt(avgLoss):.6f})") calculate_residuals(gt_xy, out_xy) if datasetName == 'Val': ValLosses.append(avgLoss) else: global TestLoss TestLoss = avgLoss return avgLoss def calculate_residuals(gt_xy, out_xy): """ gt_xy: (N, 2, 2) — [ [x1, y1], [x2, y2] ] out_xy: (N, 2, 2) — [ [x1', y1'], [x2', y2'] ] """ # Option A: match (p1->g1, p2->g2) cost_a = (out_xy - gt_xy).pow(2).sum(dim=-1).sum(dim=-1) # (N,) # Option B: match (p1->g2, p2->g1) → swap out_xy out_swapped = out_xy[:, [1, 0], :] # swap the two points: (N, 2, 2) cost_b = (out_swapped - gt_xy).pow(2).sum(dim=-1).sum(dim=-1) # (N,) # Choose best assignment per sample swap_mask = cost_b < cost_a # (N,) # Apply swapping to get optimally matched predictions out_matched = out_xy.clone() out_matched[swap_mask] = out_xy[swap_mask][:, [1, 0], :] # Compute residuals residuals = out_matched - gt_xy # (N, 2, 2) # Flatten to get all residuals (2N points) residuals_x = residuals[:, :, 0].flatten().cpu().numpy() residuals_y = residuals[:, :, 1].flatten().cpu().numpy() # Print statistics print(f"\t\tResiduals X: mean={np.mean(residuals_x):.4f}, std={np.std(residuals_x):.4f}") print(f"\t\tResiduals Y: mean={np.mean(residuals_y):.4f}, std={np.std(residuals_y):.4f}") sampleFolder = '/mnt/sls_det_storage/moench_data/MLXID/Samples/Simulation/Moench040' trainDataset = doublePhotonDataset( [f'{sampleFolder}/{Energy}_Moench040_150V_{i}.npz' for i in range(13)], sampleRatio=1.0, datasetName='Train', reuselFactor=1, noiseKeV = Noise, ) valDataset = doublePhotonDataset( [f'{sampleFolder}/{Energy}_Moench040_150V_{i}.npz' for i in range(13,14)], sampleRatio=1.0, datasetName='Val', reuselFactor=1, noiseKeV = Noise, ) testDataset = doublePhotonDataset( [f'{sampleFolder}/{Energy}_Moench040_150V_{i}.npz' for i in range(15,16)], sampleRatio=1.0, datasetName='Test', reuselFactor=1, noiseKeV = Noise, ) trainLoader = torch.utils.data.DataLoader( trainDataset, batch_size=1024, pin_memory = True, shuffle=True, num_workers=16 ) valLoader = torch.utils.data.DataLoader( valDataset, batch_size=4096, shuffle=False, num_workers=16 ) testLoader = torch.utils.data.DataLoader( testDataset, batch_size=4096, shuffle=False, num_workers=16 ) optimizer = torch.optim.Adam(model.parameters(), lr=LearningRate, weight_decay=1e-4) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.7, patience = 5) if __name__ == "__main__": for epoch in tqdm(range(1, 301)): train(model, trainLoader, optimizer) test(model, valLoader) scheduler.step(ValLosses[-1]) print(f"Learning Rate: {optimizer.param_groups[0]['lr']}") if epoch in [20, 50, 100, 200, 300, 500, 750, 1000]: torch.save(model.state_dict(), f'Models/doublePhoton{modelVersion}_{Energy}_Noise{Noise}keV_E{epoch}.pth') test(model, testLoader) torch.save(model.state_dict(), f'Models/doublePhotonNet_{modelVersion}.pth') def plot_loss_curve(TrainLosses, ValLosses, TestLoss, modelVersion): import matplotlib.pyplot as plt plt.figure(figsize=(8,6)) plt.plot(TrainLosses, label='Train Loss', color='blue') plt.plot(ValLosses, label='Validation Loss', color='orange') if TestLoss > 0: plt.axhline(y=TestLoss, color='green', linestyle='--', label='Test Loss') plt.yscale('log') plt.xlabel('Epoch') plt.ylabel('MSE Loss') plt.legend() plt.grid() plt.savefig(f'Results/loss_curve_doublePhoton_{modelVersion}.png', dpi=300) plot_loss_curve(TrainLosses, ValLosses, TestLoss, modelVersion=modelVersion)