From 4b6453c1c235175cba44a9863d2439cdbdd7d729 Mon Sep 17 00:00:00 2001 From: "xiangyu.xie" Date: Wed, 6 May 2026 16:34:49 +0200 Subject: [PATCH] Update 2Photon training --- Train_DoublePhoton.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/Train_DoublePhoton.py b/Train_DoublePhoton.py index 7a662a1..26abdb9 100644 --- a/Train_DoublePhoton.py +++ b/Train_DoublePhoton.py @@ -12,13 +12,15 @@ from torchinfo import summary torch.manual_seed(0) np.random.seed(0) -modelVersion = '251001_2' # '250910' or '251001' +modelVersion = '251124' # '250910' or '251001' or '251124' model = models.get_double_photon_model_class(modelVersion)().cuda() -Energy = '15.3keV' +Energy = '12keV' TrainLosses, ValLosses = [], [] TestLoss = -1 LearningRate = 1e-3 Noise = 0.13 # in keV +NoiseThreshold = 3*Noise +Normalize = False def two_point_set_loss_l2(pred_xy, gt_xy): def pair_cost_l2sq(p, q): # p,q: (...,2) @@ -118,6 +120,9 @@ trainDataset = doublePhotonDataset( datasetName='Train', reuselFactor=1, noiseKeV = Noise, + nSize=7, + noiseThreshold = NoiseThreshold, + normalize = Normalize ) valDataset = doublePhotonDataset( [f'{sampleFolder}/{Energy}_Moench040_150V_{i}.npz' for i in range(13,14)], @@ -125,6 +130,9 @@ valDataset = doublePhotonDataset( datasetName='Val', reuselFactor=1, noiseKeV = Noise, + nSize=7, + noiseThreshold = NoiseThreshold, + normalize = Normalize ) testDataset = doublePhotonDataset( [f'{sampleFolder}/{Energy}_Moench040_150V_{i}.npz' for i in range(15,16)], @@ -132,6 +140,9 @@ testDataset = doublePhotonDataset( datasetName='Test', reuselFactor=1, noiseKeV = Noise, + nSize=7, + noiseThreshold = NoiseThreshold, + normalize = Normalize ) trainLoader = torch.utils.data.DataLoader( trainDataset, @@ -161,10 +172,12 @@ if __name__ == "__main__": scheduler.step(ValLosses[-1]) print(f"Learning Rate: {optimizer.param_groups[0]['lr']}") if epoch in [20, 50, 100, 200, 300, 500, 750, 1000]: - torch.save(model.state_dict(), f'Models/doublePhoton{modelVersion}_{Energy}_Noise{Noise}keV_E{epoch}.pth') + modelName = f'doublePhoton{modelVersion}_{Energy}_Noise{Noise}keV_E{epoch}' + if Normalize: + modelName += '_normalized' + torch.save(model.state_dict(), f'Models/{modelName}.pth') test(model, testLoader) -torch.save(model.state_dict(), f'Models/doublePhotonNet_{modelVersion}.pth') def plot_loss_curve(TrainLosses, ValLosses, TestLoss, modelVersion): import matplotlib.pyplot as plt