New model and loss_fn to improve uniformity

Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
2026-05-07 13:49:10 +02:00
parent 3cf381b46d
commit e04e783fb6
2 changed files with 67 additions and 3 deletions
+15 -2
View File
@@ -49,6 +49,21 @@ def get_loss_function(conf):
c_b = pair_cost_l2sq(p1,g2) + pair_cost_l2sq(p2,g1)
return torch.minimum(c_a, c_b).mean()
return two_point_set_loss_l2
elif conf.loss.type == "two_point_set_loss_smooth_l1":
def two_point_set_loss_smooth_l1(pred_xy, gt_xy):
# Smooth L1 对异常值不那么敏感,但对于细小的亚像素误差能提供恒定的梯度
loss_fn = torch.nn.SmoothL1Loss(reduction='none')
p1, p2 = pred_xy[:,0], pred_xy[:,1]
g1, g2 = gt_xy[:,0], gt_xy[:,1]
# 计算两种排列组合的损失
c_a = loss_fn(p1, g1).sum(dim=-1) + loss_fn(p2, g2).sum(dim=-1)
c_b = loss_fn(p1, g2).sum(dim=-1) + loss_fn(p2, g1).sum(dim=-1)
return torch.minimum(c_a, c_b).mean()
return two_point_set_loss_smooth_l1
def train(model, trainLoader, optimizer, loss_fn):
model.train()
@@ -67,7 +82,6 @@ def train(model, trainLoader, optimizer, loss_fn):
batchLoss += loss.item() * sample.shape[0]
avgLoss = batchLoss / len(trainLoader.dataset) / 4 ### divide by 4 to get the average loss per photon per axis
print(f"[Train]\t Average Loss: {avgLoss:.6f} (RMS = {np.sqrt(avgLoss):.6f})")
TrainLosses.append(avgLoss)
return avgLoss
def evaluate(model, valLoader, loss_fn):
@@ -85,7 +99,6 @@ def evaluate(model, valLoader, loss_fn):
batchLoss += loss.item() * sample.shape[0]
avgLoss = batchLoss / len(valLoader.dataset) / 4 ### divide by 4 to get the average loss per photon per axis
print(f"[Val]\t Average Loss: {avgLoss:.6f} (RMS = {np.sqrt(avgLoss):.6f})")
ValLosses.append(avgLoss)
return avgLoss
def get_dataloaders(conf):