import sys sys.path.append('./src') import torch import numpy as np from models import * from datasets import * import torch.optim as optim from tqdm import tqdm from torchinfo import summary ### random seed for reproducibility torch.manual_seed(0) np.random.seed(0) modelVersion = '251001_2' # '250910' or '251001' TrainLosses, ValLosses = [], [] TestLoss = -1 def two_point_set_loss_l2(pred_xy, gt_xy): def pair_cost_l2sq(p, q): # p,q: (...,2) return ((p - q)**2).sum(dim=-1) # squared L2 p1, p2 = pred_xy[:,0], pred_xy[:,1] g1, g2 = gt_xy[:,0], gt_xy[:,1] c_a = pair_cost_l2sq(p1,g1) + pair_cost_l2sq(p2,g2) c_b = pair_cost_l2sq(p1,g2) + pair_cost_l2sq(p2,g1) return torch.minimum(c_a, c_b).mean() def min_matching_loss(pred, target): """ pred: [B, 4] -> (x1,y1,x2,y2) target: [B, 4] -> (x1,y1,x2,y2) """ pred = pred.view(-1, 2, 2) # [B, 2, 2] target = target.view(-1, 2, 2) # [B, 2, 2] # 计算所有匹配的MSE loss1 = torch.mean((pred[:,0] - target[:,0])**2 + (pred[:,1] - target[:,1])**2) loss2 = torch.mean((pred[:,0] - target[:,1])**2 + (pred[:,1] - target[:,0])**2) return torch.min(loss1, loss2) # switch modelVersion: if modelVersion == '250910': loss_fn = two_point_set_loss_l2 model = doublePhotonNet_250910().cuda() elif modelVersion == '251001': loss_fn = min_matching_loss model = doublePhotonNet_251001().cuda() elif modelVersion == '251001_2': loss_fn = min_matching_loss model = doublePhotonNet_251001_2().cuda() # summary(model, input_size=(128, 1, 6, 6)) ### print model summary def train(model, trainLoader, optimizer): model.train() batchLoss = 0 for batch_idx, (sample, label) in enumerate(trainLoader): sample, label = sample.cuda(), label.cuda() x1, y1, z1, e1 = label[:,0], label[:,1], label[:,2], label[:,3] x2, y2, z2, e2 = label[:,4], label[:,5], label[:,6], label[:,7] gt_xy = torch.stack((torch.stack((x1, y1), axis=1), torch.stack((x2, y2), axis=1)), axis=1) optimizer.zero_grad() output = model(sample) pred_xy = torch.stack((output[:,0:2], output[:,2:4]), axis=1) loss = loss_fn(pred_xy, gt_xy) loss.backward() optimizer.step() batchLoss += loss.item() * sample.shape[0] avgLoss = batchLoss / len(trainLoader.dataset) /2 ### divide by 2 to get the average loss per photon print(f"[Train]\t Average Loss: {avgLoss:.6f} (RMS = {np.sqrt(avgLoss):.6f})") TrainLosses.append(avgLoss) def test(model, testLoader): model.eval() batchLoss = 0 residuals_x, residuals_y = np.array([]), np.array([]) with torch.no_grad(): for batch_idx, (sample, label) in enumerate(testLoader): sample, label = sample.cuda(), label.cuda() x1, y1, z1, e1 = label[:,0], label[:,1], label[:,2], label[:,3] x2, y2, z2, e2 = label[:,4], label[:,5], label[:,6], label[:,7] gt_xy = torch.stack((torch.stack((x1, y1), axis=1), torch.stack((x2, y2), axis=1)), axis=1) output = model(sample) pred_xy = torch.stack((output[:,0:2], output[:,2:4]), axis=1) loss = loss_fn(pred_xy, gt_xy) batchLoss += loss.item() * sample.shape[0] ### Collect residuals for analysis residuals_x = np.concatenate((residuals_x, (pred_xy[:,0,0] - gt_xy[:,0,0]).cpu().numpy(), (pred_xy[:,1,0] - gt_xy[:,1,0]).cpu().numpy())) residuals_y = np.concatenate((residuals_y, (pred_xy[:,0,1] - gt_xy[:,0,1]).cpu().numpy(), (pred_xy[:,1,1] - gt_xy[:,1,1]).cpu().numpy())) avgLoss = batchLoss / len(testLoader.dataset) datasetName = testLoader.dataset.datasetName print(f"[{datasetName}]\t Average Loss: {avgLoss:.6f} (RMS = {np.sqrt(avgLoss):.6f})") print(f" Residuals X: mean={np.mean(residuals_x):.4f}, std={np.std(residuals_x):.4f}") print(f" Residuals Y: mean={np.mean(residuals_y):.4f}, std={np.std(residuals_y):.4f}") if datasetName == 'Val': ValLosses.append(avgLoss) else: global TestLoss TestLoss = avgLoss return avgLoss trainDataset = doublePhotonDataset( [f'/home/xie_x1/MLXID/McGeneration/Samples/15keV_Moench040_150V_{i}_samples.npy' for i in range(13)], [f'/home/xie_x1/MLXID/McGeneration/Samples/15keV_Moench040_150V_{i}_labels.npy' for i in range(13)], sampleRatio=1.0, datasetName='Train', reuselFactor=1, ) valDataset = doublePhotonDataset( [f'/home/xie_x1/MLXID/McGeneration/Samples/15keV_Moench040_150V_{i}_samples.npy' for i in range(13,14)], [f'/home/xie_x1/MLXID/McGeneration/Samples/15keV_Moench040_150V_{i}_labels.npy' for i in range(13,14)], sampleRatio=1.0, datasetName='Val', reuselFactor=1, ) testDataset = doublePhotonDataset( [f'/home/xie_x1/MLXID/McGeneration/Samples/15keV_Moench040_150V_{i}_samples.npy' for i in range(15,16)], [f'/home/xie_x1/MLXID/McGeneration/Samples/15keV_Moench040_150V_{i}_labels.npy' for i in range(15,16)], sampleRatio=1.0, datasetName='Test', reuselFactor=1, ) trainLoader = torch.utils.data.DataLoader( trainDataset, batch_size=1024, pin_memory = True, shuffle=True, num_workers=16 ) valLoader = torch.utils.data.DataLoader( valDataset, batch_size=4096, shuffle=False, num_workers=16 ) testLoader = torch.utils.data.DataLoader( testDataset, batch_size=4096, shuffle=False, num_workers=16 ) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.7, patience = 5) # if __name__ == "__main__": # for epoch in tqdm(range(1, 301)): # train(model, trainLoader, optimizer) # test(model, valLoader) # scheduler.step(TrainLosses[-1]) model.load_state_dict(torch.load(f'doublePhotonNet_{modelVersion}.pth', weights_only=True)) test(model, testLoader) torch.save(model.state_dict(), f'doublePhotonNet_{modelVersion}.pth') def plot_loss_curve(TrainLosses, ValLosses, TestLoss, modelVersion): import matplotlib.pyplot as plt plt.figure(figsize=(8,6)) plt.plot(TrainLosses, label='Train Loss', color='blue') plt.plot(ValLosses, label='Validation Loss', color='orange') if TestLoss > 0: plt.axhline(y=TestLoss, color='green', linestyle='--', label='Test Loss') plt.yscale('log') plt.xlabel('Epoch') plt.ylabel('MSE Loss') plt.legend() plt.grid() plt.savefig(f'loss_curve_doublePhoton_{modelVersion}.png', dpi=300) # plot_loss_curve(TrainLosses, ValLosses, TestLoss, modelVersion=modelVersion)