e04e783fb6
Co-authored-by: Copilot <copilot@github.com>
181 lines
7.8 KiB
Python
181 lines
7.8 KiB
Python
import sys
|
|
sys.path.append('./src')
|
|
from omegaconf import OmegaConf ### for yaml config parsing
|
|
import torch
|
|
import numpy as np
|
|
import models
|
|
from datasets import *
|
|
import torch.optim as optim
|
|
from tqdm import tqdm
|
|
from torchinfo import summary
|
|
from pathlib import Path
|
|
|
|
from models import get_model_class
|
|
from datasets import singlePhotonDataset
|
|
|
|
### random seed for reproducibility
|
|
torch.manual_seed(0)
|
|
torch.cuda.manual_seed(0)
|
|
np.random.seed(0)
|
|
torch.backends.cudnn.deterministic = True
|
|
torch.backends.cudnn.benchmark = False
|
|
|
|
conf = OmegaConf.load("Configs/train_2photon.yaml")
|
|
|
|
def prepare_output_folder(conf):
|
|
from datetime import datetime
|
|
date = datetime.now().strftime("%y%m%d") ## YYMMDD format
|
|
# find the next index for experiment name
|
|
exp_index = 0
|
|
while True:
|
|
exp_name = f'{date}_2ph_{conf.data.energy}keV_v{conf.model.version}_{exp_index:02d}'
|
|
if not Path(f'Results/{exp_name}').exists():
|
|
break
|
|
exp_index += 1
|
|
Path(f'Results/{exp_name}').mkdir(parents=True, exist_ok=True)
|
|
Path(f'Results/{exp_name}/Models').mkdir(parents=True, exist_ok=True)
|
|
Path(f'Results/{exp_name}/Plots').mkdir(parents=True, exist_ok=True)
|
|
OmegaConf.save(conf, f'Results/{exp_name}/config.yaml')
|
|
return exp_name
|
|
|
|
def get_loss_function(conf):
|
|
if conf.loss.type == "two_point_set_loss_l2":
|
|
def two_point_set_loss_l2(pred_xy, gt_xy):
|
|
def pair_cost_l2sq(p, q): # p,q: (...,2)
|
|
return ((p - q)**2).sum(dim=-1) # squared L2
|
|
p1, p2 = pred_xy[:,0], pred_xy[:,1]
|
|
g1, g2 = gt_xy[:,0], gt_xy[:,1]
|
|
c_a = pair_cost_l2sq(p1,g1) + pair_cost_l2sq(p2,g2)
|
|
c_b = pair_cost_l2sq(p1,g2) + pair_cost_l2sq(p2,g1)
|
|
return torch.minimum(c_a, c_b).mean()
|
|
return two_point_set_loss_l2
|
|
elif conf.loss.type == "two_point_set_loss_smooth_l1":
|
|
def two_point_set_loss_smooth_l1(pred_xy, gt_xy):
|
|
# Smooth L1 对异常值不那么敏感,但对于细小的亚像素误差能提供恒定的梯度
|
|
loss_fn = torch.nn.SmoothL1Loss(reduction='none')
|
|
|
|
p1, p2 = pred_xy[:,0], pred_xy[:,1]
|
|
g1, g2 = gt_xy[:,0], gt_xy[:,1]
|
|
|
|
# 计算两种排列组合的损失
|
|
c_a = loss_fn(p1, g1).sum(dim=-1) + loss_fn(p2, g2).sum(dim=-1)
|
|
c_b = loss_fn(p1, g2).sum(dim=-1) + loss_fn(p2, g1).sum(dim=-1)
|
|
|
|
return torch.minimum(c_a, c_b).mean()
|
|
|
|
return two_point_set_loss_smooth_l1
|
|
|
|
def train(model, trainLoader, optimizer, loss_fn):
|
|
model.train()
|
|
batchLoss = 0
|
|
for batch_idx, (sample, label) in enumerate(trainLoader):
|
|
sample, label = sample.cuda(), label.cuda()
|
|
x1, y1, z1, e1 = label[:,0], label[:,1], label[:,2], label[:,3]
|
|
x2, y2, z2, e2 = label[:,4], label[:,5], label[:,6], label[:,7]
|
|
gt_xy = torch.stack((torch.stack((x1, y1), axis=1), torch.stack((x2, y2), axis=1)), axis=1)
|
|
optimizer.zero_grad()
|
|
output = model(sample)
|
|
pred_xy = torch.stack((output[:,0:2], output[:,2:4]), axis=1)
|
|
loss = loss_fn(pred_xy, gt_xy)
|
|
loss.backward()
|
|
optimizer.step()
|
|
batchLoss += loss.item() * sample.shape[0]
|
|
avgLoss = batchLoss / len(trainLoader.dataset) / 4 ### divide by 4 to get the average loss per photon per axis
|
|
print(f"[Train]\t Average Loss: {avgLoss:.6f} (RMS = {np.sqrt(avgLoss):.6f})")
|
|
return avgLoss
|
|
|
|
def evaluate(model, valLoader, loss_fn):
|
|
model.eval()
|
|
batchLoss = 0
|
|
with torch.no_grad():
|
|
for batch_idx, (sample, label) in enumerate(valLoader):
|
|
sample, label = sample.cuda(), label.cuda()
|
|
x1, y1, z1, e1 = label[:,0], label[:,1], label[:,2], label[:,3]
|
|
x2, y2, z2, e2 = label[:,4], label[:,5], label[:,6], label[:,7]
|
|
gt_xy = torch.stack((torch.stack((x1, y1), axis=1), torch.stack((x2, y2), axis=1)), axis=1)
|
|
output = model(sample)
|
|
pred_xy = torch.stack((output[:,0:2], output[:,2:4]), axis=1)
|
|
loss = loss_fn(pred_xy, gt_xy)
|
|
batchLoss += loss.item() * sample.shape[0]
|
|
avgLoss = batchLoss / len(valLoader.dataset) / 4 ### divide by 4 to get the average loss per photon per axis
|
|
print(f"[Val]\t Average Loss: {avgLoss:.6f} (RMS = {np.sqrt(avgLoss):.6f})")
|
|
return avgLoss
|
|
|
|
def get_dataloaders(conf):
|
|
"""construct all dataloaders"""
|
|
datasets = {}
|
|
loaders = {}
|
|
splits = ['Train', 'Val', 'Test']
|
|
keys = ['train_files', 'val_files', 'test_files']
|
|
batch_keys = ['batch_size_train', 'batch_size_val', 'batch_size_test']
|
|
file_range_keys = ['train_file_range', 'val_file_range', 'test_file_range']
|
|
|
|
for split, key, batch_key, file_range_key in zip(splits, keys, batch_keys, file_range_keys):
|
|
files = [f"{conf.data.sample_folder}/{conf.data.energy}keV_Moench040_150V_{i}.npz" for i in range(conf.data[file_range_key][0], conf.data[file_range_key][1] + 1)]
|
|
|
|
datasets[split] = doublePhotonDataset(
|
|
files,
|
|
sampleRatio = 1.0,
|
|
datasetName = split.capitalize(),
|
|
noiseKeV = conf.data.noise_keV,
|
|
nSize = conf.data.n_size,
|
|
noiseThreshold = conf.data.noise_threshold * conf.data.noise_keV,
|
|
normalize = conf.data.normalize
|
|
)
|
|
|
|
loaders[split] = torch.utils.data.DataLoader(
|
|
datasets[split],
|
|
batch_size=conf.data[batch_key],
|
|
shuffle=(split=='Train'),
|
|
num_workers=conf.data.num_workers,
|
|
pin_memory=True
|
|
)
|
|
return loaders['Train'], loaders['Val'], loaders['Test']
|
|
|
|
def plot_loss_curves(train_losses, val_losses, test_loss, exp_name, conf):
|
|
import matplotlib.pyplot as plt
|
|
plt.figure(figsize=(8,6))
|
|
plt.plot(train_losses, label='Train Loss')
|
|
plt.plot(val_losses, label='Val Loss')
|
|
if test_loss > 0:
|
|
plt.axhline(y=test_loss, color='green', linestyle='--', label='Test Loss')
|
|
plt.xlabel('Epoch')
|
|
plt.ylabel('MSE Loss')
|
|
plt.yscale('log')
|
|
plt.legend()
|
|
plt.grid()
|
|
plotName = f'loss_curve_doublePhoton_{conf.model.version}.png'
|
|
plt.savefig(f'Results/{exp_name}/Plots/{plotName}')
|
|
|
|
def get_model_name(conf):
|
|
modelName = f'doublePhoton{conf.model.version}_{conf.data.energy}keV_Noise{conf.data.noise_keV}keV'
|
|
if conf.data.normalize:
|
|
modelName += '_normalized'
|
|
return modelName
|
|
|
|
if __name__ == "__main__":
|
|
exp_name = prepare_output_folder(conf)
|
|
model = models.get_double_photon_model_class(conf.model.version)().cuda()
|
|
# summary(model, input_size=(128, 1, conf.data.n_size, conf.data.n_size))
|
|
|
|
loss_fn = get_loss_function(conf)
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=conf.training.learning_rate, weight_decay=conf.training.weight_decay)
|
|
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=conf.training.scheduler_factor, patience=conf.training.scheduler_patience)
|
|
|
|
trainLoader, valLoader, testLoader = get_dataloaders(conf)
|
|
|
|
TrainLosses, ValLosses = [], []
|
|
for epoch in tqdm(range(1, conf.training.epochs + 1)):
|
|
train_loss = train(model, trainLoader, optimizer, loss_fn)
|
|
val_loss = evaluate(model, valLoader, loss_fn)
|
|
TrainLosses.append(train_loss)
|
|
ValLosses.append(val_loss)
|
|
scheduler.step(val_loss)
|
|
print(f"Learning Rate: {optimizer.param_groups[0]['lr']:.2e}")
|
|
if epoch in conf.training.checkpoint_epochs or epoch == conf.training.epochs:
|
|
modelName = get_model_name(conf)
|
|
torch.save(model.state_dict(), f'Results/{exp_name}/Models/{modelName}_E{epoch}.pth')
|
|
print(f"Saved model checkpoint: {modelName}_E{epoch}.pth")
|
|
plot_loss_curves(TrainLosses, ValLosses, test_loss=-1, exp_name=exp_name, conf=conf)
|
|
test_loss = evaluate(model, testLoader, loss_fn)
|
|
plot_loss_curves(TrainLosses, ValLosses, test_loss=test_loss, exp_name=exp_name, conf=conf) |