170 lines
6.5 KiB
Python
170 lines
6.5 KiB
Python
import sys
|
|
sys.path.append('./src')
|
|
import torch
|
|
import numpy as np
|
|
from models import *
|
|
from datasets import *
|
|
import torch.optim as optim
|
|
from tqdm import tqdm
|
|
from torchinfo import summary
|
|
|
|
### random seed for reproducibility
|
|
torch.manual_seed(0)
|
|
np.random.seed(0)
|
|
|
|
modelVersion = '251001_2' # '250910' or '251001'
|
|
TrainLosses, ValLosses = [], []
|
|
TestLoss = -1
|
|
|
|
def two_point_set_loss_l2(pred_xy, gt_xy):
|
|
def pair_cost_l2sq(p, q): # p,q: (...,2)
|
|
return ((p - q)**2).sum(dim=-1) # squared L2
|
|
p1, p2 = pred_xy[:,0], pred_xy[:,1]
|
|
g1, g2 = gt_xy[:,0], gt_xy[:,1]
|
|
c_a = pair_cost_l2sq(p1,g1) + pair_cost_l2sq(p2,g2)
|
|
c_b = pair_cost_l2sq(p1,g2) + pair_cost_l2sq(p2,g1)
|
|
return torch.minimum(c_a, c_b).mean()
|
|
|
|
def min_matching_loss(pred, target):
|
|
"""
|
|
pred: [B, 4] -> (x1,y1,x2,y2)
|
|
target: [B, 4] -> (x1,y1,x2,y2)
|
|
"""
|
|
pred = pred.view(-1, 2, 2) # [B, 2, 2]
|
|
target = target.view(-1, 2, 2) # [B, 2, 2]
|
|
|
|
# 计算所有匹配的MSE
|
|
loss1 = torch.mean((pred[:,0] - target[:,0])**2 + (pred[:,1] - target[:,1])**2)
|
|
loss2 = torch.mean((pred[:,0] - target[:,1])**2 + (pred[:,1] - target[:,0])**2)
|
|
|
|
return torch.min(loss1, loss2)
|
|
|
|
# switch modelVersion:
|
|
if modelVersion == '250910':
|
|
loss_fn = two_point_set_loss_l2
|
|
model = doublePhotonNet_250910().cuda()
|
|
elif modelVersion == '251001':
|
|
loss_fn = min_matching_loss
|
|
model = doublePhotonNet_251001().cuda()
|
|
elif modelVersion == '251001_2':
|
|
loss_fn = min_matching_loss
|
|
model = doublePhotonNet_251001_2().cuda()
|
|
|
|
# summary(model, input_size=(128, 1, 6, 6)) ### print model summary
|
|
|
|
def train(model, trainLoader, optimizer):
|
|
model.train()
|
|
batchLoss = 0
|
|
for batch_idx, (sample, label) in enumerate(trainLoader):
|
|
sample, label = sample.cuda(), label.cuda()
|
|
x1, y1, z1, e1 = label[:,0], label[:,1], label[:,2], label[:,3]
|
|
x2, y2, z2, e2 = label[:,4], label[:,5], label[:,6], label[:,7]
|
|
gt_xy = torch.stack((torch.stack((x1, y1), axis=1), torch.stack((x2, y2), axis=1)), axis=1)
|
|
optimizer.zero_grad()
|
|
output = model(sample)
|
|
pred_xy = torch.stack((output[:,0:2], output[:,2:4]), axis=1)
|
|
loss = loss_fn(pred_xy, gt_xy)
|
|
loss.backward()
|
|
optimizer.step()
|
|
batchLoss += loss.item() * sample.shape[0]
|
|
avgLoss = batchLoss / len(trainLoader.dataset) /2 ### divide by 2 to get the average loss per photon
|
|
print(f"[Train]\t Average Loss: {avgLoss:.6f} (RMS = {np.sqrt(avgLoss):.6f})")
|
|
TrainLosses.append(avgLoss)
|
|
|
|
def test(model, testLoader):
|
|
model.eval()
|
|
batchLoss = 0
|
|
residuals_x, residuals_y = np.array([]), np.array([])
|
|
with torch.no_grad():
|
|
for batch_idx, (sample, label) in enumerate(testLoader):
|
|
sample, label = sample.cuda(), label.cuda()
|
|
x1, y1, z1, e1 = label[:,0], label[:,1], label[:,2], label[:,3]
|
|
x2, y2, z2, e2 = label[:,4], label[:,5], label[:,6], label[:,7]
|
|
gt_xy = torch.stack((torch.stack((x1, y1), axis=1), torch.stack((x2, y2), axis=1)), axis=1)
|
|
output = model(sample)
|
|
pred_xy = torch.stack((output[:,0:2], output[:,2:4]), axis=1)
|
|
loss = loss_fn(pred_xy, gt_xy)
|
|
batchLoss += loss.item() * sample.shape[0]
|
|
### Collect residuals for analysis
|
|
residuals_x = np.concatenate((residuals_x, (pred_xy[:,0,0] - gt_xy[:,0,0]).cpu().numpy(), (pred_xy[:,1,0] - gt_xy[:,1,0]).cpu().numpy()))
|
|
residuals_y = np.concatenate((residuals_y, (pred_xy[:,0,1] - gt_xy[:,0,1]).cpu().numpy(), (pred_xy[:,1,1] - gt_xy[:,1,1]).cpu().numpy()))
|
|
avgLoss = batchLoss / len(testLoader.dataset)
|
|
|
|
datasetName = testLoader.dataset.datasetName
|
|
print(f"[{datasetName}]\t Average Loss: {avgLoss:.6f} (RMS = {np.sqrt(avgLoss):.6f})")
|
|
print(f" Residuals X: mean={np.mean(residuals_x):.4f}, std={np.std(residuals_x):.4f}")
|
|
print(f" Residuals Y: mean={np.mean(residuals_y):.4f}, std={np.std(residuals_y):.4f}")
|
|
if datasetName == 'Val':
|
|
ValLosses.append(avgLoss)
|
|
else:
|
|
global TestLoss
|
|
TestLoss = avgLoss
|
|
return avgLoss
|
|
|
|
trainDataset = doublePhotonDataset(
|
|
[f'/home/xie_x1/MLXID/McGeneration/Samples/15keV_Moench040_150V_{i}_samples.npy' for i in range(13)],
|
|
[f'/home/xie_x1/MLXID/McGeneration/Samples/15keV_Moench040_150V_{i}_labels.npy' for i in range(13)],
|
|
sampleRatio=1.0,
|
|
datasetName='Train',
|
|
reuselFactor=1,
|
|
)
|
|
valDataset = doublePhotonDataset(
|
|
[f'/home/xie_x1/MLXID/McGeneration/Samples/15keV_Moench040_150V_{i}_samples.npy' for i in range(13,14)],
|
|
[f'/home/xie_x1/MLXID/McGeneration/Samples/15keV_Moench040_150V_{i}_labels.npy' for i in range(13,14)],
|
|
sampleRatio=1.0,
|
|
datasetName='Val',
|
|
reuselFactor=1,
|
|
)
|
|
testDataset = doublePhotonDataset(
|
|
[f'/home/xie_x1/MLXID/McGeneration/Samples/15keV_Moench040_150V_{i}_samples.npy' for i in range(15,16)],
|
|
[f'/home/xie_x1/MLXID/McGeneration/Samples/15keV_Moench040_150V_{i}_labels.npy' for i in range(15,16)],
|
|
sampleRatio=1.0,
|
|
datasetName='Test',
|
|
reuselFactor=1,
|
|
)
|
|
trainLoader = torch.utils.data.DataLoader(
|
|
trainDataset,
|
|
batch_size=1024,
|
|
pin_memory = True,
|
|
shuffle=True,
|
|
num_workers=16
|
|
)
|
|
valLoader = torch.utils.data.DataLoader(
|
|
valDataset,
|
|
batch_size=4096,
|
|
shuffle=False,
|
|
num_workers=16
|
|
)
|
|
testLoader = torch.utils.data.DataLoader(
|
|
testDataset,
|
|
batch_size=4096,
|
|
shuffle=False,
|
|
num_workers=16
|
|
)
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
|
|
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.7, patience = 5)
|
|
# if __name__ == "__main__":
|
|
# for epoch in tqdm(range(1, 301)):
|
|
# train(model, trainLoader, optimizer)
|
|
# test(model, valLoader)
|
|
# scheduler.step(TrainLosses[-1])
|
|
|
|
model.load_state_dict(torch.load(f'doublePhotonNet_{modelVersion}.pth', weights_only=True))
|
|
test(model, testLoader)
|
|
torch.save(model.state_dict(), f'doublePhotonNet_{modelVersion}.pth')
|
|
|
|
def plot_loss_curve(TrainLosses, ValLosses, TestLoss, modelVersion):
|
|
import matplotlib.pyplot as plt
|
|
plt.figure(figsize=(8,6))
|
|
plt.plot(TrainLosses, label='Train Loss', color='blue')
|
|
plt.plot(ValLosses, label='Validation Loss', color='orange')
|
|
if TestLoss > 0:
|
|
plt.axhline(y=TestLoss, color='green', linestyle='--', label='Test Loss')
|
|
plt.yscale('log')
|
|
plt.xlabel('Epoch')
|
|
plt.ylabel('MSE Loss')
|
|
plt.legend()
|
|
plt.grid()
|
|
plt.savefig(f'loss_curve_doublePhoton_{modelVersion}.png', dpi=300)
|
|
|
|
# plot_loss_curve(TrainLosses, ValLosses, TestLoss, modelVersion=modelVersion) |