Reframe infer codes

This commit is contained in:
2026-03-20 16:23:39 +01:00
parent c1e2117729
commit ba52f860e6
+164 -121
View File
@@ -1,50 +1,24 @@
import torch
import sys
sys.path.append('./src')
import models
from datasets import *
from pathlib import Path
from omegaconf import OmegaConf
import torch
from tqdm import tqdm
from matplotlib import pyplot as plt
import numpy as np
import h5py
from models import get_model_class
from datasets import singlePhotonDataset
torch.manual_seed(0)
torch.cuda.manual_seed(0)
np.random.seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
NX, NY = 400, 400
configs = {}
configs['SiemenStarLowerLeft'] = {
# 'dataFiles': [f'/mnt/sls_det_storage/moench_data/MLXID/Samples/Measurement/2504_SOLEIL_SiemenStarClusters_MOENCH040_150V/SiemenStarLowerLeft/clusters_chunk{i}.h5' for i in range(1)], # 200 files, no zeroing pixels outside the cluster
'dataFiles': [f'/home/xie_x1/MLXID/DataProcess/Samples/SiemenStarLowerLeft/1Photon_CS3_chunk{i}.h5' for i in range(160)], # 160 files, no zeroing pixels outside the cluster
'modelVersion': '251022',
'energy': 15, # keV
'roi': [140, 230, 120, 210], # x_min, x_max, y_min, y_max,
'noise': 0.13, # keV; for the model selection
}
configs['SiemenStarLowerRight'] = {
'dataFiles': [f'/home/xie_x1/MLXID/DataProcess/Samples/SiemenStarLowerRight/1Photon_CS3_chunk{i}.h5' for i in range(320)], # 320 files. !!! zeroed pixels outside the cluster, to be fixed
'modelVersion': '251022',
'energy': 15, # keV
'roi': [235, 345, 110, 220], # x_min, x_max, y_min, y_max,
'noise': 0.13, # keV
}
task = 'SiemenStarLowerLeft'
config = configs[task]
BinningFactor = 10
numberOfAugOps = 8
flag_normalize = False
Roi = config['roi']
X_st, X_ed, Y_st, Y_ed = Roi
mlSuperFrame = np.zeros((NY*BinningFactor, NX*BinningFactor))
countFrame = np.zeros((NY, NX))
subpixelDistribution = np.zeros((BinningFactor, BinningFactor))
conf = OmegaConf.load("Configs/infer_1photon.yaml")
NX, NY = conf.data.names[conf.experiment.name].NX, conf.data.names[conf.experiment.name].NY
def inv0(p): return p
def inv1(p): return torch.stack([-p[..., 0], p[..., 1]], dim=-1)
@@ -56,14 +30,8 @@ def inv6(p): return torch.stack([-p[..., 1], p[..., 0]], dim=-1)
def inv7(p): return torch.stack([-p[..., 1], -p[..., 0]], dim=-1)
INVERSE_TRANSFORMS = {
0: inv0,
1: inv1,
2: inv2,
3: inv3,
4: inv4,
5: inv5,
6: inv6,
7: inv7,
0: inv0, 1: inv1, 2: inv2, 3: inv3,
4: inv4, 5: inv5, 6: inv6, 7: inv7,
}
def apply_inverse_transforms(predictions: torch.Tensor, numberOfAugOps: int) -> torch.Tensor:
@@ -74,89 +42,164 @@ def apply_inverse_transforms(predictions: torch.Tensor, numberOfAugOps: int) ->
corrected[:, idx, :] = INVERSE_TRANSFORMS[idx](preds[:, idx, :])
return corrected.mean(dim=1)
def prepare_output_folder(conf):
if conf.data.normalize:
normalize_suffix = '_normalized'
else:
normalize_suffix = ''
output_base = Path(conf.experiment.output_base) / conf.experiment.name / conf.model.experiment_name / f'augX{conf.inference.num_aug_ops}{normalize_suffix}'
output_base.mkdir(parents=True, exist_ok=True)
OmegaConf.save(conf, output_base / 'config.yaml')
return output_base
def get_files_list(conf):
task_conf = conf.data.names[conf.experiment.name]
file_pattern = task_conf.file_pattern
start, end = task_conf.file_range
files = [str(Path(conf.data.sample_folder) / file_pattern.format(i)) for i in range(start, end)]
return files
def run_inference(model, data_loader, conf):
model.eval()
all_predictions = []
all_reference_points = data_loader.dataset.referencePoint
with torch.no_grad():
for batch in tqdm(data_loader, desc="Inferring"):
inputs, _ = batch
inputs = inputs.to('cuda')
outputs = model(inputs)[:, :2].cpu() # 只取 x, y
all_predictions.append(outputs)
all_predictions = torch.cat(all_predictions, dim=0)
all_predictions = apply_inverse_transforms(all_predictions, conf.inference.num_aug_ops)
offset = [inputs.shape[-2] / 2., inputs.shape[-1] / 2.]
offset = torch.tensor(offset).unsqueeze(0) # (1, 2)
all_predictions += offset
print(f'mean x = {torch.mean(all_predictions[:, 0]):.4f}, std x = {torch.std(all_predictions[:, 0]):.4f}')
print(f'mean y = {torch.mean(all_predictions[:, 1]):.4f}, std y = {torch.std(all_predictions[:, 1]):.4f}')
return all_predictions.numpy(), all_reference_points
def accumulate_hits(predictions: np.ndarray, reference_points: np.ndarray,
binning_factor: int):
ml_super_frame = np.zeros((NY * binning_factor, NX * binning_factor))
count_frame = np.zeros((NY, NX))
subpixel_dist = np.zeros((binning_factor, binning_factor))
# 绝对坐标 = 预测亚像素 + 参考点(像素左下角)
absolute_positions = predictions + reference_points[:, :2]
# 超分辨帧 (binning)
hit_x = np.floor(absolute_positions[:, 0] * binning_factor).astype(int)
hit_y = np.floor(absolute_positions[:, 1] * binning_factor).astype(int)
np.add.at(ml_super_frame, (hit_y, hit_x), 1)
# 计数帧 (按参考点像素索引)
ref_x = (reference_points[:, 0] + 1).astype(int) # 参考点是左下角,+1 得像素索引
ref_y = (reference_points[:, 1] + 1).astype(int)
np.add.at(count_frame, (ref_y, ref_x), 1)
# 亚像素分布
sub_x = np.floor((absolute_positions[:, 0] % 1) * binning_factor).astype(int)
sub_y = np.floor((absolute_positions[:, 1] % 1) * binning_factor).astype(int)
np.add.at(subpixel_dist, (sub_y, sub_x), 1)
return ml_super_frame, count_frame, subpixel_dist
def save_results(ml_super_frame, count_frame, subpixel_dist,
roi: list, binning_factor: int, output_dir: Path):
x_min, x_max, y_min, y_max = roi
# 1. super-resolution frame
plt.figure(figsize=(8, 8))
plt.imshow(ml_super_frame[y_min*binning_factor:y_max*binning_factor, x_min*binning_factor:x_max*binning_factor], origin='lower', extent=[x_min, x_max, y_min, y_max])
plt.colorbar(label='Counts')
plt.title('ML Super-Resolution Frame')
plt.xlabel('X (pixel)')
plt.ylabel('Y (pixel)')
plt.savefig(output_dir / '1Photon_ML_superFrame.png', dpi=300, bbox_inches='tight')
plt.close()
np.save(output_dir / '1Photon_ML_superFrame.npy', ml_super_frame)
# 2. count frame
plt.figure(figsize=(8, 8))
plt.imshow(count_frame[y_min:y_max, x_min:x_max], origin='lower', extent=[x_min, x_max, y_min, y_max])
plt.colorbar(label='Counts')
plt.title('Photon Count Frame')
plt.xlabel('X (pixel)')
plt.ylabel('Y (pixel)')
plt.savefig(output_dir / '1Photon_count_Frame.png', dpi=300, bbox_inches='tight')
plt.close()
np.save(output_dir / '1Photon_count_Frame.npy', count_frame)
# 3. sub-pixel distribution
plt.figure(figsize=(8, 8))
plt.imshow(subpixel_dist, origin='lower',
extent=[0, binning_factor, 0, binning_factor],
cmap='viridis')
plt.colorbar(label='Counts')
plt.title('Sub-pixel Distribution')
plt.xlabel(f'Sub-pixel X (1/{binning_factor} pixel)')
plt.ylabel(f'Sub-pixel Y (1/{binning_factor} pixel)')
plt.savefig(output_dir / '1Photon_subpixel_Distribution.png', dpi=300, bbox_inches='tight')
plt.close()
np.save(output_dir / '1Photon_subpixel_Distribution.npy', subpixel_dist)
rms, mean = np.std(subpixel_dist), np.mean(subpixel_dist)
print(f"[Plotting]: Sub-pixel distribution: RMS/Mean: {rms/mean:.4f}")
print(f"Results saved to: {output_dir}")
if __name__ == "__main__":
model = models.get_model_class(config['modelVersion'])().cuda()
modelName = f'singlePhoton{config["modelVersion"]}_{config["energy"]}keV_Noise{config["noise"]}keV_E150_aug1'
if flag_normalize:
modelName += '_normalized'
model.load_state_dict(torch.load(f'/home/xie_x1/MLXID/DeepLearning/Models/{modelName}.pth', weights_only=True))
nChunks = np.ceil(len(config['dataFiles']) / 16).astype(int)
### output preparation
output_dir = prepare_output_folder(conf)
### model loading
model = get_model_class(conf.model.version)().cuda()
model.load_state_dict(torch.load(f'{conf.model.base_dir}/{conf.model.experiment_name}/Models/{conf.model.name}', weights_only=True))
model.eval()
### data loading
files_list = get_files_list(conf)
roi = conf.data.names[conf.experiment.name].roi
BinningFactor = conf.inference.binning_factor
numberOfAugOps = conf.inference.num_aug_ops
flag_normalize = conf.data.normalize
nChunks = int(np.ceil(len(files_list) / conf.inference.chunk_size))
ml_super_frame = np.zeros((NY * BinningFactor, NX * BinningFactor))
count_frame = np.zeros((NY, NX))
subpixel_dist = np.zeros((BinningFactor, BinningFactor))
### Inference loop
for idxChunk in range(nChunks):
predictions = []
referencePoints = []
stFileIdx = idxChunk * 16
edFileIdx = min((idxChunk + 1) * 16, len(config['dataFiles']))
sampleFiles = config['dataFiles'][stFileIdx : edFileIdx]
print(f'Processing files {stFileIdx} to {edFileIdx}...')
start_idx = idxChunk * conf.inference.chunk_size
end_idx = min(start_idx + conf.inference.chunk_size, len(files_list))
chunk_files = files_list[start_idx:end_idx]
print(f'[Inferring] Chunk {idxChunk+1}/{nChunks}: Loading files {start_idx} to {end_idx}...')
dataset = singlePhotonDataset(
sampleFiles,
sampleRatio=1,
chunk_files,
sampleRatio=1.0,
datasetName='Inference',
numberOfAugOps=numberOfAugOps,
normalize=flag_normalize
)
dataLoader = torch.utils.data.DataLoader(
dataset,
batch_size=8192,
shuffle=False,
num_workers=16,
pin_memory=True,
numberOfAugOps=conf.inference.num_aug_ops,
normalize=conf.data.normalize
)
referencePoints.append(dataset.referencePoint)
_chunk_predictions = []
with torch.no_grad():
for batch in tqdm(dataLoader):
inputs, _ = batch
inputs_cuda = inputs.cuda()
outputs = model(inputs_cuda)[:, :2].cpu() # only x and y
_chunk_predictions.append(outputs)
predictions.extend(_chunk_predictions)
predictions = torch.cat(predictions, dim=0)
predictions = apply_inverse_transforms(predictions, numberOfAugOps)
predictions += torch.tensor([1.5, 1.5]).unsqueeze(0) # adjust back to original coordinate system
referencePoints = np.concatenate(referencePoints, axis=0)
print(f'mean x = {torch.mean(predictions[:, 0])}, std x = {torch.std(predictions[:, 0])}')
print(f'mean y = {torch.mean(predictions[:, 1])}, std y = {torch.std(predictions[:, 1])}')
absolutePositions = predictions.numpy() + referencePoints[:, :2]
hit_x = np.floor(absolutePositions[:, 0] * BinningFactor).astype(int)
hit_y = np.floor(absolutePositions[:, 1] * BinningFactor).astype(int)
np.add.at(mlSuperFrame, (hit_y, hit_x), 1)
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=conf.data.batch_size,
shuffle=False,
num_workers=conf.data.num_workers,
pin_memory=True
)
predictions, ref_points = run_inference(model, data_loader, conf)
np.add.at(countFrame, ((referencePoints[:, 1] + 1).astype(int),
(referencePoints[:, 0] + 1).astype(int)), 1) ### the reference points refer to the lower-left corner of the pixel, so add 1 to get the pixel index
np.add.at(subpixelDistribution,
(np.floor((absolutePositions[:, 1] % 1) * BinningFactor).astype(int),
np.floor((absolutePositions[:, 0] % 1) * BinningFactor).astype(int)), 1)
### Save results and plots
import os
outputDir = f'InferenceResults/{task}/{config["modelVersion"]}_{config["energy"]}keV_Noise{config["noise"]}keV_augX{numberOfAugOps}'
if flag_normalize:
outputDir += '_normalized'
os.makedirs(outputDir, exist_ok=True)
plt.clf()
mlSuperFrame = mlSuperFrame[config['roi'][2]*BinningFactor : config['roi'][3]*BinningFactor, config['roi'][0]*BinningFactor : config['roi'][1]*BinningFactor]
average = np.mean(mlSuperFrame)
plt.imshow(mlSuperFrame, origin='lower', extent=[X_st, X_ed, Y_st, Y_ed])
plt.colorbar()
plt.savefig(f'{outputDir}/1Photon_ML_superFrame.png', dpi=300)
np.save(f'{outputDir}/1Photon_ML_superFrame.npy', mlSuperFrame)
plt.clf()
countFrame = countFrame[config['roi'][2] : config['roi'][3], config['roi'][0] : config['roi'][1]]
plt.imshow(countFrame, origin='lower', extent=[X_st, X_ed, Y_st, Y_ed])
plt.colorbar()
plt.savefig(f'{outputDir}/1Photon_count_Frame.png', dpi=300)
np.save(f'{outputDir}/1Photon_count_Frame.npy', countFrame)
plt.clf()
plt.imshow(subpixelDistribution, origin='lower', extent=[0, BinningFactor, 0, BinningFactor])
plt.colorbar()
plt.savefig(f'{outputDir}/1Photon_subpixel_Distribution.png', dpi=300)
np.save(f'{outputDir}/1Photon_subpixel_Distribution.npy', subpixelDistribution)
chunk_super, chunk_count, chunk_subpixel = accumulate_hits(
predictions, ref_points, binning_factor=BinningFactor
)
ml_super_frame += chunk_super
count_frame += chunk_count
subpixel_dist += chunk_subpixel
save_results(ml_super_frame, count_frame, subpixel_dist,
roi=roi, binning_factor=BinningFactor,
output_dir=output_dir)