Add 2photon inference

This commit is contained in:
2025-11-06 10:54:09 +01:00
parent c4a86e32e3
commit bee7a30a88

104
Inference_2Photon.py Normal file
View File

@@ -0,0 +1,104 @@
import torch
import sys
sys.path.append('./src')
import models
from datasets import *
from tqdm import tqdm
from matplotlib import pyplot as plt
torch.manual_seed(42)
torch.cuda.manual_seed(42)
np.random.seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
configs = {}
configs['SiemenStar'] = {
'dataFiles': [f'/home/xie_x1/MLXID/DataProcess/Samples/Clusters_2Photon_chunk{i}.h5' for i in range(16)],
'modelVersion': '251001_2',
'roi': [140, 230, 120, 210], # x_min, x_max, y_min, y_max,
'noise': 0.13 # in keV
}
BinningFactor = 10
Roi = configs['SiemenStar']['roi']
X_st, X_ed, Y_st, Y_ed = Roi
mlSuperFrame = np.zeros(((Y_ed-Y_st)*BinningFactor, (X_ed-X_st)*BinningFactor))
countFrame = np.zeros((Y_ed-Y_st, X_ed-X_st))
subpixelDistribution = np.zeros((BinningFactor, BinningFactor))
if __name__ == "__main__":
task = 'SiemenStar'
config = configs[task]
model = models.get_double_photon_model_class(config['modelVersion'])().cuda()
model.load_state_dict(torch.load(f'/home/xie_x1/MLXID/DeepLearning/Models/doublePhoton{config["modelVersion"]}_15.3keV_Noise0.13keV_E300.pth', weights_only=True))
predictions = []
referencePoints = []
nChunks = len(config['dataFiles']) // 16
for idxChunk in range(nChunks):
stFileIdx = idxChunk * 16
edFileIdx = min((idxChunk + 1) * 16, len(config['dataFiles']))
sampleFiles = config['dataFiles'][stFileIdx : edFileIdx]
print(f'Processing files {stFileIdx} to {edFileIdx}...')
dataset = doublePhotonInferenceDataset(
sampleFiles,
sampleRatio=1,
datasetName='Inference',
# noiseKeV=0.13
)
dataLoader = torch.utils.data.DataLoader(
dataset,
batch_size=8192,
shuffle=False,
num_workers=32,
pin_memory=True,
)
referencePoints.append(dataset.referencePoint)
with torch.no_grad():
for batch in tqdm(dataLoader):
inputs, _ = batch
inputs = inputs.cuda()
outputs = model(inputs).view(-1, 2) # 2B x 2
predictions.append(outputs.cpu())
predictions = torch.cat(predictions, dim=0)
print(f'mean x = {torch.mean(predictions[:, 0])}, std x = {torch.std(predictions[:, 0])}')
print(f'mean y = {torch.mean(predictions[:, 1])}, std y = {torch.std(predictions[:, 1])}')
referencePoints = np.concatenate(referencePoints, axis=0)
### duplicate reference points for 2-photon clusters
referencePoints = np.repeat(referencePoints, 2, axis=0)
absolutePositions = predictions.numpy() + referencePoints
print(absolutePositions[:5, 0] - Roi[0], absolutePositions[:5, 1] - Roi[2])
hit_x = np.floor((absolutePositions[:, 0] - Roi[0]) * BinningFactor).astype(int)
hit_x = np.clip(hit_x, 0, mlSuperFrame.shape[1]-1)
hit_y = np.floor((absolutePositions[:, 1] - Roi[2]) * BinningFactor).astype(int)
hit_y = np.clip(hit_y, 0, mlSuperFrame.shape[0]-1)
print(hit_x[:5], hit_y[:5])
np.add.at(mlSuperFrame, (hit_y, hit_x), 1)
np.add.at(countFrame, ((referencePoints[:, 1] - Roi[2]).astype(int),
(referencePoints[:, 0] - Roi[0]).astype(int)), 1)
np.add.at(subpixelDistribution,
(np.floor((absolutePositions[:, 1] % 1) * BinningFactor).astype(int),
np.floor((absolutePositions[:, 0] % 1) * BinningFactor).astype(int)), 1)
plt.imshow(mlSuperFrame, origin='lower', extent=[Y_st, Y_ed, X_st, X_ed])
plt.colorbar()
plt.savefig('InferenceResults/SiemenStar_ML_2Photon_superFrame.png', dpi=300)
np.save('InferenceResults/SiemenStar_ML_2Photon_superFrame.npy', mlSuperFrame)
plt.clf()
plt.imshow(countFrame, origin='lower', extent=[Y_st, Y_ed, X_st, X_ed])
plt.colorbar()
plt.savefig('InferenceResults/SiemenStar_count_2Photon_Frame.png', dpi=300)
np.save('InferenceResults/SiemenStar_count_2Photon_Frame.npy', countFrame)
plt.clf()
plt.imshow(subpixelDistribution, origin='lower')
plt.colorbar()
plt.savefig('InferenceResults/SiemenStar_subpixel_2Photon_Distribution.png', dpi=300)
np.save('InferenceResults/SiemenStar_subpixel_2Photon_Distribution.npy', subpixelDistribution)