commit 3b758ac42eba5f2e029ab9d072c61031406f60c5 Author: xiangyu.xie Date: Wed Oct 8 08:00:24 2025 +0200 Add DL codes diff --git a/Train.py b/Train.py new file mode 100644 index 0000000..47017bc --- /dev/null +++ b/Train.py @@ -0,0 +1,86 @@ +import sys +sys.path.append('./src') +import torch +import numpy as np +from models import * +from datasets import * +import torch.optim as optim + + +from tqdm import tqdm + +TrainLosses, TestLosses = [], [] + +def train(model, trainLoader, optimizer): + model.train() + batchLoss = 0 + for batch_idx, (sample, label) in enumerate(trainLoader): + sample, label = sample.cuda(), label.cuda() + x, y, z, e = label[:,0], label[:,1], label[:,2], label[:,3] + optimizer.zero_grad() + output = model(sample) + loss = torch.nn.functional.mse_loss(output, torch.stack((x, y), axis=1)) + loss.backward() + optimizer.step() + batchLoss += loss.item() * sample.shape[0] + avgLoss = batchLoss / len(trainLoader.dataset) + print(f"[Train] Average Loss: {avgLoss:.6f} (sigma = {np.sqrt(avgLoss):.6f})") + TrainLosses.append(avgLoss) + +def test(model, testLoader): + model.eval() + batchLoss = 0 + with torch.no_grad(): + for batch_idx, (sample, label) in enumerate(testLoader): + sample, label = sample.cuda(), label.cuda() + x, y, z, e = label[:,0], label[:,1], label[:,2], label[:,3] + output = model(sample) + loss = torch.nn.functional.mse_loss(output, torch.stack((x, y), axis=1)) + batchLoss += loss.item() * sample.shape[0] + avgLoss = batchLoss / len(testLoader.dataset) + print(f"[Test] Average Loss: {avgLoss:.6f} (sigma = {np.sqrt(avgLoss):.6f})") + TestLosses.append(avgLoss) + return avgLoss + +model = singlePhotonNet_250909().cuda() +from glob import glob +sampleFileList = glob('/home/xie_x1/MLXID/McGeneration/Samples/15keV_Moench040_150V_*_samples.npy') +labelFileList = glob('/home/xie_x1/MLXID/McGeneration/Samples/15keV_Moench040_150V_*_labels.npy') +trainDataset = singlePhotonDataset(sampleFileList, labelFileList, sampleRatio=0.1) +nTrainSamples = int(0.8 * len(trainDataset)) +nTestSamples = len(trainDataset) - nTrainSamples +trainDataset, testDataset = torch.utils.data.random_split(trainDataset, [nTrainSamples, nTestSamples]) + +trainLoader = torch.utils.data.DataLoader( + trainDataset, + batch_size=1024, + pin_memory = True, + shuffle=True, + num_workers=16 + ) +testLoader = torch.utils.data.DataLoader( + testDataset, + batch_size=4096, + shuffle=False, + num_workers=16 + ) +optimizer = torch.optim.Adam(model.parameters(), lr=0.001) +scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.7, patience = 5, verbose = True) +if __name__ == "__main__": + for epoch in tqdm(range(1, 201)): + train(model, trainLoader, optimizer) + test(model, testLoader) + scheduler.step(TestLosses[-1]) + +torch.save(model.state_dict(), 'singlePhotonNet_250909.pth') + +import matplotlib.pyplot as plt +plt.figure(figsize=(8,6)) +plt.plot(TrainLosses, label='Train Loss') +plt.plot(TestLosses, label='Test Loss') +plt.yscale('log') +plt.xlabel('Epoch') +plt.ylabel('MSE Loss') +plt.legend() +plt.grid() +plt.savefig('loss_curve.png', dpi=300) \ No newline at end of file diff --git a/Train_doublePhoton.py b/Train_doublePhoton.py new file mode 100644 index 0000000..911ba29 --- /dev/null +++ b/Train_doublePhoton.py @@ -0,0 +1,163 @@ +import sys +sys.path.append('./src') +import torch +import numpy as np +from models import * +from datasets import * +import torch.optim as optim +from tqdm import tqdm + +modelVersion = '251001_2' # '250910' or '251001' +TrainLosses, ValLosses = [], [] +TestLoss = -1 + +def two_point_set_loss_l2(pred_xy, gt_xy): + def pair_cost_l2sq(p, q): # p,q: (...,2) + return ((p - q)**2).sum(dim=-1) # squared L2 + p1, p2 = pred_xy[:,0], pred_xy[:,1] + g1, g2 = gt_xy[:,0], gt_xy[:,1] + c_a = pair_cost_l2sq(p1,g1) + pair_cost_l2sq(p2,g2) + c_b = pair_cost_l2sq(p1,g2) + pair_cost_l2sq(p2,g1) + return torch.minimum(c_a, c_b).mean() + +def min_matching_loss(pred, target): + """ + pred: [B, 4] -> (x1,y1,x2,y2) + target: [B, 4] -> (x1,y1,x2,y2) + """ + pred = pred.view(-1, 2, 2) # [B, 2, 2] + target = target.view(-1, 2, 2) # [B, 2, 2] + + # 计算所有匹配的MSE + loss1 = torch.mean((pred[:,0] - target[:,0])**2 + (pred[:,1] - target[:,1])**2) + loss2 = torch.mean((pred[:,0] - target[:,1])**2 + (pred[:,1] - target[:,0])**2) + + return torch.min(loss1, loss2) + +# switch modelVersion: +if modelVersion == '250910': + loss_fn = two_point_set_loss_l2 + model = doublePhotonNet_250910().cuda() +elif modelVersion == '251001': + loss_fn = min_matching_loss + model = doublePhotonNet_251001().cuda() +elif modelVersion == '251001_2': + loss_fn = min_matching_loss + model = doublePhotonNet_251001_2().cuda() + +def train(model, trainLoader, optimizer): + model.train() + batchLoss = 0 + for batch_idx, (sample, label) in enumerate(trainLoader): + sample, label = sample.cuda(), label.cuda() + x1, y1, z1, e1 = label[:,0], label[:,1], label[:,2], label[:,3] + x2, y2, z2, e2 = label[:,4], label[:,5], label[:,6], label[:,7] + gt_xy = torch.stack((torch.stack((x1, y1), axis=1), torch.stack((x2, y2), axis=1)), axis=1) + optimizer.zero_grad() + output = model(sample) + pred_xy = torch.stack((output[:,0:2], output[:,2:4]), axis=1) + loss = loss_fn(pred_xy, gt_xy) + loss.backward() + optimizer.step() + batchLoss += loss.item() * sample.shape[0] + avgLoss = batchLoss / len(trainLoader.dataset) /2 ### divide by 2 to get the average loss per photon + print(f"[Train]\t Average Loss: {avgLoss:.6f} (RMS = {np.sqrt(avgLoss):.6f})") + TrainLosses.append(avgLoss) + +def test(model, testLoader): + model.eval() + batchLoss = 0 + residuals_x, residuals_y = np.array([]), np.array([]) + with torch.no_grad(): + for batch_idx, (sample, label) in enumerate(testLoader): + sample, label = sample.cuda(), label.cuda() + x1, y1, z1, e1 = label[:,0], label[:,1], label[:,2], label[:,3] + x2, y2, z2, e2 = label[:,4], label[:,5], label[:,6], label[:,7] + gt_xy = torch.stack((torch.stack((x1, y1), axis=1), torch.stack((x2, y2), axis=1)), axis=1) + output = model(sample) + pred_xy = torch.stack((output[:,0:2], output[:,2:4]), axis=1) + loss = loss_fn(pred_xy, gt_xy) + batchLoss += loss.item() * sample.shape[0] + ### Collect residuals for analysis + residuals_x = np.concatenate((residuals_x, (pred_xy[:,0,0] - gt_xy[:,0,0]).cpu().numpy(), (pred_xy[:,1,0] - gt_xy[:,1,0]).cpu().numpy())) + residuals_y = np.concatenate((residuals_y, (pred_xy[:,0,1] - gt_xy[:,0,1]).cpu().numpy(), (pred_xy[:,1,1] - gt_xy[:,1,1]).cpu().numpy())) + avgLoss = batchLoss / len(testLoader.dataset) + + datasetName = testLoader.dataset.datasetName + print(f"[{datasetName}]\t Average Loss: {avgLoss:.6f} (RMS = {np.sqrt(avgLoss):.6f})") + print(f" Residuals X: mean={np.mean(residuals_x):.4f}, std={np.std(residuals_x):.4f}") + print(f" Residuals Y: mean={np.mean(residuals_y):.4f}, std={np.std(residuals_y):.4f}") + if datasetName == 'Val': + ValLosses.append(avgLoss) + else: + global TestLoss + TestLoss = avgLoss + return avgLoss + +trainDataset = doublePhotonDataset( + [f'/home/xie_x1/MLXID/McGeneration/Samples/15keV_Moench040_150V_{i}_samples.npy' for i in range(13)], + [f'/home/xie_x1/MLXID/McGeneration/Samples/15keV_Moench040_150V_{i}_labels.npy' for i in range(13)], + sampleRatio=1.0, + datasetName='Train', + reuselFactor=1, + ) +valDataset = doublePhotonDataset( + [f'/home/xie_x1/MLXID/McGeneration/Samples/15keV_Moench040_150V_{i}_samples.npy' for i in range(13,14)], + [f'/home/xie_x1/MLXID/McGeneration/Samples/15keV_Moench040_150V_{i}_labels.npy' for i in range(13,14)], + sampleRatio=1.0, + datasetName='Val', + reuselFactor=1, + ) +testDataset = doublePhotonDataset( + [f'/home/xie_x1/MLXID/McGeneration/Samples/15keV_Moench040_150V_{i}_samples.npy' for i in range(15,16)], + [f'/home/xie_x1/MLXID/McGeneration/Samples/15keV_Moench040_150V_{i}_labels.npy' for i in range(15,16)], + sampleRatio=1.0, + datasetName='Test', + reuselFactor=1, + ) +trainLoader = torch.utils.data.DataLoader( + trainDataset, + batch_size=1024, + pin_memory = True, + shuffle=True, + num_workers=16 + ) +valLoader = torch.utils.data.DataLoader( + valDataset, + batch_size=4096, + shuffle=False, + num_workers=16 + ) +testLoader = torch.utils.data.DataLoader( + testDataset, + batch_size=4096, + shuffle=False, + num_workers=16 + ) +optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4) +scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.7, patience = 5) +# if __name__ == "__main__": +# for epoch in tqdm(range(1, 301)): +# train(model, trainLoader, optimizer) +# test(model, valLoader) +# scheduler.step(TrainLosses[-1]) + +model.load_state_dict(torch.load('doublePhotonNet_251001_2.pth', weights_only=True)) +test(model, testLoader) +torch.save(model.state_dict(), f'doublePhotonNet_{modelVersion}.pth') + +def plot_loss_curve(TrainLosses, ValLosses, TestLoss, modelVersion): + import matplotlib.pyplot as plt + plt.figure(figsize=(8,6)) + plt.plot(TrainLosses, label='Train Loss', color='blue') + plt.plot(ValLosses, label='Validation Loss', color='orange') + if TestLoss > 0: + plt.axhline(y=TestLoss, color='green', linestyle='--', label='Test Loss') + plt.yscale('log') + plt.xlabel('Epoch') + plt.ylabel('MSE Loss') + plt.legend() + plt.grid() + plt.savefig(f'loss_curve_doublePhoton_{modelVersion}.png', dpi=300) + +# plot_loss_curve(TrainLosses, ValLosses, TestLoss, modelVersion=modelVersion) \ No newline at end of file diff --git a/src/datasets.py b/src/datasets.py new file mode 100644 index 0000000..b217bb4 --- /dev/null +++ b/src/datasets.py @@ -0,0 +1,78 @@ +from torch.utils.data import Dataset +import torch +import numpy as np + +class singlePhotonDataset(Dataset): + def __init__(self, sampleList, labelList, sampleRatio): + self.sampleFileList = sampleList + self.labelFileList = labelList + self.sampleRatio = sampleRatio + + for idx, sampleFile in enumerate(self.sampleFileList): + if idx == 0: + self.samples = np.load(sampleFile) + self.labels = np.load(self.labelFileList[idx]) + else: + self.samples = np.concatenate((self.samples, np.load(sampleFile)), axis=0) + self.labels = np.concatenate((self.labels, np.load(self.labelFileList[idx])), axis=0) + + ### total number of samples + self.length = int(self.samples.shape[0] * self.sampleRatio) + print(f"Total number of samples: {self.length}") + + def __getitem__(self, index): + sample = self.samples[index] + sample = np.expand_dims(sample, axis=0) + label = self.labels[index] + return torch.tensor(sample, dtype=torch.float32), torch.tensor(label, dtype=torch.float32) + + def __len__(self): + return self.length + +class doublePhotonDataset(Dataset): + def __init__(self, sampleList, labelList, sampleRatio, datasetName, reuselFactor=1): + self.sampleFileList = sampleList + self.labelFileList = labelList + self.sampleRatio = sampleRatio + self.datasetName = datasetName + + for idx, sampleFile in enumerate(self.sampleFileList): + if idx == 0: + self.samples = np.load(sampleFile) + self.labels = np.load(self.labelFileList[idx]) + else: + self.samples = np.concatenate((self.samples, np.load(sampleFile)), axis=0) + self.labels = np.concatenate((self.labels, np.load(self.labelFileList[idx])), axis=0) + + ### total number of samples + self.length = int(self.samples.shape[0] * self.sampleRatio) // 2 * reuselFactor + print(f"[{self.datasetName} dataset] \t Total number of samples: {self.length}") + + def __getitem__(self, index): + sample = np.zeros((8, 8), dtype=np.float32) + idx1 = np.random.randint(0, self.samples.shape[0]) + idx2 = np.random.randint(0, self.samples.shape[0]) + photon1 = self.samples[idx1] + photon2 = self.samples[idx2] + + ### random position for photons in + pos_x1 = np.random.randint(1, 4) + pos_y1 = np.random.randint(1, 4) + sample[pos_y1:pos_y1+5, pos_x1:pos_x1+5] += photon1 + pos_x2 = np.random.randint(1, 4) + pos_y2 = np.random.randint(1, 4) + sample[pos_y2:pos_y2+5, pos_x2:pos_x2+5] += photon2 + sample = sample[1:-1, 1:-1] ### sample size: 6x6 + sample = torch.tensor(sample, dtype=torch.float32).unsqueeze(0) / 1000. ### to keV + + label1 = self.labels[idx1] + np.array([pos_x1, pos_y1, 0, 0]) - 1 + label2 = self.labels[idx2] + np.array([pos_x2, pos_y2, 0, 0]) - 1 + # label = np.concatenate((label1, label2), axis=0) + if label1[0] < label2[0]: + label = np.concatenate((label1, label2), axis=0) + else: + label = np.concatenate((label2, label1), axis=0) + return sample, torch.tensor(label, dtype=torch.float32) + + def __len__(self): + return self.length \ No newline at end of file diff --git a/src/models.py b/src/models.py new file mode 100644 index 0000000..caf7783 --- /dev/null +++ b/src/models.py @@ -0,0 +1,190 @@ +import torch.nn as nn +import torch.nn.functional as F +import torch +import numpy as np + +class singlePhotonNet_250909(nn.Module): + def weight_init(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.constant_(m.bias, 0) + + def __init__(self): + super(singlePhotonNet_250909, self).__init__() + self.conv1 = nn.Conv2d(1, 5, kernel_size=3, padding=1) + self.conv2 = nn.Conv2d(5, 10, kernel_size=3, padding=1) + self.conv3 = nn.Conv2d(10, 20, kernel_size=3, padding=1) + self.fc = nn.Linear(20*5*5, 2) + + def forward(self, x): + x = F.relu(self.conv1(x)) + x = F.relu(self.conv2(x)) + x = F.relu(self.conv3(x)) + x = x.view(x.size(0), -1) + x = self.fc(x) + return x + +class doublePhotonNet_250909(nn.Module): + def __init__(self): + super(doublePhotonNet_250909, self).__init__() + self.conv1 = nn.Conv2d(1, 3, kernel_size=3, padding=1) + self.conv2 = nn.Conv2d(3, 5, kernel_size=3, padding=1) + self.conv3 = nn.Conv2d(5, 5, kernel_size=3, padding=1) + self.fc1 = nn.Linear(5*6*6, 4) + # self.fc2 = nn.Linear(50, 4) + # 初始化更稳一些 + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, nonlinearity="relu") + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + nn.init.zeros_(m.bias) + + def forward(self, x): + x = F.relu(self.conv1(x)) + x = F.relu(self.conv2(x)) + x = F.relu(self.conv3(x)) + x = x.view(x.size(0), -1) + # x = F.relu(self.fc1(x)) + # x = self.fc2(x) + x = self.fc1(x) + return x +class doublePhotonNet_250910(nn.Module): + def __init__(self): + super(doublePhotonNet_250910, self).__init__() + ### x shape: (B, 1, 6, 6) + self.conv1 = nn.Conv2d(1, 5, kernel_size=5, padding=2) # (B,5,6,6) + self.conv2 = nn.Conv2d(5, 10, kernel_size=5, padding=2) # (B,10,6,6) + self.conv3 = nn.Conv2d(10, 20, kernel_size=3, padding=0) # (B,20,4,4) + self.fc1 = nn.Linear(20*4*4, 4) + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, nonlinearity="relu") + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + nn.init.zeros_(m.bias) + + def forward(self, x): + x = F.relu(self.conv1(x)) + x = F.relu(self.conv2(x)) + x = F.relu(self.conv3(x)) + x = x.view(x.size(0), -1) + # x = F.relu(self.fc1(x)) + # x = self.fc2(x) + x = self.fc1(x) * 6 + return x + +class doublePhotonNet_251001(nn.Module): + def __init__(self): + super().__init__() + # 保持空间分辨率:使用小卷积核 + 无池化 + self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1) # 6x6 -> 6x6 + self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1) # 6x6 -> 6x6 + self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1) # 6x6 -> 6x6 + + # 全局特征提取(替代全连接层) + self.global_avg_pool = nn.AdaptiveAvgPool2d((1,1)) # 64x1x1 + self.global_max_pool = nn.AdaptiveMaxPool2d((1,1)) # 64x1x1 + + # 回归头:输出4个坐标 (x1,y1,x2,y2) + self.fc = nn.Sequential( + nn.Linear(64, 128), + nn.ReLU(), + nn.Dropout(0.3), + nn.Linear(128, 4), # 直接输出坐标 + # nn.Sigmoid() # sigmoid leads to overfitting + ) + # for m in self.modules(): + # if isinstance(m, nn.Conv2d): + # nn.init.kaiming_normal_(m.weight, nonlinearity="relu") + # if isinstance(m, nn.Linear): + # nn.init.xavier_uniform_(m.weight) + # nn.init.zeros_(m.bias) + + def forward(self, x): + x = torch.relu(self.conv1(x)) + x = torch.relu(self.conv2(x)) + x = torch.relu(self.conv3(x)) + # x = self.global_avg_pool(x).view(x.size(0), -1) + x = self.global_max_pool(x).view(x.size(0), -1) + coords = self.fc(x)*6 + return coords # shape: [B, 4] + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class doublePhotonNet_251001_2(nn.Module): + def __init__(self): + super().__init__() + # Backbone: deeper + residual-like blocks + self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1) + self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) + self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) + + # Spatial Attention Module (轻量但有效) + self.spatial_attn = nn.Sequential( + nn.Conv2d(128, 1, kernel_size=1), + nn.Sigmoid() + ) + + # Multi-scale feature fusion (optional but helpful) + self.reduce1 = nn.Conv2d(32, 32, kernel_size=1) # from conv1 + self.reduce2 = nn.Conv2d(64, 32, kernel_size=1) # from conv2 + self.fuse = nn.Conv2d(32*3, 128, kernel_size=1) + + # Global context with both Max and Avg pooling (better than GAP alone) + self.global_max_pool = nn.AdaptiveMaxPool2d((1,1)) + self.global_avg_pool = nn.AdaptiveAvgPool2d((1,1)) + + # Enhanced regression head + self.fc = nn.Sequential( + nn.Linear(128 * 2, 256), # concat max + avg + nn.ReLU(), + nn.Dropout(0.3), + nn.Linear(256, 128), + nn.ReLU(), + nn.Dropout(0.3), + nn.Linear(128, 4), + nn.Sigmoid() # output in [0,1] + ) + + self._init_weights() + + def _init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + nn.init.zeros_(m.bias) + + def forward(self, x): + # Feature extraction + c1 = F.relu(self.conv1(x)) # [B, 32, 6, 6] + c2 = F.relu(self.conv2(c1)) # [B, 64, 6, 6] + c3 = F.relu(self.conv3(c2)) # [B,128, 6, 6] + + # Spatial attention: highlight photon peaks + attn = self.spatial_attn(c3) # [B, 1, 6, 6] + c3 = c3 * attn # reweight features + + # (Optional) Multi-scale fusion — uncomment if needed + # r1 = F.interpolate(self.reduce1(c1), size=(6,6), mode='nearest') + # r2 = self.reduce2(c2) + # fused = torch.cat([r1, r2, c3], dim=1) + # c3 = self.fuse(fused) + + # Global context: MaxPool better captures peaks, Avg for context + g_max = self.global_max_pool(c3).flatten(1) # [B, 128] + g_avg = self.global_avg_pool(c3).flatten(1) # [B, 128] + global_feat = torch.cat([g_max, g_avg], dim=1) # [B, 256] + + # Regression + coords = self.fc(global_feat) * 6.0 # scale to [0,6) + return coords # [B, 4] \ No newline at end of file