From b3c08ad2e50cf323549e55b5e6bb69baba0ba388 Mon Sep 17 00:00:00 2001 From: "xiangyu.xie" Date: Mon, 1 Jun 2026 14:21:37 +0200 Subject: [PATCH] Add inference codes for 3-ph category Co-authored-by: Copilot --- Configs/infer_3photon.yaml | 54 ++++++++++ Configs/train_3photon.yaml | 4 +- Infer_3Photon.py | 195 +++++++++++++++++++++++++++++++++++++ src/datasets.py | 34 +++++++ 4 files changed, 285 insertions(+), 2 deletions(-) create mode 100644 Configs/infer_3photon.yaml create mode 100644 Infer_3Photon.py diff --git a/Configs/infer_3photon.yaml b/Configs/infer_3photon.yaml new file mode 100644 index 0000000..0f80d45 --- /dev/null +++ b/Configs/infer_3photon.yaml @@ -0,0 +1,54 @@ +# configs/infer_1photon.yaml +experiment: + name: "FlatField_3filters_pos0" # options: SiemenStarLowerLeft, SiemenStarLowerRight + task: "3Photon" + output_base: "./InferenceResults" + +data: + sample_folder: "/home/xie_x1/MLXID/DataProcess/Samples" + names: + KnifeEdge_3filters_pos0: + file_pattern: "2603MaxIV_Edge3Filters_pos0_12keV/3Photon_CS9_chunk{}.h5" ### 12 keV + NX: 101 + NY: 101 + file_range: [0, 16] # [0, 16) + roi: [0, 101, 0, 101] + FlatField_3filters_pos0: + file_pattern: "2603MaxIV_FlatField3Filters_pos0_12keV/3Photon_CS9_chunk{}.h5" ### 12 keV + NX: 101 + NY: 101 + file_range: [0, 16] # [0, 16) + roi: [0, 101, 0, 101] + KnifeEdge_2filters: + file_pattern: "2603MaxIV_Edge2Filters_12keV/3Photon_CS9_chunk{}.h5" ### 12 keV + NX: 101 + NY: 101 + file_range: [0, 16] # [0, 16) + roi: [0, 101, 0, 101] + FlatField_2filters: + file_pattern: "2603MaxIV_Flat2Filters_12keV/3Photon_CS9_chunk{}.h5" ### 12 keV + NX: 101 + NY: 101 + file_range: [0, 16] # [0, 16) + roi: [0, 101, 0, 101] + Uniformity_test: + file_pattern: "2603MaxIV_Flat2Filters_12keV/3Photon_CS9_chunk{}.h5" ### 12 keV + NX: 101 + NY: 101 + file_range: [0, 1] + roi: [0, 101, 0, 101] + + energy: 12 # keV + normalize: false # for inference dataloaders + batch_size: 8192 + num_workers: 16 + +model: + base_dir: "/home/xie_x1/MLXID/DeepLearning/Results/" + experiment_name: "260529_3ph_12keV_v260529_01" + name: "triplePhoton260529_12keV_E1000.pth" + +inference: + binning_factor: 10 + chunk_size: 16 # + num_aug_ops: 1 # augX8 not implemented for 3-photon inference yet diff --git a/Configs/train_3photon.yaml b/Configs/train_3photon.yaml index f24c05c..db7760b 100644 --- a/Configs/train_3photon.yaml +++ b/Configs/train_3photon.yaml @@ -13,14 +13,14 @@ data: train_file_range: [0, 12] val_file_range: [13, 14] test_file_range: [15, 15] - sample_ratio: 0.1 + sample_ratio: 1.0 n_size: 9 ### size of sub-images containing 3 photons model: version: "260529" training: - epochs: 100 + epochs: 1000 learning_rate: 1.0e-3 weight_decay: 1.0e-4 scheduler_factor: 0.7 diff --git a/Infer_3Photon.py b/Infer_3Photon.py new file mode 100644 index 0000000..3c85066 --- /dev/null +++ b/Infer_3Photon.py @@ -0,0 +1,195 @@ +import sys +sys.path.append('./src') +from pathlib import Path +from omegaconf import OmegaConf +import torch +from tqdm import tqdm +from matplotlib import pyplot as plt +import numpy as np +import h5py + +from models import get_triple_photon_model_class +from datasets import triplePhotonInferenceDataset + +torch.manual_seed(42) +torch.cuda.manual_seed(42) +np.random.seed(42) +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False + +conf = OmegaConf.load("Configs/infer_3photon.yaml") +NX, NY = conf.data.names[conf.experiment.name].NX, conf.data.names[conf.experiment.name].NY +NSIZE = 9 + +def inv0(p): return p +def inv1(p): return torch.stack([-p[..., 0], p[..., 1]], dim=-1) +def inv2(p): return torch.stack([p[..., 0], -p[..., 1]], dim=-1) +def inv3(p): return -p +def inv4(p): return torch.stack([p[..., 1], p[..., 0]], dim=-1) +def inv5(p): return torch.stack([p[..., 1], -p[..., 0]], dim=-1) +def inv6(p): return torch.stack([-p[..., 1], p[..., 0]], dim=-1) +def inv7(p): return torch.stack([-p[..., 1], -p[..., 0]], dim=-1) + +INVERSE_TRANSFORMS = { + 0: inv0, 1: inv1, 2: inv2, 3: inv3, + 4: inv4, 5: inv5, 6: inv6, 7: inv7, +} + +def apply_inverse_transforms(predictions: torch.Tensor, numberOfAugOps: int) -> torch.Tensor: + N = predictions.shape[0] // numberOfAugOps + preds = predictions.view(N, numberOfAugOps, 2) + corrected = torch.zeros_like(preds) + for idx in range(numberOfAugOps): + corrected[:, idx, :] = INVERSE_TRANSFORMS[idx](preds[:, idx, :]) + return corrected.mean(dim=1) + +def prepare_output_folder(conf): + if conf.data.normalize: + normalize_suffix = '_normalized' + else: + normalize_suffix = '' + output_base = Path(conf.experiment.output_base) / conf.experiment.name / conf.model.experiment_name / f'augX{conf.inference.num_aug_ops}{normalize_suffix}' + output_base.mkdir(parents=True, exist_ok=True) + OmegaConf.save(conf, output_base / 'config.yaml') + return output_base + +def get_files_list(conf): + task_conf = conf.data.names[conf.experiment.name] + file_pattern = task_conf.file_pattern + start, end = task_conf.file_range + files = [str(Path(conf.data.sample_folder) / file_pattern.format(i)) for i in range(start, end)] + return files + +def run_inference(model, data_loader, conf): + all_predictions = [] + with torch.no_grad(): + for batch_idx, batch in enumerate(tqdm(data_loader, desc="Inferring")): + inputs, _ = batch + inputs = inputs.cuda() + outputs = model(inputs).view(-1, 2) # 3B x 2 + all_predictions.append(outputs.cpu()) + + all_predictions = torch.cat(all_predictions, dim=0) + # all_predictions = apply_inverse_transforms(all_predictions, conf.inference.num_aug_ops) + all_predictions += torch.tensor([NSIZE/2., NSIZE/2.]).unsqueeze(0) # adjust back to original coordinate system + print(f'mean x = {torch.mean(all_predictions[:, 0])}, std x = {torch.std(all_predictions[:, 0])}') + print(f'mean y = {torch.mean(all_predictions[:, 1])}, std y = {torch.std(all_predictions[:, 1])}') + referencePoints = data_loader.dataset.referencePoint ### the lower-left corner of the cluster in absolute coordinate + referencePoints = np.repeat(referencePoints, 3, axis=0) ### duplicate reference points for 3-photon clusters + return all_predictions.numpy(), referencePoints + +def accumulate_hits(predictions: np.ndarray, reference_points: np.ndarray, + binning_factor: int): + ### ret + ml_super_frame = np.zeros((NY*binning_factor, NX*binning_factor), dtype=np.int32) + count_frame = np.zeros((NY, NX), dtype=np.int32) + subpixel_dist = np.zeros((binning_factor, binning_factor), dtype=np.int32) + + ### absolute coordinate = predicted subpixel + reference point + absolute_positions = predictions + reference_points + + hit_x_superpixel_idx = np.floor(absolute_positions[:, 0] * binning_factor).astype(int) + hit_x_superpixel_idx = np.clip(hit_x_superpixel_idx, 0, NX*binning_factor-1) + hit_y_superpixel_idx = np.floor(absolute_positions[:, 1] * binning_factor).astype(int) + hit_y_superpixel_idx = np.clip(hit_y_superpixel_idx, 0, NY*binning_factor-1) + np.add.at(ml_super_frame, (hit_y_superpixel_idx, hit_x_superpixel_idx), 1) + + hit_x_pixel_idx = np.floor(absolute_positions[:, 0]).astype(int) + hit_x_pixel_idx = np.clip(hit_x_pixel_idx, 0, NX-1) + hit_y_pixel_idx = np.floor(absolute_positions[:, 1]).astype(int) + hit_y_pixel_idx = np.clip(hit_y_pixel_idx, 0, NY-1) + np.add.at(count_frame, (hit_y_pixel_idx, hit_x_pixel_idx), 1) + + subpixel_x_idx = np.floor((absolute_positions[:, 0] % 1) * binning_factor).astype(int) + subpixel_y_idx = np.floor((absolute_positions[:, 1] % 1) * binning_factor).astype(int) + np.add.at(subpixel_dist, (subpixel_y_idx, subpixel_x_idx), 1) + + return ml_super_frame, count_frame, subpixel_dist + +def save_results(ml_super_frame, count_frame, subpixel_dist, + roi: list, binning_factor: int, output_dir: Path): + x_st, x_ed, y_st, y_ed = roi + + # 1. super-resolution frame + plt.figure(figsize=(8, 8)) + plt.imshow(ml_super_frame[y_st*binning_factor:y_ed*binning_factor, x_st*binning_factor:x_ed*binning_factor], origin='lower', extent=[x_st, x_ed, y_st, y_ed]) + plt.colorbar(label='Counts') + plt.title('ML Super-Resolution Frame') + plt.xlabel('X (pixel)') + plt.ylabel('Y (pixel)') + plt.savefig(output_dir / '3Photon_ML_superFrame.png', dpi=300, bbox_inches='tight') + plt.clf() + np.save(output_dir / '3Photon_ML_superFrame.npy', ml_super_frame) + + # 2. count frame + plt.imshow(count_frame[y_st:y_ed, x_st:x_ed], origin='lower', extent=[x_st, x_ed, y_st, y_ed]) + plt.colorbar(label='Counts') + plt.title('Photon Count Frame') + plt.xlabel('X (pixel)') + plt.ylabel('Y (pixel)') + plt.savefig(output_dir / '3Photon_count_Frame.png', dpi=300, bbox_inches='tight') + plt.clf() + np.save(output_dir / '3Photon_count_Frame.npy', count_frame) + + # 3. subpixel distribution + plt.imshow(subpixel_dist, origin='lower', extent=[0, 1, 0, 1]) + plt.colorbar(label='Counts') + plt.title('Subpixel Distribution') + plt.xlabel('Subpixel X') + plt.ylabel('Subpixel Y') + plt.savefig(output_dir / '3Photon_subpixel_Distribution.png', dpi=300, bbox_inches='tight') + plt.close() + np.save(output_dir / '3Photon_subpixel_Distribution.npy', subpixel_dist) + std, mean = np.std(subpixel_dist), np.mean(subpixel_dist) + print(f"[Plotting]: Sub-pixel distribution: RMS/Mean: {std/mean:.4f}, expected value = {1/np.sqrt(mean):.4f} for uniform distribution") + + print(f"Results saved to: {output_dir}") + +if __name__ == "__main__": + ### output folder preparation + output_dir = prepare_output_folder(conf) + + ### model loading + model_version = conf.model.experiment_name.split('_v')[-1][:6] + model = get_triple_photon_model_class(model_version)().cuda() + model.load_state_dict(torch.load(f'{conf.model.base_dir}/{conf.model.experiment_name}/Models/{conf.model.name}', weights_only=True)) + # model.eval() + + ### data loading + files_list = get_files_list(conf) + roi = conf.data.names[conf.experiment.name].roi + BinningFactor = conf.inference.binning_factor + numberOfAugOps = conf.inference.num_aug_ops + flag_normalize = conf.data.normalize + nChunks = np.ceil(len(files_list) / 16).astype(int) + + ml_super_frame = np.zeros((NY*BinningFactor, NX*BinningFactor), dtype=np.int32) + count_frame = np.zeros((NY, NX), dtype=np.int32) + subpixel_dist = np.zeros((BinningFactor, BinningFactor), dtype=np.int32) + + for idxChunk in range(nChunks): + start_idx = idxChunk * conf.inference.chunk_size + end_idx = min(start_idx + conf.inference.chunk_size, len(files_list)) + chunk_files = files_list[start_idx:end_idx] + print(f'[Inferring] Chunk {idxChunk+1}/{nChunks}: Loading files {start_idx} to {end_idx}...') + + dataset = triplePhotonInferenceDataset( + chunk_files, + sampleRatio=1.0, + datasetName=f'Inference_Chunk{idxChunk+1}', + # numberOfAugOps=numberOfAugOps, + ) + dataLoader = torch.utils.data.DataLoader( + dataset, + batch_size=8192, + shuffle=False, + num_workers=16, + pin_memory=True, + ) + + predictions, reference_points = run_inference(model, dataLoader, conf) + ml_super_frame_chunk, count_frame_chunk, subpixel_dist_chunk = accumulate_hits(predictions, reference_points, BinningFactor) + ml_super_frame += ml_super_frame_chunk + count_frame += count_frame_chunk + subpixel_dist += subpixel_dist_chunk + save_results(ml_super_frame, count_frame, subpixel_dist, roi, BinningFactor, output_dir) diff --git a/src/datasets.py b/src/datasets.py index 90baf66..9a181ee 100644 --- a/src/datasets.py +++ b/src/datasets.py @@ -272,5 +272,39 @@ class triplePhotonDataset(Dataset): sample = torch.tensor(sample, dtype=torch.float32).unsqueeze(0) label = torch.tensor(label, dtype=torch.float32) return sample, label + def __len__(self): + return self.length + +class triplePhotonInferenceDataset(Dataset): + def __init__(self, sampleList, sampleRatio, datasetName): + self.sampleFileList = sampleList + self.sampleRatio = sampleRatio + self.datasetName = datasetName + all_samples = [] + all_ref_pts = [] + for idx, sampleFile in enumerate(self.sampleFileList): + if '.npz' in sampleFile: + data = np.load(sampleFile) + all_samples.append(data['samples']) + all_ref_pts.append(data['referencePoint']) + elif '.h5' in sampleFile: + import h5py + with h5py.File(sampleFile, 'r') as f: + samples = f['clusters'][:] + ref_pts = f['referencePoint'][:] + all_samples.append(samples) + all_ref_pts.append(ref_pts) + self.samples = np.concatenate(all_samples, axis=0) if all_samples else None + self.referencePoint = np.concatenate(all_ref_pts, axis=0) if all_ref_pts else None + ### total number of samples + self.length = int(self.samples.shape[0] * self.sampleRatio) + self.referencePoint = self.referencePoint[:self.length] + print(f"[{self.datasetName} dataset] \t Total number of samples: {self.length}") + def __getitem__(self, index): + sample = self.samples[index] + # sample[sample == 0] += np.random.normal(loc=0.0, scale=0.13, size=sample[sample == 0].shape) ### add noise to zero pixels + sample = torch.tensor(sample, dtype=torch.float32).unsqueeze(0) + dummy_label = np.zeros((3, 4), dtype=np.float32) ### dummy label for 3 photons + return sample, torch.tensor(dummy_label, dtype=torch.float32) def __len__(self): return self.length \ No newline at end of file