Add DL codes
This commit is contained in:
86
Train.py
Normal file
86
Train.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import sys
|
||||
sys.path.append('./src')
|
||||
import torch
|
||||
import numpy as np
|
||||
from models import *
|
||||
from datasets import *
|
||||
import torch.optim as optim
|
||||
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
TrainLosses, TestLosses = [], []
|
||||
|
||||
def train(model, trainLoader, optimizer):
|
||||
model.train()
|
||||
batchLoss = 0
|
||||
for batch_idx, (sample, label) in enumerate(trainLoader):
|
||||
sample, label = sample.cuda(), label.cuda()
|
||||
x, y, z, e = label[:,0], label[:,1], label[:,2], label[:,3]
|
||||
optimizer.zero_grad()
|
||||
output = model(sample)
|
||||
loss = torch.nn.functional.mse_loss(output, torch.stack((x, y), axis=1))
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
batchLoss += loss.item() * sample.shape[0]
|
||||
avgLoss = batchLoss / len(trainLoader.dataset)
|
||||
print(f"[Train] Average Loss: {avgLoss:.6f} (sigma = {np.sqrt(avgLoss):.6f})")
|
||||
TrainLosses.append(avgLoss)
|
||||
|
||||
def test(model, testLoader):
|
||||
model.eval()
|
||||
batchLoss = 0
|
||||
with torch.no_grad():
|
||||
for batch_idx, (sample, label) in enumerate(testLoader):
|
||||
sample, label = sample.cuda(), label.cuda()
|
||||
x, y, z, e = label[:,0], label[:,1], label[:,2], label[:,3]
|
||||
output = model(sample)
|
||||
loss = torch.nn.functional.mse_loss(output, torch.stack((x, y), axis=1))
|
||||
batchLoss += loss.item() * sample.shape[0]
|
||||
avgLoss = batchLoss / len(testLoader.dataset)
|
||||
print(f"[Test] Average Loss: {avgLoss:.6f} (sigma = {np.sqrt(avgLoss):.6f})")
|
||||
TestLosses.append(avgLoss)
|
||||
return avgLoss
|
||||
|
||||
model = singlePhotonNet_250909().cuda()
|
||||
from glob import glob
|
||||
sampleFileList = glob('/home/xie_x1/MLXID/McGeneration/Samples/15keV_Moench040_150V_*_samples.npy')
|
||||
labelFileList = glob('/home/xie_x1/MLXID/McGeneration/Samples/15keV_Moench040_150V_*_labels.npy')
|
||||
trainDataset = singlePhotonDataset(sampleFileList, labelFileList, sampleRatio=0.1)
|
||||
nTrainSamples = int(0.8 * len(trainDataset))
|
||||
nTestSamples = len(trainDataset) - nTrainSamples
|
||||
trainDataset, testDataset = torch.utils.data.random_split(trainDataset, [nTrainSamples, nTestSamples])
|
||||
|
||||
trainLoader = torch.utils.data.DataLoader(
|
||||
trainDataset,
|
||||
batch_size=1024,
|
||||
pin_memory = True,
|
||||
shuffle=True,
|
||||
num_workers=16
|
||||
)
|
||||
testLoader = torch.utils.data.DataLoader(
|
||||
testDataset,
|
||||
batch_size=4096,
|
||||
shuffle=False,
|
||||
num_workers=16
|
||||
)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.7, patience = 5, verbose = True)
|
||||
if __name__ == "__main__":
|
||||
for epoch in tqdm(range(1, 201)):
|
||||
train(model, trainLoader, optimizer)
|
||||
test(model, testLoader)
|
||||
scheduler.step(TestLosses[-1])
|
||||
|
||||
torch.save(model.state_dict(), 'singlePhotonNet_250909.pth')
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
plt.figure(figsize=(8,6))
|
||||
plt.plot(TrainLosses, label='Train Loss')
|
||||
plt.plot(TestLosses, label='Test Loss')
|
||||
plt.yscale('log')
|
||||
plt.xlabel('Epoch')
|
||||
plt.ylabel('MSE Loss')
|
||||
plt.legend()
|
||||
plt.grid()
|
||||
plt.savefig('loss_curve.png', dpi=300)
|
||||
163
Train_doublePhoton.py
Normal file
163
Train_doublePhoton.py
Normal file
@@ -0,0 +1,163 @@
|
||||
import sys
|
||||
sys.path.append('./src')
|
||||
import torch
|
||||
import numpy as np
|
||||
from models import *
|
||||
from datasets import *
|
||||
import torch.optim as optim
|
||||
from tqdm import tqdm
|
||||
|
||||
modelVersion = '251001_2' # '250910' or '251001'
|
||||
TrainLosses, ValLosses = [], []
|
||||
TestLoss = -1
|
||||
|
||||
def two_point_set_loss_l2(pred_xy, gt_xy):
|
||||
def pair_cost_l2sq(p, q): # p,q: (...,2)
|
||||
return ((p - q)**2).sum(dim=-1) # squared L2
|
||||
p1, p2 = pred_xy[:,0], pred_xy[:,1]
|
||||
g1, g2 = gt_xy[:,0], gt_xy[:,1]
|
||||
c_a = pair_cost_l2sq(p1,g1) + pair_cost_l2sq(p2,g2)
|
||||
c_b = pair_cost_l2sq(p1,g2) + pair_cost_l2sq(p2,g1)
|
||||
return torch.minimum(c_a, c_b).mean()
|
||||
|
||||
def min_matching_loss(pred, target):
|
||||
"""
|
||||
pred: [B, 4] -> (x1,y1,x2,y2)
|
||||
target: [B, 4] -> (x1,y1,x2,y2)
|
||||
"""
|
||||
pred = pred.view(-1, 2, 2) # [B, 2, 2]
|
||||
target = target.view(-1, 2, 2) # [B, 2, 2]
|
||||
|
||||
# 计算所有匹配的MSE
|
||||
loss1 = torch.mean((pred[:,0] - target[:,0])**2 + (pred[:,1] - target[:,1])**2)
|
||||
loss2 = torch.mean((pred[:,0] - target[:,1])**2 + (pred[:,1] - target[:,0])**2)
|
||||
|
||||
return torch.min(loss1, loss2)
|
||||
|
||||
# switch modelVersion:
|
||||
if modelVersion == '250910':
|
||||
loss_fn = two_point_set_loss_l2
|
||||
model = doublePhotonNet_250910().cuda()
|
||||
elif modelVersion == '251001':
|
||||
loss_fn = min_matching_loss
|
||||
model = doublePhotonNet_251001().cuda()
|
||||
elif modelVersion == '251001_2':
|
||||
loss_fn = min_matching_loss
|
||||
model = doublePhotonNet_251001_2().cuda()
|
||||
|
||||
def train(model, trainLoader, optimizer):
|
||||
model.train()
|
||||
batchLoss = 0
|
||||
for batch_idx, (sample, label) in enumerate(trainLoader):
|
||||
sample, label = sample.cuda(), label.cuda()
|
||||
x1, y1, z1, e1 = label[:,0], label[:,1], label[:,2], label[:,3]
|
||||
x2, y2, z2, e2 = label[:,4], label[:,5], label[:,6], label[:,7]
|
||||
gt_xy = torch.stack((torch.stack((x1, y1), axis=1), torch.stack((x2, y2), axis=1)), axis=1)
|
||||
optimizer.zero_grad()
|
||||
output = model(sample)
|
||||
pred_xy = torch.stack((output[:,0:2], output[:,2:4]), axis=1)
|
||||
loss = loss_fn(pred_xy, gt_xy)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
batchLoss += loss.item() * sample.shape[0]
|
||||
avgLoss = batchLoss / len(trainLoader.dataset) /2 ### divide by 2 to get the average loss per photon
|
||||
print(f"[Train]\t Average Loss: {avgLoss:.6f} (RMS = {np.sqrt(avgLoss):.6f})")
|
||||
TrainLosses.append(avgLoss)
|
||||
|
||||
def test(model, testLoader):
|
||||
model.eval()
|
||||
batchLoss = 0
|
||||
residuals_x, residuals_y = np.array([]), np.array([])
|
||||
with torch.no_grad():
|
||||
for batch_idx, (sample, label) in enumerate(testLoader):
|
||||
sample, label = sample.cuda(), label.cuda()
|
||||
x1, y1, z1, e1 = label[:,0], label[:,1], label[:,2], label[:,3]
|
||||
x2, y2, z2, e2 = label[:,4], label[:,5], label[:,6], label[:,7]
|
||||
gt_xy = torch.stack((torch.stack((x1, y1), axis=1), torch.stack((x2, y2), axis=1)), axis=1)
|
||||
output = model(sample)
|
||||
pred_xy = torch.stack((output[:,0:2], output[:,2:4]), axis=1)
|
||||
loss = loss_fn(pred_xy, gt_xy)
|
||||
batchLoss += loss.item() * sample.shape[0]
|
||||
### Collect residuals for analysis
|
||||
residuals_x = np.concatenate((residuals_x, (pred_xy[:,0,0] - gt_xy[:,0,0]).cpu().numpy(), (pred_xy[:,1,0] - gt_xy[:,1,0]).cpu().numpy()))
|
||||
residuals_y = np.concatenate((residuals_y, (pred_xy[:,0,1] - gt_xy[:,0,1]).cpu().numpy(), (pred_xy[:,1,1] - gt_xy[:,1,1]).cpu().numpy()))
|
||||
avgLoss = batchLoss / len(testLoader.dataset)
|
||||
|
||||
datasetName = testLoader.dataset.datasetName
|
||||
print(f"[{datasetName}]\t Average Loss: {avgLoss:.6f} (RMS = {np.sqrt(avgLoss):.6f})")
|
||||
print(f" Residuals X: mean={np.mean(residuals_x):.4f}, std={np.std(residuals_x):.4f}")
|
||||
print(f" Residuals Y: mean={np.mean(residuals_y):.4f}, std={np.std(residuals_y):.4f}")
|
||||
if datasetName == 'Val':
|
||||
ValLosses.append(avgLoss)
|
||||
else:
|
||||
global TestLoss
|
||||
TestLoss = avgLoss
|
||||
return avgLoss
|
||||
|
||||
trainDataset = doublePhotonDataset(
|
||||
[f'/home/xie_x1/MLXID/McGeneration/Samples/15keV_Moench040_150V_{i}_samples.npy' for i in range(13)],
|
||||
[f'/home/xie_x1/MLXID/McGeneration/Samples/15keV_Moench040_150V_{i}_labels.npy' for i in range(13)],
|
||||
sampleRatio=1.0,
|
||||
datasetName='Train',
|
||||
reuselFactor=1,
|
||||
)
|
||||
valDataset = doublePhotonDataset(
|
||||
[f'/home/xie_x1/MLXID/McGeneration/Samples/15keV_Moench040_150V_{i}_samples.npy' for i in range(13,14)],
|
||||
[f'/home/xie_x1/MLXID/McGeneration/Samples/15keV_Moench040_150V_{i}_labels.npy' for i in range(13,14)],
|
||||
sampleRatio=1.0,
|
||||
datasetName='Val',
|
||||
reuselFactor=1,
|
||||
)
|
||||
testDataset = doublePhotonDataset(
|
||||
[f'/home/xie_x1/MLXID/McGeneration/Samples/15keV_Moench040_150V_{i}_samples.npy' for i in range(15,16)],
|
||||
[f'/home/xie_x1/MLXID/McGeneration/Samples/15keV_Moench040_150V_{i}_labels.npy' for i in range(15,16)],
|
||||
sampleRatio=1.0,
|
||||
datasetName='Test',
|
||||
reuselFactor=1,
|
||||
)
|
||||
trainLoader = torch.utils.data.DataLoader(
|
||||
trainDataset,
|
||||
batch_size=1024,
|
||||
pin_memory = True,
|
||||
shuffle=True,
|
||||
num_workers=16
|
||||
)
|
||||
valLoader = torch.utils.data.DataLoader(
|
||||
valDataset,
|
||||
batch_size=4096,
|
||||
shuffle=False,
|
||||
num_workers=16
|
||||
)
|
||||
testLoader = torch.utils.data.DataLoader(
|
||||
testDataset,
|
||||
batch_size=4096,
|
||||
shuffle=False,
|
||||
num_workers=16
|
||||
)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
|
||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.7, patience = 5)
|
||||
# if __name__ == "__main__":
|
||||
# for epoch in tqdm(range(1, 301)):
|
||||
# train(model, trainLoader, optimizer)
|
||||
# test(model, valLoader)
|
||||
# scheduler.step(TrainLosses[-1])
|
||||
|
||||
model.load_state_dict(torch.load('doublePhotonNet_251001_2.pth', weights_only=True))
|
||||
test(model, testLoader)
|
||||
torch.save(model.state_dict(), f'doublePhotonNet_{modelVersion}.pth')
|
||||
|
||||
def plot_loss_curve(TrainLosses, ValLosses, TestLoss, modelVersion):
|
||||
import matplotlib.pyplot as plt
|
||||
plt.figure(figsize=(8,6))
|
||||
plt.plot(TrainLosses, label='Train Loss', color='blue')
|
||||
plt.plot(ValLosses, label='Validation Loss', color='orange')
|
||||
if TestLoss > 0:
|
||||
plt.axhline(y=TestLoss, color='green', linestyle='--', label='Test Loss')
|
||||
plt.yscale('log')
|
||||
plt.xlabel('Epoch')
|
||||
plt.ylabel('MSE Loss')
|
||||
plt.legend()
|
||||
plt.grid()
|
||||
plt.savefig(f'loss_curve_doublePhoton_{modelVersion}.png', dpi=300)
|
||||
|
||||
# plot_loss_curve(TrainLosses, ValLosses, TestLoss, modelVersion=modelVersion)
|
||||
78
src/datasets.py
Normal file
78
src/datasets.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from torch.utils.data import Dataset
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
class singlePhotonDataset(Dataset):
|
||||
def __init__(self, sampleList, labelList, sampleRatio):
|
||||
self.sampleFileList = sampleList
|
||||
self.labelFileList = labelList
|
||||
self.sampleRatio = sampleRatio
|
||||
|
||||
for idx, sampleFile in enumerate(self.sampleFileList):
|
||||
if idx == 0:
|
||||
self.samples = np.load(sampleFile)
|
||||
self.labels = np.load(self.labelFileList[idx])
|
||||
else:
|
||||
self.samples = np.concatenate((self.samples, np.load(sampleFile)), axis=0)
|
||||
self.labels = np.concatenate((self.labels, np.load(self.labelFileList[idx])), axis=0)
|
||||
|
||||
### total number of samples
|
||||
self.length = int(self.samples.shape[0] * self.sampleRatio)
|
||||
print(f"Total number of samples: {self.length}")
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.samples[index]
|
||||
sample = np.expand_dims(sample, axis=0)
|
||||
label = self.labels[index]
|
||||
return torch.tensor(sample, dtype=torch.float32), torch.tensor(label, dtype=torch.float32)
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
class doublePhotonDataset(Dataset):
|
||||
def __init__(self, sampleList, labelList, sampleRatio, datasetName, reuselFactor=1):
|
||||
self.sampleFileList = sampleList
|
||||
self.labelFileList = labelList
|
||||
self.sampleRatio = sampleRatio
|
||||
self.datasetName = datasetName
|
||||
|
||||
for idx, sampleFile in enumerate(self.sampleFileList):
|
||||
if idx == 0:
|
||||
self.samples = np.load(sampleFile)
|
||||
self.labels = np.load(self.labelFileList[idx])
|
||||
else:
|
||||
self.samples = np.concatenate((self.samples, np.load(sampleFile)), axis=0)
|
||||
self.labels = np.concatenate((self.labels, np.load(self.labelFileList[idx])), axis=0)
|
||||
|
||||
### total number of samples
|
||||
self.length = int(self.samples.shape[0] * self.sampleRatio) // 2 * reuselFactor
|
||||
print(f"[{self.datasetName} dataset] \t Total number of samples: {self.length}")
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = np.zeros((8, 8), dtype=np.float32)
|
||||
idx1 = np.random.randint(0, self.samples.shape[0])
|
||||
idx2 = np.random.randint(0, self.samples.shape[0])
|
||||
photon1 = self.samples[idx1]
|
||||
photon2 = self.samples[idx2]
|
||||
|
||||
### random position for photons in
|
||||
pos_x1 = np.random.randint(1, 4)
|
||||
pos_y1 = np.random.randint(1, 4)
|
||||
sample[pos_y1:pos_y1+5, pos_x1:pos_x1+5] += photon1
|
||||
pos_x2 = np.random.randint(1, 4)
|
||||
pos_y2 = np.random.randint(1, 4)
|
||||
sample[pos_y2:pos_y2+5, pos_x2:pos_x2+5] += photon2
|
||||
sample = sample[1:-1, 1:-1] ### sample size: 6x6
|
||||
sample = torch.tensor(sample, dtype=torch.float32).unsqueeze(0) / 1000. ### to keV
|
||||
|
||||
label1 = self.labels[idx1] + np.array([pos_x1, pos_y1, 0, 0]) - 1
|
||||
label2 = self.labels[idx2] + np.array([pos_x2, pos_y2, 0, 0]) - 1
|
||||
# label = np.concatenate((label1, label2), axis=0)
|
||||
if label1[0] < label2[0]:
|
||||
label = np.concatenate((label1, label2), axis=0)
|
||||
else:
|
||||
label = np.concatenate((label2, label1), axis=0)
|
||||
return sample, torch.tensor(label, dtype=torch.float32)
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
190
src/models.py
Normal file
190
src/models.py
Normal file
@@ -0,0 +1,190 @@
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
class singlePhotonNet_250909(nn.Module):
|
||||
def weight_init(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def __init__(self):
|
||||
super(singlePhotonNet_250909, self).__init__()
|
||||
self.conv1 = nn.Conv2d(1, 5, kernel_size=3, padding=1)
|
||||
self.conv2 = nn.Conv2d(5, 10, kernel_size=3, padding=1)
|
||||
self.conv3 = nn.Conv2d(10, 20, kernel_size=3, padding=1)
|
||||
self.fc = nn.Linear(20*5*5, 2)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.relu(self.conv1(x))
|
||||
x = F.relu(self.conv2(x))
|
||||
x = F.relu(self.conv3(x))
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
class doublePhotonNet_250909(nn.Module):
|
||||
def __init__(self):
|
||||
super(doublePhotonNet_250909, self).__init__()
|
||||
self.conv1 = nn.Conv2d(1, 3, kernel_size=3, padding=1)
|
||||
self.conv2 = nn.Conv2d(3, 5, kernel_size=3, padding=1)
|
||||
self.conv3 = nn.Conv2d(5, 5, kernel_size=3, padding=1)
|
||||
self.fc1 = nn.Linear(5*6*6, 4)
|
||||
# self.fc2 = nn.Linear(50, 4)
|
||||
# 初始化更稳一些
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.xavier_uniform_(m.weight)
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.relu(self.conv1(x))
|
||||
x = F.relu(self.conv2(x))
|
||||
x = F.relu(self.conv3(x))
|
||||
x = x.view(x.size(0), -1)
|
||||
# x = F.relu(self.fc1(x))
|
||||
# x = self.fc2(x)
|
||||
x = self.fc1(x)
|
||||
return x
|
||||
class doublePhotonNet_250910(nn.Module):
|
||||
def __init__(self):
|
||||
super(doublePhotonNet_250910, self).__init__()
|
||||
### x shape: (B, 1, 6, 6)
|
||||
self.conv1 = nn.Conv2d(1, 5, kernel_size=5, padding=2) # (B,5,6,6)
|
||||
self.conv2 = nn.Conv2d(5, 10, kernel_size=5, padding=2) # (B,10,6,6)
|
||||
self.conv3 = nn.Conv2d(10, 20, kernel_size=3, padding=0) # (B,20,4,4)
|
||||
self.fc1 = nn.Linear(20*4*4, 4)
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.xavier_uniform_(m.weight)
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.relu(self.conv1(x))
|
||||
x = F.relu(self.conv2(x))
|
||||
x = F.relu(self.conv3(x))
|
||||
x = x.view(x.size(0), -1)
|
||||
# x = F.relu(self.fc1(x))
|
||||
# x = self.fc2(x)
|
||||
x = self.fc1(x) * 6
|
||||
return x
|
||||
|
||||
class doublePhotonNet_251001(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# 保持空间分辨率:使用小卷积核 + 无池化
|
||||
self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1) # 6x6 -> 6x6
|
||||
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1) # 6x6 -> 6x6
|
||||
self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1) # 6x6 -> 6x6
|
||||
|
||||
# 全局特征提取(替代全连接层)
|
||||
self.global_avg_pool = nn.AdaptiveAvgPool2d((1,1)) # 64x1x1
|
||||
self.global_max_pool = nn.AdaptiveMaxPool2d((1,1)) # 64x1x1
|
||||
|
||||
# 回归头:输出4个坐标 (x1,y1,x2,y2)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(64, 128),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(128, 4), # 直接输出坐标
|
||||
# nn.Sigmoid() # sigmoid leads to overfitting
|
||||
)
|
||||
# for m in self.modules():
|
||||
# if isinstance(m, nn.Conv2d):
|
||||
# nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
||||
# if isinstance(m, nn.Linear):
|
||||
# nn.init.xavier_uniform_(m.weight)
|
||||
# nn.init.zeros_(m.bias)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.relu(self.conv1(x))
|
||||
x = torch.relu(self.conv2(x))
|
||||
x = torch.relu(self.conv3(x))
|
||||
# x = self.global_avg_pool(x).view(x.size(0), -1)
|
||||
x = self.global_max_pool(x).view(x.size(0), -1)
|
||||
coords = self.fc(x)*6
|
||||
return coords # shape: [B, 4]
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class doublePhotonNet_251001_2(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Backbone: deeper + residual-like blocks
|
||||
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
|
||||
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
|
||||
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
|
||||
|
||||
# Spatial Attention Module (轻量但有效)
|
||||
self.spatial_attn = nn.Sequential(
|
||||
nn.Conv2d(128, 1, kernel_size=1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
# Multi-scale feature fusion (optional but helpful)
|
||||
self.reduce1 = nn.Conv2d(32, 32, kernel_size=1) # from conv1
|
||||
self.reduce2 = nn.Conv2d(64, 32, kernel_size=1) # from conv2
|
||||
self.fuse = nn.Conv2d(32*3, 128, kernel_size=1)
|
||||
|
||||
# Global context with both Max and Avg pooling (better than GAP alone)
|
||||
self.global_max_pool = nn.AdaptiveMaxPool2d((1,1))
|
||||
self.global_avg_pool = nn.AdaptiveAvgPool2d((1,1))
|
||||
|
||||
# Enhanced regression head
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(128 * 2, 256), # concat max + avg
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(256, 128),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(128, 4),
|
||||
nn.Sigmoid() # output in [0,1]
|
||||
)
|
||||
|
||||
self._init_weights()
|
||||
|
||||
def _init_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.xavier_uniform_(m.weight)
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
def forward(self, x):
|
||||
# Feature extraction
|
||||
c1 = F.relu(self.conv1(x)) # [B, 32, 6, 6]
|
||||
c2 = F.relu(self.conv2(c1)) # [B, 64, 6, 6]
|
||||
c3 = F.relu(self.conv3(c2)) # [B,128, 6, 6]
|
||||
|
||||
# Spatial attention: highlight photon peaks
|
||||
attn = self.spatial_attn(c3) # [B, 1, 6, 6]
|
||||
c3 = c3 * attn # reweight features
|
||||
|
||||
# (Optional) Multi-scale fusion — uncomment if needed
|
||||
# r1 = F.interpolate(self.reduce1(c1), size=(6,6), mode='nearest')
|
||||
# r2 = self.reduce2(c2)
|
||||
# fused = torch.cat([r1, r2, c3], dim=1)
|
||||
# c3 = self.fuse(fused)
|
||||
|
||||
# Global context: MaxPool better captures peaks, Avg for context
|
||||
g_max = self.global_max_pool(c3).flatten(1) # [B, 128]
|
||||
g_avg = self.global_avg_pool(c3).flatten(1) # [B, 128]
|
||||
global_feat = torch.cat([g_max, g_avg], dim=1) # [B, 256]
|
||||
|
||||
# Regression
|
||||
coords = self.fc(global_feat) * 6.0 # scale to [0,6)
|
||||
return coords # [B, 4]
|
||||
Reference in New Issue
Block a user