diff --git a/Configs/train_2photon.yaml b/Configs/train_2photon.yaml new file mode 100644 index 0000000..fa3852e --- /dev/null +++ b/Configs/train_2photon.yaml @@ -0,0 +1,33 @@ +# configs/train_2photon.yaml +experiment: + name: "2photon_12keV" + +data: + sample_folder: "/mnt/sls_det_storage/moench_data/MLXID/Samples/Simulation/Moench040" + energy: 12 ### in keV + noise_keV: 0.13 + noise_threshold: 0.0 ### set values below (noise * noise_threshold) to zero + normalize: false + + batch_size_train: 4096 + batch_size_val: 1024 + batch_size_test: 1024 + num_workers: 32 + train_file_range: [0, 12] + val_file_range: [13, 14] + test_file_range: [15, 15] + n_size: 7 ### size of sub-images containing 2 photons + +model: + version: "251124" + +training: + epochs: 150 + learning_rate: 1.0e-3 + weight_decay: 1.0e-4 + scheduler_factor: 0.7 + scheduler_patience: 5 + checkpoint_epochs: [10, 30, 50, 100, 150] + +loss: + type: "two_point_set_loss_l2" \ No newline at end of file diff --git a/Train_2Photon.py b/Train_2Photon.py new file mode 100644 index 0000000..152196a --- /dev/null +++ b/Train_2Photon.py @@ -0,0 +1,168 @@ +import sys +sys.path.append('./src') +from omegaconf import OmegaConf ### for yaml config parsing +import torch +import numpy as np +import models +from datasets import * +import torch.optim as optim +from tqdm import tqdm +from torchinfo import summary +from pathlib import Path + +from models import get_model_class +from datasets import singlePhotonDataset + +### random seed for reproducibility +torch.manual_seed(0) +torch.cuda.manual_seed(0) +np.random.seed(0) +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False + +conf = OmegaConf.load("Configs/train_2photon.yaml") + +def prepare_output_folder(conf): + from datetime import datetime + date = datetime.now().strftime("%y%m%d") ## YYMMDD format + # find the next index for experiment name + exp_index = 0 + while True: + exp_name = f'{date}_2ph_{conf.data.energy}keV_v{conf.model.version}_{exp_index:02d}' + if not Path(f'Results/{exp_name}').exists(): + break + exp_index += 1 + Path(f'Results/{exp_name}').mkdir(parents=True, exist_ok=True) + Path(f'Results/{exp_name}/Models').mkdir(parents=True, exist_ok=True) + Path(f'Results/{exp_name}/Plots').mkdir(parents=True, exist_ok=True) + OmegaConf.save(conf, f'Results/{exp_name}/config.yaml') + return exp_name + +def get_loss_function(conf): + if conf.loss.type == "two_point_set_loss_l2": + def two_point_set_loss_l2(pred_xy, gt_xy): + def pair_cost_l2sq(p, q): # p,q: (...,2) + return ((p - q)**2).sum(dim=-1) # squared L2 + p1, p2 = pred_xy[:,0], pred_xy[:,1] + g1, g2 = gt_xy[:,0], gt_xy[:,1] + c_a = pair_cost_l2sq(p1,g1) + pair_cost_l2sq(p2,g2) + c_b = pair_cost_l2sq(p1,g2) + pair_cost_l2sq(p2,g1) + return torch.minimum(c_a, c_b).mean() + return two_point_set_loss_l2 + +def train(model, trainLoader, optimizer, loss_fn): + model.train() + batchLoss = 0 + for batch_idx, (sample, label) in enumerate(trainLoader): + sample, label = sample.cuda(), label.cuda() + x1, y1, z1, e1 = label[:,0], label[:,1], label[:,2], label[:,3] + x2, y2, z2, e2 = label[:,4], label[:,5], label[:,6], label[:,7] + gt_xy = torch.stack((torch.stack((x1, y1), axis=1), torch.stack((x2, y2), axis=1)), axis=1) + optimizer.zero_grad() + output = model(sample) + pred_xy = torch.stack((output[:,0:2], output[:,2:4]), axis=1) + loss = loss_fn(pred_xy, gt_xy) + loss.backward() + optimizer.step() + batchLoss += loss.item() * sample.shape[0] + avgLoss = batchLoss / len(trainLoader.dataset) / 4 ### divide by 4 to get the average loss per photon per axis + print(f"[Train]\t Average Loss: {avgLoss:.6f} (RMS = {np.sqrt(avgLoss):.6f})") + TrainLosses.append(avgLoss) + return avgLoss + +def evaluate(model, valLoader, loss_fn): + model.eval() + batchLoss = 0 + with torch.no_grad(): + for batch_idx, (sample, label) in enumerate(valLoader): + sample, label = sample.cuda(), label.cuda() + x1, y1, z1, e1 = label[:,0], label[:,1], label[:,2], label[:,3] + x2, y2, z2, e2 = label[:,4], label[:,5], label[:,6], label[:,7] + gt_xy = torch.stack((torch.stack((x1, y1), axis=1), torch.stack((x2, y2), axis=1)), axis=1) + output = model(sample) + pred_xy = torch.stack((output[:,0:2], output[:,2:4]), axis=1) + loss = loss_fn(pred_xy, gt_xy) + batchLoss += loss.item() * sample.shape[0] + avgLoss = batchLoss / len(valLoader.dataset) / 4 ### divide by 4 to get the average loss per photon per axis + print(f"[Val]\t Average Loss: {avgLoss:.6f} (RMS = {np.sqrt(avgLoss):.6f})") + ValLosses.append(avgLoss) + return avgLoss + +def get_dataloaders(conf): + """construct all dataloaders""" + datasets = {} + loaders = {} + splits = ['Train', 'Val', 'Test'] + keys = ['train_files', 'val_files', 'test_files'] + batch_keys = ['batch_size_train', 'batch_size_val', 'batch_size_test'] + file_range_keys = ['train_file_range', 'val_file_range', 'test_file_range'] + + for split, key, batch_key, file_range_key in zip(splits, keys, batch_keys, file_range_keys): + files = [f"{conf.data.sample_folder}/{conf.data.energy}keV_Moench040_150V_{i}.npz" for i in range(conf.data[file_range_key][0], conf.data[file_range_key][1] + 1)] + + datasets[split] = doublePhotonDataset( + files, + sampleRatio = 1.0, + datasetName = split.capitalize(), + noiseKeV = conf.data.noise_keV, + nSize = conf.data.n_size, + noiseThreshold = conf.data.noise_threshold * conf.data.noise_keV, + normalize = conf.data.normalize + ) + + loaders[split] = torch.utils.data.DataLoader( + datasets[split], + batch_size=conf.data[batch_key], + shuffle=(split=='Train'), + num_workers=conf.data.num_workers, + pin_memory=True + ) + return loaders['Train'], loaders['Val'], loaders['Test'] + +def plot_loss_curves(train_losses, val_losses, test_loss, exp_name, conf): + import matplotlib.pyplot as plt + plt.figure(figsize=(8,6)) + plt.plot(train_losses, label='Train Loss') + plt.plot(val_losses, label='Val Loss') + if test_loss > 0: + plt.axhline(y=test_loss, color='green', linestyle='--', label='Test Loss') + plt.xlabel('Epoch') + plt.ylabel('MSE Loss') + plt.yscale('log') + plt.legend() + plt.grid() + plotName = f'loss_curve_doublePhoton_{conf.model.version}.png' + plt.savefig(f'Results/{exp_name}/Plots/{plotName}') + +def get_model_name(conf): + modelName = f'doublePhoton{conf.model.version}_{conf.data.energy}keV_Noise{conf.data.noise_keV}keV' + if conf.data.normalize: + modelName += '_normalized' + return modelName + +if __name__ == "__main__": + exp_name = prepare_output_folder(conf) + model = models.get_double_photon_model_class(conf.model.version)().cuda() + # summary(model, input_size=(128, 1, conf.data.n_size, conf.data.n_size)) + + loss_fn = get_loss_function(conf) + optimizer = torch.optim.Adam(model.parameters(), lr=conf.training.learning_rate, weight_decay=conf.training.weight_decay) + scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=conf.training.scheduler_factor, patience=conf.training.scheduler_patience) + + trainLoader, valLoader, testLoader = get_dataloaders(conf) + + TrainLosses, ValLosses = [], [] + for epoch in tqdm(range(1, conf.training.epochs + 1)): + train_loss = train(model, trainLoader, optimizer, loss_fn) + val_loss = evaluate(model, valLoader, loss_fn) + TrainLosses.append(train_loss) + ValLosses.append(val_loss) + scheduler.step(val_loss) + print(f"Learning Rate: {optimizer.param_groups[0]['lr']:.2e}") + if epoch in conf.training.checkpoint_epochs or epoch == conf.training.epochs: + modelName = get_model_name(conf) + torch.save(model.state_dict(), f'Results/{exp_name}/Models/{modelName}_E{epoch}.pth') + print(f"Saved model checkpoint: {modelName}_E{epoch}.pth") + plot_loss_curves(TrainLosses, ValLosses, test_loss=-1, exp_name=exp_name, conf=conf) + test_loss = evaluate(model, testLoader, loss_fn) + plot_loss_curves(TrainLosses, ValLosses, test_loss=test_loss, exp_name=exp_name, conf=conf) \ No newline at end of file