diff --git a/Train_DoublePhoton.py b/Train_DoublePhoton.py deleted file mode 100644 index 26abdb9..0000000 --- a/Train_DoublePhoton.py +++ /dev/null @@ -1,196 +0,0 @@ -import sys -sys.path.append('./src') -import torch -import numpy as np -import models -from datasets import * -import torch.optim as optim -from tqdm import tqdm -from torchinfo import summary - -### random seed for reproducibility -torch.manual_seed(0) -np.random.seed(0) - -modelVersion = '251124' # '250910' or '251001' or '251124' -model = models.get_double_photon_model_class(modelVersion)().cuda() -Energy = '12keV' -TrainLosses, ValLosses = [], [] -TestLoss = -1 -LearningRate = 1e-3 -Noise = 0.13 # in keV -NoiseThreshold = 3*Noise -Normalize = False - -def two_point_set_loss_l2(pred_xy, gt_xy): - def pair_cost_l2sq(p, q): # p,q: (...,2) - return ((p - q)**2).sum(dim=-1) # squared L2 - p1, p2 = pred_xy[:,0], pred_xy[:,1] - g1, g2 = gt_xy[:,0], gt_xy[:,1] - c_a = pair_cost_l2sq(p1,g1) + pair_cost_l2sq(p2,g2) - c_b = pair_cost_l2sq(p1,g2) + pair_cost_l2sq(p2,g1) - return torch.minimum(c_a, c_b).mean() - -# summary(model, input_size=(128, 1, 6, 6)) ### print model summary -loss_fn = two_point_set_loss_l2 - -def train(model, trainLoader, optimizer): - model.train() - batchLoss = 0 - for batch_idx, (sample, label) in enumerate(trainLoader): - sample, label = sample.cuda(), label.cuda() - x1, y1, z1, e1 = label[:,0], label[:,1], label[:,2], label[:,3] - x2, y2, z2, e2 = label[:,4], label[:,5], label[:,6], label[:,7] - gt_xy = torch.stack((torch.stack((x1, y1), axis=1), torch.stack((x2, y2), axis=1)), axis=1) - optimizer.zero_grad() - output = model(sample) - pred_xy = torch.stack((output[:,0:2], output[:,2:4]), axis=1) - loss = loss_fn(pred_xy, gt_xy) - loss.backward() - optimizer.step() - batchLoss += loss.item() * sample.shape[0] - avgLoss = batchLoss / len(trainLoader.dataset) / 4 ### divide by 4 to get the average loss per photon per axis - print(f"[Train]\t Average Loss: {avgLoss:.6f} (RMS = {np.sqrt(avgLoss):.6f})") - TrainLosses.append(avgLoss) - -def test(model, testLoader): - model.eval() - batchLoss = 0 - gt_xy, out_xy = [], [] - with torch.no_grad(): - for batch_idx, (sample, label) in enumerate(testLoader): - sample, label = sample.cuda(), label.cuda() - x1, y1, z1, e1 = label[:,0], label[:,1], label[:,2], label[:,3] - x2, y2, z2, e2 = label[:,4], label[:,5], label[:,6], label[:,7] - _gt_xy = torch.stack((torch.stack((x1, y1), axis=1), torch.stack((x2, y2), axis=1)), axis=1) - output = model(sample) - _pred_xy = torch.stack((output[:,0:2], output[:,2:4]), axis=1) - loss = loss_fn(_pred_xy, _gt_xy) - batchLoss += loss.item() * sample.shape[0] - gt_xy.append(_gt_xy.cpu()) - out_xy.append(_pred_xy.cpu()) - gt_xy = torch.cat(gt_xy, dim=0) - out_xy = torch.cat(out_xy, dim=0) - avgLoss = batchLoss / len(testLoader.dataset) / 4 ### divide by 4 to get the average loss per photon per axis - - datasetName = testLoader.dataset.datasetName - print(f"[{datasetName}]\t Average Loss: {avgLoss:.6f} (RMS = {np.sqrt(avgLoss):.6f})") - calculate_residuals(gt_xy, out_xy) - if datasetName == 'Val': - ValLosses.append(avgLoss) - else: - global TestLoss - TestLoss = avgLoss - return avgLoss - -def calculate_residuals(gt_xy, out_xy): - """ - gt_xy: (N, 2, 2) — [ [x1, y1], [x2, y2] ] - out_xy: (N, 2, 2) — [ [x1', y1'], [x2', y2'] ] - """ - # Option A: match (p1->g1, p2->g2) - cost_a = (out_xy - gt_xy).pow(2).sum(dim=-1).sum(dim=-1) # (N,) - - # Option B: match (p1->g2, p2->g1) → swap out_xy - out_swapped = out_xy[:, [1, 0], :] # swap the two points: (N, 2, 2) - cost_b = (out_swapped - gt_xy).pow(2).sum(dim=-1).sum(dim=-1) # (N,) - - # Choose best assignment per sample - swap_mask = cost_b < cost_a # (N,) - - # Apply swapping to get optimally matched predictions - out_matched = out_xy.clone() - out_matched[swap_mask] = out_xy[swap_mask][:, [1, 0], :] - - # Compute residuals - residuals = out_matched - gt_xy # (N, 2, 2) - - # Flatten to get all residuals (2N points) - residuals_x = residuals[:, :, 0].flatten().cpu().numpy() - residuals_y = residuals[:, :, 1].flatten().cpu().numpy() - - # Print statistics - print(f"\t\tResiduals X: mean={np.mean(residuals_x):.4f}, std={np.std(residuals_x):.4f}") - print(f"\t\tResiduals Y: mean={np.mean(residuals_y):.4f}, std={np.std(residuals_y):.4f}") - -sampleFolder = '/mnt/sls_det_storage/moench_data/MLXID/Samples/Simulation/Moench040' -trainDataset = doublePhotonDataset( - [f'{sampleFolder}/{Energy}_Moench040_150V_{i}.npz' for i in range(13)], - sampleRatio=1.0, - datasetName='Train', - reuselFactor=1, - noiseKeV = Noise, - nSize=7, - noiseThreshold = NoiseThreshold, - normalize = Normalize - ) -valDataset = doublePhotonDataset( - [f'{sampleFolder}/{Energy}_Moench040_150V_{i}.npz' for i in range(13,14)], - sampleRatio=1.0, - datasetName='Val', - reuselFactor=1, - noiseKeV = Noise, - nSize=7, - noiseThreshold = NoiseThreshold, - normalize = Normalize - ) -testDataset = doublePhotonDataset( - [f'{sampleFolder}/{Energy}_Moench040_150V_{i}.npz' for i in range(15,16)], - sampleRatio=1.0, - datasetName='Test', - reuselFactor=1, - noiseKeV = Noise, - nSize=7, - noiseThreshold = NoiseThreshold, - normalize = Normalize - ) -trainLoader = torch.utils.data.DataLoader( - trainDataset, - batch_size=1024, - pin_memory = True, - shuffle=True, - num_workers=16 - ) -valLoader = torch.utils.data.DataLoader( - valDataset, - batch_size=4096, - shuffle=False, - num_workers=16 - ) -testLoader = torch.utils.data.DataLoader( - testDataset, - batch_size=4096, - shuffle=False, - num_workers=16 - ) -optimizer = torch.optim.Adam(model.parameters(), lr=LearningRate, weight_decay=1e-4) -scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.7, patience = 5) -if __name__ == "__main__": - for epoch in tqdm(range(1, 301)): - train(model, trainLoader, optimizer) - test(model, valLoader) - scheduler.step(ValLosses[-1]) - print(f"Learning Rate: {optimizer.param_groups[0]['lr']}") - if epoch in [20, 50, 100, 200, 300, 500, 750, 1000]: - modelName = f'doublePhoton{modelVersion}_{Energy}_Noise{Noise}keV_E{epoch}' - if Normalize: - modelName += '_normalized' - torch.save(model.state_dict(), f'Models/{modelName}.pth') - -test(model, testLoader) - -def plot_loss_curve(TrainLosses, ValLosses, TestLoss, modelVersion): - import matplotlib.pyplot as plt - plt.figure(figsize=(8,6)) - plt.plot(TrainLosses, label='Train Loss', color='blue') - plt.plot(ValLosses, label='Validation Loss', color='orange') - if TestLoss > 0: - plt.axhline(y=TestLoss, color='green', linestyle='--', label='Test Loss') - plt.yscale('log') - plt.xlabel('Epoch') - plt.ylabel('MSE Loss') - plt.legend() - plt.grid() - plt.savefig(f'Results/loss_curve_doublePhoton_{modelVersion}.png', dpi=300) - -plot_loss_curve(TrainLosses, ValLosses, TestLoss, modelVersion=modelVersion) \ No newline at end of file