diff --git a/Train_1Photon.py b/Train_1Photon.py index fe4de1c..58e70e9 100644 --- a/Train_1Photon.py +++ b/Train_1Photon.py @@ -55,7 +55,28 @@ def get_loss_function(conf): loss = weights * torch.abs(pred - target) return loss.mean() return weighted_loss + if conf.loss.type == "weightedSumL1": + alpha = conf.loss.alpha_weighted_sum + beta = conf.loss.beta_weighted_sum + def weighted_loss(pred, target): + pred = pred[:, :2] + target = target[:, :2] + r = torch.norm(target, dim=1, keepdim=True) # (B, 1) + + weights = 1.0 + beta * r + alpha * torch.abs(target) + + loss = weights * torch.abs(pred - target) + return loss.mean() + return weighted_loss + if conf.loss.type == "huber_loss": + huber_loss_fn = torch.nn.SmoothL1Loss(reduction='mean', beta = conf.loss.huber_beta) + def huber_loss(pred, target): + pred = pred[:, :2] + target = target[:, :2] + return huber_loss_fn(pred, target) + return huber_loss + def train(model, trainLoader, optimizer, loss_fn): model.train() batchLoss = 0 @@ -78,8 +99,8 @@ def train(model, trainLoader, optimizer, loss_fn): rms_y = np.sqrt(rms_y / total_samples) datasetName = trainLoader.dataset.datasetName - print(f"[{datasetName}]\t Average Loss: {avgLoss:.6f} (sigma = {np.sqrt(avgLoss):.6f}) \t RMS X: {rms_x:.6f} \t RMS Y: {rms_y:.6f}") - return avgLoss + print(f"[{datasetName}]\t Average Loss: {avgLoss:.6f} \t RMS X: {rms_x:.6f} \t RMS Y: {rms_y:.6f}") + return avgLoss, rms_x, rms_y def evaluate(model, testLoader, loss_fn): model.eval() @@ -99,8 +120,8 @@ def evaluate(model, testLoader, loss_fn): rms_y = np.sqrt(rms_y / len(testLoader.dataset)) datasetName = testLoader.dataset.datasetName - print(f"[{datasetName}]\t Average Loss: {avgLoss:.6f} (sigma = {np.sqrt(avgLoss):.6f}) \t RMS X: {rms_x:.6f} \t RMS Y: {rms_y:.6f}") - return avgLoss + print(f"[{datasetName}]\t Average Loss: {avgLoss:.6f} \t RMS X: {rms_x:.6f} \t RMS Y: {rms_y:.6f}") + return avgLoss, rms_x, rms_y def get_dataloaders(conf): """construct all dataLoaders""" @@ -146,11 +167,32 @@ def plot_loss_curve(train_losses, val_losses, test_loss, exp_name, conf): plt.ylabel('MSE Loss') plt.legend() plt.grid() + ### print test loss in the plot + plt.text(0.5, 0.5, f'Test Loss: {test_loss:.6f}', fontsize=12, color='green', ha='center', va='center', transform=plt.gca().transAxes) plotName = f'loss_curve_singlePhoton_{conf.model.version}' if conf.data.normalize: plotName += '_normalized' plt.savefig(f'Results/{exp_name}/Plots/{plotName}.png') +def plot_rms_curve(train_rms_x, train_rms_y, val_rms_x, val_rms_y, test_rms_x, test_rms_y, exp_name, conf): + import matplotlib.pyplot as plt + plt.figure(figsize=(8,6)) + plt.plot(train_rms_x, label='Train RMS X', color='blue') + plt.plot(train_rms_y, label='Train RMS Y', color='cyan') + plt.plot(val_rms_x, label='Val RMS X', color='orange') + plt.plot(val_rms_y, label='Val RMS Y', color='magenta') + if test_rms_x > 0 and test_rms_y > 0: + plt.axhline(y=test_rms_x, color='green', linestyle='--', label='Test RMS X') + plt.axhline(y=test_rms_y, color='lime', linestyle='--', label='Test RMS Y') + plt.xlabel('Epoch') + plt.ylabel('RMS Error [pixels]') + plt.legend() + plt.grid() + plotName = f'rms_curve_singlePhoton_{conf.model.version}' + if conf.data.normalize: + plotName += '_normalized' + plt.savefig(f'Results/{exp_name}/Plots/{plotName}.png') + def get_model_name(epoch, conf): modelName = f'singlePhoton{conf.model.version}_{conf.data.energy}keV_Noise{conf.data.noise_keV}keV_E{epoch}_aug{conf.data.num_aug_ops}' if conf.data.normalize: @@ -161,7 +203,7 @@ if __name__ == "__main__": exp_name = prepare_output_folder(conf) model = models.get_model_class(conf.model.version)().cuda() - # summary(model, input_size=(128, 1, 3, 3)) + # summary(model, input_size=(128, 3, 3, 3)) loss_fn = get_loss_function(conf) optimizer = torch.optim.Adam(model.parameters(), lr=conf.training.learning_rate) @@ -171,12 +213,18 @@ if __name__ == "__main__": train_losses = [] val_losses = [] + train_rms_xs, train_rms_ys = [], [] + val_rms_xs, val_rms_ys = [], [] for epoch in tqdm(range(1, conf.training.epochs + 1)): - t_loss = train(model, trainLoader, optimizer, loss_fn) + t_loss, train_rms_x, train_rms_y = train(model, trainLoader, optimizer, loss_fn) train_losses.append(t_loss) - v_loss = evaluate(model, valLoader, loss_fn) + train_rms_xs.append(train_rms_x) + train_rms_ys.append(train_rms_y) + v_loss, val_rms_x, val_rms_y = evaluate(model, valLoader, loss_fn) val_losses.append(v_loss) + val_rms_xs.append(val_rms_x) + val_rms_ys.append(val_rms_y) scheduler.step(val_losses[-1]) print(f"Learning Rate: {optimizer.param_groups[0]['lr']}") if epoch in conf.training.checkpoint_epochs or epoch == conf.training.epochs: @@ -184,5 +232,7 @@ if __name__ == "__main__": torch.save(model.state_dict(), f'Results/{exp_name}/Models/{modelName}.pth') print(f"Saved model checkpoint: {modelName}.pth") plot_loss_curve(train_losses, val_losses, test_loss=-1, exp_name=exp_name, conf=conf) - test_loss = evaluate(model, testLoader, loss_fn) - plot_loss_curve(train_losses, val_losses, test_loss=test_loss, exp_name=exp_name, conf=conf) \ No newline at end of file + plot_rms_curve(train_rms_xs, train_rms_ys, val_rms_xs, val_rms_ys, test_rms_x = -1, test_rms_y = -1, exp_name=exp_name, conf=conf) + test_loss, test_rms_x, test_rms_y = evaluate(model, testLoader, loss_fn) + plot_loss_curve(train_losses, val_losses, test_loss=test_loss, exp_name=exp_name, conf=conf) + plot_rms_curve(train_rms_xs, train_rms_ys, val_rms_xs, val_rms_ys, test_rms_x, test_rms_y, exp_name=exp_name, conf=conf) \ No newline at end of file diff --git a/Train_2Photon.py b/Train_2Photon.py index bce77c1..a0daadd 100644 --- a/Train_2Photon.py +++ b/Train_2Photon.py @@ -3,15 +3,13 @@ sys.path.append('./src') from omegaconf import OmegaConf ### for yaml config parsing import torch import numpy as np -import models -from datasets import * import torch.optim as optim from tqdm import tqdm from torchinfo import summary from pathlib import Path -from models import get_model_class -from datasets import singlePhotonDataset +from models import get_double_photon_model_class +from datasets import doublePhotonDataset ### random seed for reproducibility torch.manual_seed(0) @@ -51,13 +49,11 @@ def get_loss_function(conf): return two_point_set_loss_l2 elif conf.loss.type == "two_point_set_loss_smooth_l1": def two_point_set_loss_smooth_l1(pred_xy, gt_xy): - # Smooth L1 对异常值不那么敏感,但对于细小的亚像素误差能提供恒定的梯度 loss_fn = torch.nn.SmoothL1Loss(reduction='none') p1, p2 = pred_xy[:,0], pred_xy[:,1] g1, g2 = gt_xy[:,0], gt_xy[:,1] - # 计算两种排列组合的损失 c_a = loss_fn(p1, g1).sum(dim=-1) + loss_fn(p2, g2).sum(dim=-1) c_b = loss_fn(p1, g2).sum(dim=-1) + loss_fn(p2, g1).sum(dim=-1) @@ -155,8 +151,8 @@ def get_model_name(conf): if __name__ == "__main__": exp_name = prepare_output_folder(conf) - model = models.get_double_photon_model_class(conf.model.version)().cuda() - # summary(model, input_size=(128, 1, conf.data.n_size, conf.data.n_size)) + model = get_double_photon_model_class(conf.model.version)().cuda() + # summary(model, input_size=(128, 3, conf.data.n_size, conf.data.n_size)) loss_fn = get_loss_function(conf) optimizer = torch.optim.Adam(model.parameters(), lr=conf.training.learning_rate, weight_decay=conf.training.weight_decay)