Files
DeepLearning/Train_1Photon.py
T
2026-05-13 14:30:14 +02:00

238 lines
9.8 KiB
Python

import sys
sys.path.append('./src')
from omegaconf import OmegaConf ### for yaml config parsing
import torch
import numpy as np
import models
from datasets import *
import torch.optim as optim
from tqdm import tqdm
from torchinfo import summary
from pathlib import Path
from models import get_model_class
from datasets import singlePhotonDataset
### random seed for reproducibility
torch.manual_seed(0)
torch.cuda.manual_seed(0)
np.random.seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
conf = OmegaConf.load("Configs/train_1photon.yaml")
def prepare_output_folder(conf):
from datetime import datetime
date = datetime.now().strftime("%y%m%d") ## YYMMDD format
# find the next index for experiment name
exp_index = 0
while True:
exp_name = f'{date}_1ph_{conf.data.energy}keV_v{conf.model.version}_{exp_index:02d}'
if not Path(f'Results/{exp_name}').exists():
break
exp_index += 1
Path(f'Results/{exp_name}').mkdir(parents=True, exist_ok=True)
Path(f'Results/{exp_name}/Models').mkdir(parents=True, exist_ok=True)
Path(f'Results/{exp_name}/Plots').mkdir(parents=True, exist_ok=True)
OmegaConf.save(conf, f'Results/{exp_name}/config.yaml')
return exp_name
def get_loss_function(conf):
if conf.loss.type == "weightedL1":
alpha = conf.loss.alpha
beta = conf.loss.beta
def weighted_loss(pred, target):
# weighted L1 loss for x,y position
pred = pred[:, :2]
target = target[:, :2]
direction_weight = 1.0 + alpha * torch.abs(target) # (B, 2)
r = torch.norm(target, dim=1, keepdim=True)
radial_weight = 1.0 + beta * r # (B, 1) →
weights = radial_weight * direction_weight
loss = weights * torch.abs(pred - target)
return loss.mean()
return weighted_loss
if conf.loss.type == "weightedSumL1":
alpha = conf.loss.alpha_weighted_sum
beta = conf.loss.beta_weighted_sum
def weighted_loss(pred, target):
pred = pred[:, :2]
target = target[:, :2]
r = torch.norm(target, dim=1, keepdim=True) # (B, 1)
weights = 1.0 + beta * r + alpha * torch.abs(target)
loss = weights * torch.abs(pred - target)
return loss.mean()
return weighted_loss
if conf.loss.type == "huber_loss":
huber_loss_fn = torch.nn.SmoothL1Loss(reduction='mean', beta = conf.loss.huber_beta)
def huber_loss(pred, target):
pred = pred[:, :2]
target = target[:, :2]
return huber_loss_fn(pred, target)
return huber_loss
def train(model, trainLoader, optimizer, loss_fn):
model.train()
batchLoss = 0
total_samples = 0
rms_x, rms_y = 0, 0
for batch_idx, (sample, label) in enumerate(trainLoader):
sample, label = sample.cuda(), label.cuda()
x, y, z, e = label[:,0], label[:,1], label[:,2], label[:,3]
optimizer.zero_grad()
output = model(sample)
loss = loss_fn(output, torch.stack((x, y, z), axis=1))
loss.backward()
optimizer.step()
batchLoss += loss.item() * sample.shape[0]
total_samples += sample.shape[0]
rms_x += torch.sum((output[:,0] - x)**2).item()
rms_y += torch.sum((output[:,1] - y)**2).item()
avgLoss = batchLoss / total_samples
rms_x = np.sqrt(rms_x / total_samples)
rms_y = np.sqrt(rms_y / total_samples)
datasetName = trainLoader.dataset.datasetName
print(f"[{datasetName}]\t Average Loss: {avgLoss:.6f} \t RMS X: {rms_x:.6f} \t RMS Y: {rms_y:.6f}")
return avgLoss, rms_x, rms_y
def evaluate(model, testLoader, loss_fn):
model.eval()
batchLoss = 0
rms_x, rms_y = 0, 0
with torch.no_grad():
for batch_idx, (sample, label) in enumerate(testLoader):
sample, label = sample.cuda(), label.cuda()
x, y, z, e = label[:,0], label[:,1], label[:,2], label[:,3]
output = model(sample)
loss = loss_fn(output, torch.stack((x, y, z), axis=1))
batchLoss += loss.item() * sample.shape[0]
rms_x += torch.sum((output[:,0] - x)**2).item()
rms_y += torch.sum((output[:,1] - y)**2).item()
avgLoss = batchLoss / len(testLoader.dataset)
rms_x = np.sqrt(rms_x / len(testLoader.dataset))
rms_y = np.sqrt(rms_y / len(testLoader.dataset))
datasetName = testLoader.dataset.datasetName
print(f"[{datasetName}]\t Average Loss: {avgLoss:.6f} \t RMS X: {rms_x:.6f} \t RMS Y: {rms_y:.6f}")
return avgLoss, rms_x, rms_y
def get_dataloaders(conf):
"""construct all dataLoaders"""
datasets = {}
loaders = {}
splits = ['train', 'val', 'test']
keys = ['train_files', 'val_files', 'test_files']
batch_keys = ['batch_size_train', 'batch_size_val', 'batch_size_test']
file_range_keys = ['train_file_range', 'val_file_range', 'test_file_range']
for split, f_key, b_key, fr_key in zip(splits, keys, batch_keys, file_range_keys):
file_paths = [f"{conf.data.sample_folder}/{conf.data.energy}keV_Moench040_150V_{i}.npz" for i in range(conf.data[fr_key][0], conf.data[fr_key][1] + 1)]
dataset = singlePhotonDataset(
file_paths,
sampleRatio=1.0,
datasetName=split.capitalize(),
noiseKeV=conf.data.noise_keV,
numberOfAugOps=conf.data.num_aug_ops,
normalize=conf.data.normalize,
noiseThreshold=conf.data.noise_threshold * conf.data.noise_keV
)
datasets[split] = dataset
loaders[split] = torch.utils.data.DataLoader(
dataset,
batch_size=conf.data[b_key],
shuffle=(split == 'train'),
num_workers=conf.data.num_workers,
pin_memory=True
)
return loaders['train'], loaders['val'], loaders['test']
def plot_loss_curve(train_losses, val_losses, test_loss, exp_name, conf):
import matplotlib.pyplot as plt
plt.figure(figsize=(8,6))
plt.plot(train_losses, label='Train Loss', color='blue')
plt.plot(val_losses, label='Validation Loss', color='orange')
if test_loss > 0:
plt.axhline(y=test_loss, color='green', linestyle='--', label='Test Loss')
plt.yscale('log')
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.legend()
plt.grid()
### print test loss in the plot
plt.text(0.5, 0.5, f'Test Loss: {test_loss:.6f}', fontsize=12, color='green', ha='center', va='center', transform=plt.gca().transAxes)
plotName = f'loss_curve_singlePhoton_{conf.model.version}'
if conf.data.normalize:
plotName += '_normalized'
plt.savefig(f'Results/{exp_name}/Plots/{plotName}.png')
def plot_rms_curve(train_rms_x, train_rms_y, val_rms_x, val_rms_y, test_rms_x, test_rms_y, exp_name, conf):
import matplotlib.pyplot as plt
plt.figure(figsize=(8,6))
plt.plot(train_rms_x, label='Train RMS X', color='blue')
plt.plot(train_rms_y, label='Train RMS Y', color='cyan')
plt.plot(val_rms_x, label='Val RMS X', color='orange')
plt.plot(val_rms_y, label='Val RMS Y', color='magenta')
if test_rms_x > 0 and test_rms_y > 0:
plt.axhline(y=test_rms_x, color='green', linestyle='--', label='Test RMS X')
plt.axhline(y=test_rms_y, color='lime', linestyle='--', label='Test RMS Y')
plt.xlabel('Epoch')
plt.ylabel('RMS Error [pixels]')
plt.legend()
plt.grid()
plotName = f'rms_curve_singlePhoton_{conf.model.version}'
if conf.data.normalize:
plotName += '_normalized'
plt.savefig(f'Results/{exp_name}/Plots/{plotName}.png')
def get_model_name(epoch, conf):
modelName = f'singlePhoton{conf.model.version}_{conf.data.energy}keV_Noise{conf.data.noise_keV}keV_E{epoch}_aug{conf.data.num_aug_ops}'
if conf.data.normalize:
modelName += '_normalized'
return modelName
if __name__ == "__main__":
exp_name = prepare_output_folder(conf)
model = models.get_model_class(conf.model.version)().cuda()
# summary(model, input_size=(128, 3, 3, 3))
loss_fn = get_loss_function(conf)
optimizer = torch.optim.Adam(model.parameters(), lr=conf.training.learning_rate)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=conf.training.scheduler_factor, patience = conf.training.scheduler_patience)
trainLoader, valLoader, testLoader = get_dataloaders(conf)
train_losses = []
val_losses = []
train_rms_xs, train_rms_ys = [], []
val_rms_xs, val_rms_ys = [], []
for epoch in tqdm(range(1, conf.training.epochs + 1)):
t_loss, train_rms_x, train_rms_y = train(model, trainLoader, optimizer, loss_fn)
train_losses.append(t_loss)
train_rms_xs.append(train_rms_x)
train_rms_ys.append(train_rms_y)
v_loss, val_rms_x, val_rms_y = evaluate(model, valLoader, loss_fn)
val_losses.append(v_loss)
val_rms_xs.append(val_rms_x)
val_rms_ys.append(val_rms_y)
scheduler.step(val_losses[-1])
print(f"Learning Rate: {optimizer.param_groups[0]['lr']}")
if epoch in conf.training.checkpoint_epochs or epoch == conf.training.epochs:
modelName = get_model_name(epoch, conf)
torch.save(model.state_dict(), f'Results/{exp_name}/Models/{modelName}.pth')
print(f"Saved model checkpoint: {modelName}.pth")
plot_loss_curve(train_losses, val_losses, test_loss=-1, exp_name=exp_name, conf=conf)
plot_rms_curve(train_rms_xs, train_rms_ys, val_rms_xs, val_rms_ys, test_rms_x = -1, test_rms_y = -1, exp_name=exp_name, conf=conf)
test_loss, test_rms_x, test_rms_y = evaluate(model, testLoader, loss_fn)
plot_loss_curve(train_losses, val_losses, test_loss=test_loss, exp_name=exp_name, conf=conf)
plot_rms_curve(train_rms_xs, train_rms_ys, val_rms_xs, val_rms_ys, test_rms_x, test_rms_y, exp_name=exp_name, conf=conf)