Add inference
This commit is contained in:
83
Inference_SinglePhoton.py
Normal file
83
Inference_SinglePhoton.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import torch
|
||||
import sys
|
||||
sys.path.append('./src')
|
||||
from models import *
|
||||
from datasets import *
|
||||
from tqdm import tqdm
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
configs = {}
|
||||
configs['SiemenStar'] = {
|
||||
'dataFiles': [f'/home/xie_x1/MLXID/DeepLearning/SiemenStarClusters/clusters_chunk{i}.h5' for i in range(100)],
|
||||
'modelVersion': 'singlePhotonNet_251020',
|
||||
'roi': [140, 230, 120, 210], # x_min, x_max, y_min, y_max,
|
||||
'noise': 0.13 # in keV
|
||||
}
|
||||
BinningFactor = 10
|
||||
Roi = configs['SiemenStar']['roi']
|
||||
X_st, X_ed, Y_st, Y_ed = Roi
|
||||
mlSuperFrame = np.zeros(((Y_ed-Y_st)*BinningFactor, (X_ed-X_st)*BinningFactor))
|
||||
countFrame = np.zeros((Y_ed-Y_st, X_ed-X_st))
|
||||
subpixelDistribution = np.zeros((BinningFactor, BinningFactor))
|
||||
|
||||
if __name__ == "__main__":
|
||||
task = 'SiemenStar'
|
||||
config = configs[task]
|
||||
|
||||
if config['modelVersion'] == 'singlePhotonNet_251020':
|
||||
model = singlePhotonNet_251020().cuda()
|
||||
model.load_state_dict(torch.load(f'/home/xie_x1/MLXID/DeepLearning/Models/singlePhotonNet_Noise{config["noise"]}keV_251020.pth', weights_only=True))
|
||||
|
||||
dataset = singlePhotonDataset(config['dataFiles'], sampleRatio=1.0, datasetName='Inference')
|
||||
dataLoader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=4096,
|
||||
shuffle=False,
|
||||
num_workers=16,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
referencePoints = dataset.referencePoint
|
||||
predictions = []
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(dataLoader):
|
||||
inputs, _ = batch
|
||||
inputs = inputs.cuda()
|
||||
outputs = model(inputs)
|
||||
predictions.append(outputs.cpu())
|
||||
|
||||
predictions = torch.cat(predictions, dim=0)
|
||||
print(f'mean x = {torch.mean(predictions[:, 0])}, std x = {torch.std(predictions[:, 0])}')
|
||||
print(f'mean y = {torch.mean(predictions[:, 1])}, std y = {torch.std(predictions[:, 1])}')
|
||||
absolutePositions = predictions.numpy() + referencePoints[:, :2] - 1
|
||||
|
||||
hit_x = np.floor((absolutePositions[:, 0] - Roi[0]) * BinningFactor).astype(int)
|
||||
hit_x = np.clip(hit_x, 0, mlSuperFrame.shape[1]-1)
|
||||
hit_y = np.floor((absolutePositions[:, 1] - Roi[2]) * BinningFactor).astype(int)
|
||||
hit_y = np.clip(hit_y, 0, mlSuperFrame.shape[0]-1)
|
||||
np.add.at(mlSuperFrame, (hit_y, hit_x), 1)
|
||||
|
||||
np.add.at(countFrame, ((referencePoints[:, 1] - Roi[2]).astype(int),
|
||||
(referencePoints[:, 0] - Roi[0]).astype(int)), 1)
|
||||
|
||||
np.add.at(subpixelDistribution,
|
||||
(np.floor((absolutePositions[:, 1] % 1) * BinningFactor).astype(int),
|
||||
np.floor((absolutePositions[:, 0] % 1) * BinningFactor).astype(int)), 1)
|
||||
|
||||
plt.imshow(mlSuperFrame, origin='lower', extent=[Y_st, Y_ed, X_st, X_ed])
|
||||
plt.colorbar()
|
||||
plt.savefig('InferenceResults/SiemenStar_ML_superFrame.png', dpi=300)
|
||||
np.save('InferenceResults/SiemenStar_ML_superFrame.npy', mlSuperFrame)
|
||||
plt.clf()
|
||||
|
||||
plt.imshow(countFrame, origin='lower', extent=[Y_st, Y_ed, X_st, X_ed])
|
||||
plt.colorbar()
|
||||
plt.savefig('InferenceResults/SiemenStar_count_Frame.png', dpi=300)
|
||||
np.save('InferenceResults/SiemenStar_count_Frame.npy', countFrame)
|
||||
|
||||
plt.clf()
|
||||
plt.imshow(subpixelDistribution, origin='lower')
|
||||
plt.colorbar()
|
||||
plt.savefig('InferenceResults/SiemenStar_subpixel_Distribution.png', dpi=300)
|
||||
np.save('InferenceResults/SiemenStar_subpixel_Distribution.npy', subpixelDistribution)
|
||||
Reference in New Issue
Block a user