diff --git a/Train_SinglePhoton.py b/Train_SinglePhoton.py index e9cf067..e21627f 100644 --- a/Train_SinglePhoton.py +++ b/Train_SinglePhoton.py @@ -10,14 +10,37 @@ from torchinfo import summary ### random seed for reproducibility torch.manual_seed(0) +torch.cuda.manual_seed(0) np.random.seed(0) +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False modelVersion = '251022' # '250909' or '251020' +Energy = '15.3keV' TrainLosses, ValLosses = [], [] +LearningRates = [] TestLoss = -1 +model = models.get_model_class(modelVersion)().cuda() +# summary(model, input_size=(128, 1, 3, 3)) +LearningRate = 1e-3 Noise = 0.13 # in keV TrainLosses, TestLosses = [], [] +def weighted_loss(pred, target, alpha=7.0): + # weighted L1 loss for x,y position + pred = pred[:, :2] + target = target[:, :2] + + # weights = 1.0 + alpha * torch.abs(target) + direction_weight = 1.0 + alpha * torch.abs(target) # (B, 2) + beta = 3. + r = torch.norm(target, dim=1, keepdim=True) + radial_weight = 1.0 + beta * r # (B, 1) → + weights = radial_weight * direction_weight + + loss = weights * torch.abs(pred - target) + return loss.mean() +LossFunction = weighted_loss def train(model, trainLoader, optimizer): model.train() @@ -27,7 +50,7 @@ def train(model, trainLoader, optimizer): x, y, z, e = label[:,0], label[:,1], label[:,2], label[:,3] optimizer.zero_grad() output = model(sample) - loss = torch.nn.functional.mse_loss(output, torch.stack((x, y), axis=1)) + loss = LossFunction(output, torch.stack((x, y, z), axis=1)) loss.backward() optimizer.step() batchLoss += loss.item() * sample.shape[0] @@ -45,7 +68,7 @@ def test(model, testLoader): sample, label = sample.cuda(), label.cuda() x, y, z, e = label[:,0], label[:,1], label[:,2], label[:,3] output = model(sample) - loss = torch.nn.functional.mse_loss(output, torch.stack((x, y), axis=1)) + loss = LossFunction(output, torch.stack((x, y, z), axis=1)) batchLoss += loss.item() * sample.shape[0] avgLoss = batchLoss / len(testLoader.dataset) @@ -58,32 +81,28 @@ def test(model, testLoader): TestLoss = avgLoss return avgLoss -model = models.get_model_class(modelVersion)().cuda() -# summary(model, input_size=(128, 1, 3, 3)) - sampleFolder = '/mnt/sls_det_storage/moench_data/MLXID/Samples/Simulation/Moench040' trainDataset = singlePhotonDataset( - [f'{sampleFolder}/15keV_Moench040_150V_{i}.npz' for i in range(13)], + [f'{sampleFolder}/{Energy}_Moench040_150V_{i}.npz' for i in range(13)], sampleRatio=1.0, datasetName='Train', noiseKeV = Noise, ) valDataset = singlePhotonDataset( - [f'{sampleFolder}/15keV_Moench040_150V_{i}.npz' for i in range(13,14)], + [f'{sampleFolder}/{Energy}_Moench040_150V_{i}.npz' for i in range(13,14)], sampleRatio=1.0, datasetName='Val', noiseKeV = Noise, ) testDataset = singlePhotonDataset( - [f'{sampleFolder}/15keV_Moench040_150V_{i}.npz' for i in range(15,16)], + [f'{sampleFolder}/{Energy}_Moench040_150V_{i}.npz' for i in range(15,16)], sampleRatio=1.0, datasetName='Test', noiseKeV = Noise, ) - trainLoader = torch.utils.data.DataLoader( trainDataset, - batch_size=1024, + batch_size=4096, shuffle=True, num_workers=16, pin_memory=True, @@ -103,16 +122,19 @@ testLoader = torch.utils.data.DataLoader( pin_memory=True, ) -optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) -scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.7, patience = 5) +optimizer = torch.optim.Adam(model.parameters(), lr=LearningRate) +scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.7, patience = 3) if __name__ == "__main__": - for epoch in tqdm(range(1, 101)): + for epoch in tqdm(range(1, 301)): train(model, trainLoader, optimizer) test(model, valLoader) - scheduler.step(TrainLosses[-1]) + scheduler.step(ValLosses[-1]) + print(f"Learning Rate: {optimizer.param_groups[0]['lr']}") + if epoch in [20, 50, 100, 200, 300, 500]: + torch.save(model.state_dict(), f'Models/singlePhoton{modelVersion}_{Energy}_Noise{Noise}keV_E{epoch}.pth') test(model, testLoader) -torch.save(model.state_dict(), f'Models/singlePhotonNet_Noise{Noise}keV_{modelVersion}.pth') +torch.save(model.state_dict(), f'Models/singlePhoton{modelVersion}_{Energy}_Noise{Noise}keV.pth') def plot_loss_curve(TrainLosses, ValLosses, TestLoss, modelVersion): import matplotlib.pyplot as plt