Files
DeepLearning/Train_1Photon.py

188 lines
7.4 KiB
Python

import sys
sys.path.append('./src')
from omegaconf import OmegaConf ### for yaml config parsing
import torch
import numpy as np
import models
from datasets import *
import torch.optim as optim
from tqdm import tqdm
from torchinfo import summary
from pathlib import Path
from models import get_model_class
from datasets import singlePhotonDataset
### random seed for reproducibility
torch.manual_seed(0)
torch.cuda.manual_seed(0)
np.random.seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
conf = OmegaConf.load("Configs/train_1photon.yaml")
def prepare_output_folder(conf):
from datetime import datetime
date = datetime.now().strftime("%y%m%d") ## YYMMDD format
# find the next index for experiment name
exp_index = 0
while True:
exp_name = f'{date}_1ph_{conf.data.energy}keV_v{conf.model.version}_{exp_index:02d}'
if not Path(f'Results/{exp_name}').exists():
break
exp_index += 1
Path(f'Results/{exp_name}').mkdir(parents=True, exist_ok=True)
Path(f'Results/{exp_name}/Models').mkdir(parents=True, exist_ok=True)
Path(f'Results/{exp_name}/Plots').mkdir(parents=True, exist_ok=True)
OmegaConf.save(conf, f'Results/{exp_name}/config.yaml')
return exp_name
def get_loss_function(conf):
if conf.loss.type == "weightedL1":
alpha = conf.loss.alpha
beta = conf.loss.beta
def weighted_loss(pred, target):
# weighted L1 loss for x,y position
pred = pred[:, :2]
target = target[:, :2]
direction_weight = 1.0 + alpha * torch.abs(target) # (B, 2)
r = torch.norm(target, dim=1, keepdim=True)
radial_weight = 1.0 + beta * r # (B, 1) →
weights = radial_weight * direction_weight
loss = weights * torch.abs(pred - target)
return loss.mean()
return weighted_loss
def train(model, trainLoader, optimizer, loss_fn):
model.train()
batchLoss = 0
total_samples = 0
rms_x, rms_y = 0, 0
for batch_idx, (sample, label) in enumerate(trainLoader):
sample, label = sample.cuda(), label.cuda()
x, y, z, e = label[:,0], label[:,1], label[:,2], label[:,3]
optimizer.zero_grad()
output = model(sample)
loss = loss_fn(output, torch.stack((x, y, z), axis=1))
loss.backward()
optimizer.step()
batchLoss += loss.item() * sample.shape[0]
total_samples += sample.shape[0]
rms_x += torch.sum((output[:,0] - x)**2).item()
rms_y += torch.sum((output[:,1] - y)**2).item()
avgLoss = batchLoss / total_samples
rms_x = np.sqrt(rms_x / total_samples)
rms_y = np.sqrt(rms_y / total_samples)
datasetName = trainLoader.dataset.datasetName
print(f"[{datasetName}]\t Average Loss: {avgLoss:.6f} (sigma = {np.sqrt(avgLoss):.6f}) \t RMS X: {rms_x:.6f} \t RMS Y: {rms_y:.6f}")
return avgLoss
def evaluate(model, testLoader, loss_fn):
model.eval()
batchLoss = 0
rms_x, rms_y = 0, 0
with torch.no_grad():
for batch_idx, (sample, label) in enumerate(testLoader):
sample, label = sample.cuda(), label.cuda()
x, y, z, e = label[:,0], label[:,1], label[:,2], label[:,3]
output = model(sample)
loss = loss_fn(output, torch.stack((x, y, z), axis=1))
batchLoss += loss.item() * sample.shape[0]
rms_x += torch.sum((output[:,0] - x)**2).item()
rms_y += torch.sum((output[:,1] - y)**2).item()
avgLoss = batchLoss / len(testLoader.dataset)
rms_x = np.sqrt(rms_x / len(testLoader.dataset))
rms_y = np.sqrt(rms_y / len(testLoader.dataset))
datasetName = testLoader.dataset.datasetName
print(f"[{datasetName}]\t Average Loss: {avgLoss:.6f} (sigma = {np.sqrt(avgLoss):.6f}) \t RMS X: {rms_x:.6f} \t RMS Y: {rms_y:.6f}")
return avgLoss
def get_dataloaders(conf):
"""construct all dataLoaders"""
datasets = {}
loaders = {}
splits = ['train', 'val', 'test']
keys = ['train_files', 'val_files', 'test_files']
batch_keys = ['batch_size_train', 'batch_size_val', 'batch_size_test']
file_range_keys = ['train_file_range', 'val_file_range', 'test_file_range']
for split, f_key, b_key, fr_key in zip(splits, keys, batch_keys, file_range_keys):
file_paths = [f"{conf.data.sample_folder}/{conf.data.energy}keV_Moench040_150V_{i}.npz" for i in range(conf.data[fr_key][0], conf.data[fr_key][1] + 1)]
dataset = singlePhotonDataset(
file_paths,
sampleRatio=1.0,
datasetName=split.capitalize(),
noiseKeV=conf.data.noise_keV,
numberOfAugOps=conf.data.num_aug_ops,
normalize=conf.data.normalize,
noiseThreshold=conf.data.noise_threshold * conf.data.noise_keV
)
datasets[split] = dataset
loaders[split] = torch.utils.data.DataLoader(
dataset,
batch_size=conf.data[b_key],
shuffle=(split == 'train'),
num_workers=conf.data.num_workers,
pin_memory=True
)
return loaders['train'], loaders['val'], loaders['test']
def plot_loss_curve(train_losses, val_losses, test_loss, exp_name, conf):
import matplotlib.pyplot as plt
plt.figure(figsize=(8,6))
plt.plot(train_losses, label='Train Loss', color='blue')
plt.plot(val_losses, label='Validation Loss', color='orange')
if test_loss > 0:
plt.axhline(y=test_loss, color='green', linestyle='--', label='Test Loss')
plt.yscale('log')
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.legend()
plt.grid()
plotName = f'loss_curve_singlePhoton_{conf.model.version}'
if conf.data.normalize:
plotName += '_normalized'
plt.savefig(f'Results/{exp_name}/Plots/{plotName}.png')
def get_model_name(epoch, conf):
modelName = f'singlePhoton{conf.model.version}_{conf.data.energy}keV_Noise{conf.data.noise_keV}keV_E{epoch}_aug{conf.data.num_aug_ops}'
if conf.data.normalize:
modelName += '_normalized'
return modelName
if __name__ == "__main__":
exp_name = prepare_output_folder(conf)
model = models.get_model_class(conf.model.version)().cuda()
# summary(model, input_size=(128, 1, 3, 3))
loss_fn = get_loss_function(conf)
optimizer = torch.optim.Adam(model.parameters(), lr=conf.training.learning_rate)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=conf.training.scheduler_factor, patience = conf.training.scheduler_patience)
trainLoader, valLoader, testLoader = get_dataloaders(conf)
train_losses = []
val_losses = []
for epoch in tqdm(range(1, conf.training.epochs + 1)):
t_loss = train(model, trainLoader, optimizer, loss_fn)
train_losses.append(t_loss)
v_loss = evaluate(model, valLoader, loss_fn)
val_losses.append(v_loss)
scheduler.step(val_losses[-1])
print(f"Learning Rate: {optimizer.param_groups[0]['lr']}")
if epoch in conf.training.checkpoint_epochs or epoch == conf.training.epochs:
modelName = get_model_name(epoch, conf)
torch.save(model.state_dict(), f'Results/{exp_name}/Models/{modelName}.pth')
print(f"Saved model checkpoint: {modelName}.pth")
plot_loss_curve(train_losses, val_losses, test_loss=-1, exp_name=exp_name, conf=conf)
test_loss = evaluate(model, testLoader, loss_fn)
plot_loss_curve(train_losses, val_losses, test_loss=test_loss, exp_name=exp_name, conf=conf)