Add rms curve

This commit is contained in:
2026-05-13 14:30:14 +02:00
parent f06e301ca7
commit 28e122d21a
2 changed files with 63 additions and 17 deletions
+59 -9
View File
@@ -55,7 +55,28 @@ def get_loss_function(conf):
loss = weights * torch.abs(pred - target)
return loss.mean()
return weighted_loss
if conf.loss.type == "weightedSumL1":
alpha = conf.loss.alpha_weighted_sum
beta = conf.loss.beta_weighted_sum
def weighted_loss(pred, target):
pred = pred[:, :2]
target = target[:, :2]
r = torch.norm(target, dim=1, keepdim=True) # (B, 1)
weights = 1.0 + beta * r + alpha * torch.abs(target)
loss = weights * torch.abs(pred - target)
return loss.mean()
return weighted_loss
if conf.loss.type == "huber_loss":
huber_loss_fn = torch.nn.SmoothL1Loss(reduction='mean', beta = conf.loss.huber_beta)
def huber_loss(pred, target):
pred = pred[:, :2]
target = target[:, :2]
return huber_loss_fn(pred, target)
return huber_loss
def train(model, trainLoader, optimizer, loss_fn):
model.train()
batchLoss = 0
@@ -78,8 +99,8 @@ def train(model, trainLoader, optimizer, loss_fn):
rms_y = np.sqrt(rms_y / total_samples)
datasetName = trainLoader.dataset.datasetName
print(f"[{datasetName}]\t Average Loss: {avgLoss:.6f} (sigma = {np.sqrt(avgLoss):.6f}) \t RMS X: {rms_x:.6f} \t RMS Y: {rms_y:.6f}")
return avgLoss
print(f"[{datasetName}]\t Average Loss: {avgLoss:.6f} \t RMS X: {rms_x:.6f} \t RMS Y: {rms_y:.6f}")
return avgLoss, rms_x, rms_y
def evaluate(model, testLoader, loss_fn):
model.eval()
@@ -99,8 +120,8 @@ def evaluate(model, testLoader, loss_fn):
rms_y = np.sqrt(rms_y / len(testLoader.dataset))
datasetName = testLoader.dataset.datasetName
print(f"[{datasetName}]\t Average Loss: {avgLoss:.6f} (sigma = {np.sqrt(avgLoss):.6f}) \t RMS X: {rms_x:.6f} \t RMS Y: {rms_y:.6f}")
return avgLoss
print(f"[{datasetName}]\t Average Loss: {avgLoss:.6f} \t RMS X: {rms_x:.6f} \t RMS Y: {rms_y:.6f}")
return avgLoss, rms_x, rms_y
def get_dataloaders(conf):
"""construct all dataLoaders"""
@@ -146,11 +167,32 @@ def plot_loss_curve(train_losses, val_losses, test_loss, exp_name, conf):
plt.ylabel('MSE Loss')
plt.legend()
plt.grid()
### print test loss in the plot
plt.text(0.5, 0.5, f'Test Loss: {test_loss:.6f}', fontsize=12, color='green', ha='center', va='center', transform=plt.gca().transAxes)
plotName = f'loss_curve_singlePhoton_{conf.model.version}'
if conf.data.normalize:
plotName += '_normalized'
plt.savefig(f'Results/{exp_name}/Plots/{plotName}.png')
def plot_rms_curve(train_rms_x, train_rms_y, val_rms_x, val_rms_y, test_rms_x, test_rms_y, exp_name, conf):
import matplotlib.pyplot as plt
plt.figure(figsize=(8,6))
plt.plot(train_rms_x, label='Train RMS X', color='blue')
plt.plot(train_rms_y, label='Train RMS Y', color='cyan')
plt.plot(val_rms_x, label='Val RMS X', color='orange')
plt.plot(val_rms_y, label='Val RMS Y', color='magenta')
if test_rms_x > 0 and test_rms_y > 0:
plt.axhline(y=test_rms_x, color='green', linestyle='--', label='Test RMS X')
plt.axhline(y=test_rms_y, color='lime', linestyle='--', label='Test RMS Y')
plt.xlabel('Epoch')
plt.ylabel('RMS Error [pixels]')
plt.legend()
plt.grid()
plotName = f'rms_curve_singlePhoton_{conf.model.version}'
if conf.data.normalize:
plotName += '_normalized'
plt.savefig(f'Results/{exp_name}/Plots/{plotName}.png')
def get_model_name(epoch, conf):
modelName = f'singlePhoton{conf.model.version}_{conf.data.energy}keV_Noise{conf.data.noise_keV}keV_E{epoch}_aug{conf.data.num_aug_ops}'
if conf.data.normalize:
@@ -161,7 +203,7 @@ if __name__ == "__main__":
exp_name = prepare_output_folder(conf)
model = models.get_model_class(conf.model.version)().cuda()
# summary(model, input_size=(128, 1, 3, 3))
# summary(model, input_size=(128, 3, 3, 3))
loss_fn = get_loss_function(conf)
optimizer = torch.optim.Adam(model.parameters(), lr=conf.training.learning_rate)
@@ -171,12 +213,18 @@ if __name__ == "__main__":
train_losses = []
val_losses = []
train_rms_xs, train_rms_ys = [], []
val_rms_xs, val_rms_ys = [], []
for epoch in tqdm(range(1, conf.training.epochs + 1)):
t_loss = train(model, trainLoader, optimizer, loss_fn)
t_loss, train_rms_x, train_rms_y = train(model, trainLoader, optimizer, loss_fn)
train_losses.append(t_loss)
v_loss = evaluate(model, valLoader, loss_fn)
train_rms_xs.append(train_rms_x)
train_rms_ys.append(train_rms_y)
v_loss, val_rms_x, val_rms_y = evaluate(model, valLoader, loss_fn)
val_losses.append(v_loss)
val_rms_xs.append(val_rms_x)
val_rms_ys.append(val_rms_y)
scheduler.step(val_losses[-1])
print(f"Learning Rate: {optimizer.param_groups[0]['lr']}")
if epoch in conf.training.checkpoint_epochs or epoch == conf.training.epochs:
@@ -184,5 +232,7 @@ if __name__ == "__main__":
torch.save(model.state_dict(), f'Results/{exp_name}/Models/{modelName}.pth')
print(f"Saved model checkpoint: {modelName}.pth")
plot_loss_curve(train_losses, val_losses, test_loss=-1, exp_name=exp_name, conf=conf)
test_loss = evaluate(model, testLoader, loss_fn)
plot_loss_curve(train_losses, val_losses, test_loss=test_loss, exp_name=exp_name, conf=conf)
plot_rms_curve(train_rms_xs, train_rms_ys, val_rms_xs, val_rms_ys, test_rms_x = -1, test_rms_y = -1, exp_name=exp_name, conf=conf)
test_loss, test_rms_x, test_rms_y = evaluate(model, testLoader, loss_fn)
plot_loss_curve(train_losses, val_losses, test_loss=test_loss, exp_name=exp_name, conf=conf)
plot_rms_curve(train_rms_xs, train_rms_ys, val_rms_xs, val_rms_ys, test_rms_x, test_rms_y, exp_name=exp_name, conf=conf)
+4 -8
View File
@@ -3,15 +3,13 @@ sys.path.append('./src')
from omegaconf import OmegaConf ### for yaml config parsing
import torch
import numpy as np
import models
from datasets import *
import torch.optim as optim
from tqdm import tqdm
from torchinfo import summary
from pathlib import Path
from models import get_model_class
from datasets import singlePhotonDataset
from models import get_double_photon_model_class
from datasets import doublePhotonDataset
### random seed for reproducibility
torch.manual_seed(0)
@@ -51,13 +49,11 @@ def get_loss_function(conf):
return two_point_set_loss_l2
elif conf.loss.type == "two_point_set_loss_smooth_l1":
def two_point_set_loss_smooth_l1(pred_xy, gt_xy):
# Smooth L1 对异常值不那么敏感,但对于细小的亚像素误差能提供恒定的梯度
loss_fn = torch.nn.SmoothL1Loss(reduction='none')
p1, p2 = pred_xy[:,0], pred_xy[:,1]
g1, g2 = gt_xy[:,0], gt_xy[:,1]
# 计算两种排列组合的损失
c_a = loss_fn(p1, g1).sum(dim=-1) + loss_fn(p2, g2).sum(dim=-1)
c_b = loss_fn(p1, g2).sum(dim=-1) + loss_fn(p2, g1).sum(dim=-1)
@@ -155,8 +151,8 @@ def get_model_name(conf):
if __name__ == "__main__":
exp_name = prepare_output_folder(conf)
model = models.get_double_photon_model_class(conf.model.version)().cuda()
# summary(model, input_size=(128, 1, conf.data.n_size, conf.data.n_size))
model = get_double_photon_model_class(conf.model.version)().cuda()
# summary(model, input_size=(128, 3, conf.data.n_size, conf.data.n_size))
loss_fn = get_loss_function(conf)
optimizer = torch.optim.Adam(model.parameters(), lr=conf.training.learning_rate, weight_decay=conf.training.weight_decay)