From 2e4f22d062b0a3b827913f1ccb9913a824ce868b Mon Sep 17 00:00:00 2001 From: "xiangyu.xie" Date: Fri, 29 May 2026 16:11:56 +0200 Subject: [PATCH] 3phs category Co-authored-by: Copilot --- Configs/train_3photon.yaml | 31 +++++++ Train_3Photon.py | 167 +++++++++++++++++++++++++++++++++++++ src/datasets.py | 34 ++++++++ src/models.py | 68 +++++++++++++++ 4 files changed, 300 insertions(+) create mode 100644 Configs/train_3photon.yaml create mode 100644 Train_3Photon.py diff --git a/Configs/train_3photon.yaml b/Configs/train_3photon.yaml new file mode 100644 index 0000000..f24c05c --- /dev/null +++ b/Configs/train_3photon.yaml @@ -0,0 +1,31 @@ +# configs/train_3photon.yaml +experiment: + name: "3photon_12keV" + +data: + sample_folder: "/home/xie_x1/MLXID/DeepLearning/PileupSample" + energy: 12 ### in keV + + batch_size_train: 4096 + batch_size_val: 1024 + batch_size_test: 1024 + num_workers: 32 + train_file_range: [0, 12] + val_file_range: [13, 14] + test_file_range: [15, 15] + sample_ratio: 0.1 + n_size: 9 ### size of sub-images containing 3 photons + +model: + version: "260529" + +training: + epochs: 100 + learning_rate: 1.0e-3 + weight_decay: 1.0e-4 + scheduler_factor: 0.7 + scheduler_patience: 5 + checkpoint_epochs: [10, 30, 50, 100, 150, 300, 500, 1000] + +loss: + type: "three_point_set_loss_smooth_l1" \ No newline at end of file diff --git a/Train_3Photon.py b/Train_3Photon.py new file mode 100644 index 0000000..67c91ef --- /dev/null +++ b/Train_3Photon.py @@ -0,0 +1,167 @@ +import sys +sys.path.append('./src') +from omegaconf import OmegaConf ### for yaml config parsing +import torch +import numpy as np +import torch.optim as optim +from tqdm import tqdm +from torchinfo import summary +from pathlib import Path + +from models import get_triple_photon_model_class +from datasets import triplePhotonDataset + +### random seed for reproducibility +torch.manual_seed(0) +torch.cuda.manual_seed(0) +np.random.seed(0) +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False + +conf = OmegaConf.load("Configs/train_3photon.yaml") + +def prepare_output_folder(conf): + from datetime import datetime + date = datetime.now().strftime("%y%m%d") ## YYMMDD format + # find the next index for experiment name + exp_index = 0 + while True: + exp_name = f'{date}_3ph_{conf.data.energy}keV_v{conf.model.version}_{exp_index:02d}' + if not Path(f'Results/{exp_name}').exists(): + break + exp_index += 1 + Path(f'Results/{exp_name}').mkdir(parents=True, exist_ok=True) + Path(f'Results/{exp_name}/Models').mkdir(parents=True, exist_ok=True) + Path(f'Results/{exp_name}/Plots').mkdir(parents=True, exist_ok=True) + OmegaConf.save(conf, f'Results/{exp_name}/config.yaml') + return exp_name + +def get_loss_function(conf): + if conf.loss.type == "three_point_set_loss_smooth_l1": + def three_point_set_loss_smooth_l1(pred_xy, gt_xy): + loss_fn = torch.nn.SmoothL1Loss(reduction='none') + + p1, p2, p3 = pred_xy[:,0], pred_xy[:,1], pred_xy[:,2] + g1, g2, g3 = gt_xy[:,0], gt_xy[:,1], gt_xy[:,2] + + c_a = loss_fn(p1, g1).sum(dim=-1) + loss_fn(p2, g2).sum(dim=-1) + loss_fn(p3, g3).sum(dim=-1) + c_b = loss_fn(p1, g2).sum(dim=-1) + loss_fn(p2, g1).sum(dim=-1) + loss_fn(p3, g3).sum(dim=-1) + c_c = loss_fn(p1, g3).sum(dim=-1) + loss_fn(p2, g2).sum(dim=-1) + loss_fn(p3, g1).sum(dim=-1) + c_d = loss_fn(p1, g1).sum(dim=-1) + loss_fn(p2, g3).sum(dim=-1) + loss_fn(p3, g2).sum(dim=-1) + c_e = loss_fn(p1, g2).sum(dim=-1) + loss_fn(p2, g3).sum(dim=-1) + loss_fn(p3, g1).sum(dim=-1) + c_f = loss_fn(p1, g3).sum(dim=-1) + loss_fn(p2, g1).sum(dim=-1) + loss_fn(p3, g2).sum(dim=-1) + + return torch.minimum(torch.minimum(torch.minimum(torch.minimum(torch.minimum(c_a, c_b), c_c), c_d), c_e), c_f).mean() + + return three_point_set_loss_smooth_l1 + +def train(model, trainLoader, optimizer, loss_fn): + model.train() + batchLoss = 0 + for batch_idx, (sample, label) in enumerate(trainLoader): + sample, label = sample.cuda(), label.cuda() + x1, y1, z1, e1 = label[:, 0, 0], label[:, 0, 1], label[:, 0, 2], label[:, 0, 3] + x2, y2, z2, e2 = label[:, 1, 0], label[:, 1, 1], label[:, 1, 2], label[:, 1, 3] + x3, y3, z3, e3 = label[:, 2, 0], label[:, 2, 1], label[:, 2, 2], label[:, 2, 3] + gt_xy = torch.stack((torch.stack((x1, y1), axis=1), torch.stack((x2, y2), axis=1), torch.stack((x3, y3), axis=1)), axis=1) + optimizer.zero_grad() + output = model(sample) + pred_xy = torch.stack((output[:,0:2], output[:,2:4], output[:,4:6]), axis=1) + loss = loss_fn(pred_xy, gt_xy) + loss.backward() + optimizer.step() + batchLoss += loss.item() * sample.shape[0] + avgLoss = batchLoss / len(trainLoader.dataset) / 6 ### divide by 6 to get the average loss per photon per axis + print(f"[Train]\t Average Loss: {avgLoss:.6f} (RMS = {np.sqrt(avgLoss):.6f})") + return avgLoss + +def evaluate(model, valLoader, loss_fn): + model.eval() + batchLoss = 0 + with torch.no_grad(): + for batch_idx, (sample, label) in enumerate(valLoader): + sample, label = sample.cuda(), label.cuda() + x1, y1, z1, e1 = label[:, 0, 0], label[:, 0, 1], label[:, 0, 2], label[:, 0, 3] + x2, y2, z2, e2 = label[:, 1, 0], label[:, 1, 1], label[:, 1, 2], label[:, 1, 3] + x3, y3, z3, e3 = label[:, 2, 0], label[:, 2, 1], label[:, 2, 2], label[:, 2, 3] + gt_xy = torch.stack((torch.stack((x1, y1), axis=1), torch.stack((x2, y2), axis=1), torch.stack((x3, y3), axis=1)), axis=1) + output = model(sample) + pred_xy = torch.stack((output[:,0:2], output[:,2:4], output[:,4:6]), axis=1) + loss = loss_fn(pred_xy, gt_xy) + batchLoss += loss.item() * sample.shape[0] + avgLoss = batchLoss / len(valLoader.dataset) / 6 ### divide by 6 to get the average loss per photon per axis + print(f"[Val]\t Average Loss: {avgLoss:.6f} (RMS = {np.sqrt(avgLoss):.6f})") + return avgLoss + +def get_dataloaders(conf): + """construct all dataloaders""" + datasets = {} + loaders = {} + splits = ['Train', 'Val', 'Test'] + keys = ['train_files', 'val_files', 'test_files'] + batch_keys = ['batch_size_train', 'batch_size_val', 'batch_size_test'] + file_range_keys = ['train_file_range', 'val_file_range', 'test_file_range'] + + for split, key, batch_key, file_range_key in zip(splits, keys, batch_keys, file_range_keys): + files = [f"{conf.data.sample_folder}/pileupOf3phs_sample_{i}.npz" for i in range(conf.data[file_range_key][0], conf.data[file_range_key][1] + 1)] + + datasets[split] = triplePhotonDataset( + files, + sampleRatio = conf.data.sample_ratio, + datasetName = split.capitalize(), + ) + + loaders[split] = torch.utils.data.DataLoader( + datasets[split], + batch_size=conf.data[batch_key], + shuffle=(split=='Train'), + num_workers=conf.data.num_workers, + pin_memory=True + ) + return loaders['Train'], loaders['Val'], loaders['Test'] + +def plot_loss_curves(train_losses, val_losses, test_loss, exp_name, conf): + import matplotlib.pyplot as plt + plt.figure(figsize=(8,6)) + plt.plot(train_losses, label='Train Loss') + plt.plot(val_losses, label='Val Loss') + if test_loss > 0: + plt.axhline(y=test_loss, color='green', linestyle='--', label='Test Loss') + plt.xlabel('Epoch') + plt.ylabel('MSE Loss') + plt.yscale('log') + plt.legend() + plt.grid() + plotName = f'loss_curve_triplePhoton_{conf.model.version}.png' + plt.savefig(f'Results/{exp_name}/Plots/{plotName}') + +def get_model_name(conf): + modelName = f'triplePhoton{conf.model.version}_{conf.data.energy}keV' + return modelName + +if __name__ == "__main__": + exp_name = prepare_output_folder(conf) + model = get_triple_photon_model_class(conf.model.version)().cuda() + # summary(model, input_size=(128, 3, conf.data.n_size, conf.data.n_size)) + + loss_fn = get_loss_function(conf) + optimizer = torch.optim.Adam(model.parameters(), lr=conf.training.learning_rate, weight_decay=conf.training.weight_decay) + scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=conf.training.scheduler_factor, patience=conf.training.scheduler_patience) + + trainLoader, valLoader, testLoader = get_dataloaders(conf) + + TrainLosses, ValLosses = [], [] + for epoch in tqdm(range(1, conf.training.epochs + 1)): + train_loss = train(model, trainLoader, optimizer, loss_fn) + val_loss = evaluate(model, valLoader, loss_fn) + TrainLosses.append(train_loss) + ValLosses.append(val_loss) + scheduler.step(val_loss) + print(f"Learning Rate: {optimizer.param_groups[0]['lr']:.2e}") + if epoch in conf.training.checkpoint_epochs or epoch == conf.training.epochs: + modelName = get_model_name(conf) + torch.save(model.state_dict(), f'Results/{exp_name}/Models/{modelName}_E{epoch}.pth') + print(f"Saved model checkpoint: {modelName}_E{epoch}.pth") + plot_loss_curves(TrainLosses, ValLosses, test_loss=-1, exp_name=exp_name, conf=conf) + test_loss = evaluate(model, testLoader, loss_fn) + plot_loss_curves(TrainLosses, ValLosses, test_loss=test_loss, exp_name=exp_name, conf=conf) \ No newline at end of file diff --git a/src/datasets.py b/src/datasets.py index 71a874f..90baf66 100644 --- a/src/datasets.py +++ b/src/datasets.py @@ -238,5 +238,39 @@ class doublePhotonInferenceDataset(Dataset): dummy_label = np.zeros((8,), dtype=np.float32) ### dummy label return sample, torch.tensor(dummy_label, dtype=torch.float32) + def __len__(self): + return self.length + +class triplePhotonDataset(Dataset): + def __init__(self, sampleList, sampleRatio, datasetName): + self.sampleFileList = sampleList + self.sampleRatio = sampleRatio + self.datasetName = datasetName + all_samples = [] + all_labels = [] + for sampleFile in self.sampleFileList: + if '.npz' in sampleFile: + data = np.load(sampleFile) + all_samples.append(data['samples']) + all_labels.append(data['labels']) + elif '.h5' in sampleFile: + import h5py + with h5py.File(sampleFile, 'r') as f: + samples = f['clusters'][:] + labels = f['labels'][:] + all_samples.append(samples) + all_labels.append(labels) + self.samples = np.concatenate(all_samples, axis=0) + self.labels = np.concatenate(all_labels, axis=0) + ### total number of samples + self.length = int(self.samples.shape[0] * self.sampleRatio) + print(f"[{self.datasetName} dataset] \t Total number of samples: {self.length}") + def __getitem__(self, index): + sample = self.samples[index] + label = self.labels[index] + label[:, :2] -= self.samples.shape[-1] / 2. ### adjust labels to be centered at sample center + sample = torch.tensor(sample, dtype=torch.float32).unsqueeze(0) + label = torch.tensor(label, dtype=torch.float32) + return sample, label def __len__(self): return self.length \ No newline at end of file diff --git a/src/models.py b/src/models.py index ac4b81d..3cc2a1c 100644 --- a/src/models.py +++ b/src/models.py @@ -17,6 +17,13 @@ def get_double_photon_model_class(version): raise ValueError(f"Model class '{class_name}' not found.") return cls +def get_triple_photon_model_class(version): + class_name = f'triplePhotonNet_{version}' + cls = globals().get(class_name) + if cls is None: + raise ValueError(f"Model class '{class_name}' not found.") + return cls + class singlePhotonNet_250909(nn.Module): def weight_init(self): for m in self.modules(): @@ -415,4 +422,65 @@ class doublePhotonNet_260507(nn.Module): ## adapted from 251124, removed max poo flat_feat = c3.view(c3.size(0), -1) # [B, 6272] coords = self.fc(flat_feat) # [B, 4] + return coords + + +class triplePhotonNet_260529(nn.Module): ## adapted from doublePhotonNet_260507, add one more conv layer and increase capacity of FC layers, for 3-photon pileup with 9x9 input + def __init__(self): + super().__init__() + # Backbone: deeper for 9x9 input containing 3 photons + self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1) ### 9x9 + self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) ## 9x9 + self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) ## 9x9 + self.conv4 = nn.Conv2d(128, 128, kernel_size=3) ## 7x7 + + # Spatial Attention Module + self.spatial_attn = nn.Sequential( + nn.Conv2d(128, 1, kernel_size=1), + nn.Sigmoid() + ) + + self.fc = nn.Sequential( + nn.Linear(128 * 7 * 7, 512), + nn.ReLU(), + nn.Dropout(0.3), + nn.Linear(512, 128), + nn.ReLU(), + nn.Dropout(0.3), + nn.Linear(128, 6) + ) + + self._init_weights() + self._init_coords() + + def _init_coords(self): + # Create a coordinate grid; moved from dataset generation to model initialization for lower traffic and more flexibility + nSize = 9 # should match the input size of the model + x = np.linspace(-nSize/2. + 0.5, nSize/2. - 0.5, nSize) + y = np.linspace(-nSize/2. + 0.5, nSize/2. - 0.5, nSize) + x_grid, y_grid = np.meshgrid(x, y, indexing='ij') # (nSize,nSize), (nSize,nSize) + self.x_grid = torch.tensor(np.expand_dims(x_grid, axis=0)).float().contiguous().to('cuda') # (1, nSize, nSize) + self.y_grid = torch.tensor(np.expand_dims(y_grid, axis=0)).float().contiguous().to('cuda') # (1, nSize, nSize) + + def _init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + nn.init.zeros_(m.bias) + + def forward(self, x): + x = torch.cat((x, self.x_grid.expand(x.size(0), -1, -1, -1), self.y_grid.expand(x.size(0), -1, -1, -1)), dim=1) # [B, 3, 9, 9] + c1 = F.relu(self.conv1(x)) # [B, 32, 9, 9] + c2 = F.relu(self.conv2(c1)) # [B, 64, 9, 9] + c3 = F.relu(self.conv3(c2)) # [B, 128, 9, 9] + c4 = F.relu(self.conv4(c3)) # [B, 128, 7, 7] + + attn = self.spatial_attn(c4) # [B, 1, 7, 7] + c4 = c4 * attn # [B, 128, 7, 7] + + flat_feat = c4.view(c4.size(0), -1) # [B, 6272] + + coords = self.fc(flat_feat) # [B, 6] return coords \ No newline at end of file