import torch import sys sys.path.append('./src') import models from datasets import * from tqdm import tqdm from matplotlib import pyplot as plt torch.manual_seed(42) torch.cuda.manual_seed(42) np.random.seed(42) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False configs = {} configs['SiemenStarLowerLeft'] = { 'dataFiles': [f'/home/xie_x1/MLXID/DataProcess/Samples/SiemenStarLowerLeft_old/Clusters_2Photon_CS7_chunk{i}.h5' for i in range(200)], 'modelVersion': '251124', 'roi': [140, 230, 120, 210], # x_min, x_max, y_min, y_max, 'nSize': 7, } configs['SiemenStarLowerRight'] = { 'dataFiles': [f'/home/xie_x1/MLXID/DataProcess/Samples/SiemenStarLowerRight/2Photon_CS7_chunk{i}.h5' for i in range(320)], ### 320 files 'modelVersion': '251124', 'roi': [235, 345, 110, 220], # x_min, x_max, y_min, y_max, 'nSize': 7, } task = 'SiemenStarLowerRight' config = configs[task] BinningFactor = 10 Roi = configs[task]['roi'] X_st, X_ed, Y_st, Y_ed = Roi mlSuperFrame = np.zeros(((Y_ed-Y_st)*BinningFactor, (X_ed-X_st)*BinningFactor)) countFrame = np.zeros((Y_ed-Y_st, X_ed-X_st)) subpixelDistribution = np.zeros((BinningFactor, BinningFactor)) if __name__ == "__main__": model = models.get_double_photon_model_class(config['modelVersion'])().cuda() model.load_state_dict(torch.load(f'/home/xie_x1/MLXID/DeepLearning/Models/doublePhoton{config["modelVersion"]}_15keV_Noise0.13keV_E300.pth', weights_only=True)) nChunks = np.ceil(len(config['dataFiles']) / 16).astype(int) for idxChunk in range(nChunks): predictions = [] referencePoints = [] stFileIdx = idxChunk * 16 edFileIdx = min((idxChunk + 1) * 16, len(config['dataFiles'])) sampleFiles = config['dataFiles'][stFileIdx : edFileIdx] print(f'Processing files {stFileIdx} to {edFileIdx}...') dataset = doublePhotonInferenceDataset( sampleFiles, sampleRatio=1., datasetName='Inference', nSize=config['nSize'] ) dataLoader = torch.utils.data.DataLoader( dataset, batch_size=8192, shuffle=False, num_workers=16, pin_memory=True, ) referencePoints.append(dataset.referencePoint) with torch.no_grad(): for batch in tqdm(dataLoader): inputs, _ = batch inputs = inputs.cuda() outputs = model(inputs).view(-1, 2) # 2B x 2 predictions.append(outputs.cpu()) predictions = torch.cat(predictions, dim=0) predictions += torch.tensor([config['nSize']/2., config['nSize']/2.]).unsqueeze(0) # adjust back to original coordinate system print(f'mean x = {torch.mean(predictions[:, 0])}, std x = {torch.std(predictions[:, 0])}') print(f'mean y = {torch.mean(predictions[:, 1])}, std y = {torch.std(predictions[:, 1])}') referencePoints = np.concatenate(referencePoints, axis=0) ### the lower-left corner of the cluster in absolute coordinate ### duplicate reference points for 2-photon clusters referencePoints = np.repeat(referencePoints, 2, axis=0) absolutePositions = predictions.numpy() + referencePoints hit_x = np.floor((absolutePositions[:, 0] - Roi[0]) * BinningFactor).astype(int) hit_x = np.clip(hit_x, 0, mlSuperFrame.shape[1]-1) hit_y = np.floor((absolutePositions[:, 1] - Roi[2]) * BinningFactor).astype(int) hit_y = np.clip(hit_y, 0, mlSuperFrame.shape[0]-1) np.add.at(mlSuperFrame, (hit_y, hit_x), 1) np.add.at(countFrame, ((referencePoints[:, 1] - Roi[2]).astype(int), (referencePoints[:, 0] - Roi[0]).astype(int)), 1) np.add.at(subpixelDistribution, (np.floor((absolutePositions[:, 1] % 1) * BinningFactor).astype(int), np.floor((absolutePositions[:, 0] % 1) * BinningFactor).astype(int)), 1) import os os.makedirs(f'InferenceResults/{task}', exist_ok=True) plt.imshow(mlSuperFrame, origin='lower', extent=[Y_st, Y_ed, X_st, X_ed]) plt.colorbar() plt.savefig(f'InferenceResults/{task}/ML_2Photon_superFrame.png', dpi=300) np.save(f'InferenceResults/{task}/ML_2Photon_superFrame.npy', mlSuperFrame) plt.clf() plt.imshow(countFrame, origin='lower', extent=[Y_st, Y_ed, X_st, X_ed]) plt.colorbar() plt.savefig(f'InferenceResults/{task}/count_2Photon_Frame.png', dpi=300) np.save(f'InferenceResults/{task}/count_2Photon_Frame.npy', countFrame) plt.clf() plt.imshow(subpixelDistribution, origin='lower') plt.colorbar() plt.savefig(f'InferenceResults/{task}/subpixel_2Photon_Distribution.png', dpi=300) np.save(f'InferenceResults/{task}/subpixel_2Photon_Distribution.npy', subpixelDistribution)