import sys sys.path.append('./src') from pathlib import Path from omegaconf import OmegaConf import torch from tqdm import tqdm from matplotlib import pyplot as plt import numpy as np import h5py from models import get_model_class from datasets import singlePhotonDataset torch.manual_seed(0) torch.cuda.manual_seed(0) np.random.seed(0) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False conf = OmegaConf.load("Configs/infer_1photon.yaml") NX, NY = conf.data.names[conf.experiment.name].NX, conf.data.names[conf.experiment.name].NY def inv0(p): return p def inv1(p): return torch.stack([-p[..., 0], p[..., 1]], dim=-1) def inv2(p): return torch.stack([p[..., 0], -p[..., 1]], dim=-1) def inv3(p): return -p def inv4(p): return torch.stack([p[..., 1], p[..., 0]], dim=-1) def inv5(p): return torch.stack([p[..., 1], -p[..., 0]], dim=-1) def inv6(p): return torch.stack([-p[..., 1], p[..., 0]], dim=-1) def inv7(p): return torch.stack([-p[..., 1], -p[..., 0]], dim=-1) INVERSE_TRANSFORMS = { 0: inv0, 1: inv1, 2: inv2, 3: inv3, 4: inv4, 5: inv5, 6: inv6, 7: inv7, } def apply_inverse_transforms(predictions: torch.Tensor, numberOfAugOps: int) -> torch.Tensor: N = predictions.shape[0] // numberOfAugOps preds = predictions.view(N, numberOfAugOps, 2) corrected = torch.zeros_like(preds) for idx in range(numberOfAugOps): corrected[:, idx, :] = INVERSE_TRANSFORMS[idx](preds[:, idx, :]) return corrected.mean(dim=1) def prepare_output_folder(conf): if conf.data.normalize: normalize_suffix = '_normalized' else: normalize_suffix = '' output_base = Path(conf.experiment.output_base) / conf.experiment.name / conf.model.experiment_name / f'augX{conf.inference.num_aug_ops}{normalize_suffix}' output_base.mkdir(parents=True, exist_ok=True) OmegaConf.save(conf, output_base / 'config.yaml') return output_base def get_files_list(conf): task_conf = conf.data.names[conf.experiment.name] file_pattern = task_conf.file_pattern start, end = task_conf.file_range files = [str(Path(conf.data.sample_folder) / file_pattern.format(i)) for i in range(start, end)] return files def run_inference(model, data_loader, conf): model.eval() all_predictions = [] all_reference_points = data_loader.dataset.referencePoint with torch.no_grad(): for batch in tqdm(data_loader, desc="Inferring"): inputs, _ = batch inputs = inputs.to('cuda') outputs = model(inputs)[:, :2].cpu() # 只取 x, y all_predictions.append(outputs) all_predictions = torch.cat(all_predictions, dim=0) all_predictions = apply_inverse_transforms(all_predictions, conf.inference.num_aug_ops) offset = [inputs.shape[-2] / 2., inputs.shape[-1] / 2.] offset = torch.tensor(offset).unsqueeze(0) # (1, 2) all_predictions += offset print(f'mean x = {torch.mean(all_predictions[:, 0]):.4f}, std x = {torch.std(all_predictions[:, 0]):.4f}') print(f'mean y = {torch.mean(all_predictions[:, 1]):.4f}, std y = {torch.std(all_predictions[:, 1]):.4f}') return all_predictions.numpy(), all_reference_points def accumulate_hits(predictions: np.ndarray, reference_points: np.ndarray, binning_factor: int): ml_super_frame = np.zeros((NY * binning_factor, NX * binning_factor)) count_frame = np.zeros((NY, NX)) subpixel_dist = np.zeros((binning_factor, binning_factor)) # 绝对坐标 = 预测亚像素 + 参考点(像素左下角) absolute_positions = predictions + reference_points[:, :2] # 超分辨帧 (binning) hit_x = np.floor(absolute_positions[:, 0] * binning_factor).astype(int) hit_y = np.floor(absolute_positions[:, 1] * binning_factor).astype(int) np.add.at(ml_super_frame, (hit_y, hit_x), 1) # 计数帧 (按参考点像素索引) ref_x = (reference_points[:, 0] + 1).astype(int) # 参考点是左下角,+1 得像素索引 ref_y = (reference_points[:, 1] + 1).astype(int) np.add.at(count_frame, (ref_y, ref_x), 1) # 亚像素分布 sub_x = np.floor((absolute_positions[:, 0] % 1) * binning_factor).astype(int) sub_y = np.floor((absolute_positions[:, 1] % 1) * binning_factor).astype(int) np.add.at(subpixel_dist, (sub_y, sub_x), 1) return ml_super_frame, count_frame, subpixel_dist def save_results(ml_super_frame, count_frame, subpixel_dist, roi: list, binning_factor: int, output_dir: Path): x_min, x_max, y_min, y_max = roi # 1. super-resolution frame plt.figure(figsize=(8, 8)) plt.imshow(ml_super_frame[y_min*binning_factor:y_max*binning_factor, x_min*binning_factor:x_max*binning_factor], origin='lower', extent=[x_min, x_max, y_min, y_max]) plt.colorbar(label='Counts') plt.title('ML Super-Resolution Frame') plt.xlabel('X (pixel)') plt.ylabel('Y (pixel)') plt.savefig(output_dir / '1Photon_ML_superFrame.png', dpi=300, bbox_inches='tight') plt.close() np.save(output_dir / '1Photon_ML_superFrame.npy', ml_super_frame) # 2. count frame plt.figure(figsize=(8, 8)) plt.imshow(count_frame[y_min:y_max, x_min:x_max], origin='lower', extent=[x_min, x_max, y_min, y_max]) plt.colorbar(label='Counts') plt.title('Photon Count Frame') plt.xlabel('X (pixel)') plt.ylabel('Y (pixel)') plt.savefig(output_dir / '1Photon_count_Frame.png', dpi=300, bbox_inches='tight') plt.close() np.save(output_dir / '1Photon_count_Frame.npy', count_frame) # 3. sub-pixel distribution plt.figure(figsize=(8, 8)) plt.imshow(subpixel_dist, origin='lower', extent=[0, binning_factor, 0, binning_factor], cmap='viridis') plt.colorbar(label='Counts') plt.title('Sub-pixel Distribution') plt.xlabel(f'Sub-pixel X (1/{binning_factor} pixel)') plt.ylabel(f'Sub-pixel Y (1/{binning_factor} pixel)') plt.savefig(output_dir / '1Photon_subpixel_Distribution.png', dpi=300, bbox_inches='tight') plt.close() np.save(output_dir / '1Photon_subpixel_Distribution.npy', subpixel_dist) rms, mean = np.std(subpixel_dist), np.mean(subpixel_dist) print(f"[Plotting]: Sub-pixel distribution: RMS/Mean: {rms/mean:.4f}") print(f"Results saved to: {output_dir}") if __name__ == "__main__": ### output preparation output_dir = prepare_output_folder(conf) ### model loading model = get_model_class(conf.model.version)().cuda() model.load_state_dict(torch.load(f'{conf.model.base_dir}/{conf.model.experiment_name}/Models/{conf.model.name}', weights_only=True)) model.eval() ### data loading files_list = get_files_list(conf) roi = conf.data.names[conf.experiment.name].roi BinningFactor = conf.inference.binning_factor numberOfAugOps = conf.inference.num_aug_ops flag_normalize = conf.data.normalize nChunks = int(np.ceil(len(files_list) / conf.inference.chunk_size)) ml_super_frame = np.zeros((NY * BinningFactor, NX * BinningFactor)) count_frame = np.zeros((NY, NX)) subpixel_dist = np.zeros((BinningFactor, BinningFactor)) ### Inference loop for idxChunk in range(nChunks): start_idx = idxChunk * conf.inference.chunk_size end_idx = min(start_idx + conf.inference.chunk_size, len(files_list)) chunk_files = files_list[start_idx:end_idx] print(f'[Inferring] Chunk {idxChunk+1}/{nChunks}: Loading files {start_idx} to {end_idx}...') dataset = singlePhotonDataset( chunk_files, sampleRatio=1.0, datasetName='Inference', numberOfAugOps=conf.inference.num_aug_ops, normalize=conf.data.normalize ) data_loader = torch.utils.data.DataLoader( dataset, batch_size=conf.data.batch_size, shuffle=False, num_workers=conf.data.num_workers, pin_memory=True ) predictions, ref_points = run_inference(model, data_loader, conf) chunk_super, chunk_count, chunk_subpixel = accumulate_hits( predictions, ref_points, binning_factor=BinningFactor ) ml_super_frame += chunk_super count_frame += chunk_count subpixel_dist += chunk_subpixel save_results(ml_super_frame, count_frame, subpixel_dist, roi=roi, binning_factor=BinningFactor, output_dir=output_dir)