Add inference codes for 3-ph category

Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
2026-06-01 14:21:37 +02:00
parent 2e4f22d062
commit b3c08ad2e5
4 changed files with 285 additions and 2 deletions
+54
View File
@@ -0,0 +1,54 @@
# configs/infer_1photon.yaml
experiment:
name: "FlatField_3filters_pos0" # options: SiemenStarLowerLeft, SiemenStarLowerRight
task: "3Photon"
output_base: "./InferenceResults"
data:
sample_folder: "/home/xie_x1/MLXID/DataProcess/Samples"
names:
KnifeEdge_3filters_pos0:
file_pattern: "2603MaxIV_Edge3Filters_pos0_12keV/3Photon_CS9_chunk{}.h5" ### 12 keV
NX: 101
NY: 101
file_range: [0, 16] # [0, 16)
roi: [0, 101, 0, 101]
FlatField_3filters_pos0:
file_pattern: "2603MaxIV_FlatField3Filters_pos0_12keV/3Photon_CS9_chunk{}.h5" ### 12 keV
NX: 101
NY: 101
file_range: [0, 16] # [0, 16)
roi: [0, 101, 0, 101]
KnifeEdge_2filters:
file_pattern: "2603MaxIV_Edge2Filters_12keV/3Photon_CS9_chunk{}.h5" ### 12 keV
NX: 101
NY: 101
file_range: [0, 16] # [0, 16)
roi: [0, 101, 0, 101]
FlatField_2filters:
file_pattern: "2603MaxIV_Flat2Filters_12keV/3Photon_CS9_chunk{}.h5" ### 12 keV
NX: 101
NY: 101
file_range: [0, 16] # [0, 16)
roi: [0, 101, 0, 101]
Uniformity_test:
file_pattern: "2603MaxIV_Flat2Filters_12keV/3Photon_CS9_chunk{}.h5" ### 12 keV
NX: 101
NY: 101
file_range: [0, 1]
roi: [0, 101, 0, 101]
energy: 12 # keV
normalize: false # for inference dataloaders
batch_size: 8192
num_workers: 16
model:
base_dir: "/home/xie_x1/MLXID/DeepLearning/Results/"
experiment_name: "260529_3ph_12keV_v260529_01"
name: "triplePhoton260529_12keV_E1000.pth"
inference:
binning_factor: 10
chunk_size: 16 #
num_aug_ops: 1 # augX8 not implemented for 3-photon inference yet
+2 -2
View File
@@ -13,14 +13,14 @@ data:
train_file_range: [0, 12]
val_file_range: [13, 14]
test_file_range: [15, 15]
sample_ratio: 0.1
sample_ratio: 1.0
n_size: 9 ### size of sub-images containing 3 photons
model:
version: "260529"
training:
epochs: 100
epochs: 1000
learning_rate: 1.0e-3
weight_decay: 1.0e-4
scheduler_factor: 0.7
+195
View File
@@ -0,0 +1,195 @@
import sys
sys.path.append('./src')
from pathlib import Path
from omegaconf import OmegaConf
import torch
from tqdm import tqdm
from matplotlib import pyplot as plt
import numpy as np
import h5py
from models import get_triple_photon_model_class
from datasets import triplePhotonInferenceDataset
torch.manual_seed(42)
torch.cuda.manual_seed(42)
np.random.seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
conf = OmegaConf.load("Configs/infer_3photon.yaml")
NX, NY = conf.data.names[conf.experiment.name].NX, conf.data.names[conf.experiment.name].NY
NSIZE = 9
def inv0(p): return p
def inv1(p): return torch.stack([-p[..., 0], p[..., 1]], dim=-1)
def inv2(p): return torch.stack([p[..., 0], -p[..., 1]], dim=-1)
def inv3(p): return -p
def inv4(p): return torch.stack([p[..., 1], p[..., 0]], dim=-1)
def inv5(p): return torch.stack([p[..., 1], -p[..., 0]], dim=-1)
def inv6(p): return torch.stack([-p[..., 1], p[..., 0]], dim=-1)
def inv7(p): return torch.stack([-p[..., 1], -p[..., 0]], dim=-1)
INVERSE_TRANSFORMS = {
0: inv0, 1: inv1, 2: inv2, 3: inv3,
4: inv4, 5: inv5, 6: inv6, 7: inv7,
}
def apply_inverse_transforms(predictions: torch.Tensor, numberOfAugOps: int) -> torch.Tensor:
N = predictions.shape[0] // numberOfAugOps
preds = predictions.view(N, numberOfAugOps, 2)
corrected = torch.zeros_like(preds)
for idx in range(numberOfAugOps):
corrected[:, idx, :] = INVERSE_TRANSFORMS[idx](preds[:, idx, :])
return corrected.mean(dim=1)
def prepare_output_folder(conf):
if conf.data.normalize:
normalize_suffix = '_normalized'
else:
normalize_suffix = ''
output_base = Path(conf.experiment.output_base) / conf.experiment.name / conf.model.experiment_name / f'augX{conf.inference.num_aug_ops}{normalize_suffix}'
output_base.mkdir(parents=True, exist_ok=True)
OmegaConf.save(conf, output_base / 'config.yaml')
return output_base
def get_files_list(conf):
task_conf = conf.data.names[conf.experiment.name]
file_pattern = task_conf.file_pattern
start, end = task_conf.file_range
files = [str(Path(conf.data.sample_folder) / file_pattern.format(i)) for i in range(start, end)]
return files
def run_inference(model, data_loader, conf):
all_predictions = []
with torch.no_grad():
for batch_idx, batch in enumerate(tqdm(data_loader, desc="Inferring")):
inputs, _ = batch
inputs = inputs.cuda()
outputs = model(inputs).view(-1, 2) # 3B x 2
all_predictions.append(outputs.cpu())
all_predictions = torch.cat(all_predictions, dim=0)
# all_predictions = apply_inverse_transforms(all_predictions, conf.inference.num_aug_ops)
all_predictions += torch.tensor([NSIZE/2., NSIZE/2.]).unsqueeze(0) # adjust back to original coordinate system
print(f'mean x = {torch.mean(all_predictions[:, 0])}, std x = {torch.std(all_predictions[:, 0])}')
print(f'mean y = {torch.mean(all_predictions[:, 1])}, std y = {torch.std(all_predictions[:, 1])}')
referencePoints = data_loader.dataset.referencePoint ### the lower-left corner of the cluster in absolute coordinate
referencePoints = np.repeat(referencePoints, 3, axis=0) ### duplicate reference points for 3-photon clusters
return all_predictions.numpy(), referencePoints
def accumulate_hits(predictions: np.ndarray, reference_points: np.ndarray,
binning_factor: int):
### ret
ml_super_frame = np.zeros((NY*binning_factor, NX*binning_factor), dtype=np.int32)
count_frame = np.zeros((NY, NX), dtype=np.int32)
subpixel_dist = np.zeros((binning_factor, binning_factor), dtype=np.int32)
### absolute coordinate = predicted subpixel + reference point
absolute_positions = predictions + reference_points
hit_x_superpixel_idx = np.floor(absolute_positions[:, 0] * binning_factor).astype(int)
hit_x_superpixel_idx = np.clip(hit_x_superpixel_idx, 0, NX*binning_factor-1)
hit_y_superpixel_idx = np.floor(absolute_positions[:, 1] * binning_factor).astype(int)
hit_y_superpixel_idx = np.clip(hit_y_superpixel_idx, 0, NY*binning_factor-1)
np.add.at(ml_super_frame, (hit_y_superpixel_idx, hit_x_superpixel_idx), 1)
hit_x_pixel_idx = np.floor(absolute_positions[:, 0]).astype(int)
hit_x_pixel_idx = np.clip(hit_x_pixel_idx, 0, NX-1)
hit_y_pixel_idx = np.floor(absolute_positions[:, 1]).astype(int)
hit_y_pixel_idx = np.clip(hit_y_pixel_idx, 0, NY-1)
np.add.at(count_frame, (hit_y_pixel_idx, hit_x_pixel_idx), 1)
subpixel_x_idx = np.floor((absolute_positions[:, 0] % 1) * binning_factor).astype(int)
subpixel_y_idx = np.floor((absolute_positions[:, 1] % 1) * binning_factor).astype(int)
np.add.at(subpixel_dist, (subpixel_y_idx, subpixel_x_idx), 1)
return ml_super_frame, count_frame, subpixel_dist
def save_results(ml_super_frame, count_frame, subpixel_dist,
roi: list, binning_factor: int, output_dir: Path):
x_st, x_ed, y_st, y_ed = roi
# 1. super-resolution frame
plt.figure(figsize=(8, 8))
plt.imshow(ml_super_frame[y_st*binning_factor:y_ed*binning_factor, x_st*binning_factor:x_ed*binning_factor], origin='lower', extent=[x_st, x_ed, y_st, y_ed])
plt.colorbar(label='Counts')
plt.title('ML Super-Resolution Frame')
plt.xlabel('X (pixel)')
plt.ylabel('Y (pixel)')
plt.savefig(output_dir / '3Photon_ML_superFrame.png', dpi=300, bbox_inches='tight')
plt.clf()
np.save(output_dir / '3Photon_ML_superFrame.npy', ml_super_frame)
# 2. count frame
plt.imshow(count_frame[y_st:y_ed, x_st:x_ed], origin='lower', extent=[x_st, x_ed, y_st, y_ed])
plt.colorbar(label='Counts')
plt.title('Photon Count Frame')
plt.xlabel('X (pixel)')
plt.ylabel('Y (pixel)')
plt.savefig(output_dir / '3Photon_count_Frame.png', dpi=300, bbox_inches='tight')
plt.clf()
np.save(output_dir / '3Photon_count_Frame.npy', count_frame)
# 3. subpixel distribution
plt.imshow(subpixel_dist, origin='lower', extent=[0, 1, 0, 1])
plt.colorbar(label='Counts')
plt.title('Subpixel Distribution')
plt.xlabel('Subpixel X')
plt.ylabel('Subpixel Y')
plt.savefig(output_dir / '3Photon_subpixel_Distribution.png', dpi=300, bbox_inches='tight')
plt.close()
np.save(output_dir / '3Photon_subpixel_Distribution.npy', subpixel_dist)
std, mean = np.std(subpixel_dist), np.mean(subpixel_dist)
print(f"[Plotting]: Sub-pixel distribution: RMS/Mean: {std/mean:.4f}, expected value = {1/np.sqrt(mean):.4f} for uniform distribution")
print(f"Results saved to: {output_dir}")
if __name__ == "__main__":
### output folder preparation
output_dir = prepare_output_folder(conf)
### model loading
model_version = conf.model.experiment_name.split('_v')[-1][:6]
model = get_triple_photon_model_class(model_version)().cuda()
model.load_state_dict(torch.load(f'{conf.model.base_dir}/{conf.model.experiment_name}/Models/{conf.model.name}', weights_only=True))
# model.eval()
### data loading
files_list = get_files_list(conf)
roi = conf.data.names[conf.experiment.name].roi
BinningFactor = conf.inference.binning_factor
numberOfAugOps = conf.inference.num_aug_ops
flag_normalize = conf.data.normalize
nChunks = np.ceil(len(files_list) / 16).astype(int)
ml_super_frame = np.zeros((NY*BinningFactor, NX*BinningFactor), dtype=np.int32)
count_frame = np.zeros((NY, NX), dtype=np.int32)
subpixel_dist = np.zeros((BinningFactor, BinningFactor), dtype=np.int32)
for idxChunk in range(nChunks):
start_idx = idxChunk * conf.inference.chunk_size
end_idx = min(start_idx + conf.inference.chunk_size, len(files_list))
chunk_files = files_list[start_idx:end_idx]
print(f'[Inferring] Chunk {idxChunk+1}/{nChunks}: Loading files {start_idx} to {end_idx}...')
dataset = triplePhotonInferenceDataset(
chunk_files,
sampleRatio=1.0,
datasetName=f'Inference_Chunk{idxChunk+1}',
# numberOfAugOps=numberOfAugOps,
)
dataLoader = torch.utils.data.DataLoader(
dataset,
batch_size=8192,
shuffle=False,
num_workers=16,
pin_memory=True,
)
predictions, reference_points = run_inference(model, dataLoader, conf)
ml_super_frame_chunk, count_frame_chunk, subpixel_dist_chunk = accumulate_hits(predictions, reference_points, BinningFactor)
ml_super_frame += ml_super_frame_chunk
count_frame += count_frame_chunk
subpixel_dist += subpixel_dist_chunk
save_results(ml_super_frame, count_frame, subpixel_dist, roi, BinningFactor, output_dir)
+34
View File
@@ -272,5 +272,39 @@ class triplePhotonDataset(Dataset):
sample = torch.tensor(sample, dtype=torch.float32).unsqueeze(0)
label = torch.tensor(label, dtype=torch.float32)
return sample, label
def __len__(self):
return self.length
class triplePhotonInferenceDataset(Dataset):
def __init__(self, sampleList, sampleRatio, datasetName):
self.sampleFileList = sampleList
self.sampleRatio = sampleRatio
self.datasetName = datasetName
all_samples = []
all_ref_pts = []
for idx, sampleFile in enumerate(self.sampleFileList):
if '.npz' in sampleFile:
data = np.load(sampleFile)
all_samples.append(data['samples'])
all_ref_pts.append(data['referencePoint'])
elif '.h5' in sampleFile:
import h5py
with h5py.File(sampleFile, 'r') as f:
samples = f['clusters'][:]
ref_pts = f['referencePoint'][:]
all_samples.append(samples)
all_ref_pts.append(ref_pts)
self.samples = np.concatenate(all_samples, axis=0) if all_samples else None
self.referencePoint = np.concatenate(all_ref_pts, axis=0) if all_ref_pts else None
### total number of samples
self.length = int(self.samples.shape[0] * self.sampleRatio)
self.referencePoint = self.referencePoint[:self.length]
print(f"[{self.datasetName} dataset] \t Total number of samples: {self.length}")
def __getitem__(self, index):
sample = self.samples[index]
# sample[sample == 0] += np.random.normal(loc=0.0, scale=0.13, size=sample[sample == 0].shape) ### add noise to zero pixels
sample = torch.tensor(sample, dtype=torch.float32).unsqueeze(0)
dummy_label = np.zeros((3, 4), dtype=np.float32) ### dummy label for 3 photons
return sample, torch.tensor(dummy_label, dtype=torch.float32)
def __len__(self):
return self.length