import sys sys.path.append('./src') from pathlib import Path from omegaconf import OmegaConf import torch from tqdm import tqdm from matplotlib import pyplot as plt import numpy as np import h5py from models import get_triple_photon_model_class from datasets import triplePhotonInferenceDataset torch.manual_seed(42) torch.cuda.manual_seed(42) np.random.seed(42) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False conf = OmegaConf.load("Configs/infer_3photon.yaml") NX, NY = conf.data.names[conf.experiment.name].NX, conf.data.names[conf.experiment.name].NY NSIZE = 9 def inv0(p): return p def inv1(p): return torch.stack([-p[..., 0], p[..., 1]], dim=-1) def inv2(p): return torch.stack([p[..., 0], -p[..., 1]], dim=-1) def inv3(p): return -p def inv4(p): return torch.stack([p[..., 1], p[..., 0]], dim=-1) def inv5(p): return torch.stack([p[..., 1], -p[..., 0]], dim=-1) def inv6(p): return torch.stack([-p[..., 1], p[..., 0]], dim=-1) def inv7(p): return torch.stack([-p[..., 1], -p[..., 0]], dim=-1) INVERSE_TRANSFORMS = { 0: inv0, 1: inv1, 2: inv2, 3: inv3, 4: inv4, 5: inv5, 6: inv6, 7: inv7, } def apply_inverse_transforms(predictions: torch.Tensor, numberOfAugOps: int) -> torch.Tensor: N = predictions.shape[0] // numberOfAugOps preds = predictions.view(N, numberOfAugOps, 2) corrected = torch.zeros_like(preds) for idx in range(numberOfAugOps): corrected[:, idx, :] = INVERSE_TRANSFORMS[idx](preds[:, idx, :]) return corrected.mean(dim=1) def prepare_output_folder(conf): if conf.data.normalize: normalize_suffix = '_normalized' else: normalize_suffix = '' output_base = Path(conf.experiment.output_base) / conf.experiment.name / conf.model.experiment_name / f'augX{conf.inference.num_aug_ops}{normalize_suffix}' output_base.mkdir(parents=True, exist_ok=True) OmegaConf.save(conf, output_base / 'config.yaml') return output_base def get_files_list(conf): task_conf = conf.data.names[conf.experiment.name] file_pattern = task_conf.file_pattern start, end = task_conf.file_range files = [str(Path(conf.data.sample_folder) / file_pattern.format(i)) for i in range(start, end)] return files def run_inference(model, data_loader, conf): all_predictions = [] with torch.no_grad(): for batch_idx, batch in enumerate(tqdm(data_loader, desc="Inferring")): inputs, _ = batch inputs = inputs.cuda() outputs = model(inputs).view(-1, 2) # 3B x 2 all_predictions.append(outputs.cpu()) all_predictions = torch.cat(all_predictions, dim=0) # all_predictions = apply_inverse_transforms(all_predictions, conf.inference.num_aug_ops) all_predictions += torch.tensor([NSIZE/2., NSIZE/2.]).unsqueeze(0) # adjust back to original coordinate system print(f'mean x = {torch.mean(all_predictions[:, 0])}, std x = {torch.std(all_predictions[:, 0])}') print(f'mean y = {torch.mean(all_predictions[:, 1])}, std y = {torch.std(all_predictions[:, 1])}') referencePoints = data_loader.dataset.referencePoint ### the lower-left corner of the cluster in absolute coordinate referencePoints = np.repeat(referencePoints, 3, axis=0) ### duplicate reference points for 3-photon clusters return all_predictions.numpy(), referencePoints def accumulate_hits(predictions: np.ndarray, reference_points: np.ndarray, binning_factor: int): ### ret ml_super_frame = np.zeros((NY*binning_factor, NX*binning_factor), dtype=np.int32) count_frame = np.zeros((NY, NX), dtype=np.int32) subpixel_dist = np.zeros((binning_factor, binning_factor), dtype=np.int32) ### absolute coordinate = predicted subpixel + reference point absolute_positions = predictions + reference_points hit_x_superpixel_idx = np.floor(absolute_positions[:, 0] * binning_factor).astype(int) hit_x_superpixel_idx = np.clip(hit_x_superpixel_idx, 0, NX*binning_factor-1) hit_y_superpixel_idx = np.floor(absolute_positions[:, 1] * binning_factor).astype(int) hit_y_superpixel_idx = np.clip(hit_y_superpixel_idx, 0, NY*binning_factor-1) np.add.at(ml_super_frame, (hit_y_superpixel_idx, hit_x_superpixel_idx), 1) hit_x_pixel_idx = np.floor(absolute_positions[:, 0]).astype(int) hit_x_pixel_idx = np.clip(hit_x_pixel_idx, 0, NX-1) hit_y_pixel_idx = np.floor(absolute_positions[:, 1]).astype(int) hit_y_pixel_idx = np.clip(hit_y_pixel_idx, 0, NY-1) np.add.at(count_frame, (hit_y_pixel_idx, hit_x_pixel_idx), 1) subpixel_x_idx = np.floor((absolute_positions[:, 0] % 1) * binning_factor).astype(int) subpixel_y_idx = np.floor((absolute_positions[:, 1] % 1) * binning_factor).astype(int) np.add.at(subpixel_dist, (subpixel_y_idx, subpixel_x_idx), 1) return ml_super_frame, count_frame, subpixel_dist def save_results(ml_super_frame, count_frame, subpixel_dist, roi: list, binning_factor: int, output_dir: Path): x_st, x_ed, y_st, y_ed = roi # 1. super-resolution frame plt.figure(figsize=(8, 8)) plt.imshow(ml_super_frame[y_st*binning_factor:y_ed*binning_factor, x_st*binning_factor:x_ed*binning_factor], origin='lower', extent=[x_st, x_ed, y_st, y_ed]) plt.colorbar(label='Counts') plt.title('ML Super-Resolution Frame') plt.xlabel('X (pixel)') plt.ylabel('Y (pixel)') plt.savefig(output_dir / '3Photon_ML_superFrame.png', dpi=300, bbox_inches='tight') plt.clf() np.save(output_dir / '3Photon_ML_superFrame.npy', ml_super_frame) # 2. count frame plt.imshow(count_frame[y_st:y_ed, x_st:x_ed], origin='lower', extent=[x_st, x_ed, y_st, y_ed]) plt.colorbar(label='Counts') plt.title('Photon Count Frame') plt.xlabel('X (pixel)') plt.ylabel('Y (pixel)') plt.savefig(output_dir / '3Photon_count_Frame.png', dpi=300, bbox_inches='tight') plt.clf() np.save(output_dir / '3Photon_count_Frame.npy', count_frame) # 3. subpixel distribution plt.imshow(subpixel_dist, origin='lower', extent=[0, 1, 0, 1]) plt.colorbar(label='Counts') plt.title('Subpixel Distribution') plt.xlabel('Subpixel X') plt.ylabel('Subpixel Y') plt.savefig(output_dir / '3Photon_subpixel_Distribution.png', dpi=300, bbox_inches='tight') plt.close() np.save(output_dir / '3Photon_subpixel_Distribution.npy', subpixel_dist) std, mean = np.std(subpixel_dist), np.mean(subpixel_dist) print(f"[Plotting]: Sub-pixel distribution: RMS/Mean: {std/mean:.4f}, expected value = {1/np.sqrt(mean):.4f} for uniform distribution") print(f"Results saved to: {output_dir}") if __name__ == "__main__": ### output folder preparation output_dir = prepare_output_folder(conf) ### model loading model_version = conf.model.experiment_name.split('_v')[-1][:6] model = get_triple_photon_model_class(model_version)().cuda() model.load_state_dict(torch.load(f'{conf.model.base_dir}/{conf.model.experiment_name}/Models/{conf.model.name}', weights_only=True)) # model.eval() ### data loading files_list = get_files_list(conf) roi = conf.data.names[conf.experiment.name].roi BinningFactor = conf.inference.binning_factor numberOfAugOps = conf.inference.num_aug_ops flag_normalize = conf.data.normalize nChunks = np.ceil(len(files_list) / 16).astype(int) ml_super_frame = np.zeros((NY*BinningFactor, NX*BinningFactor), dtype=np.int32) count_frame = np.zeros((NY, NX), dtype=np.int32) subpixel_dist = np.zeros((BinningFactor, BinningFactor), dtype=np.int32) for idxChunk in range(nChunks): start_idx = idxChunk * conf.inference.chunk_size end_idx = min(start_idx + conf.inference.chunk_size, len(files_list)) chunk_files = files_list[start_idx:end_idx] print(f'[Inferring] Chunk {idxChunk+1}/{nChunks}: Loading files {start_idx} to {end_idx}...') dataset = triplePhotonInferenceDataset( chunk_files, sampleRatio=1.0, datasetName=f'Inference_Chunk{idxChunk+1}', # numberOfAugOps=numberOfAugOps, ) dataLoader = torch.utils.data.DataLoader( dataset, batch_size=8192, shuffle=False, num_workers=16, pin_memory=True, ) predictions, reference_points = run_inference(model, dataLoader, conf) ml_super_frame_chunk, count_frame_chunk, subpixel_dist_chunk = accumulate_hits(predictions, reference_points, BinningFactor) ml_super_frame += ml_super_frame_chunk count_frame += count_frame_chunk subpixel_dist += subpixel_dist_chunk save_results(ml_super_frame, count_frame, subpixel_dist, roi, BinningFactor, output_dir)