diff --git a/Inference_SinglePhoton.py b/Inference_SinglePhoton.py index 67cdf97..b610d7c 100644 --- a/Inference_SinglePhoton.py +++ b/Inference_SinglePhoton.py @@ -20,18 +20,47 @@ configs['SiemenStar'] = { 'noise': 0.13 # in keV } BinningFactor = 10 +numberOfAugOps = 6 Roi = configs['SiemenStar']['roi'] X_st, X_ed, Y_st, Y_ed = Roi mlSuperFrame = np.zeros(((Y_ed-Y_st)*BinningFactor, (X_ed-X_st)*BinningFactor)) countFrame = np.zeros((Y_ed-Y_st, X_ed-X_st)) subpixelDistribution = np.zeros((BinningFactor, BinningFactor)) +def inv0(p): return p +def inv1(p): return torch.stack([-p[..., 0], p[..., 1]], dim=-1) +def inv2(p): return torch.stack([p[..., 0], -p[..., 1]], dim=-1) +def inv3(p): return -p +def inv4(p): return torch.stack([p[..., 1], p[..., 0]], dim=-1) +def inv5(p): return torch.stack([p[..., 1], -p[..., 0]], dim=-1) +def inv6(p): return torch.stack([-p[..., 1], p[..., 0]], dim=-1) +def inv7(p): return torch.stack([-p[..., 1], -p[..., 0]], dim=-1) + +INVERSE_TRANSFORMS = { + 0: inv0, + 1: inv1, + 2: inv2, + 3: inv3, + 4: inv4, + 5: inv5, + 6: inv6, + 7: inv7, +} + +def apply_inverse_transforms(predictions: torch.Tensor, numberOfAugOps: int) -> torch.Tensor: + N = predictions.shape[0] // numberOfAugOps + preds = predictions.view(N, numberOfAugOps, 2) + corrected = torch.zeros_like(preds) + for idx in range(numberOfAugOps): + corrected[:, idx, :] = INVERSE_TRANSFORMS[idx](preds[:, idx, :]) + return corrected.mean(dim=1) + if __name__ == "__main__": task = 'SiemenStar' config = configs[task] model = models.get_model_class(config['modelVersion'])().cuda() - model.load_state_dict(torch.load(f'/home/xie_x1/MLXID/DeepLearning/Models/singlePhoton{config["modelVersion"]}_15.3keV_Noise{config["noise"]}keV_E500.pth', weights_only=True)) + model.load_state_dict(torch.load(f'/home/xie_x1/MLXID/DeepLearning/Models/singlePhoton{config["modelVersion"]}_15.3keV_Noise{config["noise"]}keV_E500_aug8.pth', weights_only=True)) predictions = [] referencePoints = [] nChunks = len(config['dataFiles']) // 32 + 1 @@ -40,12 +69,17 @@ if __name__ == "__main__": edFileIdx = min((idxChunk + 1) * 32, len(config['dataFiles'])) sampleFiles = config['dataFiles'][stFileIdx : edFileIdx] print(f'Processing files {stFileIdx} to {edFileIdx}...') - dataset = singlePhotonDataset(sampleFiles, sampleRatio=1.0, datasetName='Inference') + dataset = singlePhotonDataset( + sampleFiles, + sampleRatio=1.0, + datasetName='Inference', + numberOfAugOps=numberOfAugOps + ) dataLoader = torch.utils.data.DataLoader( dataset, batch_size=8192, shuffle=False, - num_workers=16, + num_workers=32, pin_memory=True, ) @@ -59,6 +93,7 @@ if __name__ == "__main__": predictions.append(outputs.cpu()) predictions = torch.cat(predictions, dim=0) + predictions = apply_inverse_transforms(predictions, numberOfAugOps) predictions += torch.tensor([1.5, 1.5]).unsqueeze(0) # adjust back to original coordinate system referencePoints = np.concatenate(referencePoints, axis=0) print(f'mean x = {torch.mean(predictions[:, 0])}, std x = {torch.std(predictions[:, 0])}')