diff --git a/Train_SinglePhoton.py b/Train_SinglePhoton.py index e21627f..eef9f01 100644 --- a/Train_SinglePhoton.py +++ b/Train_SinglePhoton.py @@ -24,6 +24,7 @@ model = models.get_model_class(modelVersion)().cuda() # summary(model, input_size=(128, 1, 3, 3)) LearningRate = 1e-3 Noise = 0.13 # in keV +numberOfAugOps = 8 # 1 (no augmentation) or (1,8] (with augmentation) TrainLosses, TestLosses = [], [] def weighted_loss(pred, target, alpha=7.0): @@ -87,12 +88,14 @@ trainDataset = singlePhotonDataset( sampleRatio=1.0, datasetName='Train', noiseKeV = Noise, + numberOfAugOps=numberOfAugOps, ) valDataset = singlePhotonDataset( [f'{sampleFolder}/{Energy}_Moench040_150V_{i}.npz' for i in range(13,14)], sampleRatio=1.0, datasetName='Val', noiseKeV = Noise, + numberOfAugOps=numberOfAugOps, ) testDataset = singlePhotonDataset( [f'{sampleFolder}/{Energy}_Moench040_150V_{i}.npz' for i in range(15,16)], @@ -104,37 +107,37 @@ trainLoader = torch.utils.data.DataLoader( trainDataset, batch_size=4096, shuffle=True, - num_workers=16, + num_workers=32, pin_memory=True, ) valLoader = torch.utils.data.DataLoader( valDataset, batch_size=1024, shuffle=False, - num_workers=16, + num_workers=32, pin_memory=True, ) testLoader = torch.utils.data.DataLoader( testDataset, batch_size=1024, shuffle=4096, - num_workers=16, + num_workers=32, pin_memory=True, ) optimizer = torch.optim.Adam(model.parameters(), lr=LearningRate) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.7, patience = 3) if __name__ == "__main__": - for epoch in tqdm(range(1, 301)): + for epoch in tqdm(range(1, 1001)): train(model, trainLoader, optimizer) test(model, valLoader) scheduler.step(ValLosses[-1]) print(f"Learning Rate: {optimizer.param_groups[0]['lr']}") - if epoch in [20, 50, 100, 200, 300, 500]: - torch.save(model.state_dict(), f'Models/singlePhoton{modelVersion}_{Energy}_Noise{Noise}keV_E{epoch}.pth') + if epoch in [20, 50, 100, 200, 300, 500, 750, 1000]: + torch.save(model.state_dict(), f'Models/singlePhoton{modelVersion}_{Energy}_Noise{Noise}keV_E{epoch}_aug{numberOfAugOps}.pth') test(model, testLoader) -torch.save(model.state_dict(), f'Models/singlePhoton{modelVersion}_{Energy}_Noise{Noise}keV.pth') +torch.save(model.state_dict(), f'Models/singlePhoton{modelVersion}_{Energy}_Noise{Noise}keV_aug{numberOfAugOps}.pth') def plot_loss_curve(TrainLosses, ValLosses, TestLoss, modelVersion): import matplotlib.pyplot as plt