import torch import sys sys.path.append('./src') import models from datasets import * from tqdm import tqdm from matplotlib import pyplot as plt import numpy as np import h5py torch.manual_seed(42) torch.cuda.manual_seed(42) np.random.seed(42) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False NX, NY = 400, 400 configs = {} configs['SiemenStarLowerLeft'] = { # 'dataFiles': [f'/mnt/sls_det_storage/moench_data/MLXID/Samples/Measurement/2504_SOLEIL_SiemenStarClusters_MOENCH040_150V/SiemenStarLowerLeft/clusters_chunk{i}.h5' for i in range(1)], # 200 files, no zeroing pixels outside the cluster 'dataFiles': [f'/home/xie_x1/MLXID/DataProcess/Samples/SiemenStarLowerLeft/1Photon_CS3_chunk{i}.h5' for i in range(1)], # 160 files, no zeroing pixels outside the cluster 'modelVersion': '251022', 'energy': 15, # keV 'roi': [140, 230, 120, 210], # x_min, x_max, y_min, y_max, 'noise': 0.13, # keV; for the model selection } configs['SiemenStarLowerRight'] = { 'dataFiles': [f'/home/xie_x1/MLXID/DataProcess/Samples/SiemenStarLowerRight/1Photon_CS3_chunk{i}.h5' for i in range(320)], # 320 files. !!! zeroed pixels outside the cluster, to be fixed 'modelVersion': '251022', 'energy': 15, # keV 'roi': [235, 345, 110, 220], # x_min, x_max, y_min, y_max, 'noise': 0.13, # keV } task = 'SiemenStarLowerLeft' config = configs[task] BinningFactor = 10 numberOfAugOps = 8 flag_normalize = False Roi = config['roi'] X_st, X_ed, Y_st, Y_ed = Roi mlSuperFrame = np.zeros((NY*BinningFactor, NX*BinningFactor)) countFrame = np.zeros((NY, NX)) subpixelDistribution = np.zeros((BinningFactor, BinningFactor)) def inv0(p): return p def inv1(p): return torch.stack([-p[..., 0], p[..., 1]], dim=-1) def inv2(p): return torch.stack([p[..., 0], -p[..., 1]], dim=-1) def inv3(p): return -p def inv4(p): return torch.stack([p[..., 1], p[..., 0]], dim=-1) def inv5(p): return torch.stack([p[..., 1], -p[..., 0]], dim=-1) def inv6(p): return torch.stack([-p[..., 1], p[..., 0]], dim=-1) def inv7(p): return torch.stack([-p[..., 1], -p[..., 0]], dim=-1) INVERSE_TRANSFORMS = { 0: inv0, 1: inv1, 2: inv2, 3: inv3, 4: inv4, 5: inv5, 6: inv6, 7: inv7, } def apply_inverse_transforms(predictions: torch.Tensor, numberOfAugOps: int) -> torch.Tensor: N = predictions.shape[0] // numberOfAugOps preds = predictions.view(N, numberOfAugOps, 2) corrected = torch.zeros_like(preds) for idx in range(numberOfAugOps): corrected[:, idx, :] = INVERSE_TRANSFORMS[idx](preds[:, idx, :]) return corrected.mean(dim=1) if __name__ == "__main__": model = models.get_model_class(config['modelVersion'])().cuda() modelName = f'singlePhoton{config["modelVersion"]}_{config["energy"]}keV_Noise{config["noise"]}keV_E500_aug1' if flag_normalize: modelName += '_normalized' model.load_state_dict(torch.load(f'/home/xie_x1/MLXID/DeepLearning/Models/{modelName}.pth', weights_only=True)) nChunks = np.ceil(len(config['dataFiles']) / 16).astype(int) for idxChunk in range(nChunks): predictions = [] referencePoints = [] stFileIdx = idxChunk * 16 edFileIdx = min((idxChunk + 1) * 16, len(config['dataFiles'])) sampleFiles = config['dataFiles'][stFileIdx : edFileIdx] print(f'Processing files {stFileIdx} to {edFileIdx}...') dataset = singlePhotonDataset( sampleFiles, sampleRatio=1, datasetName='Inference', numberOfAugOps=numberOfAugOps, normalize=flag_normalize ) dataLoader = torch.utils.data.DataLoader( dataset, batch_size=8192, shuffle=False, num_workers=16, pin_memory=True, ) referencePoints.append(dataset.referencePoint) _chunk_predictions = [] with torch.no_grad(): for batch in tqdm(dataLoader): inputs, _ = batch inputs_cuda = inputs.cuda() outputs = model(inputs_cuda)[:, :2].cpu() # only x and y _chunk_predictions.append(outputs) predictions.extend(_chunk_predictions) predictions = torch.cat(predictions, dim=0) predictions = apply_inverse_transforms(predictions, numberOfAugOps) predictions += torch.tensor([1.5, 1.5]).unsqueeze(0) # adjust back to original coordinate system referencePoints = np.concatenate(referencePoints, axis=0) print(f'mean x = {torch.mean(predictions[:, 0])}, std x = {torch.std(predictions[:, 0])}') print(f'mean y = {torch.mean(predictions[:, 1])}, std y = {torch.std(predictions[:, 1])}') absolutePositions = predictions.numpy() + referencePoints[:, :2] hit_x = np.floor(absolutePositions[:, 0] * BinningFactor).astype(int) hit_y = np.floor(absolutePositions[:, 1] * BinningFactor).astype(int) np.add.at(mlSuperFrame, (hit_y, hit_x), 1) np.add.at(countFrame, ((referencePoints[:, 1] + 1).astype(int), (referencePoints[:, 0] + 1).astype(int)), 1) ### the reference points refer to the lower-left corner of the pixel, so add 1 to get the pixel index np.add.at(subpixelDistribution, (np.floor((absolutePositions[:, 1] % 1) * BinningFactor).astype(int), np.floor((absolutePositions[:, 0] % 1) * BinningFactor).astype(int)), 1) ### Save results and plots import os outputDir = f'InferenceResults/{task}/{config["modelVersion"]}_{config["energy"]}keV_Noise{config["noise"]}keV_augX{numberOfAugOps}' if flag_normalize: outputDir += '_normalized' os.makedirs(outputDir, exist_ok=True) plt.clf() mlSuperFrame = mlSuperFrame[config['roi'][2]*BinningFactor : config['roi'][3]*BinningFactor, config['roi'][0]*BinningFactor : config['roi'][1]*BinningFactor] average = np.mean(mlSuperFrame) plt.imshow(mlSuperFrame, origin='lower', extent=[Y_st, Y_ed, X_st, X_ed]) plt.colorbar() plt.savefig(f'{outputDir}/1Photon_ML_superFrame.png', dpi=300) np.save(f'{outputDir}/1Photon_ML_superFrame.npy', mlSuperFrame) plt.clf() countFrame = countFrame[config['roi'][2] : config['roi'][3], config['roi'][0] : config['roi'][1]] plt.imshow(countFrame, origin='lower', extent=[Y_st, Y_ed, X_st, X_ed]) plt.colorbar() plt.savefig(f'{outputDir}/1Photon_count_Frame.png', dpi=300) np.save(f'{outputDir}/1Photon_count_Frame.npy', countFrame) plt.clf() plt.imshow(subpixelDistribution, origin='lower', extent=[0, BinningFactor, 0, BinningFactor]) plt.colorbar() plt.savefig(f'{outputDir}/1Photon_subpixel_Distribution.png', dpi=300) np.save(f'{outputDir}/1Photon_subpixel_Distribution.npy', subpixelDistribution)