import sys sys.path.append('./src') from omegaconf import OmegaConf ### for yaml config parsing import torch import numpy as np import models from datasets import * import torch.optim as optim from tqdm import tqdm from torchinfo import summary from pathlib import Path from models import get_model_class from datasets import singlePhotonDataset ### random seed for reproducibility torch.manual_seed(0) torch.cuda.manual_seed(0) np.random.seed(0) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False conf = OmegaConf.load("Configs/train_1photon.yaml") def prepare_output_folder(conf): from datetime import datetime date = datetime.now().strftime("%y%m%d") ## YYMMDD format # find the next index for experiment name exp_index = 0 while True: exp_name = f'{date}_1ph_{conf.data.energy}keV_v{conf.model.version}_{exp_index:02d}' if not Path(f'Results/{exp_name}').exists(): break exp_index += 1 Path(f'Results/{exp_name}').mkdir(parents=True, exist_ok=True) Path(f'Results/{exp_name}/Models').mkdir(parents=True, exist_ok=True) Path(f'Results/{exp_name}/Plots').mkdir(parents=True, exist_ok=True) OmegaConf.save(conf, f'Results/{exp_name}/config.yaml') return exp_name def get_loss_function(conf): if conf.loss.type == "weightedL1": alpha = conf.loss.alpha beta = conf.loss.beta def weighted_loss(pred, target): # weighted L1 loss for x,y position pred = pred[:, :2] target = target[:, :2] direction_weight = 1.0 + alpha * torch.abs(target) # (B, 2) r = torch.norm(target, dim=1, keepdim=True) radial_weight = 1.0 + beta * r # (B, 1) → weights = radial_weight * direction_weight loss = weights * torch.abs(pred - target) return loss.mean() return weighted_loss def train(model, trainLoader, optimizer, loss_fn): model.train() batchLoss = 0 total_samples = 0 rms_x, rms_y = 0, 0 for batch_idx, (sample, label) in enumerate(trainLoader): sample, label = sample.cuda(), label.cuda() x, y, z, e = label[:,0], label[:,1], label[:,2], label[:,3] optimizer.zero_grad() output = model(sample) loss = loss_fn(output, torch.stack((x, y, z), axis=1)) loss.backward() optimizer.step() batchLoss += loss.item() * sample.shape[0] total_samples += sample.shape[0] rms_x += torch.sum((output[:,0] - x)**2).item() rms_y += torch.sum((output[:,1] - y)**2).item() avgLoss = batchLoss / total_samples rms_x = np.sqrt(rms_x / total_samples) rms_y = np.sqrt(rms_y / total_samples) datasetName = trainLoader.dataset.datasetName print(f"[{datasetName}]\t Average Loss: {avgLoss:.6f} (sigma = {np.sqrt(avgLoss):.6f}) \t RMS X: {rms_x:.6f} \t RMS Y: {rms_y:.6f}") return avgLoss def evaluate(model, testLoader, loss_fn): model.eval() batchLoss = 0 rms_x, rms_y = 0, 0 with torch.no_grad(): for batch_idx, (sample, label) in enumerate(testLoader): sample, label = sample.cuda(), label.cuda() x, y, z, e = label[:,0], label[:,1], label[:,2], label[:,3] output = model(sample) loss = loss_fn(output, torch.stack((x, y, z), axis=1)) batchLoss += loss.item() * sample.shape[0] rms_x += torch.sum((output[:,0] - x)**2).item() rms_y += torch.sum((output[:,1] - y)**2).item() avgLoss = batchLoss / len(testLoader.dataset) rms_x = np.sqrt(rms_x / len(testLoader.dataset)) rms_y = np.sqrt(rms_y / len(testLoader.dataset)) datasetName = testLoader.dataset.datasetName print(f"[{datasetName}]\t Average Loss: {avgLoss:.6f} (sigma = {np.sqrt(avgLoss):.6f}) \t RMS X: {rms_x:.6f} \t RMS Y: {rms_y:.6f}") return avgLoss def get_dataloaders(conf): """construct all dataLoaders""" datasets = {} loaders = {} splits = ['train', 'val', 'test'] keys = ['train_files', 'val_files', 'test_files'] batch_keys = ['batch_size_train', 'batch_size_val', 'batch_size_test'] file_range_keys = ['train_file_range', 'val_file_range', 'test_file_range'] for split, f_key, b_key, fr_key in zip(splits, keys, batch_keys, file_range_keys): file_paths = [f"{conf.data.sample_folder}/{conf.data.energy}keV_Moench040_150V_{i}.npz" for i in range(conf.data[fr_key][0], conf.data[fr_key][1] + 1)] dataset = singlePhotonDataset( file_paths, sampleRatio=1.0, datasetName=split.capitalize(), noiseKeV=conf.data.noise_keV, numberOfAugOps=conf.data.num_aug_ops, normalize=conf.data.normalize, noiseThreshold=conf.data.noise_threshold * conf.data.noise_keV ) datasets[split] = dataset loaders[split] = torch.utils.data.DataLoader( dataset, batch_size=conf.data[b_key], shuffle=(split == 'train'), num_workers=conf.data.num_workers, pin_memory=True ) return loaders['train'], loaders['val'], loaders['test'] def plot_loss_curve(train_losses, val_losses, test_loss, exp_name, conf): import matplotlib.pyplot as plt plt.figure(figsize=(8,6)) plt.plot(train_losses, label='Train Loss', color='blue') plt.plot(val_losses, label='Validation Loss', color='orange') if test_loss > 0: plt.axhline(y=test_loss, color='green', linestyle='--', label='Test Loss') plt.yscale('log') plt.xlabel('Epoch') plt.ylabel('MSE Loss') plt.legend() plt.grid() plotName = f'loss_curve_singlePhoton_{conf.model.version}' if conf.data.normalize: plotName += '_normalized' plt.savefig(f'Results/{exp_name}/Plots/{plotName}.png') def get_model_name(epoch, conf): modelName = f'singlePhoton{conf.model.version}_{conf.data.energy}keV_Noise{conf.data.noise_keV}keV_E{epoch}_aug{conf.data.num_aug_ops}' if conf.data.normalize: modelName += '_normalized' return modelName if __name__ == "__main__": exp_name = prepare_output_folder(conf) model = models.get_model_class(conf.model.version)().cuda() # summary(model, input_size=(128, 1, 3, 3)) loss_fn = get_loss_function(conf) optimizer = torch.optim.Adam(model.parameters(), lr=conf.training.learning_rate) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=conf.training.scheduler_factor, patience = conf.training.scheduler_patience) trainLoader, valLoader, testLoader = get_dataloaders(conf) train_losses = [] val_losses = [] for epoch in tqdm(range(1, conf.training.epochs + 1)): t_loss = train(model, trainLoader, optimizer, loss_fn) train_losses.append(t_loss) v_loss = evaluate(model, valLoader, loss_fn) val_losses.append(v_loss) scheduler.step(val_losses[-1]) print(f"Learning Rate: {optimizer.param_groups[0]['lr']}") if epoch in conf.training.checkpoint_epochs or epoch == conf.training.epochs: modelName = get_model_name(epoch, conf) torch.save(model.state_dict(), f'Results/{exp_name}/Models/{modelName}.pth') print(f"Saved model checkpoint: {modelName}.pth") plot_loss_curve(train_losses, val_losses, test_loss=-1, exp_name=exp_name, conf=conf) test_loss = evaluate(model, testLoader, loss_fn) plot_loss_curve(train_losses, val_losses, test_loss=test_loss, exp_name=exp_name, conf=conf)