diff --git a/Configs/infer_2photon.yaml b/Configs/infer_2photon.yaml new file mode 100644 index 0000000..4befca8 --- /dev/null +++ b/Configs/infer_2photon.yaml @@ -0,0 +1,53 @@ +# configs/infer_1photon.yaml +experiment: + name: "FlatField_3filters_pos0" # options: SiemenStarLowerLeft, SiemenStarLowerRight + task: "2Photon" + output_base: "./InferenceResults" + +data: + sample_folder: "/home/xie_x1/MLXID/DataProcess/Samples" + names: + SiemenStarLowerLeft: + file_pattern: "SiemenStarLowerLeft/2Photon_CS7_chunk{}.h5" ### 15 keV + NX: 400 + NY: 400 + file_range: [0, 1] # [0, 160) + roi: [140, 230, 120, 210] # [x_min, x_max, y_min, y_max] + nSize: 7 + SiemenStarLowerRight: + file_pattern: "SiemenStarLowerRight/2Photon_CS7_chunk{}.h5" ### 15 keV + NX: 400 + NY: 400 + file_range: [0, 320] # [0, 320) + roi: [235, 345, 110, 220] + nSize: 7 + KnifeEdge_3filters_pos0: + file_pattern: "2603MaxIV_Edge3Filters_pos0_12keV/2Photon_CS7_chunk{}.h5" ### 12 keV + NX: 101 + NY: 101 + file_range: [0, 16] # [0, 16) + roi: [0, 101, 0, 101] + nSize: 7 + FlatField_3filters_pos0: + file_pattern: "2603MaxIV_FlatField3Filters_pos0_12keV/2Photon_CS7_chunk{}.h5" ### 12 keV + NX: 101 + NY: 101 + file_range: [0, 48] # [0, 48) + roi: [0, 101, 0, 101] + nSize: 7 + + energy: 12 # keV + normalize: false # for inference dataloaders + batch_size: 8192 + num_workers: 16 + +model: + version: "251124" + base_dir: "/home/xie_x1/MLXID/DeepLearning/Results/" + experiment_name: "260506_2ph_12keV_v251124_03" + name: "doublePhoton251124_12keV_Noise0.13keV_E150.pth" + +inference: + binning_factor: 10 + chunk_size: 16 # + num_aug_ops: 1 # augX8 not implemented for 2-photon inference yet diff --git a/Infer_2Photon.py b/Infer_2Photon.py new file mode 100644 index 0000000..695a5de --- /dev/null +++ b/Infer_2Photon.py @@ -0,0 +1,194 @@ +import sys +sys.path.append('./src') +from pathlib import Path +from omegaconf import OmegaConf +import torch +from tqdm import tqdm +from matplotlib import pyplot as plt +import numpy as np +import h5py + +from models import get_double_photon_model_class +from datasets import doublePhotonInferenceDataset + +torch.manual_seed(42) +torch.cuda.manual_seed(42) +np.random.seed(42) +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False + +conf = OmegaConf.load("Configs/infer_2photon.yaml") +NX, NY = conf.data.names[conf.experiment.name].NX, conf.data.names[conf.experiment.name].NY + +def inv0(p): return p +def inv1(p): return torch.stack([-p[..., 0], p[..., 1]], dim=-1) +def inv2(p): return torch.stack([p[..., 0], -p[..., 1]], dim=-1) +def inv3(p): return -p +def inv4(p): return torch.stack([p[..., 1], p[..., 0]], dim=-1) +def inv5(p): return torch.stack([p[..., 1], -p[..., 0]], dim=-1) +def inv6(p): return torch.stack([-p[..., 1], p[..., 0]], dim=-1) +def inv7(p): return torch.stack([-p[..., 1], -p[..., 0]], dim=-1) + +INVERSE_TRANSFORMS = { + 0: inv0, 1: inv1, 2: inv2, 3: inv3, + 4: inv4, 5: inv5, 6: inv6, 7: inv7, +} + +def apply_inverse_transforms(predictions: torch.Tensor, numberOfAugOps: int) -> torch.Tensor: + N = predictions.shape[0] // numberOfAugOps + preds = predictions.view(N, numberOfAugOps, 2) + corrected = torch.zeros_like(preds) + for idx in range(numberOfAugOps): + corrected[:, idx, :] = INVERSE_TRANSFORMS[idx](preds[:, idx, :]) + return corrected.mean(dim=1) + +def prepare_output_folder(conf): + if conf.data.normalize: + normalize_suffix = '_normalized' + else: + normalize_suffix = '' + output_base = Path(conf.experiment.output_base) / conf.experiment.name / conf.model.experiment_name / f'augX{conf.inference.num_aug_ops}{normalize_suffix}' + output_base.mkdir(parents=True, exist_ok=True) + OmegaConf.save(conf, output_base / 'config.yaml') + return output_base + +def get_files_list(conf): + task_conf = conf.data.names[conf.experiment.name] + file_pattern = task_conf.file_pattern + start, end = task_conf.file_range + files = [str(Path(conf.data.sample_folder) / file_pattern.format(i)) for i in range(start, end)] + return files + +def run_inference(model, data_loader, conf): + all_predictions = [] + with torch.no_grad(): + for batch_idx, batch in enumerate(tqdm(data_loader, desc="Inferring")): + inputs, _ = batch + inputs = inputs.cuda() + outputs = model(inputs).view(-1, 2) # 2B x 2 + all_predictions.append(outputs.cpu()) + + all_predictions = torch.cat(all_predictions, dim=0) + # all_predictions = apply_inverse_transforms(all_predictions, conf.inference.num_aug_ops) + all_predictions += torch.tensor([conf.data.names[conf.experiment.name].nSize/2., conf.data.names[conf.experiment.name].nSize/2.]).unsqueeze(0) # adjust back to original coordinate system + print(f'mean x = {torch.mean(all_predictions[:, 0])}, std x = {torch.std(all_predictions[:, 0])}') + print(f'mean y = {torch.mean(all_predictions[:, 1])}, std y = {torch.std(all_predictions[:, 1])}') + referencePoints = data_loader.dataset.referencePoint ### the lower-left corner of the cluster in absolute coordinate + referencePoints = np.repeat(referencePoints, 2, axis=0) ### duplicate reference points for 2-photon clusters + return all_predictions.numpy(), referencePoints + +def accumulate_hits(predictions: np.ndarray, reference_points: np.ndarray, + binning_factor: int): + ### ret + ml_super_frame = np.zeros((NY*binning_factor, NX*binning_factor), dtype=np.int32) + count_frame = np.zeros((NY, NX), dtype=np.int32) + subpixel_dist = np.zeros((binning_factor, binning_factor), dtype=np.int32) + + ### absolute coordinate = predicted subpixel + reference point + absolute_positions = predictions + reference_points + + hit_x_superpixel_idx = np.floor(absolute_positions[:, 0] * binning_factor).astype(int) + hit_x_superpixel_idx = np.clip(hit_x_superpixel_idx, 0, NX*binning_factor-1) + hit_y_superpixel_idx = np.floor(absolute_positions[:, 1] * binning_factor).astype(int) + hit_y_superpixel_idx = np.clip(hit_y_superpixel_idx, 0, NY*binning_factor-1) + np.add.at(ml_super_frame, (hit_y_superpixel_idx, hit_x_superpixel_idx), 1) + + hit_x_pixel_idx = np.floor(absolute_positions[:, 0]).astype(int) + hit_x_pixel_idx = np.clip(hit_x_pixel_idx, 0, NX-1) + hit_y_pixel_idx = np.floor(absolute_positions[:, 1]).astype(int) + hit_y_pixel_idx = np.clip(hit_y_pixel_idx, 0, NY-1) + np.add.at(count_frame, (hit_y_pixel_idx, hit_x_pixel_idx), 1) + + subpixel_x_idx = np.floor((absolute_positions[:, 0] % 1) * binning_factor).astype(int) + subpixel_y_idx = np.floor((absolute_positions[:, 1] % 1) * binning_factor).astype(int) + np.add.at(subpixel_dist, (subpixel_y_idx, subpixel_x_idx), 1) + + return ml_super_frame, count_frame, subpixel_dist + +def save_results(ml_super_frame, count_frame, subpixel_dist, + roi: list, binning_factor: int, output_dir: Path): + x_st, x_ed, y_st, y_ed = roi + + # 1. super-resolution frame + plt.figure(figsize=(8, 8)) + plt.imshow(ml_super_frame[y_st*binning_factor:y_ed*binning_factor, x_st*binning_factor:x_ed*binning_factor], origin='lower', extent=[x_st, x_ed, y_st, y_ed]) + plt.colorbar(label='Counts') + plt.title('ML Super-Resolution Frame') + plt.xlabel('X (pixel)') + plt.ylabel('Y (pixel)') + plt.savefig(output_dir / '2Photon_ML_superFrame.png', dpi=300, bbox_inches='tight') + plt.clf() + np.save(output_dir / '2Photon_ML_superFrame.npy', ml_super_frame) + + # 2. count frame + plt.imshow(count_frame[y_st:y_ed, x_st:x_ed], origin='lower', extent=[x_st, x_ed, y_st, y_ed]) + plt.colorbar(label='Counts') + plt.title('Photon Count Frame') + plt.xlabel('X (pixel)') + plt.ylabel('Y (pixel)') + plt.savefig(output_dir / '2Photon_count_Frame.png', dpi=300, bbox_inches='tight') + plt.clf() + np.save(output_dir / '2Photon_count_Frame.npy', count_frame) + + # 3. subpixel distribution + plt.imshow(subpixel_dist, origin='lower', extent=[0, 1, 0, 1]) + plt.colorbar(label='Counts') + plt.title('Subpixel Distribution') + plt.xlabel('Subpixel X') + plt.ylabel('Subpixel Y') + plt.savefig(output_dir / '2Photon_subpixel_Distribution.png', dpi=300, bbox_inches='tight') + plt.close() + np.save(output_dir / '2Photon_subpixel_Distribution.npy', subpixel_dist) + std, mean = np.std(subpixel_dist), np.mean(subpixel_dist) + print(f"[Plotting]: Sub-pixel distribution: Std/Mean: {std/mean:.4f}") + + print(f"Results saved to: {output_dir}") + +if __name__ == "__main__": + ### output folder preparation + output_dir = prepare_output_folder(conf) + + ### model loading + model = get_double_photon_model_class(conf.model.version)().cuda() + model.load_state_dict(torch.load(f'{conf.model.base_dir}/{conf.model.experiment_name}/Models/{conf.model.name}', weights_only=True)) + # model.eval() + + ### data loading + files_list = get_files_list(conf) + roi = conf.data.names[conf.experiment.name].roi + BinningFactor = conf.inference.binning_factor + numberOfAugOps = conf.inference.num_aug_ops + flag_normalize = conf.data.normalize + nChunks = np.ceil(len(files_list) / 16).astype(int) + + ml_super_frame = np.zeros((NY*BinningFactor, NX*BinningFactor), dtype=np.int32) + count_frame = np.zeros((NY, NX), dtype=np.int32) + subpixel_dist = np.zeros((BinningFactor, BinningFactor), dtype=np.int32) + + for idxChunk in range(nChunks): + start_idx = idxChunk * conf.inference.chunk_size + end_idx = min(start_idx + conf.inference.chunk_size, len(files_list)) + chunk_files = files_list[start_idx:end_idx] + print(f'[Inferring] Chunk {idxChunk+1}/{nChunks}: Loading files {start_idx} to {end_idx}...') + + dataset = doublePhotonInferenceDataset( + chunk_files, + sampleRatio=1.0, + datasetName=f'Inference_Chunk{idxChunk+1}', + # numberOfAugOps=numberOfAugOps, + nSize=conf.data.names[conf.experiment.name].nSize, + ) + dataLoader = torch.utils.data.DataLoader( + dataset, + batch_size=8192, + shuffle=False, + num_workers=16, + pin_memory=True, + ) + + predictions, reference_points = run_inference(model, dataLoader, conf) + ml_super_frame_chunk, count_frame_chunk, subpixel_dist_chunk = accumulate_hits(predictions, reference_points, BinningFactor) + ml_super_frame += ml_super_frame_chunk + count_frame += count_frame_chunk + subpixel_dist += subpixel_dist_chunk + save_results(ml_super_frame, count_frame, subpixel_dist, roi, BinningFactor, output_dir) diff --git a/Inference_2Photon.py b/Inference_2Photon.py deleted file mode 100644 index 5a7dc35..0000000 --- a/Inference_2Photon.py +++ /dev/null @@ -1,114 +0,0 @@ -import torch -import sys -sys.path.append('./src') -import models -from datasets import * -from tqdm import tqdm -from matplotlib import pyplot as plt - -torch.manual_seed(42) -torch.cuda.manual_seed(42) -np.random.seed(42) -torch.backends.cudnn.deterministic = True -torch.backends.cudnn.benchmark = False - -configs = {} -configs['SiemenStarLowerLeft'] = { - 'dataFiles': [f'/home/xie_x1/MLXID/DataProcess/Samples/SiemenStarLowerLeft_old/Clusters_2Photon_CS7_chunk{i}.h5' for i in range(200)], - 'modelVersion': '251124', - 'roi': [140, 230, 120, 210], # x_min, x_max, y_min, y_max, - 'nSize': 7, -} - -configs['SiemenStarLowerRight'] = { - 'dataFiles': [f'/home/xie_x1/MLXID/DataProcess/Samples/SiemenStarLowerRight/2Photon_CS7_chunk{i}.h5' for i in range(320)], ### 320 files - 'modelVersion': '251124', - 'roi': [235, 345, 110, 220], # x_min, x_max, y_min, y_max, - 'nSize': 7, -} - -task = 'SiemenStarLowerRight' -config = configs[task] - -BinningFactor = 10 -Roi = configs[task]['roi'] -X_st, X_ed, Y_st, Y_ed = Roi -mlSuperFrame = np.zeros(((Y_ed-Y_st)*BinningFactor, (X_ed-X_st)*BinningFactor)) -countFrame = np.zeros((Y_ed-Y_st, X_ed-X_st)) -subpixelDistribution = np.zeros((BinningFactor, BinningFactor)) - -if __name__ == "__main__": - model = models.get_double_photon_model_class(config['modelVersion'])().cuda() - model.load_state_dict(torch.load(f'/home/xie_x1/MLXID/DeepLearning/Models/doublePhoton{config["modelVersion"]}_15keV_Noise0.13keV_E300.pth', weights_only=True)) - nChunks = np.ceil(len(config['dataFiles']) / 16).astype(int) - for idxChunk in range(nChunks): - predictions = [] - referencePoints = [] - stFileIdx = idxChunk * 16 - edFileIdx = min((idxChunk + 1) * 16, len(config['dataFiles'])) - sampleFiles = config['dataFiles'][stFileIdx : edFileIdx] - print(f'Processing files {stFileIdx} to {edFileIdx}...') - dataset = doublePhotonInferenceDataset( - sampleFiles, - sampleRatio=1., - datasetName='Inference', - nSize=config['nSize'] - ) - dataLoader = torch.utils.data.DataLoader( - dataset, - batch_size=8192, - shuffle=False, - num_workers=16, - pin_memory=True, - ) - - referencePoints.append(dataset.referencePoint) - - with torch.no_grad(): - for batch in tqdm(dataLoader): - inputs, _ = batch - inputs = inputs.cuda() - outputs = model(inputs).view(-1, 2) # 2B x 2 - predictions.append(outputs.cpu()) - - predictions = torch.cat(predictions, dim=0) - predictions += torch.tensor([config['nSize']/2., config['nSize']/2.]).unsqueeze(0) # adjust back to original coordinate system - print(f'mean x = {torch.mean(predictions[:, 0])}, std x = {torch.std(predictions[:, 0])}') - print(f'mean y = {torch.mean(predictions[:, 1])}, std y = {torch.std(predictions[:, 1])}') - referencePoints = np.concatenate(referencePoints, axis=0) ### the lower-left corner of the cluster in absolute coordinate - ### duplicate reference points for 2-photon clusters - referencePoints = np.repeat(referencePoints, 2, axis=0) - absolutePositions = predictions.numpy() + referencePoints - - hit_x = np.floor((absolutePositions[:, 0] - Roi[0]) * BinningFactor).astype(int) - hit_x = np.clip(hit_x, 0, mlSuperFrame.shape[1]-1) - hit_y = np.floor((absolutePositions[:, 1] - Roi[2]) * BinningFactor).astype(int) - hit_y = np.clip(hit_y, 0, mlSuperFrame.shape[0]-1) - np.add.at(mlSuperFrame, (hit_y, hit_x), 1) - - np.add.at(countFrame, ((referencePoints[:, 1] - Roi[2]).astype(int), - (referencePoints[:, 0] - Roi[0]).astype(int)), 1) - - np.add.at(subpixelDistribution, - (np.floor((absolutePositions[:, 1] % 1) * BinningFactor).astype(int), - np.floor((absolutePositions[:, 0] % 1) * BinningFactor).astype(int)), 1) - - import os - os.makedirs(f'InferenceResults/{task}', exist_ok=True) - - plt.imshow(mlSuperFrame, origin='lower', extent=[Y_st, Y_ed, X_st, X_ed]) - plt.colorbar() - plt.savefig(f'InferenceResults/{task}/ML_2Photon_superFrame.png', dpi=300) - np.save(f'InferenceResults/{task}/ML_2Photon_superFrame.npy', mlSuperFrame) - plt.clf() - - plt.imshow(countFrame, origin='lower', extent=[Y_st, Y_ed, X_st, X_ed]) - plt.colorbar() - plt.savefig(f'InferenceResults/{task}/count_2Photon_Frame.png', dpi=300) - np.save(f'InferenceResults/{task}/count_2Photon_Frame.npy', countFrame) - - plt.clf() - plt.imshow(subpixelDistribution, origin='lower') - plt.colorbar() - plt.savefig(f'InferenceResults/{task}/subpixel_2Photon_Distribution.png', dpi=300) - np.save(f'InferenceResults/{task}/subpixel_2Photon_Distribution.npy', subpixelDistribution) \ No newline at end of file