Old inference_2Photon replaced by new one
This commit is contained in:
@@ -1,196 +0,0 @@
|
||||
import sys
|
||||
sys.path.append('./src')
|
||||
import torch
|
||||
import numpy as np
|
||||
import models
|
||||
from datasets import *
|
||||
import torch.optim as optim
|
||||
from tqdm import tqdm
|
||||
from torchinfo import summary
|
||||
|
||||
### random seed for reproducibility
|
||||
torch.manual_seed(0)
|
||||
np.random.seed(0)
|
||||
|
||||
modelVersion = '251124' # '250910' or '251001' or '251124'
|
||||
model = models.get_double_photon_model_class(modelVersion)().cuda()
|
||||
Energy = '12keV'
|
||||
TrainLosses, ValLosses = [], []
|
||||
TestLoss = -1
|
||||
LearningRate = 1e-3
|
||||
Noise = 0.13 # in keV
|
||||
NoiseThreshold = 3*Noise
|
||||
Normalize = False
|
||||
|
||||
def two_point_set_loss_l2(pred_xy, gt_xy):
|
||||
def pair_cost_l2sq(p, q): # p,q: (...,2)
|
||||
return ((p - q)**2).sum(dim=-1) # squared L2
|
||||
p1, p2 = pred_xy[:,0], pred_xy[:,1]
|
||||
g1, g2 = gt_xy[:,0], gt_xy[:,1]
|
||||
c_a = pair_cost_l2sq(p1,g1) + pair_cost_l2sq(p2,g2)
|
||||
c_b = pair_cost_l2sq(p1,g2) + pair_cost_l2sq(p2,g1)
|
||||
return torch.minimum(c_a, c_b).mean()
|
||||
|
||||
# summary(model, input_size=(128, 1, 6, 6)) ### print model summary
|
||||
loss_fn = two_point_set_loss_l2
|
||||
|
||||
def train(model, trainLoader, optimizer):
|
||||
model.train()
|
||||
batchLoss = 0
|
||||
for batch_idx, (sample, label) in enumerate(trainLoader):
|
||||
sample, label = sample.cuda(), label.cuda()
|
||||
x1, y1, z1, e1 = label[:,0], label[:,1], label[:,2], label[:,3]
|
||||
x2, y2, z2, e2 = label[:,4], label[:,5], label[:,6], label[:,7]
|
||||
gt_xy = torch.stack((torch.stack((x1, y1), axis=1), torch.stack((x2, y2), axis=1)), axis=1)
|
||||
optimizer.zero_grad()
|
||||
output = model(sample)
|
||||
pred_xy = torch.stack((output[:,0:2], output[:,2:4]), axis=1)
|
||||
loss = loss_fn(pred_xy, gt_xy)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
batchLoss += loss.item() * sample.shape[0]
|
||||
avgLoss = batchLoss / len(trainLoader.dataset) / 4 ### divide by 4 to get the average loss per photon per axis
|
||||
print(f"[Train]\t Average Loss: {avgLoss:.6f} (RMS = {np.sqrt(avgLoss):.6f})")
|
||||
TrainLosses.append(avgLoss)
|
||||
|
||||
def test(model, testLoader):
|
||||
model.eval()
|
||||
batchLoss = 0
|
||||
gt_xy, out_xy = [], []
|
||||
with torch.no_grad():
|
||||
for batch_idx, (sample, label) in enumerate(testLoader):
|
||||
sample, label = sample.cuda(), label.cuda()
|
||||
x1, y1, z1, e1 = label[:,0], label[:,1], label[:,2], label[:,3]
|
||||
x2, y2, z2, e2 = label[:,4], label[:,5], label[:,6], label[:,7]
|
||||
_gt_xy = torch.stack((torch.stack((x1, y1), axis=1), torch.stack((x2, y2), axis=1)), axis=1)
|
||||
output = model(sample)
|
||||
_pred_xy = torch.stack((output[:,0:2], output[:,2:4]), axis=1)
|
||||
loss = loss_fn(_pred_xy, _gt_xy)
|
||||
batchLoss += loss.item() * sample.shape[0]
|
||||
gt_xy.append(_gt_xy.cpu())
|
||||
out_xy.append(_pred_xy.cpu())
|
||||
gt_xy = torch.cat(gt_xy, dim=0)
|
||||
out_xy = torch.cat(out_xy, dim=0)
|
||||
avgLoss = batchLoss / len(testLoader.dataset) / 4 ### divide by 4 to get the average loss per photon per axis
|
||||
|
||||
datasetName = testLoader.dataset.datasetName
|
||||
print(f"[{datasetName}]\t Average Loss: {avgLoss:.6f} (RMS = {np.sqrt(avgLoss):.6f})")
|
||||
calculate_residuals(gt_xy, out_xy)
|
||||
if datasetName == 'Val':
|
||||
ValLosses.append(avgLoss)
|
||||
else:
|
||||
global TestLoss
|
||||
TestLoss = avgLoss
|
||||
return avgLoss
|
||||
|
||||
def calculate_residuals(gt_xy, out_xy):
|
||||
"""
|
||||
gt_xy: (N, 2, 2) — [ [x1, y1], [x2, y2] ]
|
||||
out_xy: (N, 2, 2) — [ [x1', y1'], [x2', y2'] ]
|
||||
"""
|
||||
# Option A: match (p1->g1, p2->g2)
|
||||
cost_a = (out_xy - gt_xy).pow(2).sum(dim=-1).sum(dim=-1) # (N,)
|
||||
|
||||
# Option B: match (p1->g2, p2->g1) → swap out_xy
|
||||
out_swapped = out_xy[:, [1, 0], :] # swap the two points: (N, 2, 2)
|
||||
cost_b = (out_swapped - gt_xy).pow(2).sum(dim=-1).sum(dim=-1) # (N,)
|
||||
|
||||
# Choose best assignment per sample
|
||||
swap_mask = cost_b < cost_a # (N,)
|
||||
|
||||
# Apply swapping to get optimally matched predictions
|
||||
out_matched = out_xy.clone()
|
||||
out_matched[swap_mask] = out_xy[swap_mask][:, [1, 0], :]
|
||||
|
||||
# Compute residuals
|
||||
residuals = out_matched - gt_xy # (N, 2, 2)
|
||||
|
||||
# Flatten to get all residuals (2N points)
|
||||
residuals_x = residuals[:, :, 0].flatten().cpu().numpy()
|
||||
residuals_y = residuals[:, :, 1].flatten().cpu().numpy()
|
||||
|
||||
# Print statistics
|
||||
print(f"\t\tResiduals X: mean={np.mean(residuals_x):.4f}, std={np.std(residuals_x):.4f}")
|
||||
print(f"\t\tResiduals Y: mean={np.mean(residuals_y):.4f}, std={np.std(residuals_y):.4f}")
|
||||
|
||||
sampleFolder = '/mnt/sls_det_storage/moench_data/MLXID/Samples/Simulation/Moench040'
|
||||
trainDataset = doublePhotonDataset(
|
||||
[f'{sampleFolder}/{Energy}_Moench040_150V_{i}.npz' for i in range(13)],
|
||||
sampleRatio=1.0,
|
||||
datasetName='Train',
|
||||
reuselFactor=1,
|
||||
noiseKeV = Noise,
|
||||
nSize=7,
|
||||
noiseThreshold = NoiseThreshold,
|
||||
normalize = Normalize
|
||||
)
|
||||
valDataset = doublePhotonDataset(
|
||||
[f'{sampleFolder}/{Energy}_Moench040_150V_{i}.npz' for i in range(13,14)],
|
||||
sampleRatio=1.0,
|
||||
datasetName='Val',
|
||||
reuselFactor=1,
|
||||
noiseKeV = Noise,
|
||||
nSize=7,
|
||||
noiseThreshold = NoiseThreshold,
|
||||
normalize = Normalize
|
||||
)
|
||||
testDataset = doublePhotonDataset(
|
||||
[f'{sampleFolder}/{Energy}_Moench040_150V_{i}.npz' for i in range(15,16)],
|
||||
sampleRatio=1.0,
|
||||
datasetName='Test',
|
||||
reuselFactor=1,
|
||||
noiseKeV = Noise,
|
||||
nSize=7,
|
||||
noiseThreshold = NoiseThreshold,
|
||||
normalize = Normalize
|
||||
)
|
||||
trainLoader = torch.utils.data.DataLoader(
|
||||
trainDataset,
|
||||
batch_size=1024,
|
||||
pin_memory = True,
|
||||
shuffle=True,
|
||||
num_workers=16
|
||||
)
|
||||
valLoader = torch.utils.data.DataLoader(
|
||||
valDataset,
|
||||
batch_size=4096,
|
||||
shuffle=False,
|
||||
num_workers=16
|
||||
)
|
||||
testLoader = torch.utils.data.DataLoader(
|
||||
testDataset,
|
||||
batch_size=4096,
|
||||
shuffle=False,
|
||||
num_workers=16
|
||||
)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=LearningRate, weight_decay=1e-4)
|
||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.7, patience = 5)
|
||||
if __name__ == "__main__":
|
||||
for epoch in tqdm(range(1, 301)):
|
||||
train(model, trainLoader, optimizer)
|
||||
test(model, valLoader)
|
||||
scheduler.step(ValLosses[-1])
|
||||
print(f"Learning Rate: {optimizer.param_groups[0]['lr']}")
|
||||
if epoch in [20, 50, 100, 200, 300, 500, 750, 1000]:
|
||||
modelName = f'doublePhoton{modelVersion}_{Energy}_Noise{Noise}keV_E{epoch}'
|
||||
if Normalize:
|
||||
modelName += '_normalized'
|
||||
torch.save(model.state_dict(), f'Models/{modelName}.pth')
|
||||
|
||||
test(model, testLoader)
|
||||
|
||||
def plot_loss_curve(TrainLosses, ValLosses, TestLoss, modelVersion):
|
||||
import matplotlib.pyplot as plt
|
||||
plt.figure(figsize=(8,6))
|
||||
plt.plot(TrainLosses, label='Train Loss', color='blue')
|
||||
plt.plot(ValLosses, label='Validation Loss', color='orange')
|
||||
if TestLoss > 0:
|
||||
plt.axhline(y=TestLoss, color='green', linestyle='--', label='Test Loss')
|
||||
plt.yscale('log')
|
||||
plt.xlabel('Epoch')
|
||||
plt.ylabel('MSE Loss')
|
||||
plt.legend()
|
||||
plt.grid()
|
||||
plt.savefig(f'Results/loss_curve_doublePhoton_{modelVersion}.png', dpi=300)
|
||||
|
||||
plot_loss_curve(TrainLosses, ValLosses, TestLoss, modelVersion=modelVersion)
|
||||
Reference in New Issue
Block a user