diff --git a/Infer_1Photon.py b/Infer_1Photon.py index d175934..16b5f1a 100644 --- a/Infer_1Photon.py +++ b/Infer_1Photon.py @@ -1,50 +1,24 @@ -import torch import sys sys.path.append('./src') -import models -from datasets import * +from pathlib import Path +from omegaconf import OmegaConf +import torch from tqdm import tqdm from matplotlib import pyplot as plt import numpy as np import h5py +from models import get_model_class +from datasets import singlePhotonDataset + torch.manual_seed(0) torch.cuda.manual_seed(0) np.random.seed(0) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False -NX, NY = 400, 400 - -configs = {} -configs['SiemenStarLowerLeft'] = { - # 'dataFiles': [f'/mnt/sls_det_storage/moench_data/MLXID/Samples/Measurement/2504_SOLEIL_SiemenStarClusters_MOENCH040_150V/SiemenStarLowerLeft/clusters_chunk{i}.h5' for i in range(1)], # 200 files, no zeroing pixels outside the cluster - 'dataFiles': [f'/home/xie_x1/MLXID/DataProcess/Samples/SiemenStarLowerLeft/1Photon_CS3_chunk{i}.h5' for i in range(160)], # 160 files, no zeroing pixels outside the cluster - 'modelVersion': '251022', - 'energy': 15, # keV - 'roi': [140, 230, 120, 210], # x_min, x_max, y_min, y_max, - 'noise': 0.13, # keV; for the model selection -} - -configs['SiemenStarLowerRight'] = { - 'dataFiles': [f'/home/xie_x1/MLXID/DataProcess/Samples/SiemenStarLowerRight/1Photon_CS3_chunk{i}.h5' for i in range(320)], # 320 files. !!! zeroed pixels outside the cluster, to be fixed - 'modelVersion': '251022', - 'energy': 15, # keV - 'roi': [235, 345, 110, 220], # x_min, x_max, y_min, y_max, - 'noise': 0.13, # keV -} - -task = 'SiemenStarLowerLeft' -config = configs[task] - -BinningFactor = 10 -numberOfAugOps = 8 -flag_normalize = False -Roi = config['roi'] -X_st, X_ed, Y_st, Y_ed = Roi -mlSuperFrame = np.zeros((NY*BinningFactor, NX*BinningFactor)) -countFrame = np.zeros((NY, NX)) -subpixelDistribution = np.zeros((BinningFactor, BinningFactor)) +conf = OmegaConf.load("Configs/infer_1photon.yaml") +NX, NY = conf.data.names[conf.experiment.name].NX, conf.data.names[conf.experiment.name].NY def inv0(p): return p def inv1(p): return torch.stack([-p[..., 0], p[..., 1]], dim=-1) @@ -56,14 +30,8 @@ def inv6(p): return torch.stack([-p[..., 1], p[..., 0]], dim=-1) def inv7(p): return torch.stack([-p[..., 1], -p[..., 0]], dim=-1) INVERSE_TRANSFORMS = { - 0: inv0, - 1: inv1, - 2: inv2, - 3: inv3, - 4: inv4, - 5: inv5, - 6: inv6, - 7: inv7, + 0: inv0, 1: inv1, 2: inv2, 3: inv3, + 4: inv4, 5: inv5, 6: inv6, 7: inv7, } def apply_inverse_transforms(predictions: torch.Tensor, numberOfAugOps: int) -> torch.Tensor: @@ -74,89 +42,164 @@ def apply_inverse_transforms(predictions: torch.Tensor, numberOfAugOps: int) -> corrected[:, idx, :] = INVERSE_TRANSFORMS[idx](preds[:, idx, :]) return corrected.mean(dim=1) +def prepare_output_folder(conf): + if conf.data.normalize: + normalize_suffix = '_normalized' + else: + normalize_suffix = '' + output_base = Path(conf.experiment.output_base) / conf.experiment.name / conf.model.experiment_name / f'augX{conf.inference.num_aug_ops}{normalize_suffix}' + output_base.mkdir(parents=True, exist_ok=True) + OmegaConf.save(conf, output_base / 'config.yaml') + return output_base + +def get_files_list(conf): + task_conf = conf.data.names[conf.experiment.name] + file_pattern = task_conf.file_pattern + start, end = task_conf.file_range + files = [str(Path(conf.data.sample_folder) / file_pattern.format(i)) for i in range(start, end)] + return files + +def run_inference(model, data_loader, conf): + model.eval() + all_predictions = [] + all_reference_points = data_loader.dataset.referencePoint + with torch.no_grad(): + for batch in tqdm(data_loader, desc="Inferring"): + inputs, _ = batch + inputs = inputs.to('cuda') + outputs = model(inputs)[:, :2].cpu() # 只取 x, y + all_predictions.append(outputs) + + all_predictions = torch.cat(all_predictions, dim=0) + all_predictions = apply_inverse_transforms(all_predictions, conf.inference.num_aug_ops) + offset = [inputs.shape[-2] / 2., inputs.shape[-1] / 2.] + offset = torch.tensor(offset).unsqueeze(0) # (1, 2) + all_predictions += offset + print(f'mean x = {torch.mean(all_predictions[:, 0]):.4f}, std x = {torch.std(all_predictions[:, 0]):.4f}') + print(f'mean y = {torch.mean(all_predictions[:, 1]):.4f}, std y = {torch.std(all_predictions[:, 1]):.4f}') + return all_predictions.numpy(), all_reference_points + +def accumulate_hits(predictions: np.ndarray, reference_points: np.ndarray, + binning_factor: int): + ml_super_frame = np.zeros((NY * binning_factor, NX * binning_factor)) + count_frame = np.zeros((NY, NX)) + subpixel_dist = np.zeros((binning_factor, binning_factor)) + + # 绝对坐标 = 预测亚像素 + 参考点(像素左下角) + absolute_positions = predictions + reference_points[:, :2] + + # 超分辨帧 (binning) + hit_x = np.floor(absolute_positions[:, 0] * binning_factor).astype(int) + hit_y = np.floor(absolute_positions[:, 1] * binning_factor).astype(int) + np.add.at(ml_super_frame, (hit_y, hit_x), 1) + + # 计数帧 (按参考点像素索引) + ref_x = (reference_points[:, 0] + 1).astype(int) # 参考点是左下角,+1 得像素索引 + ref_y = (reference_points[:, 1] + 1).astype(int) + np.add.at(count_frame, (ref_y, ref_x), 1) + + # 亚像素分布 + sub_x = np.floor((absolute_positions[:, 0] % 1) * binning_factor).astype(int) + sub_y = np.floor((absolute_positions[:, 1] % 1) * binning_factor).astype(int) + np.add.at(subpixel_dist, (sub_y, sub_x), 1) + + return ml_super_frame, count_frame, subpixel_dist + +def save_results(ml_super_frame, count_frame, subpixel_dist, + roi: list, binning_factor: int, output_dir: Path): + x_min, x_max, y_min, y_max = roi + + # 1. super-resolution frame + plt.figure(figsize=(8, 8)) + plt.imshow(ml_super_frame[y_min*binning_factor:y_max*binning_factor, x_min*binning_factor:x_max*binning_factor], origin='lower', extent=[x_min, x_max, y_min, y_max]) + plt.colorbar(label='Counts') + plt.title('ML Super-Resolution Frame') + plt.xlabel('X (pixel)') + plt.ylabel('Y (pixel)') + plt.savefig(output_dir / '1Photon_ML_superFrame.png', dpi=300, bbox_inches='tight') + plt.close() + np.save(output_dir / '1Photon_ML_superFrame.npy', ml_super_frame) + + # 2. count frame + plt.figure(figsize=(8, 8)) + plt.imshow(count_frame[y_min:y_max, x_min:x_max], origin='lower', extent=[x_min, x_max, y_min, y_max]) + plt.colorbar(label='Counts') + plt.title('Photon Count Frame') + plt.xlabel('X (pixel)') + plt.ylabel('Y (pixel)') + plt.savefig(output_dir / '1Photon_count_Frame.png', dpi=300, bbox_inches='tight') + plt.close() + np.save(output_dir / '1Photon_count_Frame.npy', count_frame) + + # 3. sub-pixel distribution + plt.figure(figsize=(8, 8)) + plt.imshow(subpixel_dist, origin='lower', + extent=[0, binning_factor, 0, binning_factor], + cmap='viridis') + plt.colorbar(label='Counts') + plt.title('Sub-pixel Distribution') + plt.xlabel(f'Sub-pixel X (1/{binning_factor} pixel)') + plt.ylabel(f'Sub-pixel Y (1/{binning_factor} pixel)') + plt.savefig(output_dir / '1Photon_subpixel_Distribution.png', dpi=300, bbox_inches='tight') + plt.close() + np.save(output_dir / '1Photon_subpixel_Distribution.npy', subpixel_dist) + rms, mean = np.std(subpixel_dist), np.mean(subpixel_dist) + print(f"[Plotting]: Sub-pixel distribution: RMS/Mean: {rms/mean:.4f}") + + print(f"Results saved to: {output_dir}") + if __name__ == "__main__": - model = models.get_model_class(config['modelVersion'])().cuda() - modelName = f'singlePhoton{config["modelVersion"]}_{config["energy"]}keV_Noise{config["noise"]}keV_E150_aug1' - if flag_normalize: - modelName += '_normalized' - model.load_state_dict(torch.load(f'/home/xie_x1/MLXID/DeepLearning/Models/{modelName}.pth', weights_only=True)) - nChunks = np.ceil(len(config['dataFiles']) / 16).astype(int) + ### output preparation + output_dir = prepare_output_folder(conf) + + ### model loading + model = get_model_class(conf.model.version)().cuda() + model.load_state_dict(torch.load(f'{conf.model.base_dir}/{conf.model.experiment_name}/Models/{conf.model.name}', weights_only=True)) + model.eval() + + ### data loading + files_list = get_files_list(conf) + roi = conf.data.names[conf.experiment.name].roi + BinningFactor = conf.inference.binning_factor + numberOfAugOps = conf.inference.num_aug_ops + flag_normalize = conf.data.normalize + nChunks = int(np.ceil(len(files_list) / conf.inference.chunk_size)) + + ml_super_frame = np.zeros((NY * BinningFactor, NX * BinningFactor)) + count_frame = np.zeros((NY, NX)) + subpixel_dist = np.zeros((BinningFactor, BinningFactor)) + + ### Inference loop for idxChunk in range(nChunks): - predictions = [] - referencePoints = [] - stFileIdx = idxChunk * 16 - edFileIdx = min((idxChunk + 1) * 16, len(config['dataFiles'])) - sampleFiles = config['dataFiles'][stFileIdx : edFileIdx] - print(f'Processing files {stFileIdx} to {edFileIdx}...') + start_idx = idxChunk * conf.inference.chunk_size + end_idx = min(start_idx + conf.inference.chunk_size, len(files_list)) + chunk_files = files_list[start_idx:end_idx] + print(f'[Inferring] Chunk {idxChunk+1}/{nChunks}: Loading files {start_idx} to {end_idx}...') + dataset = singlePhotonDataset( - sampleFiles, - sampleRatio=1, + chunk_files, + sampleRatio=1.0, datasetName='Inference', - numberOfAugOps=numberOfAugOps, - normalize=flag_normalize - ) - dataLoader = torch.utils.data.DataLoader( - dataset, - batch_size=8192, - shuffle=False, - num_workers=16, - pin_memory=True, + numberOfAugOps=conf.inference.num_aug_ops, + normalize=conf.data.normalize ) - - referencePoints.append(dataset.referencePoint) - - _chunk_predictions = [] - with torch.no_grad(): - for batch in tqdm(dataLoader): - inputs, _ = batch - inputs_cuda = inputs.cuda() - outputs = model(inputs_cuda)[:, :2].cpu() # only x and y - _chunk_predictions.append(outputs) - predictions.extend(_chunk_predictions) - - predictions = torch.cat(predictions, dim=0) - predictions = apply_inverse_transforms(predictions, numberOfAugOps) - predictions += torch.tensor([1.5, 1.5]).unsqueeze(0) # adjust back to original coordinate system - referencePoints = np.concatenate(referencePoints, axis=0) - print(f'mean x = {torch.mean(predictions[:, 0])}, std x = {torch.std(predictions[:, 0])}') - print(f'mean y = {torch.mean(predictions[:, 1])}, std y = {torch.std(predictions[:, 1])}') - absolutePositions = predictions.numpy() + referencePoints[:, :2] - hit_x = np.floor(absolutePositions[:, 0] * BinningFactor).astype(int) - hit_y = np.floor(absolutePositions[:, 1] * BinningFactor).astype(int) - np.add.at(mlSuperFrame, (hit_y, hit_x), 1) + data_loader = torch.utils.data.DataLoader( + dataset, + batch_size=conf.data.batch_size, + shuffle=False, + num_workers=conf.data.num_workers, + pin_memory=True + ) + + predictions, ref_points = run_inference(model, data_loader, conf) - np.add.at(countFrame, ((referencePoints[:, 1] + 1).astype(int), - (referencePoints[:, 0] + 1).astype(int)), 1) ### the reference points refer to the lower-left corner of the pixel, so add 1 to get the pixel index - - np.add.at(subpixelDistribution, - (np.floor((absolutePositions[:, 1] % 1) * BinningFactor).astype(int), - np.floor((absolutePositions[:, 0] % 1) * BinningFactor).astype(int)), 1) - - ### Save results and plots - import os - outputDir = f'InferenceResults/{task}/{config["modelVersion"]}_{config["energy"]}keV_Noise{config["noise"]}keV_augX{numberOfAugOps}' - if flag_normalize: - outputDir += '_normalized' - os.makedirs(outputDir, exist_ok=True) - - plt.clf() - mlSuperFrame = mlSuperFrame[config['roi'][2]*BinningFactor : config['roi'][3]*BinningFactor, config['roi'][0]*BinningFactor : config['roi'][1]*BinningFactor] - average = np.mean(mlSuperFrame) - plt.imshow(mlSuperFrame, origin='lower', extent=[X_st, X_ed, Y_st, Y_ed]) - plt.colorbar() - plt.savefig(f'{outputDir}/1Photon_ML_superFrame.png', dpi=300) - np.save(f'{outputDir}/1Photon_ML_superFrame.npy', mlSuperFrame) - - plt.clf() - countFrame = countFrame[config['roi'][2] : config['roi'][3], config['roi'][0] : config['roi'][1]] - plt.imshow(countFrame, origin='lower', extent=[X_st, X_ed, Y_st, Y_ed]) - plt.colorbar() - plt.savefig(f'{outputDir}/1Photon_count_Frame.png', dpi=300) - np.save(f'{outputDir}/1Photon_count_Frame.npy', countFrame) - - plt.clf() - plt.imshow(subpixelDistribution, origin='lower', extent=[0, BinningFactor, 0, BinningFactor]) - plt.colorbar() - plt.savefig(f'{outputDir}/1Photon_subpixel_Distribution.png', dpi=300) - np.save(f'{outputDir}/1Photon_subpixel_Distribution.npy', subpixelDistribution) \ No newline at end of file + chunk_super, chunk_count, chunk_subpixel = accumulate_hits( + predictions, ref_points, binning_factor=BinningFactor + ) + ml_super_frame += chunk_super + count_frame += chunk_count + subpixel_dist += chunk_subpixel + save_results(ml_super_frame, count_frame, subpixel_dist, + roi=roi, binning_factor=BinningFactor, + output_dir=output_dir) \ No newline at end of file