import sys sys.path.append('./src') from omegaconf import OmegaConf ### for yaml config parsing import torch import numpy as np import torch.optim as optim from tqdm import tqdm from torchinfo import summary from pathlib import Path from models import get_double_photon_model_class from datasets import doublePhotonDataset ### random seed for reproducibility torch.manual_seed(0) torch.cuda.manual_seed(0) np.random.seed(0) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False conf = OmegaConf.load("Configs/train_2photon.yaml") def prepare_output_folder(conf): from datetime import datetime date = datetime.now().strftime("%y%m%d") ## YYMMDD format # find the next index for experiment name exp_index = 0 while True: exp_name = f'{date}_2ph_{conf.data.energy}keV_v{conf.model.version}_{exp_index:02d}' if not Path(f'Results/{exp_name}').exists(): break exp_index += 1 Path(f'Results/{exp_name}').mkdir(parents=True, exist_ok=True) Path(f'Results/{exp_name}/Models').mkdir(parents=True, exist_ok=True) Path(f'Results/{exp_name}/Plots').mkdir(parents=True, exist_ok=True) OmegaConf.save(conf, f'Results/{exp_name}/config.yaml') return exp_name def get_loss_function(conf): if conf.loss.type == "two_point_set_loss_l2": def two_point_set_loss_l2(pred_xy, gt_xy): def pair_cost_l2sq(p, q): # p,q: (...,2) return ((p - q)**2).sum(dim=-1) # squared L2 p1, p2 = pred_xy[:,0], pred_xy[:,1] g1, g2 = gt_xy[:,0], gt_xy[:,1] c_a = pair_cost_l2sq(p1,g1) + pair_cost_l2sq(p2,g2) c_b = pair_cost_l2sq(p1,g2) + pair_cost_l2sq(p2,g1) return torch.minimum(c_a, c_b).mean() return two_point_set_loss_l2 elif conf.loss.type == "two_point_set_loss_smooth_l1": def two_point_set_loss_smooth_l1(pred_xy, gt_xy): loss_fn = torch.nn.SmoothL1Loss(reduction='none') p1, p2 = pred_xy[:,0], pred_xy[:,1] g1, g2 = gt_xy[:,0], gt_xy[:,1] c_a = loss_fn(p1, g1).sum(dim=-1) + loss_fn(p2, g2).sum(dim=-1) c_b = loss_fn(p1, g2).sum(dim=-1) + loss_fn(p2, g1).sum(dim=-1) return torch.minimum(c_a, c_b).mean() return two_point_set_loss_smooth_l1 def train(model, trainLoader, optimizer, loss_fn): model.train() batchLoss = 0 for batch_idx, (sample, label) in enumerate(trainLoader): sample, label = sample.cuda(), label.cuda() x1, y1, z1, e1 = label[:,0], label[:,1], label[:,2], label[:,3] x2, y2, z2, e2 = label[:,4], label[:,5], label[:,6], label[:,7] gt_xy = torch.stack((torch.stack((x1, y1), axis=1), torch.stack((x2, y2), axis=1)), axis=1) optimizer.zero_grad() output = model(sample) pred_xy = torch.stack((output[:,0:2], output[:,2:4]), axis=1) loss = loss_fn(pred_xy, gt_xy) loss.backward() optimizer.step() batchLoss += loss.item() * sample.shape[0] avgLoss = batchLoss / len(trainLoader.dataset) / 4 ### divide by 4 to get the average loss per photon per axis print(f"[Train]\t Average Loss: {avgLoss:.6f} (RMS = {np.sqrt(avgLoss):.6f})") return avgLoss def evaluate(model, valLoader, loss_fn): model.eval() batchLoss = 0 with torch.no_grad(): for batch_idx, (sample, label) in enumerate(valLoader): sample, label = sample.cuda(), label.cuda() x1, y1, z1, e1 = label[:,0], label[:,1], label[:,2], label[:,3] x2, y2, z2, e2 = label[:,4], label[:,5], label[:,6], label[:,7] gt_xy = torch.stack((torch.stack((x1, y1), axis=1), torch.stack((x2, y2), axis=1)), axis=1) output = model(sample) pred_xy = torch.stack((output[:,0:2], output[:,2:4]), axis=1) loss = loss_fn(pred_xy, gt_xy) batchLoss += loss.item() * sample.shape[0] avgLoss = batchLoss / len(valLoader.dataset) / 4 ### divide by 4 to get the average loss per photon per axis print(f"[Val]\t Average Loss: {avgLoss:.6f} (RMS = {np.sqrt(avgLoss):.6f})") return avgLoss def get_dataloaders(conf): """construct all dataloaders""" datasets = {} loaders = {} splits = ['Train', 'Val', 'Test'] keys = ['train_files', 'val_files', 'test_files'] batch_keys = ['batch_size_train', 'batch_size_val', 'batch_size_test'] file_range_keys = ['train_file_range', 'val_file_range', 'test_file_range'] for split, key, batch_key, file_range_key in zip(splits, keys, batch_keys, file_range_keys): files = [f"{conf.data.sample_folder}/{conf.data.energy}keV_Moench040_150V_{i}.npz" for i in range(conf.data[file_range_key][0], conf.data[file_range_key][1] + 1)] datasets[split] = doublePhotonDataset( files, sampleRatio = 1.0, datasetName = split.capitalize(), noiseKeV = conf.data.noise_keV, nSize = conf.data.n_size, noiseThreshold = conf.data.noise_threshold * conf.data.noise_keV, normalize = conf.data.normalize ) loaders[split] = torch.utils.data.DataLoader( datasets[split], batch_size=conf.data[batch_key], shuffle=(split=='Train'), num_workers=conf.data.num_workers, pin_memory=True ) return loaders['Train'], loaders['Val'], loaders['Test'] def plot_loss_curves(train_losses, val_losses, test_loss, exp_name, conf): import matplotlib.pyplot as plt plt.figure(figsize=(8,6)) plt.plot(train_losses, label='Train Loss') plt.plot(val_losses, label='Val Loss') if test_loss > 0: plt.axhline(y=test_loss, color='green', linestyle='--', label='Test Loss') plt.xlabel('Epoch') plt.ylabel('MSE Loss') plt.yscale('log') plt.legend() plt.grid() plotName = f'loss_curve_doublePhoton_{conf.model.version}.png' plt.savefig(f'Results/{exp_name}/Plots/{plotName}') def get_model_name(conf): modelName = f'doublePhoton{conf.model.version}_{conf.data.energy}keV_Noise{conf.data.noise_keV}keV' if conf.data.normalize: modelName += '_normalized' return modelName if __name__ == "__main__": exp_name = prepare_output_folder(conf) model = get_double_photon_model_class(conf.model.version)().cuda() # summary(model, input_size=(128, 3, conf.data.n_size, conf.data.n_size)) loss_fn = get_loss_function(conf) optimizer = torch.optim.Adam(model.parameters(), lr=conf.training.learning_rate, weight_decay=conf.training.weight_decay) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=conf.training.scheduler_factor, patience=conf.training.scheduler_patience) trainLoader, valLoader, testLoader = get_dataloaders(conf) TrainLosses, ValLosses = [], [] for epoch in tqdm(range(1, conf.training.epochs + 1)): train_loss = train(model, trainLoader, optimizer, loss_fn) val_loss = evaluate(model, valLoader, loss_fn) TrainLosses.append(train_loss) ValLosses.append(val_loss) scheduler.step(val_loss) print(f"Learning Rate: {optimizer.param_groups[0]['lr']:.2e}") if epoch in conf.training.checkpoint_epochs or epoch == conf.training.epochs: modelName = get_model_name(conf) torch.save(model.state_dict(), f'Results/{exp_name}/Models/{modelName}_E{epoch}.pth') print(f"Saved model checkpoint: {modelName}_E{epoch}.pth") plot_loss_curves(TrainLosses, ValLosses, test_loss=-1, exp_name=exp_name, conf=conf) test_loss = evaluate(model, testLoader, loss_fn) plot_loss_curves(TrainLosses, ValLosses, test_loss=test_loss, exp_name=exp_name, conf=conf)