Reframe the training codes, add yaml config

This commit is contained in:
2026-03-20 11:06:53 +01:00
parent 96e78f8ab5
commit 90cb2f5a47
2 changed files with 165 additions and 119 deletions
+40
View File
@@ -0,0 +1,40 @@
# configs/train_1photon.yaml
experiment:
name: "1photon_15keV"
data:
sample_folder: "/mnt/sls_det_storage/moench_data/MLXID/Samples/Simulation/Moench040"
energy: 15 ### in keV
noise_keV: 0.13
noise_threshold: 0.0 ### set values below (noise * noise_threshold) to zero
num_aug_ops: 1
normalize: true
batch_size_train: 4096
batch_size_val: 1024
batch_size_test: 1024
num_workers: 32
train_file_range: [0, 12]
val_file_range: [13, 14]
test_file_range: [15, 15]
model:
version: "251022"
training:
epochs: 15
learning_rate: 1.0e-3
loss_alpha: 7.0
loss_beta: 3.0
scheduler_factor: 0.7
scheduler_patience: 3
checkpoint_epochs: [10, 30, 50, 100, 150]
loss:
type: "weightedL1"
alpha: 7.0
beta: 3.0
logging:
save_dir: "Results"
model_dir: "Models"
+125 -119
View File
@@ -1,5 +1,6 @@
import sys
sys.path.append('./src')
from omegaconf import OmegaConf ### for yaml config parsing
import torch
import numpy as np
import models
@@ -7,6 +8,10 @@ from datasets import *
import torch.optim as optim
from tqdm import tqdm
from torchinfo import summary
from pathlib import Path
from models import get_model_class
from datasets import singlePhotonDataset
### random seed for reproducibility
torch.manual_seed(0)
@@ -15,168 +20,169 @@ np.random.seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
modelVersion = '251022' # '250909' or '251020'
Energy = '15keV'
TrainLosses, ValLosses = [], []
LearningRates = []
TestLoss = -1
model = models.get_model_class(modelVersion)().cuda()
# summary(model, input_size=(128, 1, 3, 3))
LearningRate = 1e-3
Noise = 0.23 # in keV
NoiseThreshold = 0 * Noise # in keV, set values below this threshold to zero
numberOfAugOps = 1 # 1 (no augmentation) or (1,8] (with augmentation)
flag_normalize = False
conf = OmegaConf.load("Configs/train_1photon.yaml")
TrainLosses, TestLosses = [], []
def weighted_loss(pred, target, alpha=7.0):
# weighted L1 loss for x,y position
pred = pred[:, :2]
target = target[:, :2]
def prepare_output_folder(conf):
from datetime import datetime
date = datetime.now().strftime("%y%m%d") ## YYMMDD format
# find the next index for experiment name
exp_index = 0
while True:
exp_name = f'{date}_1ph_{conf.data.energy}keV_v{conf.model.version}_{exp_index:02d}'
if not Path(f'Results/{exp_name}').exists():
break
exp_index += 1
Path(f'Results/{exp_name}').mkdir(parents=True, exist_ok=True)
Path(f'Results/{exp_name}/Models').mkdir(parents=True, exist_ok=True)
Path(f'Results/{exp_name}/Plots').mkdir(parents=True, exist_ok=True)
OmegaConf.save(conf, f'Results/{exp_name}/config.yaml')
return exp_name
# weights = 1.0 + alpha * torch.abs(target)
direction_weight = 1.0 + alpha * torch.abs(target) # (B, 2)
beta = 3.
r = torch.norm(target, dim=1, keepdim=True)
radial_weight = 1.0 + beta * r # (B, 1) →
weights = radial_weight * direction_weight
def get_loss_function(conf):
if conf.loss.type == "weightedL1":
alpha = conf.loss.alpha
beta = conf.loss.beta
def weighted_loss(pred, target):
# weighted L1 loss for x,y position
pred = pred[:, :2]
target = target[:, :2]
loss = weights * torch.abs(pred - target)
return loss.mean()
LossFunction = weighted_loss
direction_weight = 1.0 + alpha * torch.abs(target) # (B, 2)
r = torch.norm(target, dim=1, keepdim=True)
radial_weight = 1.0 + beta * r # (B, 1) →
weights = radial_weight * direction_weight
def train(model, trainLoader, optimizer):
loss = weights * torch.abs(pred - target)
return loss.mean()
return weighted_loss
def train(model, trainLoader, optimizer, loss_fn):
model.train()
batchLoss = 0
total_samples = 0
rms_x, rms_y = 0, 0
for batch_idx, (sample, label) in enumerate(trainLoader):
sample, label = sample.cuda(), label.cuda()
x, y, z, e = label[:,0], label[:,1], label[:,2], label[:,3]
optimizer.zero_grad()
output = model(sample)
loss = LossFunction(output, torch.stack((x, y, z), axis=1))
loss = loss_fn(output, torch.stack((x, y, z), axis=1))
loss.backward()
optimizer.step()
batchLoss += loss.item() * sample.shape[0]
total_samples += sample.shape[0]
rms_x += torch.sum((output[:,0] - x)**2).item()
rms_y += torch.sum((output[:,1] - y)**2).item()
avgLoss = batchLoss / len(trainLoader.dataset)
rms_x = np.sqrt(rms_x / len(trainLoader.dataset))
rms_y = np.sqrt(rms_y / len(trainLoader.dataset))
avgLoss = batchLoss / total_samples
rms_x = np.sqrt(rms_x / total_samples)
rms_y = np.sqrt(rms_y / total_samples)
datasetName = trainLoader.dataset.datasetName
print(f"[{datasetName}]\t Average Loss: {avgLoss:.6f} (sigma = {np.sqrt(avgLoss):.6f}) \t RMS X: {rms_x:.6f} \t RMS Y: {rms_y:.6f}")
TrainLosses.append(avgLoss)
return avgLoss
def test(model, testLoader):
def evaluate(model, testLoader, loss_fn):
model.eval()
batchLoss = 0
rms_x, rms_y = 0, 0
with torch.no_grad():
for batch_idx, (sample, label) in enumerate(testLoader):
sample, label = sample.cuda(), label.cuda()
x, y, z, e = label[:,0], label[:,1], label[:,2], label[:,3]
output = model(sample)
loss = LossFunction(output, torch.stack((x, y, z), axis=1))
loss = loss_fn(output, torch.stack((x, y, z), axis=1))
batchLoss += loss.item() * sample.shape[0]
rms_x += torch.sum((output[:,0] - x)**2).item()
rms_y += torch.sum((output[:,1] - y)**2).item()
avgLoss = batchLoss / len(testLoader.dataset)
rms_x = np.sqrt(rms_x / len(testLoader.dataset))
rms_y = np.sqrt(rms_y / len(testLoader.dataset))
datasetName = testLoader.dataset.datasetName
print(f"[{datasetName}]\t Average Loss: {avgLoss:.6f} (sigma = {np.sqrt(avgLoss):.6f})")
if datasetName == 'Val':
ValLosses.append(avgLoss)
else:
global TestLoss
TestLoss = avgLoss
print(f"[{datasetName}]\t Average Loss: {avgLoss:.6f} (sigma = {np.sqrt(avgLoss):.6f}) \t RMS X: {rms_x:.6f} \t RMS Y: {rms_y:.6f}")
return avgLoss
sampleFolder = '/mnt/sls_det_storage/moench_data/MLXID/Samples/Simulation/Moench040'
trainDataset = singlePhotonDataset(
[f'{sampleFolder}/{Energy}_Moench040_150V_{i}.npz' for i in range(13)],
sampleRatio=1.0,
datasetName='Train',
noiseKeV = Noise,
numberOfAugOps=numberOfAugOps,
normalize=flag_normalize,
noiseThreshold=NoiseThreshold
)
valDataset = singlePhotonDataset(
[f'{sampleFolder}/{Energy}_Moench040_150V_{i}.npz' for i in range(13,14)],
sampleRatio=1.0,
datasetName='Val',
noiseKeV = Noise,
numberOfAugOps=numberOfAugOps,
normalize=flag_normalize,
noiseThreshold=NoiseThreshold
)
testDataset = singlePhotonDataset(
[f'{sampleFolder}/{Energy}_Moench040_150V_{i}.npz' for i in range(15,16)],
sampleRatio=1.0,
datasetName='Test',
noiseKeV = Noise,
numberOfAugOps=numberOfAugOps,
normalize=flag_normalize,
noiseThreshold=NoiseThreshold
)
trainLoader = torch.utils.data.DataLoader(
trainDataset,
batch_size=4096,
shuffle=True,
num_workers=32,
pin_memory=True,
)
valLoader = torch.utils.data.DataLoader(
valDataset,
batch_size=1024,
shuffle=False,
num_workers=32,
pin_memory=True,
)
testLoader = torch.utils.data.DataLoader(
testDataset,
batch_size=1024,
shuffle=4096,
num_workers=32,
pin_memory=True,
)
def get_dataloaders(conf):
"""construct all dataLoaders"""
datasets = {}
loaders = {}
splits = ['train', 'val', 'test']
keys = ['train_files', 'val_files', 'test_files']
batch_keys = ['batch_size_train', 'batch_size_val', 'batch_size_test']
file_range_keys = ['train_file_range', 'val_file_range', 'test_file_range']
for split, f_key, b_key, fr_key in zip(splits, keys, batch_keys, file_range_keys):
file_paths = [f"{conf.data.sample_folder}/{conf.data.energy}keV_Moench040_150V_{i}.npz" for i in range(conf.data[fr_key][0], conf.data[fr_key][1] + 1)]
dataset = singlePhotonDataset(
file_paths,
sampleRatio=1.0,
datasetName=split.capitalize(),
noiseKeV=conf.data.noise_keV,
numberOfAugOps=conf.data.num_aug_ops,
normalize=conf.data.normalize,
noiseThreshold=conf.data.noise_threshold * conf.data.noise_keV
)
datasets[split] = dataset
loaders[split] = torch.utils.data.DataLoader(
dataset,
batch_size=conf.data[b_key],
shuffle=(split == 'train'),
num_workers=conf.data.num_workers,
pin_memory=True
)
return loaders['train'], loaders['val'], loaders['test']
optimizer = torch.optim.Adam(model.parameters(), lr=LearningRate)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.7, patience = 3)
def plot_loss_curve(TrainLosses, ValLosses, modelVersion, TestLoss=0):
def plot_loss_curve(train_losses, val_losses, test_loss, exp_name, conf):
import matplotlib.pyplot as plt
plt.figure(figsize=(8,6))
plt.plot(TrainLosses, label='Train Loss', color='blue')
plt.plot(ValLosses, label='Validation Loss', color='orange')
if TestLoss > 0:
plt.axhline(y=TestLoss, color='green', linestyle='--', label='Test Loss')
plt.plot(train_losses, label='Train Loss', color='blue')
plt.plot(val_losses, label='Validation Loss', color='orange')
if test_loss > 0:
plt.axhline(y=test_loss, color='green', linestyle='--', label='Test Loss')
plt.yscale('log')
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.legend()
plt.grid()
plotName = f'loss_curve_singlePhoton_{modelVersion}'
if flag_normalize:
plotName = f'loss_curve_singlePhoton_{conf.model.version}'
if conf.data.normalize:
plotName += '_normalized'
plt.savefig(f'Results/{plotName}.png')
plt.savefig(f'Results/{exp_name}/Plots/{plotName}.png')
def get_model_name(epoch, conf):
modelName = f'singlePhoton{conf.model.version}_{conf.data.energy}keV_Noise{conf.data.noise_keV}keV_E{epoch}_aug{conf.data.num_aug_ops}'
if conf.data.normalize:
modelName += '_normalized'
return modelName
if __name__ == "__main__":
for epoch in tqdm(range(1, 151)):
train(model, trainLoader, optimizer)
test(model, valLoader)
scheduler.step(ValLosses[-1])
print(f"Learning Rate: {optimizer.param_groups[0]['lr']}")
if epoch in [20, 30, 50, 100, 150]:
modelName = f'singlePhoton{modelVersion}_{Energy}_Noise{Noise}keV_E{epoch}_aug{numberOfAugOps}'
if flag_normalize == True:
modelName += '_normalized'
torch.save(model.state_dict(), f'Models/{modelName}.pth')
print(f"Saved model checkpoint: {modelName}.pth")
plot_loss_curve(TrainLosses, ValLosses, modelVersion=modelVersion)
test(model, testLoader)
modelName = f'singlePhoton{modelVersion}_{Energy}_Noise{Noise}keV_E{epoch}_aug{numberOfAugOps}'
if flag_normalize == True:
modelName += '_normalized'
torch.save(model.state_dict(), f'Models/{modelName}.pth')
print(f"Saved final model checkpoint: {modelName}.pth")
plot_loss_curve(TrainLosses, ValLosses, modelVersion=modelVersion, TestLoss=TestLoss)
exp_name = prepare_output_folder(conf)
model = models.get_model_class(conf.model.version)().cuda()
# summary(model, input_size=(128, 1, 3, 3))
loss_fn = get_loss_function(conf)
optimizer = torch.optim.Adam(model.parameters(), lr=conf.training.learning_rate)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=conf.training.scheduler_factor, patience = conf.training.scheduler_patience)
trainLoader, valLoader, testLoader = get_dataloaders(conf)
train_losses = []
val_losses = []
for epoch in tqdm(range(1, conf.training.epochs + 1)):
t_loss = train(model, trainLoader, optimizer, loss_fn)
train_losses.append(t_loss)
v_loss = evaluate(model, valLoader, loss_fn)
val_losses.append(v_loss)
scheduler.step(val_losses[-1])
print(f"Learning Rate: {optimizer.param_groups[0]['lr']}")
if epoch in conf.training.checkpoint_epochs or epoch == conf.training.epochs:
modelName = get_model_name(epoch, conf)
torch.save(model.state_dict(), f'Results/{exp_name}/Models/{modelName}.pth')
print(f"Saved model checkpoint: {modelName}.pth")
plot_loss_curve(train_losses, val_losses, test_loss=-1, exp_name=exp_name, conf=conf)
test_loss = evaluate(model, testLoader, loss_fn)
plot_loss_curve(train_losses, val_losses, test_loss=test_loss, exp_name=exp_name, conf=conf)