New model and loss_fn to improve uniformity
Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
+15
-2
@@ -49,6 +49,21 @@ def get_loss_function(conf):
|
||||
c_b = pair_cost_l2sq(p1,g2) + pair_cost_l2sq(p2,g1)
|
||||
return torch.minimum(c_a, c_b).mean()
|
||||
return two_point_set_loss_l2
|
||||
elif conf.loss.type == "two_point_set_loss_smooth_l1":
|
||||
def two_point_set_loss_smooth_l1(pred_xy, gt_xy):
|
||||
# Smooth L1 对异常值不那么敏感,但对于细小的亚像素误差能提供恒定的梯度
|
||||
loss_fn = torch.nn.SmoothL1Loss(reduction='none')
|
||||
|
||||
p1, p2 = pred_xy[:,0], pred_xy[:,1]
|
||||
g1, g2 = gt_xy[:,0], gt_xy[:,1]
|
||||
|
||||
# 计算两种排列组合的损失
|
||||
c_a = loss_fn(p1, g1).sum(dim=-1) + loss_fn(p2, g2).sum(dim=-1)
|
||||
c_b = loss_fn(p1, g2).sum(dim=-1) + loss_fn(p2, g1).sum(dim=-1)
|
||||
|
||||
return torch.minimum(c_a, c_b).mean()
|
||||
|
||||
return two_point_set_loss_smooth_l1
|
||||
|
||||
def train(model, trainLoader, optimizer, loss_fn):
|
||||
model.train()
|
||||
@@ -67,7 +82,6 @@ def train(model, trainLoader, optimizer, loss_fn):
|
||||
batchLoss += loss.item() * sample.shape[0]
|
||||
avgLoss = batchLoss / len(trainLoader.dataset) / 4 ### divide by 4 to get the average loss per photon per axis
|
||||
print(f"[Train]\t Average Loss: {avgLoss:.6f} (RMS = {np.sqrt(avgLoss):.6f})")
|
||||
TrainLosses.append(avgLoss)
|
||||
return avgLoss
|
||||
|
||||
def evaluate(model, valLoader, loss_fn):
|
||||
@@ -85,7 +99,6 @@ def evaluate(model, valLoader, loss_fn):
|
||||
batchLoss += loss.item() * sample.shape[0]
|
||||
avgLoss = batchLoss / len(valLoader.dataset) / 4 ### divide by 4 to get the average loss per photon per axis
|
||||
print(f"[Val]\t Average Loss: {avgLoss:.6f} (RMS = {np.sqrt(avgLoss):.6f})")
|
||||
ValLosses.append(avgLoss)
|
||||
return avgLoss
|
||||
|
||||
def get_dataloaders(conf):
|
||||
|
||||
Reference in New Issue
Block a user