Update 2Photon training

This commit is contained in:
2026-05-06 16:34:49 +02:00
parent 73bbcb404b
commit 4b6453c1c2
+17 -4
View File
@@ -12,13 +12,15 @@ from torchinfo import summary
torch.manual_seed(0)
np.random.seed(0)
modelVersion = '251001_2' # '250910' or '251001'
modelVersion = '251124' # '250910' or '251001' or '251124'
model = models.get_double_photon_model_class(modelVersion)().cuda()
Energy = '15.3keV'
Energy = '12keV'
TrainLosses, ValLosses = [], []
TestLoss = -1
LearningRate = 1e-3
Noise = 0.13 # in keV
NoiseThreshold = 3*Noise
Normalize = False
def two_point_set_loss_l2(pred_xy, gt_xy):
def pair_cost_l2sq(p, q): # p,q: (...,2)
@@ -118,6 +120,9 @@ trainDataset = doublePhotonDataset(
datasetName='Train',
reuselFactor=1,
noiseKeV = Noise,
nSize=7,
noiseThreshold = NoiseThreshold,
normalize = Normalize
)
valDataset = doublePhotonDataset(
[f'{sampleFolder}/{Energy}_Moench040_150V_{i}.npz' for i in range(13,14)],
@@ -125,6 +130,9 @@ valDataset = doublePhotonDataset(
datasetName='Val',
reuselFactor=1,
noiseKeV = Noise,
nSize=7,
noiseThreshold = NoiseThreshold,
normalize = Normalize
)
testDataset = doublePhotonDataset(
[f'{sampleFolder}/{Energy}_Moench040_150V_{i}.npz' for i in range(15,16)],
@@ -132,6 +140,9 @@ testDataset = doublePhotonDataset(
datasetName='Test',
reuselFactor=1,
noiseKeV = Noise,
nSize=7,
noiseThreshold = NoiseThreshold,
normalize = Normalize
)
trainLoader = torch.utils.data.DataLoader(
trainDataset,
@@ -161,10 +172,12 @@ if __name__ == "__main__":
scheduler.step(ValLosses[-1])
print(f"Learning Rate: {optimizer.param_groups[0]['lr']}")
if epoch in [20, 50, 100, 200, 300, 500, 750, 1000]:
torch.save(model.state_dict(), f'Models/doublePhoton{modelVersion}_{Energy}_Noise{Noise}keV_E{epoch}.pth')
modelName = f'doublePhoton{modelVersion}_{Energy}_Noise{Noise}keV_E{epoch}'
if Normalize:
modelName += '_normalized'
torch.save(model.state_dict(), f'Models/{modelName}.pth')
test(model, testLoader)
torch.save(model.state_dict(), f'Models/doublePhotonNet_{modelVersion}.pth')
def plot_loss_curve(TrainLosses, ValLosses, TestLoss, modelVersion):
import matplotlib.pyplot as plt