From bee7a30a8860b51555f218418600fa20d3d4b7c4 Mon Sep 17 00:00:00 2001 From: "xiangyu.xie" Date: Thu, 6 Nov 2025 10:54:09 +0100 Subject: [PATCH] Add 2photon inference --- Inference_2Photon.py | 104 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 104 insertions(+) create mode 100644 Inference_2Photon.py diff --git a/Inference_2Photon.py b/Inference_2Photon.py new file mode 100644 index 0000000..b6a780a --- /dev/null +++ b/Inference_2Photon.py @@ -0,0 +1,104 @@ +import torch +import sys +sys.path.append('./src') +import models +from datasets import * +from tqdm import tqdm +from matplotlib import pyplot as plt + +torch.manual_seed(42) +torch.cuda.manual_seed(42) +np.random.seed(42) +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False + +configs = {} +configs['SiemenStar'] = { + 'dataFiles': [f'/home/xie_x1/MLXID/DataProcess/Samples/Clusters_2Photon_chunk{i}.h5' for i in range(16)], + 'modelVersion': '251001_2', + 'roi': [140, 230, 120, 210], # x_min, x_max, y_min, y_max, + 'noise': 0.13 # in keV +} +BinningFactor = 10 +Roi = configs['SiemenStar']['roi'] +X_st, X_ed, Y_st, Y_ed = Roi +mlSuperFrame = np.zeros(((Y_ed-Y_st)*BinningFactor, (X_ed-X_st)*BinningFactor)) +countFrame = np.zeros((Y_ed-Y_st, X_ed-X_st)) +subpixelDistribution = np.zeros((BinningFactor, BinningFactor)) + +if __name__ == "__main__": + task = 'SiemenStar' + config = configs[task] + + model = models.get_double_photon_model_class(config['modelVersion'])().cuda() + model.load_state_dict(torch.load(f'/home/xie_x1/MLXID/DeepLearning/Models/doublePhoton{config["modelVersion"]}_15.3keV_Noise0.13keV_E300.pth', weights_only=True)) + predictions = [] + referencePoints = [] + nChunks = len(config['dataFiles']) // 16 + for idxChunk in range(nChunks): + stFileIdx = idxChunk * 16 + edFileIdx = min((idxChunk + 1) * 16, len(config['dataFiles'])) + sampleFiles = config['dataFiles'][stFileIdx : edFileIdx] + print(f'Processing files {stFileIdx} to {edFileIdx}...') + dataset = doublePhotonInferenceDataset( + sampleFiles, + sampleRatio=1, + datasetName='Inference', + # noiseKeV=0.13 + ) + dataLoader = torch.utils.data.DataLoader( + dataset, + batch_size=8192, + shuffle=False, + num_workers=32, + pin_memory=True, + ) + + referencePoints.append(dataset.referencePoint) + + with torch.no_grad(): + for batch in tqdm(dataLoader): + inputs, _ = batch + inputs = inputs.cuda() + outputs = model(inputs).view(-1, 2) # 2B x 2 + predictions.append(outputs.cpu()) + + predictions = torch.cat(predictions, dim=0) + print(f'mean x = {torch.mean(predictions[:, 0])}, std x = {torch.std(predictions[:, 0])}') + print(f'mean y = {torch.mean(predictions[:, 1])}, std y = {torch.std(predictions[:, 1])}') + referencePoints = np.concatenate(referencePoints, axis=0) + ### duplicate reference points for 2-photon clusters + referencePoints = np.repeat(referencePoints, 2, axis=0) + absolutePositions = predictions.numpy() + referencePoints + print(absolutePositions[:5, 0] - Roi[0], absolutePositions[:5, 1] - Roi[2]) + + hit_x = np.floor((absolutePositions[:, 0] - Roi[0]) * BinningFactor).astype(int) + hit_x = np.clip(hit_x, 0, mlSuperFrame.shape[1]-1) + hit_y = np.floor((absolutePositions[:, 1] - Roi[2]) * BinningFactor).astype(int) + hit_y = np.clip(hit_y, 0, mlSuperFrame.shape[0]-1) + print(hit_x[:5], hit_y[:5]) + np.add.at(mlSuperFrame, (hit_y, hit_x), 1) + + np.add.at(countFrame, ((referencePoints[:, 1] - Roi[2]).astype(int), + (referencePoints[:, 0] - Roi[0]).astype(int)), 1) + + np.add.at(subpixelDistribution, + (np.floor((absolutePositions[:, 1] % 1) * BinningFactor).astype(int), + np.floor((absolutePositions[:, 0] % 1) * BinningFactor).astype(int)), 1) + + plt.imshow(mlSuperFrame, origin='lower', extent=[Y_st, Y_ed, X_st, X_ed]) + plt.colorbar() + plt.savefig('InferenceResults/SiemenStar_ML_2Photon_superFrame.png', dpi=300) + np.save('InferenceResults/SiemenStar_ML_2Photon_superFrame.npy', mlSuperFrame) + plt.clf() + + plt.imshow(countFrame, origin='lower', extent=[Y_st, Y_ed, X_st, X_ed]) + plt.colorbar() + plt.savefig('InferenceResults/SiemenStar_count_2Photon_Frame.png', dpi=300) + np.save('InferenceResults/SiemenStar_count_2Photon_Frame.npy', countFrame) + + plt.clf() + plt.imshow(subpixelDistribution, origin='lower') + plt.colorbar() + plt.savefig('InferenceResults/SiemenStar_subpixel_2Photon_Distribution.png', dpi=300) + np.save('InferenceResults/SiemenStar_subpixel_2Photon_Distribution.npy', subpixelDistribution) \ No newline at end of file