diff --git a/Configs/train_1photon.yaml b/Configs/train_1photon.yaml new file mode 100644 index 0000000..d166e5a --- /dev/null +++ b/Configs/train_1photon.yaml @@ -0,0 +1,40 @@ +# configs/train_1photon.yaml +experiment: + name: "1photon_15keV" + +data: + sample_folder: "/mnt/sls_det_storage/moench_data/MLXID/Samples/Simulation/Moench040" + energy: 15 ### in keV + noise_keV: 0.13 + noise_threshold: 0.0 ### set values below (noise * noise_threshold) to zero + num_aug_ops: 1 + normalize: true + + batch_size_train: 4096 + batch_size_val: 1024 + batch_size_test: 1024 + num_workers: 32 + train_file_range: [0, 12] + val_file_range: [13, 14] + test_file_range: [15, 15] + +model: + version: "251022" + +training: + epochs: 15 + learning_rate: 1.0e-3 + loss_alpha: 7.0 + loss_beta: 3.0 + scheduler_factor: 0.7 + scheduler_patience: 3 + checkpoint_epochs: [10, 30, 50, 100, 150] + +loss: + type: "weightedL1" + alpha: 7.0 + beta: 3.0 + +logging: + save_dir: "Results" + model_dir: "Models" \ No newline at end of file diff --git a/Train_1Photon.py b/Train_1Photon.py index 4af36bc..fe4de1c 100644 --- a/Train_1Photon.py +++ b/Train_1Photon.py @@ -1,5 +1,6 @@ import sys sys.path.append('./src') +from omegaconf import OmegaConf ### for yaml config parsing import torch import numpy as np import models @@ -7,6 +8,10 @@ from datasets import * import torch.optim as optim from tqdm import tqdm from torchinfo import summary +from pathlib import Path + +from models import get_model_class +from datasets import singlePhotonDataset ### random seed for reproducibility torch.manual_seed(0) @@ -15,168 +20,169 @@ np.random.seed(0) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False -modelVersion = '251022' # '250909' or '251020' -Energy = '15keV' -TrainLosses, ValLosses = [], [] -LearningRates = [] -TestLoss = -1 -model = models.get_model_class(modelVersion)().cuda() -# summary(model, input_size=(128, 1, 3, 3)) -LearningRate = 1e-3 -Noise = 0.23 # in keV -NoiseThreshold = 0 * Noise # in keV, set values below this threshold to zero -numberOfAugOps = 1 # 1 (no augmentation) or (1,8] (with augmentation) -flag_normalize = False +conf = OmegaConf.load("Configs/train_1photon.yaml") -TrainLosses, TestLosses = [], [] -def weighted_loss(pred, target, alpha=7.0): - # weighted L1 loss for x,y position - pred = pred[:, :2] - target = target[:, :2] +def prepare_output_folder(conf): + from datetime import datetime + date = datetime.now().strftime("%y%m%d") ## YYMMDD format + # find the next index for experiment name + exp_index = 0 + while True: + exp_name = f'{date}_1ph_{conf.data.energy}keV_v{conf.model.version}_{exp_index:02d}' + if not Path(f'Results/{exp_name}').exists(): + break + exp_index += 1 + Path(f'Results/{exp_name}').mkdir(parents=True, exist_ok=True) + Path(f'Results/{exp_name}/Models').mkdir(parents=True, exist_ok=True) + Path(f'Results/{exp_name}/Plots').mkdir(parents=True, exist_ok=True) + OmegaConf.save(conf, f'Results/{exp_name}/config.yaml') + return exp_name - # weights = 1.0 + alpha * torch.abs(target) - direction_weight = 1.0 + alpha * torch.abs(target) # (B, 2) - beta = 3. - r = torch.norm(target, dim=1, keepdim=True) - radial_weight = 1.0 + beta * r # (B, 1) → - weights = radial_weight * direction_weight +def get_loss_function(conf): + if conf.loss.type == "weightedL1": + alpha = conf.loss.alpha + beta = conf.loss.beta + def weighted_loss(pred, target): + # weighted L1 loss for x,y position + pred = pred[:, :2] + target = target[:, :2] - loss = weights * torch.abs(pred - target) - return loss.mean() -LossFunction = weighted_loss + direction_weight = 1.0 + alpha * torch.abs(target) # (B, 2) + r = torch.norm(target, dim=1, keepdim=True) + radial_weight = 1.0 + beta * r # (B, 1) → + weights = radial_weight * direction_weight -def train(model, trainLoader, optimizer): + loss = weights * torch.abs(pred - target) + return loss.mean() + return weighted_loss + +def train(model, trainLoader, optimizer, loss_fn): model.train() batchLoss = 0 + total_samples = 0 rms_x, rms_y = 0, 0 for batch_idx, (sample, label) in enumerate(trainLoader): sample, label = sample.cuda(), label.cuda() x, y, z, e = label[:,0], label[:,1], label[:,2], label[:,3] optimizer.zero_grad() output = model(sample) - loss = LossFunction(output, torch.stack((x, y, z), axis=1)) + loss = loss_fn(output, torch.stack((x, y, z), axis=1)) loss.backward() optimizer.step() batchLoss += loss.item() * sample.shape[0] + total_samples += sample.shape[0] rms_x += torch.sum((output[:,0] - x)**2).item() rms_y += torch.sum((output[:,1] - y)**2).item() - avgLoss = batchLoss / len(trainLoader.dataset) - rms_x = np.sqrt(rms_x / len(trainLoader.dataset)) - rms_y = np.sqrt(rms_y / len(trainLoader.dataset)) + avgLoss = batchLoss / total_samples + rms_x = np.sqrt(rms_x / total_samples) + rms_y = np.sqrt(rms_y / total_samples) datasetName = trainLoader.dataset.datasetName print(f"[{datasetName}]\t Average Loss: {avgLoss:.6f} (sigma = {np.sqrt(avgLoss):.6f}) \t RMS X: {rms_x:.6f} \t RMS Y: {rms_y:.6f}") - TrainLosses.append(avgLoss) + return avgLoss -def test(model, testLoader): +def evaluate(model, testLoader, loss_fn): model.eval() batchLoss = 0 + rms_x, rms_y = 0, 0 with torch.no_grad(): for batch_idx, (sample, label) in enumerate(testLoader): sample, label = sample.cuda(), label.cuda() x, y, z, e = label[:,0], label[:,1], label[:,2], label[:,3] output = model(sample) - loss = LossFunction(output, torch.stack((x, y, z), axis=1)) + loss = loss_fn(output, torch.stack((x, y, z), axis=1)) batchLoss += loss.item() * sample.shape[0] + rms_x += torch.sum((output[:,0] - x)**2).item() + rms_y += torch.sum((output[:,1] - y)**2).item() avgLoss = batchLoss / len(testLoader.dataset) + rms_x = np.sqrt(rms_x / len(testLoader.dataset)) + rms_y = np.sqrt(rms_y / len(testLoader.dataset)) datasetName = testLoader.dataset.datasetName - print(f"[{datasetName}]\t Average Loss: {avgLoss:.6f} (sigma = {np.sqrt(avgLoss):.6f})") - if datasetName == 'Val': - ValLosses.append(avgLoss) - else: - global TestLoss - TestLoss = avgLoss + print(f"[{datasetName}]\t Average Loss: {avgLoss:.6f} (sigma = {np.sqrt(avgLoss):.6f}) \t RMS X: {rms_x:.6f} \t RMS Y: {rms_y:.6f}") return avgLoss -sampleFolder = '/mnt/sls_det_storage/moench_data/MLXID/Samples/Simulation/Moench040' -trainDataset = singlePhotonDataset( - [f'{sampleFolder}/{Energy}_Moench040_150V_{i}.npz' for i in range(13)], - sampleRatio=1.0, - datasetName='Train', - noiseKeV = Noise, - numberOfAugOps=numberOfAugOps, - normalize=flag_normalize, - noiseThreshold=NoiseThreshold - ) -valDataset = singlePhotonDataset( - [f'{sampleFolder}/{Energy}_Moench040_150V_{i}.npz' for i in range(13,14)], - sampleRatio=1.0, - datasetName='Val', - noiseKeV = Noise, - numberOfAugOps=numberOfAugOps, - normalize=flag_normalize, - noiseThreshold=NoiseThreshold - ) -testDataset = singlePhotonDataset( - [f'{sampleFolder}/{Energy}_Moench040_150V_{i}.npz' for i in range(15,16)], - sampleRatio=1.0, - datasetName='Test', - noiseKeV = Noise, - numberOfAugOps=numberOfAugOps, - normalize=flag_normalize, - noiseThreshold=NoiseThreshold - ) -trainLoader = torch.utils.data.DataLoader( - trainDataset, - batch_size=4096, - shuffle=True, - num_workers=32, - pin_memory=True, - ) -valLoader = torch.utils.data.DataLoader( - valDataset, - batch_size=1024, - shuffle=False, - num_workers=32, - pin_memory=True, - ) -testLoader = torch.utils.data.DataLoader( - testDataset, - batch_size=1024, - shuffle=4096, - num_workers=32, - pin_memory=True, - ) +def get_dataloaders(conf): + """construct all dataLoaders""" + datasets = {} + loaders = {} + splits = ['train', 'val', 'test'] + keys = ['train_files', 'val_files', 'test_files'] + batch_keys = ['batch_size_train', 'batch_size_val', 'batch_size_test'] + file_range_keys = ['train_file_range', 'val_file_range', 'test_file_range'] + + for split, f_key, b_key, fr_key in zip(splits, keys, batch_keys, file_range_keys): + file_paths = [f"{conf.data.sample_folder}/{conf.data.energy}keV_Moench040_150V_{i}.npz" for i in range(conf.data[fr_key][0], conf.data[fr_key][1] + 1)] + + dataset = singlePhotonDataset( + file_paths, + sampleRatio=1.0, + datasetName=split.capitalize(), + noiseKeV=conf.data.noise_keV, + numberOfAugOps=conf.data.num_aug_ops, + normalize=conf.data.normalize, + noiseThreshold=conf.data.noise_threshold * conf.data.noise_keV + ) + datasets[split] = dataset + + loaders[split] = torch.utils.data.DataLoader( + dataset, + batch_size=conf.data[b_key], + shuffle=(split == 'train'), + num_workers=conf.data.num_workers, + pin_memory=True + ) + return loaders['train'], loaders['val'], loaders['test'] -optimizer = torch.optim.Adam(model.parameters(), lr=LearningRate) -scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.7, patience = 3) - -def plot_loss_curve(TrainLosses, ValLosses, modelVersion, TestLoss=0): +def plot_loss_curve(train_losses, val_losses, test_loss, exp_name, conf): import matplotlib.pyplot as plt plt.figure(figsize=(8,6)) - plt.plot(TrainLosses, label='Train Loss', color='blue') - plt.plot(ValLosses, label='Validation Loss', color='orange') - if TestLoss > 0: - plt.axhline(y=TestLoss, color='green', linestyle='--', label='Test Loss') + plt.plot(train_losses, label='Train Loss', color='blue') + plt.plot(val_losses, label='Validation Loss', color='orange') + if test_loss > 0: + plt.axhline(y=test_loss, color='green', linestyle='--', label='Test Loss') plt.yscale('log') plt.xlabel('Epoch') plt.ylabel('MSE Loss') plt.legend() plt.grid() - plotName = f'loss_curve_singlePhoton_{modelVersion}' - if flag_normalize: + plotName = f'loss_curve_singlePhoton_{conf.model.version}' + if conf.data.normalize: plotName += '_normalized' - plt.savefig(f'Results/{plotName}.png') + plt.savefig(f'Results/{exp_name}/Plots/{plotName}.png') + +def get_model_name(epoch, conf): + modelName = f'singlePhoton{conf.model.version}_{conf.data.energy}keV_Noise{conf.data.noise_keV}keV_E{epoch}_aug{conf.data.num_aug_ops}' + if conf.data.normalize: + modelName += '_normalized' + return modelName if __name__ == "__main__": - for epoch in tqdm(range(1, 151)): - train(model, trainLoader, optimizer) - test(model, valLoader) - scheduler.step(ValLosses[-1]) - print(f"Learning Rate: {optimizer.param_groups[0]['lr']}") - if epoch in [20, 30, 50, 100, 150]: - modelName = f'singlePhoton{modelVersion}_{Energy}_Noise{Noise}keV_E{epoch}_aug{numberOfAugOps}' - if flag_normalize == True: - modelName += '_normalized' - torch.save(model.state_dict(), f'Models/{modelName}.pth') - print(f"Saved model checkpoint: {modelName}.pth") - plot_loss_curve(TrainLosses, ValLosses, modelVersion=modelVersion) -test(model, testLoader) -modelName = f'singlePhoton{modelVersion}_{Energy}_Noise{Noise}keV_E{epoch}_aug{numberOfAugOps}' -if flag_normalize == True: - modelName += '_normalized' -torch.save(model.state_dict(), f'Models/{modelName}.pth') -print(f"Saved final model checkpoint: {modelName}.pth") -plot_loss_curve(TrainLosses, ValLosses, modelVersion=modelVersion, TestLoss=TestLoss) \ No newline at end of file + exp_name = prepare_output_folder(conf) + model = models.get_model_class(conf.model.version)().cuda() + # summary(model, input_size=(128, 1, 3, 3)) + + loss_fn = get_loss_function(conf) + optimizer = torch.optim.Adam(model.parameters(), lr=conf.training.learning_rate) + scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=conf.training.scheduler_factor, patience = conf.training.scheduler_patience) + + trainLoader, valLoader, testLoader = get_dataloaders(conf) + + train_losses = [] + val_losses = [] + + for epoch in tqdm(range(1, conf.training.epochs + 1)): + t_loss = train(model, trainLoader, optimizer, loss_fn) + train_losses.append(t_loss) + v_loss = evaluate(model, valLoader, loss_fn) + val_losses.append(v_loss) + scheduler.step(val_losses[-1]) + print(f"Learning Rate: {optimizer.param_groups[0]['lr']}") + if epoch in conf.training.checkpoint_epochs or epoch == conf.training.epochs: + modelName = get_model_name(epoch, conf) + torch.save(model.state_dict(), f'Results/{exp_name}/Models/{modelName}.pth') + print(f"Saved model checkpoint: {modelName}.pth") + plot_loss_curve(train_losses, val_losses, test_loss=-1, exp_name=exp_name, conf=conf) + test_loss = evaluate(model, testLoader, loss_fn) + plot_loss_curve(train_losses, val_losses, test_loss=test_loss, exp_name=exp_name, conf=conf) \ No newline at end of file