@@ -0,0 +1,31 @@
|
||||
# configs/train_3photon.yaml
|
||||
experiment:
|
||||
name: "3photon_12keV"
|
||||
|
||||
data:
|
||||
sample_folder: "/home/xie_x1/MLXID/DeepLearning/PileupSample"
|
||||
energy: 12 ### in keV
|
||||
|
||||
batch_size_train: 4096
|
||||
batch_size_val: 1024
|
||||
batch_size_test: 1024
|
||||
num_workers: 32
|
||||
train_file_range: [0, 12]
|
||||
val_file_range: [13, 14]
|
||||
test_file_range: [15, 15]
|
||||
sample_ratio: 0.1
|
||||
n_size: 9 ### size of sub-images containing 3 photons
|
||||
|
||||
model:
|
||||
version: "260529"
|
||||
|
||||
training:
|
||||
epochs: 100
|
||||
learning_rate: 1.0e-3
|
||||
weight_decay: 1.0e-4
|
||||
scheduler_factor: 0.7
|
||||
scheduler_patience: 5
|
||||
checkpoint_epochs: [10, 30, 50, 100, 150, 300, 500, 1000]
|
||||
|
||||
loss:
|
||||
type: "three_point_set_loss_smooth_l1"
|
||||
@@ -0,0 +1,167 @@
|
||||
import sys
|
||||
sys.path.append('./src')
|
||||
from omegaconf import OmegaConf ### for yaml config parsing
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.optim as optim
|
||||
from tqdm import tqdm
|
||||
from torchinfo import summary
|
||||
from pathlib import Path
|
||||
|
||||
from models import get_triple_photon_model_class
|
||||
from datasets import triplePhotonDataset
|
||||
|
||||
### random seed for reproducibility
|
||||
torch.manual_seed(0)
|
||||
torch.cuda.manual_seed(0)
|
||||
np.random.seed(0)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
conf = OmegaConf.load("Configs/train_3photon.yaml")
|
||||
|
||||
def prepare_output_folder(conf):
|
||||
from datetime import datetime
|
||||
date = datetime.now().strftime("%y%m%d") ## YYMMDD format
|
||||
# find the next index for experiment name
|
||||
exp_index = 0
|
||||
while True:
|
||||
exp_name = f'{date}_3ph_{conf.data.energy}keV_v{conf.model.version}_{exp_index:02d}'
|
||||
if not Path(f'Results/{exp_name}').exists():
|
||||
break
|
||||
exp_index += 1
|
||||
Path(f'Results/{exp_name}').mkdir(parents=True, exist_ok=True)
|
||||
Path(f'Results/{exp_name}/Models').mkdir(parents=True, exist_ok=True)
|
||||
Path(f'Results/{exp_name}/Plots').mkdir(parents=True, exist_ok=True)
|
||||
OmegaConf.save(conf, f'Results/{exp_name}/config.yaml')
|
||||
return exp_name
|
||||
|
||||
def get_loss_function(conf):
|
||||
if conf.loss.type == "three_point_set_loss_smooth_l1":
|
||||
def three_point_set_loss_smooth_l1(pred_xy, gt_xy):
|
||||
loss_fn = torch.nn.SmoothL1Loss(reduction='none')
|
||||
|
||||
p1, p2, p3 = pred_xy[:,0], pred_xy[:,1], pred_xy[:,2]
|
||||
g1, g2, g3 = gt_xy[:,0], gt_xy[:,1], gt_xy[:,2]
|
||||
|
||||
c_a = loss_fn(p1, g1).sum(dim=-1) + loss_fn(p2, g2).sum(dim=-1) + loss_fn(p3, g3).sum(dim=-1)
|
||||
c_b = loss_fn(p1, g2).sum(dim=-1) + loss_fn(p2, g1).sum(dim=-1) + loss_fn(p3, g3).sum(dim=-1)
|
||||
c_c = loss_fn(p1, g3).sum(dim=-1) + loss_fn(p2, g2).sum(dim=-1) + loss_fn(p3, g1).sum(dim=-1)
|
||||
c_d = loss_fn(p1, g1).sum(dim=-1) + loss_fn(p2, g3).sum(dim=-1) + loss_fn(p3, g2).sum(dim=-1)
|
||||
c_e = loss_fn(p1, g2).sum(dim=-1) + loss_fn(p2, g3).sum(dim=-1) + loss_fn(p3, g1).sum(dim=-1)
|
||||
c_f = loss_fn(p1, g3).sum(dim=-1) + loss_fn(p2, g1).sum(dim=-1) + loss_fn(p3, g2).sum(dim=-1)
|
||||
|
||||
return torch.minimum(torch.minimum(torch.minimum(torch.minimum(torch.minimum(c_a, c_b), c_c), c_d), c_e), c_f).mean()
|
||||
|
||||
return three_point_set_loss_smooth_l1
|
||||
|
||||
def train(model, trainLoader, optimizer, loss_fn):
|
||||
model.train()
|
||||
batchLoss = 0
|
||||
for batch_idx, (sample, label) in enumerate(trainLoader):
|
||||
sample, label = sample.cuda(), label.cuda()
|
||||
x1, y1, z1, e1 = label[:, 0, 0], label[:, 0, 1], label[:, 0, 2], label[:, 0, 3]
|
||||
x2, y2, z2, e2 = label[:, 1, 0], label[:, 1, 1], label[:, 1, 2], label[:, 1, 3]
|
||||
x3, y3, z3, e3 = label[:, 2, 0], label[:, 2, 1], label[:, 2, 2], label[:, 2, 3]
|
||||
gt_xy = torch.stack((torch.stack((x1, y1), axis=1), torch.stack((x2, y2), axis=1), torch.stack((x3, y3), axis=1)), axis=1)
|
||||
optimizer.zero_grad()
|
||||
output = model(sample)
|
||||
pred_xy = torch.stack((output[:,0:2], output[:,2:4], output[:,4:6]), axis=1)
|
||||
loss = loss_fn(pred_xy, gt_xy)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
batchLoss += loss.item() * sample.shape[0]
|
||||
avgLoss = batchLoss / len(trainLoader.dataset) / 6 ### divide by 6 to get the average loss per photon per axis
|
||||
print(f"[Train]\t Average Loss: {avgLoss:.6f} (RMS = {np.sqrt(avgLoss):.6f})")
|
||||
return avgLoss
|
||||
|
||||
def evaluate(model, valLoader, loss_fn):
|
||||
model.eval()
|
||||
batchLoss = 0
|
||||
with torch.no_grad():
|
||||
for batch_idx, (sample, label) in enumerate(valLoader):
|
||||
sample, label = sample.cuda(), label.cuda()
|
||||
x1, y1, z1, e1 = label[:, 0, 0], label[:, 0, 1], label[:, 0, 2], label[:, 0, 3]
|
||||
x2, y2, z2, e2 = label[:, 1, 0], label[:, 1, 1], label[:, 1, 2], label[:, 1, 3]
|
||||
x3, y3, z3, e3 = label[:, 2, 0], label[:, 2, 1], label[:, 2, 2], label[:, 2, 3]
|
||||
gt_xy = torch.stack((torch.stack((x1, y1), axis=1), torch.stack((x2, y2), axis=1), torch.stack((x3, y3), axis=1)), axis=1)
|
||||
output = model(sample)
|
||||
pred_xy = torch.stack((output[:,0:2], output[:,2:4], output[:,4:6]), axis=1)
|
||||
loss = loss_fn(pred_xy, gt_xy)
|
||||
batchLoss += loss.item() * sample.shape[0]
|
||||
avgLoss = batchLoss / len(valLoader.dataset) / 6 ### divide by 6 to get the average loss per photon per axis
|
||||
print(f"[Val]\t Average Loss: {avgLoss:.6f} (RMS = {np.sqrt(avgLoss):.6f})")
|
||||
return avgLoss
|
||||
|
||||
def get_dataloaders(conf):
|
||||
"""construct all dataloaders"""
|
||||
datasets = {}
|
||||
loaders = {}
|
||||
splits = ['Train', 'Val', 'Test']
|
||||
keys = ['train_files', 'val_files', 'test_files']
|
||||
batch_keys = ['batch_size_train', 'batch_size_val', 'batch_size_test']
|
||||
file_range_keys = ['train_file_range', 'val_file_range', 'test_file_range']
|
||||
|
||||
for split, key, batch_key, file_range_key in zip(splits, keys, batch_keys, file_range_keys):
|
||||
files = [f"{conf.data.sample_folder}/pileupOf3phs_sample_{i}.npz" for i in range(conf.data[file_range_key][0], conf.data[file_range_key][1] + 1)]
|
||||
|
||||
datasets[split] = triplePhotonDataset(
|
||||
files,
|
||||
sampleRatio = conf.data.sample_ratio,
|
||||
datasetName = split.capitalize(),
|
||||
)
|
||||
|
||||
loaders[split] = torch.utils.data.DataLoader(
|
||||
datasets[split],
|
||||
batch_size=conf.data[batch_key],
|
||||
shuffle=(split=='Train'),
|
||||
num_workers=conf.data.num_workers,
|
||||
pin_memory=True
|
||||
)
|
||||
return loaders['Train'], loaders['Val'], loaders['Test']
|
||||
|
||||
def plot_loss_curves(train_losses, val_losses, test_loss, exp_name, conf):
|
||||
import matplotlib.pyplot as plt
|
||||
plt.figure(figsize=(8,6))
|
||||
plt.plot(train_losses, label='Train Loss')
|
||||
plt.plot(val_losses, label='Val Loss')
|
||||
if test_loss > 0:
|
||||
plt.axhline(y=test_loss, color='green', linestyle='--', label='Test Loss')
|
||||
plt.xlabel('Epoch')
|
||||
plt.ylabel('MSE Loss')
|
||||
plt.yscale('log')
|
||||
plt.legend()
|
||||
plt.grid()
|
||||
plotName = f'loss_curve_triplePhoton_{conf.model.version}.png'
|
||||
plt.savefig(f'Results/{exp_name}/Plots/{plotName}')
|
||||
|
||||
def get_model_name(conf):
|
||||
modelName = f'triplePhoton{conf.model.version}_{conf.data.energy}keV'
|
||||
return modelName
|
||||
|
||||
if __name__ == "__main__":
|
||||
exp_name = prepare_output_folder(conf)
|
||||
model = get_triple_photon_model_class(conf.model.version)().cuda()
|
||||
# summary(model, input_size=(128, 3, conf.data.n_size, conf.data.n_size))
|
||||
|
||||
loss_fn = get_loss_function(conf)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=conf.training.learning_rate, weight_decay=conf.training.weight_decay)
|
||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=conf.training.scheduler_factor, patience=conf.training.scheduler_patience)
|
||||
|
||||
trainLoader, valLoader, testLoader = get_dataloaders(conf)
|
||||
|
||||
TrainLosses, ValLosses = [], []
|
||||
for epoch in tqdm(range(1, conf.training.epochs + 1)):
|
||||
train_loss = train(model, trainLoader, optimizer, loss_fn)
|
||||
val_loss = evaluate(model, valLoader, loss_fn)
|
||||
TrainLosses.append(train_loss)
|
||||
ValLosses.append(val_loss)
|
||||
scheduler.step(val_loss)
|
||||
print(f"Learning Rate: {optimizer.param_groups[0]['lr']:.2e}")
|
||||
if epoch in conf.training.checkpoint_epochs or epoch == conf.training.epochs:
|
||||
modelName = get_model_name(conf)
|
||||
torch.save(model.state_dict(), f'Results/{exp_name}/Models/{modelName}_E{epoch}.pth')
|
||||
print(f"Saved model checkpoint: {modelName}_E{epoch}.pth")
|
||||
plot_loss_curves(TrainLosses, ValLosses, test_loss=-1, exp_name=exp_name, conf=conf)
|
||||
test_loss = evaluate(model, testLoader, loss_fn)
|
||||
plot_loss_curves(TrainLosses, ValLosses, test_loss=test_loss, exp_name=exp_name, conf=conf)
|
||||
@@ -238,5 +238,39 @@ class doublePhotonInferenceDataset(Dataset):
|
||||
dummy_label = np.zeros((8,), dtype=np.float32) ### dummy label
|
||||
return sample, torch.tensor(dummy_label, dtype=torch.float32)
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
class triplePhotonDataset(Dataset):
|
||||
def __init__(self, sampleList, sampleRatio, datasetName):
|
||||
self.sampleFileList = sampleList
|
||||
self.sampleRatio = sampleRatio
|
||||
self.datasetName = datasetName
|
||||
all_samples = []
|
||||
all_labels = []
|
||||
for sampleFile in self.sampleFileList:
|
||||
if '.npz' in sampleFile:
|
||||
data = np.load(sampleFile)
|
||||
all_samples.append(data['samples'])
|
||||
all_labels.append(data['labels'])
|
||||
elif '.h5' in sampleFile:
|
||||
import h5py
|
||||
with h5py.File(sampleFile, 'r') as f:
|
||||
samples = f['clusters'][:]
|
||||
labels = f['labels'][:]
|
||||
all_samples.append(samples)
|
||||
all_labels.append(labels)
|
||||
self.samples = np.concatenate(all_samples, axis=0)
|
||||
self.labels = np.concatenate(all_labels, axis=0)
|
||||
### total number of samples
|
||||
self.length = int(self.samples.shape[0] * self.sampleRatio)
|
||||
print(f"[{self.datasetName} dataset] \t Total number of samples: {self.length}")
|
||||
def __getitem__(self, index):
|
||||
sample = self.samples[index]
|
||||
label = self.labels[index]
|
||||
label[:, :2] -= self.samples.shape[-1] / 2. ### adjust labels to be centered at sample center
|
||||
sample = torch.tensor(sample, dtype=torch.float32).unsqueeze(0)
|
||||
label = torch.tensor(label, dtype=torch.float32)
|
||||
return sample, label
|
||||
def __len__(self):
|
||||
return self.length
|
||||
@@ -17,6 +17,13 @@ def get_double_photon_model_class(version):
|
||||
raise ValueError(f"Model class '{class_name}' not found.")
|
||||
return cls
|
||||
|
||||
def get_triple_photon_model_class(version):
|
||||
class_name = f'triplePhotonNet_{version}'
|
||||
cls = globals().get(class_name)
|
||||
if cls is None:
|
||||
raise ValueError(f"Model class '{class_name}' not found.")
|
||||
return cls
|
||||
|
||||
class singlePhotonNet_250909(nn.Module):
|
||||
def weight_init(self):
|
||||
for m in self.modules():
|
||||
@@ -415,4 +422,65 @@ class doublePhotonNet_260507(nn.Module): ## adapted from 251124, removed max poo
|
||||
flat_feat = c3.view(c3.size(0), -1) # [B, 6272]
|
||||
|
||||
coords = self.fc(flat_feat) # [B, 4]
|
||||
return coords
|
||||
|
||||
|
||||
class triplePhotonNet_260529(nn.Module): ## adapted from doublePhotonNet_260507, add one more conv layer and increase capacity of FC layers, for 3-photon pileup with 9x9 input
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Backbone: deeper for 9x9 input containing 3 photons
|
||||
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1) ### 9x9
|
||||
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) ## 9x9
|
||||
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) ## 9x9
|
||||
self.conv4 = nn.Conv2d(128, 128, kernel_size=3) ## 7x7
|
||||
|
||||
# Spatial Attention Module
|
||||
self.spatial_attn = nn.Sequential(
|
||||
nn.Conv2d(128, 1, kernel_size=1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(128 * 7 * 7, 512),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(512, 128),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(128, 6)
|
||||
)
|
||||
|
||||
self._init_weights()
|
||||
self._init_coords()
|
||||
|
||||
def _init_coords(self):
|
||||
# Create a coordinate grid; moved from dataset generation to model initialization for lower traffic and more flexibility
|
||||
nSize = 9 # should match the input size of the model
|
||||
x = np.linspace(-nSize/2. + 0.5, nSize/2. - 0.5, nSize)
|
||||
y = np.linspace(-nSize/2. + 0.5, nSize/2. - 0.5, nSize)
|
||||
x_grid, y_grid = np.meshgrid(x, y, indexing='ij') # (nSize,nSize), (nSize,nSize)
|
||||
self.x_grid = torch.tensor(np.expand_dims(x_grid, axis=0)).float().contiguous().to('cuda') # (1, nSize, nSize)
|
||||
self.y_grid = torch.tensor(np.expand_dims(y_grid, axis=0)).float().contiguous().to('cuda') # (1, nSize, nSize)
|
||||
|
||||
def _init_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.xavier_uniform_(m.weight)
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.cat((x, self.x_grid.expand(x.size(0), -1, -1, -1), self.y_grid.expand(x.size(0), -1, -1, -1)), dim=1) # [B, 3, 9, 9]
|
||||
c1 = F.relu(self.conv1(x)) # [B, 32, 9, 9]
|
||||
c2 = F.relu(self.conv2(c1)) # [B, 64, 9, 9]
|
||||
c3 = F.relu(self.conv3(c2)) # [B, 128, 9, 9]
|
||||
c4 = F.relu(self.conv4(c3)) # [B, 128, 7, 7]
|
||||
|
||||
attn = self.spatial_attn(c4) # [B, 1, 7, 7]
|
||||
c4 = c4 * attn # [B, 128, 7, 7]
|
||||
|
||||
flat_feat = c4.view(c4.size(0), -1) # [B, 6272]
|
||||
|
||||
coords = self.fc(flat_feat) # [B, 6]
|
||||
return coords
|
||||
Reference in New Issue
Block a user