New model and loss_fn to improve uniformity
Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
+15
-2
@@ -49,6 +49,21 @@ def get_loss_function(conf):
|
|||||||
c_b = pair_cost_l2sq(p1,g2) + pair_cost_l2sq(p2,g1)
|
c_b = pair_cost_l2sq(p1,g2) + pair_cost_l2sq(p2,g1)
|
||||||
return torch.minimum(c_a, c_b).mean()
|
return torch.minimum(c_a, c_b).mean()
|
||||||
return two_point_set_loss_l2
|
return two_point_set_loss_l2
|
||||||
|
elif conf.loss.type == "two_point_set_loss_smooth_l1":
|
||||||
|
def two_point_set_loss_smooth_l1(pred_xy, gt_xy):
|
||||||
|
# Smooth L1 对异常值不那么敏感,但对于细小的亚像素误差能提供恒定的梯度
|
||||||
|
loss_fn = torch.nn.SmoothL1Loss(reduction='none')
|
||||||
|
|
||||||
|
p1, p2 = pred_xy[:,0], pred_xy[:,1]
|
||||||
|
g1, g2 = gt_xy[:,0], gt_xy[:,1]
|
||||||
|
|
||||||
|
# 计算两种排列组合的损失
|
||||||
|
c_a = loss_fn(p1, g1).sum(dim=-1) + loss_fn(p2, g2).sum(dim=-1)
|
||||||
|
c_b = loss_fn(p1, g2).sum(dim=-1) + loss_fn(p2, g1).sum(dim=-1)
|
||||||
|
|
||||||
|
return torch.minimum(c_a, c_b).mean()
|
||||||
|
|
||||||
|
return two_point_set_loss_smooth_l1
|
||||||
|
|
||||||
def train(model, trainLoader, optimizer, loss_fn):
|
def train(model, trainLoader, optimizer, loss_fn):
|
||||||
model.train()
|
model.train()
|
||||||
@@ -67,7 +82,6 @@ def train(model, trainLoader, optimizer, loss_fn):
|
|||||||
batchLoss += loss.item() * sample.shape[0]
|
batchLoss += loss.item() * sample.shape[0]
|
||||||
avgLoss = batchLoss / len(trainLoader.dataset) / 4 ### divide by 4 to get the average loss per photon per axis
|
avgLoss = batchLoss / len(trainLoader.dataset) / 4 ### divide by 4 to get the average loss per photon per axis
|
||||||
print(f"[Train]\t Average Loss: {avgLoss:.6f} (RMS = {np.sqrt(avgLoss):.6f})")
|
print(f"[Train]\t Average Loss: {avgLoss:.6f} (RMS = {np.sqrt(avgLoss):.6f})")
|
||||||
TrainLosses.append(avgLoss)
|
|
||||||
return avgLoss
|
return avgLoss
|
||||||
|
|
||||||
def evaluate(model, valLoader, loss_fn):
|
def evaluate(model, valLoader, loss_fn):
|
||||||
@@ -85,7 +99,6 @@ def evaluate(model, valLoader, loss_fn):
|
|||||||
batchLoss += loss.item() * sample.shape[0]
|
batchLoss += loss.item() * sample.shape[0]
|
||||||
avgLoss = batchLoss / len(valLoader.dataset) / 4 ### divide by 4 to get the average loss per photon per axis
|
avgLoss = batchLoss / len(valLoader.dataset) / 4 ### divide by 4 to get the average loss per photon per axis
|
||||||
print(f"[Val]\t Average Loss: {avgLoss:.6f} (RMS = {np.sqrt(avgLoss):.6f})")
|
print(f"[Val]\t Average Loss: {avgLoss:.6f} (RMS = {np.sqrt(avgLoss):.6f})")
|
||||||
ValLosses.append(avgLoss)
|
|
||||||
return avgLoss
|
return avgLoss
|
||||||
|
|
||||||
def get_dataloaders(conf):
|
def get_dataloaders(conf):
|
||||||
|
|||||||
+52
-1
@@ -326,4 +326,55 @@ class doublePhotonNet_251124(nn.Module): ### adapted for 7x7 input from 251001_2
|
|||||||
|
|
||||||
# Regression
|
# Regression
|
||||||
coords = self.fc(global_feat)
|
coords = self.fc(global_feat)
|
||||||
return coords # [B, 4]
|
return coords # [B, 4]
|
||||||
|
|
||||||
|
class doublePhotonNet_260507(nn.Module): ## adapted from 251124, removed max pooling, added more capacity in FC layers
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
# Backbone 保持不变
|
||||||
|
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
|
||||||
|
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
|
||||||
|
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
|
||||||
|
|
||||||
|
# 空间注意力模块 (Spatial Attention Module)
|
||||||
|
self.spatial_attn = nn.Sequential(
|
||||||
|
nn.Conv2d(128, 1, kernel_size=1),
|
||||||
|
nn.Sigmoid()
|
||||||
|
)
|
||||||
|
|
||||||
|
# 我们移除了全局池化层 (Global Pooling)。
|
||||||
|
# 7x7 的空间特征图有 128 个通道,展平后是 128 * 7 * 7 = 6272 个特征。
|
||||||
|
# 这使得全连接层能够直接“看到”精确的空间分布规律。
|
||||||
|
self.fc = nn.Sequential(
|
||||||
|
nn.Linear(128 * 7 * 7, 512), # 增加了中间层的维度以适配展平后的输入
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(0.3),
|
||||||
|
nn.Linear(512, 128),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(0.3),
|
||||||
|
nn.Linear(128, 4)
|
||||||
|
)
|
||||||
|
|
||||||
|
self._init_weights()
|
||||||
|
|
||||||
|
def _init_weights(self):
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||||
|
elif isinstance(m, nn.Linear):
|
||||||
|
nn.init.xavier_uniform_(m.weight)
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
c1 = F.relu(self.conv1(x)) # [B, 32, 7, 7]
|
||||||
|
c2 = F.relu(self.conv2(c1)) # [B, 64, 7, 7]
|
||||||
|
c3 = F.relu(self.conv3(c2)) # [B, 128, 7, 7]
|
||||||
|
|
||||||
|
attn = self.spatial_attn(c3) # [B, 1, 7, 7]
|
||||||
|
c3 = c3 * attn # [B, 128, 7, 7]
|
||||||
|
|
||||||
|
# 直接展平 (Flatten) 而不是池化
|
||||||
|
flat_feat = c3.view(c3.size(0), -1) # [B, 6272]
|
||||||
|
|
||||||
|
coords = self.fc(flat_feat) # [B, 4]
|
||||||
|
return coords
|
||||||
Reference in New Issue
Block a user