Add rms curve
This commit is contained in:
+59
-9
@@ -55,7 +55,28 @@ def get_loss_function(conf):
|
||||
loss = weights * torch.abs(pred - target)
|
||||
return loss.mean()
|
||||
return weighted_loss
|
||||
if conf.loss.type == "weightedSumL1":
|
||||
alpha = conf.loss.alpha_weighted_sum
|
||||
beta = conf.loss.beta_weighted_sum
|
||||
def weighted_loss(pred, target):
|
||||
pred = pred[:, :2]
|
||||
target = target[:, :2]
|
||||
|
||||
r = torch.norm(target, dim=1, keepdim=True) # (B, 1)
|
||||
|
||||
weights = 1.0 + beta * r + alpha * torch.abs(target)
|
||||
|
||||
loss = weights * torch.abs(pred - target)
|
||||
return loss.mean()
|
||||
return weighted_loss
|
||||
if conf.loss.type == "huber_loss":
|
||||
huber_loss_fn = torch.nn.SmoothL1Loss(reduction='mean', beta = conf.loss.huber_beta)
|
||||
def huber_loss(pred, target):
|
||||
pred = pred[:, :2]
|
||||
target = target[:, :2]
|
||||
return huber_loss_fn(pred, target)
|
||||
return huber_loss
|
||||
|
||||
def train(model, trainLoader, optimizer, loss_fn):
|
||||
model.train()
|
||||
batchLoss = 0
|
||||
@@ -78,8 +99,8 @@ def train(model, trainLoader, optimizer, loss_fn):
|
||||
rms_y = np.sqrt(rms_y / total_samples)
|
||||
|
||||
datasetName = trainLoader.dataset.datasetName
|
||||
print(f"[{datasetName}]\t Average Loss: {avgLoss:.6f} (sigma = {np.sqrt(avgLoss):.6f}) \t RMS X: {rms_x:.6f} \t RMS Y: {rms_y:.6f}")
|
||||
return avgLoss
|
||||
print(f"[{datasetName}]\t Average Loss: {avgLoss:.6f} \t RMS X: {rms_x:.6f} \t RMS Y: {rms_y:.6f}")
|
||||
return avgLoss, rms_x, rms_y
|
||||
|
||||
def evaluate(model, testLoader, loss_fn):
|
||||
model.eval()
|
||||
@@ -99,8 +120,8 @@ def evaluate(model, testLoader, loss_fn):
|
||||
rms_y = np.sqrt(rms_y / len(testLoader.dataset))
|
||||
|
||||
datasetName = testLoader.dataset.datasetName
|
||||
print(f"[{datasetName}]\t Average Loss: {avgLoss:.6f} (sigma = {np.sqrt(avgLoss):.6f}) \t RMS X: {rms_x:.6f} \t RMS Y: {rms_y:.6f}")
|
||||
return avgLoss
|
||||
print(f"[{datasetName}]\t Average Loss: {avgLoss:.6f} \t RMS X: {rms_x:.6f} \t RMS Y: {rms_y:.6f}")
|
||||
return avgLoss, rms_x, rms_y
|
||||
|
||||
def get_dataloaders(conf):
|
||||
"""construct all dataLoaders"""
|
||||
@@ -146,11 +167,32 @@ def plot_loss_curve(train_losses, val_losses, test_loss, exp_name, conf):
|
||||
plt.ylabel('MSE Loss')
|
||||
plt.legend()
|
||||
plt.grid()
|
||||
### print test loss in the plot
|
||||
plt.text(0.5, 0.5, f'Test Loss: {test_loss:.6f}', fontsize=12, color='green', ha='center', va='center', transform=plt.gca().transAxes)
|
||||
plotName = f'loss_curve_singlePhoton_{conf.model.version}'
|
||||
if conf.data.normalize:
|
||||
plotName += '_normalized'
|
||||
plt.savefig(f'Results/{exp_name}/Plots/{plotName}.png')
|
||||
|
||||
def plot_rms_curve(train_rms_x, train_rms_y, val_rms_x, val_rms_y, test_rms_x, test_rms_y, exp_name, conf):
|
||||
import matplotlib.pyplot as plt
|
||||
plt.figure(figsize=(8,6))
|
||||
plt.plot(train_rms_x, label='Train RMS X', color='blue')
|
||||
plt.plot(train_rms_y, label='Train RMS Y', color='cyan')
|
||||
plt.plot(val_rms_x, label='Val RMS X', color='orange')
|
||||
plt.plot(val_rms_y, label='Val RMS Y', color='magenta')
|
||||
if test_rms_x > 0 and test_rms_y > 0:
|
||||
plt.axhline(y=test_rms_x, color='green', linestyle='--', label='Test RMS X')
|
||||
plt.axhline(y=test_rms_y, color='lime', linestyle='--', label='Test RMS Y')
|
||||
plt.xlabel('Epoch')
|
||||
plt.ylabel('RMS Error [pixels]')
|
||||
plt.legend()
|
||||
plt.grid()
|
||||
plotName = f'rms_curve_singlePhoton_{conf.model.version}'
|
||||
if conf.data.normalize:
|
||||
plotName += '_normalized'
|
||||
plt.savefig(f'Results/{exp_name}/Plots/{plotName}.png')
|
||||
|
||||
def get_model_name(epoch, conf):
|
||||
modelName = f'singlePhoton{conf.model.version}_{conf.data.energy}keV_Noise{conf.data.noise_keV}keV_E{epoch}_aug{conf.data.num_aug_ops}'
|
||||
if conf.data.normalize:
|
||||
@@ -161,7 +203,7 @@ if __name__ == "__main__":
|
||||
|
||||
exp_name = prepare_output_folder(conf)
|
||||
model = models.get_model_class(conf.model.version)().cuda()
|
||||
# summary(model, input_size=(128, 1, 3, 3))
|
||||
# summary(model, input_size=(128, 3, 3, 3))
|
||||
|
||||
loss_fn = get_loss_function(conf)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=conf.training.learning_rate)
|
||||
@@ -171,12 +213,18 @@ if __name__ == "__main__":
|
||||
|
||||
train_losses = []
|
||||
val_losses = []
|
||||
train_rms_xs, train_rms_ys = [], []
|
||||
val_rms_xs, val_rms_ys = [], []
|
||||
|
||||
for epoch in tqdm(range(1, conf.training.epochs + 1)):
|
||||
t_loss = train(model, trainLoader, optimizer, loss_fn)
|
||||
t_loss, train_rms_x, train_rms_y = train(model, trainLoader, optimizer, loss_fn)
|
||||
train_losses.append(t_loss)
|
||||
v_loss = evaluate(model, valLoader, loss_fn)
|
||||
train_rms_xs.append(train_rms_x)
|
||||
train_rms_ys.append(train_rms_y)
|
||||
v_loss, val_rms_x, val_rms_y = evaluate(model, valLoader, loss_fn)
|
||||
val_losses.append(v_loss)
|
||||
val_rms_xs.append(val_rms_x)
|
||||
val_rms_ys.append(val_rms_y)
|
||||
scheduler.step(val_losses[-1])
|
||||
print(f"Learning Rate: {optimizer.param_groups[0]['lr']}")
|
||||
if epoch in conf.training.checkpoint_epochs or epoch == conf.training.epochs:
|
||||
@@ -184,5 +232,7 @@ if __name__ == "__main__":
|
||||
torch.save(model.state_dict(), f'Results/{exp_name}/Models/{modelName}.pth')
|
||||
print(f"Saved model checkpoint: {modelName}.pth")
|
||||
plot_loss_curve(train_losses, val_losses, test_loss=-1, exp_name=exp_name, conf=conf)
|
||||
test_loss = evaluate(model, testLoader, loss_fn)
|
||||
plot_loss_curve(train_losses, val_losses, test_loss=test_loss, exp_name=exp_name, conf=conf)
|
||||
plot_rms_curve(train_rms_xs, train_rms_ys, val_rms_xs, val_rms_ys, test_rms_x = -1, test_rms_y = -1, exp_name=exp_name, conf=conf)
|
||||
test_loss, test_rms_x, test_rms_y = evaluate(model, testLoader, loss_fn)
|
||||
plot_loss_curve(train_losses, val_losses, test_loss=test_loss, exp_name=exp_name, conf=conf)
|
||||
plot_rms_curve(train_rms_xs, train_rms_ys, val_rms_xs, val_rms_ys, test_rms_x, test_rms_y, exp_name=exp_name, conf=conf)
|
||||
+4
-8
@@ -3,15 +3,13 @@ sys.path.append('./src')
|
||||
from omegaconf import OmegaConf ### for yaml config parsing
|
||||
import torch
|
||||
import numpy as np
|
||||
import models
|
||||
from datasets import *
|
||||
import torch.optim as optim
|
||||
from tqdm import tqdm
|
||||
from torchinfo import summary
|
||||
from pathlib import Path
|
||||
|
||||
from models import get_model_class
|
||||
from datasets import singlePhotonDataset
|
||||
from models import get_double_photon_model_class
|
||||
from datasets import doublePhotonDataset
|
||||
|
||||
### random seed for reproducibility
|
||||
torch.manual_seed(0)
|
||||
@@ -51,13 +49,11 @@ def get_loss_function(conf):
|
||||
return two_point_set_loss_l2
|
||||
elif conf.loss.type == "two_point_set_loss_smooth_l1":
|
||||
def two_point_set_loss_smooth_l1(pred_xy, gt_xy):
|
||||
# Smooth L1 对异常值不那么敏感,但对于细小的亚像素误差能提供恒定的梯度
|
||||
loss_fn = torch.nn.SmoothL1Loss(reduction='none')
|
||||
|
||||
p1, p2 = pred_xy[:,0], pred_xy[:,1]
|
||||
g1, g2 = gt_xy[:,0], gt_xy[:,1]
|
||||
|
||||
# 计算两种排列组合的损失
|
||||
c_a = loss_fn(p1, g1).sum(dim=-1) + loss_fn(p2, g2).sum(dim=-1)
|
||||
c_b = loss_fn(p1, g2).sum(dim=-1) + loss_fn(p2, g1).sum(dim=-1)
|
||||
|
||||
@@ -155,8 +151,8 @@ def get_model_name(conf):
|
||||
|
||||
if __name__ == "__main__":
|
||||
exp_name = prepare_output_folder(conf)
|
||||
model = models.get_double_photon_model_class(conf.model.version)().cuda()
|
||||
# summary(model, input_size=(128, 1, conf.data.n_size, conf.data.n_size))
|
||||
model = get_double_photon_model_class(conf.model.version)().cuda()
|
||||
# summary(model, input_size=(128, 3, conf.data.n_size, conf.data.n_size))
|
||||
|
||||
loss_fn = get_loss_function(conf)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=conf.training.learning_rate, weight_decay=conf.training.weight_decay)
|
||||
|
||||
Reference in New Issue
Block a user