Reframe the training codes, add yaml config
This commit is contained in:
@@ -0,0 +1,40 @@
|
||||
# configs/train_1photon.yaml
|
||||
experiment:
|
||||
name: "1photon_15keV"
|
||||
|
||||
data:
|
||||
sample_folder: "/mnt/sls_det_storage/moench_data/MLXID/Samples/Simulation/Moench040"
|
||||
energy: 15 ### in keV
|
||||
noise_keV: 0.13
|
||||
noise_threshold: 0.0 ### set values below (noise * noise_threshold) to zero
|
||||
num_aug_ops: 1
|
||||
normalize: true
|
||||
|
||||
batch_size_train: 4096
|
||||
batch_size_val: 1024
|
||||
batch_size_test: 1024
|
||||
num_workers: 32
|
||||
train_file_range: [0, 12]
|
||||
val_file_range: [13, 14]
|
||||
test_file_range: [15, 15]
|
||||
|
||||
model:
|
||||
version: "251022"
|
||||
|
||||
training:
|
||||
epochs: 15
|
||||
learning_rate: 1.0e-3
|
||||
loss_alpha: 7.0
|
||||
loss_beta: 3.0
|
||||
scheduler_factor: 0.7
|
||||
scheduler_patience: 3
|
||||
checkpoint_epochs: [10, 30, 50, 100, 150]
|
||||
|
||||
loss:
|
||||
type: "weightedL1"
|
||||
alpha: 7.0
|
||||
beta: 3.0
|
||||
|
||||
logging:
|
||||
save_dir: "Results"
|
||||
model_dir: "Models"
|
||||
+125
-119
@@ -1,5 +1,6 @@
|
||||
import sys
|
||||
sys.path.append('./src')
|
||||
from omegaconf import OmegaConf ### for yaml config parsing
|
||||
import torch
|
||||
import numpy as np
|
||||
import models
|
||||
@@ -7,6 +8,10 @@ from datasets import *
|
||||
import torch.optim as optim
|
||||
from tqdm import tqdm
|
||||
from torchinfo import summary
|
||||
from pathlib import Path
|
||||
|
||||
from models import get_model_class
|
||||
from datasets import singlePhotonDataset
|
||||
|
||||
### random seed for reproducibility
|
||||
torch.manual_seed(0)
|
||||
@@ -15,168 +20,169 @@ np.random.seed(0)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
modelVersion = '251022' # '250909' or '251020'
|
||||
Energy = '15keV'
|
||||
TrainLosses, ValLosses = [], []
|
||||
LearningRates = []
|
||||
TestLoss = -1
|
||||
model = models.get_model_class(modelVersion)().cuda()
|
||||
# summary(model, input_size=(128, 1, 3, 3))
|
||||
LearningRate = 1e-3
|
||||
Noise = 0.23 # in keV
|
||||
NoiseThreshold = 0 * Noise # in keV, set values below this threshold to zero
|
||||
numberOfAugOps = 1 # 1 (no augmentation) or (1,8] (with augmentation)
|
||||
flag_normalize = False
|
||||
conf = OmegaConf.load("Configs/train_1photon.yaml")
|
||||
|
||||
TrainLosses, TestLosses = [], []
|
||||
def weighted_loss(pred, target, alpha=7.0):
|
||||
# weighted L1 loss for x,y position
|
||||
pred = pred[:, :2]
|
||||
target = target[:, :2]
|
||||
def prepare_output_folder(conf):
|
||||
from datetime import datetime
|
||||
date = datetime.now().strftime("%y%m%d") ## YYMMDD format
|
||||
# find the next index for experiment name
|
||||
exp_index = 0
|
||||
while True:
|
||||
exp_name = f'{date}_1ph_{conf.data.energy}keV_v{conf.model.version}_{exp_index:02d}'
|
||||
if not Path(f'Results/{exp_name}').exists():
|
||||
break
|
||||
exp_index += 1
|
||||
Path(f'Results/{exp_name}').mkdir(parents=True, exist_ok=True)
|
||||
Path(f'Results/{exp_name}/Models').mkdir(parents=True, exist_ok=True)
|
||||
Path(f'Results/{exp_name}/Plots').mkdir(parents=True, exist_ok=True)
|
||||
OmegaConf.save(conf, f'Results/{exp_name}/config.yaml')
|
||||
return exp_name
|
||||
|
||||
# weights = 1.0 + alpha * torch.abs(target)
|
||||
direction_weight = 1.0 + alpha * torch.abs(target) # (B, 2)
|
||||
beta = 3.
|
||||
r = torch.norm(target, dim=1, keepdim=True)
|
||||
radial_weight = 1.0 + beta * r # (B, 1) →
|
||||
weights = radial_weight * direction_weight
|
||||
def get_loss_function(conf):
|
||||
if conf.loss.type == "weightedL1":
|
||||
alpha = conf.loss.alpha
|
||||
beta = conf.loss.beta
|
||||
def weighted_loss(pred, target):
|
||||
# weighted L1 loss for x,y position
|
||||
pred = pred[:, :2]
|
||||
target = target[:, :2]
|
||||
|
||||
loss = weights * torch.abs(pred - target)
|
||||
return loss.mean()
|
||||
LossFunction = weighted_loss
|
||||
direction_weight = 1.0 + alpha * torch.abs(target) # (B, 2)
|
||||
r = torch.norm(target, dim=1, keepdim=True)
|
||||
radial_weight = 1.0 + beta * r # (B, 1) →
|
||||
weights = radial_weight * direction_weight
|
||||
|
||||
def train(model, trainLoader, optimizer):
|
||||
loss = weights * torch.abs(pred - target)
|
||||
return loss.mean()
|
||||
return weighted_loss
|
||||
|
||||
def train(model, trainLoader, optimizer, loss_fn):
|
||||
model.train()
|
||||
batchLoss = 0
|
||||
total_samples = 0
|
||||
rms_x, rms_y = 0, 0
|
||||
for batch_idx, (sample, label) in enumerate(trainLoader):
|
||||
sample, label = sample.cuda(), label.cuda()
|
||||
x, y, z, e = label[:,0], label[:,1], label[:,2], label[:,3]
|
||||
optimizer.zero_grad()
|
||||
output = model(sample)
|
||||
loss = LossFunction(output, torch.stack((x, y, z), axis=1))
|
||||
loss = loss_fn(output, torch.stack((x, y, z), axis=1))
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
batchLoss += loss.item() * sample.shape[0]
|
||||
total_samples += sample.shape[0]
|
||||
rms_x += torch.sum((output[:,0] - x)**2).item()
|
||||
rms_y += torch.sum((output[:,1] - y)**2).item()
|
||||
avgLoss = batchLoss / len(trainLoader.dataset)
|
||||
rms_x = np.sqrt(rms_x / len(trainLoader.dataset))
|
||||
rms_y = np.sqrt(rms_y / len(trainLoader.dataset))
|
||||
avgLoss = batchLoss / total_samples
|
||||
rms_x = np.sqrt(rms_x / total_samples)
|
||||
rms_y = np.sqrt(rms_y / total_samples)
|
||||
|
||||
datasetName = trainLoader.dataset.datasetName
|
||||
print(f"[{datasetName}]\t Average Loss: {avgLoss:.6f} (sigma = {np.sqrt(avgLoss):.6f}) \t RMS X: {rms_x:.6f} \t RMS Y: {rms_y:.6f}")
|
||||
TrainLosses.append(avgLoss)
|
||||
return avgLoss
|
||||
|
||||
def test(model, testLoader):
|
||||
def evaluate(model, testLoader, loss_fn):
|
||||
model.eval()
|
||||
batchLoss = 0
|
||||
rms_x, rms_y = 0, 0
|
||||
with torch.no_grad():
|
||||
for batch_idx, (sample, label) in enumerate(testLoader):
|
||||
sample, label = sample.cuda(), label.cuda()
|
||||
x, y, z, e = label[:,0], label[:,1], label[:,2], label[:,3]
|
||||
output = model(sample)
|
||||
loss = LossFunction(output, torch.stack((x, y, z), axis=1))
|
||||
loss = loss_fn(output, torch.stack((x, y, z), axis=1))
|
||||
batchLoss += loss.item() * sample.shape[0]
|
||||
rms_x += torch.sum((output[:,0] - x)**2).item()
|
||||
rms_y += torch.sum((output[:,1] - y)**2).item()
|
||||
avgLoss = batchLoss / len(testLoader.dataset)
|
||||
rms_x = np.sqrt(rms_x / len(testLoader.dataset))
|
||||
rms_y = np.sqrt(rms_y / len(testLoader.dataset))
|
||||
|
||||
datasetName = testLoader.dataset.datasetName
|
||||
print(f"[{datasetName}]\t Average Loss: {avgLoss:.6f} (sigma = {np.sqrt(avgLoss):.6f})")
|
||||
if datasetName == 'Val':
|
||||
ValLosses.append(avgLoss)
|
||||
else:
|
||||
global TestLoss
|
||||
TestLoss = avgLoss
|
||||
print(f"[{datasetName}]\t Average Loss: {avgLoss:.6f} (sigma = {np.sqrt(avgLoss):.6f}) \t RMS X: {rms_x:.6f} \t RMS Y: {rms_y:.6f}")
|
||||
return avgLoss
|
||||
|
||||
sampleFolder = '/mnt/sls_det_storage/moench_data/MLXID/Samples/Simulation/Moench040'
|
||||
trainDataset = singlePhotonDataset(
|
||||
[f'{sampleFolder}/{Energy}_Moench040_150V_{i}.npz' for i in range(13)],
|
||||
sampleRatio=1.0,
|
||||
datasetName='Train',
|
||||
noiseKeV = Noise,
|
||||
numberOfAugOps=numberOfAugOps,
|
||||
normalize=flag_normalize,
|
||||
noiseThreshold=NoiseThreshold
|
||||
)
|
||||
valDataset = singlePhotonDataset(
|
||||
[f'{sampleFolder}/{Energy}_Moench040_150V_{i}.npz' for i in range(13,14)],
|
||||
sampleRatio=1.0,
|
||||
datasetName='Val',
|
||||
noiseKeV = Noise,
|
||||
numberOfAugOps=numberOfAugOps,
|
||||
normalize=flag_normalize,
|
||||
noiseThreshold=NoiseThreshold
|
||||
)
|
||||
testDataset = singlePhotonDataset(
|
||||
[f'{sampleFolder}/{Energy}_Moench040_150V_{i}.npz' for i in range(15,16)],
|
||||
sampleRatio=1.0,
|
||||
datasetName='Test',
|
||||
noiseKeV = Noise,
|
||||
numberOfAugOps=numberOfAugOps,
|
||||
normalize=flag_normalize,
|
||||
noiseThreshold=NoiseThreshold
|
||||
)
|
||||
trainLoader = torch.utils.data.DataLoader(
|
||||
trainDataset,
|
||||
batch_size=4096,
|
||||
shuffle=True,
|
||||
num_workers=32,
|
||||
pin_memory=True,
|
||||
)
|
||||
valLoader = torch.utils.data.DataLoader(
|
||||
valDataset,
|
||||
batch_size=1024,
|
||||
shuffle=False,
|
||||
num_workers=32,
|
||||
pin_memory=True,
|
||||
)
|
||||
testLoader = torch.utils.data.DataLoader(
|
||||
testDataset,
|
||||
batch_size=1024,
|
||||
shuffle=4096,
|
||||
num_workers=32,
|
||||
pin_memory=True,
|
||||
)
|
||||
def get_dataloaders(conf):
|
||||
"""construct all dataLoaders"""
|
||||
datasets = {}
|
||||
loaders = {}
|
||||
splits = ['train', 'val', 'test']
|
||||
keys = ['train_files', 'val_files', 'test_files']
|
||||
batch_keys = ['batch_size_train', 'batch_size_val', 'batch_size_test']
|
||||
file_range_keys = ['train_file_range', 'val_file_range', 'test_file_range']
|
||||
|
||||
for split, f_key, b_key, fr_key in zip(splits, keys, batch_keys, file_range_keys):
|
||||
file_paths = [f"{conf.data.sample_folder}/{conf.data.energy}keV_Moench040_150V_{i}.npz" for i in range(conf.data[fr_key][0], conf.data[fr_key][1] + 1)]
|
||||
|
||||
dataset = singlePhotonDataset(
|
||||
file_paths,
|
||||
sampleRatio=1.0,
|
||||
datasetName=split.capitalize(),
|
||||
noiseKeV=conf.data.noise_keV,
|
||||
numberOfAugOps=conf.data.num_aug_ops,
|
||||
normalize=conf.data.normalize,
|
||||
noiseThreshold=conf.data.noise_threshold * conf.data.noise_keV
|
||||
)
|
||||
datasets[split] = dataset
|
||||
|
||||
loaders[split] = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=conf.data[b_key],
|
||||
shuffle=(split == 'train'),
|
||||
num_workers=conf.data.num_workers,
|
||||
pin_memory=True
|
||||
)
|
||||
return loaders['train'], loaders['val'], loaders['test']
|
||||
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=LearningRate)
|
||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.7, patience = 3)
|
||||
|
||||
def plot_loss_curve(TrainLosses, ValLosses, modelVersion, TestLoss=0):
|
||||
def plot_loss_curve(train_losses, val_losses, test_loss, exp_name, conf):
|
||||
import matplotlib.pyplot as plt
|
||||
plt.figure(figsize=(8,6))
|
||||
plt.plot(TrainLosses, label='Train Loss', color='blue')
|
||||
plt.plot(ValLosses, label='Validation Loss', color='orange')
|
||||
if TestLoss > 0:
|
||||
plt.axhline(y=TestLoss, color='green', linestyle='--', label='Test Loss')
|
||||
plt.plot(train_losses, label='Train Loss', color='blue')
|
||||
plt.plot(val_losses, label='Validation Loss', color='orange')
|
||||
if test_loss > 0:
|
||||
plt.axhline(y=test_loss, color='green', linestyle='--', label='Test Loss')
|
||||
plt.yscale('log')
|
||||
plt.xlabel('Epoch')
|
||||
plt.ylabel('MSE Loss')
|
||||
plt.legend()
|
||||
plt.grid()
|
||||
plotName = f'loss_curve_singlePhoton_{modelVersion}'
|
||||
if flag_normalize:
|
||||
plotName = f'loss_curve_singlePhoton_{conf.model.version}'
|
||||
if conf.data.normalize:
|
||||
plotName += '_normalized'
|
||||
plt.savefig(f'Results/{plotName}.png')
|
||||
plt.savefig(f'Results/{exp_name}/Plots/{plotName}.png')
|
||||
|
||||
def get_model_name(epoch, conf):
|
||||
modelName = f'singlePhoton{conf.model.version}_{conf.data.energy}keV_Noise{conf.data.noise_keV}keV_E{epoch}_aug{conf.data.num_aug_ops}'
|
||||
if conf.data.normalize:
|
||||
modelName += '_normalized'
|
||||
return modelName
|
||||
|
||||
if __name__ == "__main__":
|
||||
for epoch in tqdm(range(1, 151)):
|
||||
train(model, trainLoader, optimizer)
|
||||
test(model, valLoader)
|
||||
scheduler.step(ValLosses[-1])
|
||||
print(f"Learning Rate: {optimizer.param_groups[0]['lr']}")
|
||||
if epoch in [20, 30, 50, 100, 150]:
|
||||
modelName = f'singlePhoton{modelVersion}_{Energy}_Noise{Noise}keV_E{epoch}_aug{numberOfAugOps}'
|
||||
if flag_normalize == True:
|
||||
modelName += '_normalized'
|
||||
torch.save(model.state_dict(), f'Models/{modelName}.pth')
|
||||
print(f"Saved model checkpoint: {modelName}.pth")
|
||||
plot_loss_curve(TrainLosses, ValLosses, modelVersion=modelVersion)
|
||||
|
||||
test(model, testLoader)
|
||||
modelName = f'singlePhoton{modelVersion}_{Energy}_Noise{Noise}keV_E{epoch}_aug{numberOfAugOps}'
|
||||
if flag_normalize == True:
|
||||
modelName += '_normalized'
|
||||
torch.save(model.state_dict(), f'Models/{modelName}.pth')
|
||||
print(f"Saved final model checkpoint: {modelName}.pth")
|
||||
plot_loss_curve(TrainLosses, ValLosses, modelVersion=modelVersion, TestLoss=TestLoss)
|
||||
exp_name = prepare_output_folder(conf)
|
||||
model = models.get_model_class(conf.model.version)().cuda()
|
||||
# summary(model, input_size=(128, 1, 3, 3))
|
||||
|
||||
loss_fn = get_loss_function(conf)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=conf.training.learning_rate)
|
||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=conf.training.scheduler_factor, patience = conf.training.scheduler_patience)
|
||||
|
||||
trainLoader, valLoader, testLoader = get_dataloaders(conf)
|
||||
|
||||
train_losses = []
|
||||
val_losses = []
|
||||
|
||||
for epoch in tqdm(range(1, conf.training.epochs + 1)):
|
||||
t_loss = train(model, trainLoader, optimizer, loss_fn)
|
||||
train_losses.append(t_loss)
|
||||
v_loss = evaluate(model, valLoader, loss_fn)
|
||||
val_losses.append(v_loss)
|
||||
scheduler.step(val_losses[-1])
|
||||
print(f"Learning Rate: {optimizer.param_groups[0]['lr']}")
|
||||
if epoch in conf.training.checkpoint_epochs or epoch == conf.training.epochs:
|
||||
modelName = get_model_name(epoch, conf)
|
||||
torch.save(model.state_dict(), f'Results/{exp_name}/Models/{modelName}.pth')
|
||||
print(f"Saved model checkpoint: {modelName}.pth")
|
||||
plot_loss_curve(train_losses, val_losses, test_loss=-1, exp_name=exp_name, conf=conf)
|
||||
test_loss = evaluate(model, testLoader, loss_fn)
|
||||
plot_loss_curve(train_losses, val_losses, test_loss=test_loss, exp_name=exp_name, conf=conf)
|
||||
Reference in New Issue
Block a user