import sys sys.path.append('./src') import torch import numpy as np import models from datasets import * import torch.optim as optim from tqdm import tqdm from torchinfo import summary ### random seed for reproducibility torch.manual_seed(0) torch.cuda.manual_seed(0) np.random.seed(0) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False modelVersion = '251022' # '250909' or '251020' Energy = '15.3keV' TrainLosses, ValLosses = [], [] LearningRates = [] TestLoss = -1 model = models.get_model_class(modelVersion)().cuda() # summary(model, input_size=(128, 1, 3, 3)) LearningRate = 1e-3 Noise = 0.13 # in keV numberOfAugOps = 8 # 1 (no augmentation) or (1,8] (with augmentation) TrainLosses, TestLosses = [], [] def weighted_loss(pred, target, alpha=7.0): # weighted L1 loss for x,y position pred = pred[:, :2] target = target[:, :2] # weights = 1.0 + alpha * torch.abs(target) direction_weight = 1.0 + alpha * torch.abs(target) # (B, 2) beta = 3. r = torch.norm(target, dim=1, keepdim=True) radial_weight = 1.0 + beta * r # (B, 1) → weights = radial_weight * direction_weight loss = weights * torch.abs(pred - target) return loss.mean() LossFunction = weighted_loss def train(model, trainLoader, optimizer): model.train() batchLoss = 0 for batch_idx, (sample, label) in enumerate(trainLoader): sample, label = sample.cuda(), label.cuda() x, y, z, e = label[:,0], label[:,1], label[:,2], label[:,3] optimizer.zero_grad() output = model(sample) loss = LossFunction(output, torch.stack((x, y, z), axis=1)) loss.backward() optimizer.step() batchLoss += loss.item() * sample.shape[0] avgLoss = batchLoss / len(trainLoader.dataset) datasetName = trainLoader.dataset.datasetName print(f"[{datasetName}]\t Average Loss: {avgLoss:.6f} (sigma = {np.sqrt(avgLoss):.6f})") TrainLosses.append(avgLoss) def test(model, testLoader): model.eval() batchLoss = 0 with torch.no_grad(): for batch_idx, (sample, label) in enumerate(testLoader): sample, label = sample.cuda(), label.cuda() x, y, z, e = label[:,0], label[:,1], label[:,2], label[:,3] output = model(sample) loss = LossFunction(output, torch.stack((x, y, z), axis=1)) batchLoss += loss.item() * sample.shape[0] avgLoss = batchLoss / len(testLoader.dataset) datasetName = testLoader.dataset.datasetName print(f"[{datasetName}]\t Average Loss: {avgLoss:.6f} (sigma = {np.sqrt(avgLoss):.6f})") if datasetName == 'Val': ValLosses.append(avgLoss) else: global TestLoss TestLoss = avgLoss return avgLoss sampleFolder = '/mnt/sls_det_storage/moench_data/MLXID/Samples/Simulation/Moench040' trainDataset = singlePhotonDataset( [f'{sampleFolder}/{Energy}_Moench040_150V_{i}.npz' for i in range(13)], sampleRatio=1.0, datasetName='Train', noiseKeV = Noise, numberOfAugOps=numberOfAugOps, ) valDataset = singlePhotonDataset( [f'{sampleFolder}/{Energy}_Moench040_150V_{i}.npz' for i in range(13,14)], sampleRatio=1.0, datasetName='Val', noiseKeV = Noise, numberOfAugOps=numberOfAugOps, ) testDataset = singlePhotonDataset( [f'{sampleFolder}/{Energy}_Moench040_150V_{i}.npz' for i in range(15,16)], sampleRatio=1.0, datasetName='Test', noiseKeV = Noise, ) trainLoader = torch.utils.data.DataLoader( trainDataset, batch_size=4096, shuffle=True, num_workers=32, pin_memory=True, ) valLoader = torch.utils.data.DataLoader( valDataset, batch_size=1024, shuffle=False, num_workers=32, pin_memory=True, ) testLoader = torch.utils.data.DataLoader( testDataset, batch_size=1024, shuffle=4096, num_workers=32, pin_memory=True, ) optimizer = torch.optim.Adam(model.parameters(), lr=LearningRate) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.7, patience = 3) if __name__ == "__main__": for epoch in tqdm(range(1, 1001)): train(model, trainLoader, optimizer) test(model, valLoader) scheduler.step(ValLosses[-1]) print(f"Learning Rate: {optimizer.param_groups[0]['lr']}") if epoch in [20, 50, 100, 200, 300, 500, 750, 1000]: torch.save(model.state_dict(), f'Models/singlePhoton{modelVersion}_{Energy}_Noise{Noise}keV_E{epoch}_aug{numberOfAugOps}.pth') test(model, testLoader) torch.save(model.state_dict(), f'Models/singlePhoton{modelVersion}_{Energy}_Noise{Noise}keV_aug{numberOfAugOps}.pth') def plot_loss_curve(TrainLosses, ValLosses, TestLoss, modelVersion): import matplotlib.pyplot as plt plt.figure(figsize=(8,6)) plt.plot(TrainLosses, label='Train Loss', color='blue') plt.plot(ValLosses, label='Validation Loss', color='orange') if TestLoss > 0: plt.axhline(y=TestLoss, color='green', linestyle='--', label='Test Loss') plt.yscale('log') plt.xlabel('Epoch') plt.ylabel('MSE Loss') plt.legend() plt.grid() plt.savefig(f'Results/loss_curve_singlePhoton_{modelVersion}.png', dpi=300) plot_loss_curve(TrainLosses, ValLosses, TestLoss, modelVersion=modelVersion)