import sys sys.path.append('./src') from omegaconf import OmegaConf ### for yaml config parsing import torch import numpy as np import models from datasets import * import torch.optim as optim from tqdm import tqdm from torchinfo import summary from pathlib import Path from models import get_model_class from datasets import singlePhotonDataset ### random seed for reproducibility torch.manual_seed(0) torch.cuda.manual_seed(0) np.random.seed(0) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False conf = OmegaConf.load("Configs/train_1photon.yaml") def prepare_output_folder(conf): from datetime import datetime date = datetime.now().strftime("%y%m%d") ## YYMMDD format # find the next index for experiment name exp_index = 0 while True: exp_name = f'{date}_1ph_{conf.data.energy}keV_v{conf.model.version}_{exp_index:02d}' if not Path(f'Results/{exp_name}').exists(): break exp_index += 1 Path(f'Results/{exp_name}').mkdir(parents=True, exist_ok=True) Path(f'Results/{exp_name}/Models').mkdir(parents=True, exist_ok=True) Path(f'Results/{exp_name}/Plots').mkdir(parents=True, exist_ok=True) OmegaConf.save(conf, f'Results/{exp_name}/config.yaml') return exp_name def get_loss_function(conf): if conf.loss.type == "weightedL1": alpha = conf.loss.alpha beta = conf.loss.beta def weighted_loss(pred, target): # weighted L1 loss for x,y position pred = pred[:, :2] target = target[:, :2] direction_weight = 1.0 + alpha * torch.abs(target) # (B, 2) r = torch.norm(target, dim=1, keepdim=True) radial_weight = 1.0 + beta * r # (B, 1) → weights = radial_weight * direction_weight loss = weights * torch.abs(pred - target) return loss.mean() return weighted_loss if conf.loss.type == "weightedSumL1": alpha = conf.loss.alpha_weighted_sum beta = conf.loss.beta_weighted_sum def weighted_loss(pred, target): pred = pred[:, :2] target = target[:, :2] r = torch.norm(target, dim=1, keepdim=True) # (B, 1) weights = 1.0 + beta * r + alpha * torch.abs(target) loss = weights * torch.abs(pred - target) return loss.mean() return weighted_loss if conf.loss.type == "huber_loss": huber_loss_fn = torch.nn.SmoothL1Loss(reduction='mean', beta = conf.loss.huber_beta) def huber_loss(pred, target): pred = pred[:, :2] target = target[:, :2] return huber_loss_fn(pred, target) return huber_loss def train(model, trainLoader, optimizer, loss_fn): model.train() batchLoss = 0 total_samples = 0 rms_x, rms_y = 0, 0 for batch_idx, (sample, label) in enumerate(trainLoader): sample, label = sample.cuda(), label.cuda() x, y, z, e = label[:,0], label[:,1], label[:,2], label[:,3] optimizer.zero_grad() output = model(sample) loss = loss_fn(output, torch.stack((x, y, z), axis=1)) loss.backward() optimizer.step() batchLoss += loss.item() * sample.shape[0] total_samples += sample.shape[0] rms_x += torch.sum((output[:,0] - x)**2).item() rms_y += torch.sum((output[:,1] - y)**2).item() avgLoss = batchLoss / total_samples rms_x = np.sqrt(rms_x / total_samples) rms_y = np.sqrt(rms_y / total_samples) datasetName = trainLoader.dataset.datasetName print(f"[{datasetName}]\t Average Loss: {avgLoss:.6f} \t RMS X: {rms_x:.6f} \t RMS Y: {rms_y:.6f}") return avgLoss, rms_x, rms_y def evaluate(model, testLoader, loss_fn): model.eval() batchLoss = 0 rms_x, rms_y = 0, 0 with torch.no_grad(): for batch_idx, (sample, label) in enumerate(testLoader): sample, label = sample.cuda(), label.cuda() x, y, z, e = label[:,0], label[:,1], label[:,2], label[:,3] output = model(sample) loss = loss_fn(output, torch.stack((x, y, z), axis=1)) batchLoss += loss.item() * sample.shape[0] rms_x += torch.sum((output[:,0] - x)**2).item() rms_y += torch.sum((output[:,1] - y)**2).item() avgLoss = batchLoss / len(testLoader.dataset) rms_x = np.sqrt(rms_x / len(testLoader.dataset)) rms_y = np.sqrt(rms_y / len(testLoader.dataset)) datasetName = testLoader.dataset.datasetName print(f"[{datasetName}]\t Average Loss: {avgLoss:.6f} \t RMS X: {rms_x:.6f} \t RMS Y: {rms_y:.6f}") return avgLoss, rms_x, rms_y def get_dataloaders(conf): """construct all dataLoaders""" datasets = {} loaders = {} splits = ['train', 'val', 'test'] keys = ['train_files', 'val_files', 'test_files'] batch_keys = ['batch_size_train', 'batch_size_val', 'batch_size_test'] file_range_keys = ['train_file_range', 'val_file_range', 'test_file_range'] for split, f_key, b_key, fr_key in zip(splits, keys, batch_keys, file_range_keys): file_paths = [f"{conf.data.sample_folder}/{conf.data.energy}keV_Moench040_150V_{i}.npz" for i in range(conf.data[fr_key][0], conf.data[fr_key][1] + 1)] dataset = singlePhotonDataset( file_paths, sampleRatio=1.0, datasetName=split.capitalize(), noiseKeV=conf.data.noise_keV, numberOfAugOps=conf.data.num_aug_ops, normalize=conf.data.normalize, noiseThreshold=conf.data.noise_threshold * conf.data.noise_keV ) datasets[split] = dataset loaders[split] = torch.utils.data.DataLoader( dataset, batch_size=conf.data[b_key], shuffle=(split == 'train'), num_workers=conf.data.num_workers, pin_memory=True ) return loaders['train'], loaders['val'], loaders['test'] def plot_loss_curve(train_losses, val_losses, test_loss, exp_name, conf): import matplotlib.pyplot as plt plt.figure(figsize=(8,6)) plt.plot(train_losses, label='Train Loss', color='blue') plt.plot(val_losses, label='Validation Loss', color='orange') if test_loss > 0: plt.axhline(y=test_loss, color='green', linestyle='--', label='Test Loss') plt.yscale('log') plt.xlabel('Epoch') plt.ylabel('MSE Loss') plt.legend() plt.grid() ### print test loss in the plot plt.text(0.5, 0.5, f'Test Loss: {test_loss:.6f}', fontsize=12, color='green', ha='center', va='center', transform=plt.gca().transAxes) plotName = f'loss_curve_singlePhoton_{conf.model.version}' if conf.data.normalize: plotName += '_normalized' plt.savefig(f'Results/{exp_name}/Plots/{plotName}.png') def plot_rms_curve(train_rms_x, train_rms_y, val_rms_x, val_rms_y, test_rms_x, test_rms_y, exp_name, conf): import matplotlib.pyplot as plt plt.figure(figsize=(8,6)) plt.plot(train_rms_x, label='Train RMS X', color='blue') plt.plot(train_rms_y, label='Train RMS Y', color='cyan') plt.plot(val_rms_x, label='Val RMS X', color='orange') plt.plot(val_rms_y, label='Val RMS Y', color='magenta') if test_rms_x > 0 and test_rms_y > 0: plt.axhline(y=test_rms_x, color='green', linestyle='--', label='Test RMS X') plt.axhline(y=test_rms_y, color='lime', linestyle='--', label='Test RMS Y') plt.xlabel('Epoch') plt.ylabel('RMS Error [pixels]') plt.legend() plt.grid() plotName = f'rms_curve_singlePhoton_{conf.model.version}' if conf.data.normalize: plotName += '_normalized' plt.savefig(f'Results/{exp_name}/Plots/{plotName}.png') def get_model_name(epoch, conf): modelName = f'singlePhoton{conf.model.version}_{conf.data.energy}keV_Noise{conf.data.noise_keV}keV_E{epoch}_aug{conf.data.num_aug_ops}' if conf.data.normalize: modelName += '_normalized' return modelName if __name__ == "__main__": exp_name = prepare_output_folder(conf) model = models.get_model_class(conf.model.version)().cuda() # summary(model, input_size=(128, 3, 3, 3)) loss_fn = get_loss_function(conf) optimizer = torch.optim.Adam(model.parameters(), lr=conf.training.learning_rate) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=conf.training.scheduler_factor, patience = conf.training.scheduler_patience) trainLoader, valLoader, testLoader = get_dataloaders(conf) train_losses = [] val_losses = [] train_rms_xs, train_rms_ys = [], [] val_rms_xs, val_rms_ys = [], [] for epoch in tqdm(range(1, conf.training.epochs + 1)): t_loss, train_rms_x, train_rms_y = train(model, trainLoader, optimizer, loss_fn) train_losses.append(t_loss) train_rms_xs.append(train_rms_x) train_rms_ys.append(train_rms_y) v_loss, val_rms_x, val_rms_y = evaluate(model, valLoader, loss_fn) val_losses.append(v_loss) val_rms_xs.append(val_rms_x) val_rms_ys.append(val_rms_y) scheduler.step(val_losses[-1]) print(f"Learning Rate: {optimizer.param_groups[0]['lr']}") if epoch in conf.training.checkpoint_epochs or epoch == conf.training.epochs: modelName = get_model_name(epoch, conf) torch.save(model.state_dict(), f'Results/{exp_name}/Models/{modelName}.pth') print(f"Saved model checkpoint: {modelName}.pth") plot_loss_curve(train_losses, val_losses, test_loss=-1, exp_name=exp_name, conf=conf) plot_rms_curve(train_rms_xs, train_rms_ys, val_rms_xs, val_rms_ys, test_rms_x = -1, test_rms_y = -1, exp_name=exp_name, conf=conf) test_loss, test_rms_x, test_rms_y = evaluate(model, testLoader, loss_fn) plot_loss_curve(train_losses, val_losses, test_loss=test_loss, exp_name=exp_name, conf=conf) plot_rms_curve(train_rms_xs, train_rms_ys, val_rms_xs, val_rms_ys, test_rms_x, test_rms_y, exp_name=exp_name, conf=conf)