Reformat the 2ph inference using conf style
This commit is contained in:
@@ -0,0 +1,53 @@
|
||||
# configs/infer_1photon.yaml
|
||||
experiment:
|
||||
name: "FlatField_3filters_pos0" # options: SiemenStarLowerLeft, SiemenStarLowerRight
|
||||
task: "2Photon"
|
||||
output_base: "./InferenceResults"
|
||||
|
||||
data:
|
||||
sample_folder: "/home/xie_x1/MLXID/DataProcess/Samples"
|
||||
names:
|
||||
SiemenStarLowerLeft:
|
||||
file_pattern: "SiemenStarLowerLeft/2Photon_CS7_chunk{}.h5" ### 15 keV
|
||||
NX: 400
|
||||
NY: 400
|
||||
file_range: [0, 1] # [0, 160)
|
||||
roi: [140, 230, 120, 210] # [x_min, x_max, y_min, y_max]
|
||||
nSize: 7
|
||||
SiemenStarLowerRight:
|
||||
file_pattern: "SiemenStarLowerRight/2Photon_CS7_chunk{}.h5" ### 15 keV
|
||||
NX: 400
|
||||
NY: 400
|
||||
file_range: [0, 320] # [0, 320)
|
||||
roi: [235, 345, 110, 220]
|
||||
nSize: 7
|
||||
KnifeEdge_3filters_pos0:
|
||||
file_pattern: "2603MaxIV_Edge3Filters_pos0_12keV/2Photon_CS7_chunk{}.h5" ### 12 keV
|
||||
NX: 101
|
||||
NY: 101
|
||||
file_range: [0, 16] # [0, 16)
|
||||
roi: [0, 101, 0, 101]
|
||||
nSize: 7
|
||||
FlatField_3filters_pos0:
|
||||
file_pattern: "2603MaxIV_FlatField3Filters_pos0_12keV/2Photon_CS7_chunk{}.h5" ### 12 keV
|
||||
NX: 101
|
||||
NY: 101
|
||||
file_range: [0, 48] # [0, 48)
|
||||
roi: [0, 101, 0, 101]
|
||||
nSize: 7
|
||||
|
||||
energy: 12 # keV
|
||||
normalize: false # for inference dataloaders
|
||||
batch_size: 8192
|
||||
num_workers: 16
|
||||
|
||||
model:
|
||||
version: "251124"
|
||||
base_dir: "/home/xie_x1/MLXID/DeepLearning/Results/"
|
||||
experiment_name: "260506_2ph_12keV_v251124_03"
|
||||
name: "doublePhoton251124_12keV_Noise0.13keV_E150.pth"
|
||||
|
||||
inference:
|
||||
binning_factor: 10
|
||||
chunk_size: 16 #
|
||||
num_aug_ops: 1 # augX8 not implemented for 2-photon inference yet
|
||||
@@ -0,0 +1,194 @@
|
||||
import sys
|
||||
sys.path.append('./src')
|
||||
from pathlib import Path
|
||||
from omegaconf import OmegaConf
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from matplotlib import pyplot as plt
|
||||
import numpy as np
|
||||
import h5py
|
||||
|
||||
from models import get_double_photon_model_class
|
||||
from datasets import doublePhotonInferenceDataset
|
||||
|
||||
torch.manual_seed(42)
|
||||
torch.cuda.manual_seed(42)
|
||||
np.random.seed(42)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
conf = OmegaConf.load("Configs/infer_2photon.yaml")
|
||||
NX, NY = conf.data.names[conf.experiment.name].NX, conf.data.names[conf.experiment.name].NY
|
||||
|
||||
def inv0(p): return p
|
||||
def inv1(p): return torch.stack([-p[..., 0], p[..., 1]], dim=-1)
|
||||
def inv2(p): return torch.stack([p[..., 0], -p[..., 1]], dim=-1)
|
||||
def inv3(p): return -p
|
||||
def inv4(p): return torch.stack([p[..., 1], p[..., 0]], dim=-1)
|
||||
def inv5(p): return torch.stack([p[..., 1], -p[..., 0]], dim=-1)
|
||||
def inv6(p): return torch.stack([-p[..., 1], p[..., 0]], dim=-1)
|
||||
def inv7(p): return torch.stack([-p[..., 1], -p[..., 0]], dim=-1)
|
||||
|
||||
INVERSE_TRANSFORMS = {
|
||||
0: inv0, 1: inv1, 2: inv2, 3: inv3,
|
||||
4: inv4, 5: inv5, 6: inv6, 7: inv7,
|
||||
}
|
||||
|
||||
def apply_inverse_transforms(predictions: torch.Tensor, numberOfAugOps: int) -> torch.Tensor:
|
||||
N = predictions.shape[0] // numberOfAugOps
|
||||
preds = predictions.view(N, numberOfAugOps, 2)
|
||||
corrected = torch.zeros_like(preds)
|
||||
for idx in range(numberOfAugOps):
|
||||
corrected[:, idx, :] = INVERSE_TRANSFORMS[idx](preds[:, idx, :])
|
||||
return corrected.mean(dim=1)
|
||||
|
||||
def prepare_output_folder(conf):
|
||||
if conf.data.normalize:
|
||||
normalize_suffix = '_normalized'
|
||||
else:
|
||||
normalize_suffix = ''
|
||||
output_base = Path(conf.experiment.output_base) / conf.experiment.name / conf.model.experiment_name / f'augX{conf.inference.num_aug_ops}{normalize_suffix}'
|
||||
output_base.mkdir(parents=True, exist_ok=True)
|
||||
OmegaConf.save(conf, output_base / 'config.yaml')
|
||||
return output_base
|
||||
|
||||
def get_files_list(conf):
|
||||
task_conf = conf.data.names[conf.experiment.name]
|
||||
file_pattern = task_conf.file_pattern
|
||||
start, end = task_conf.file_range
|
||||
files = [str(Path(conf.data.sample_folder) / file_pattern.format(i)) for i in range(start, end)]
|
||||
return files
|
||||
|
||||
def run_inference(model, data_loader, conf):
|
||||
all_predictions = []
|
||||
with torch.no_grad():
|
||||
for batch_idx, batch in enumerate(tqdm(data_loader, desc="Inferring")):
|
||||
inputs, _ = batch
|
||||
inputs = inputs.cuda()
|
||||
outputs = model(inputs).view(-1, 2) # 2B x 2
|
||||
all_predictions.append(outputs.cpu())
|
||||
|
||||
all_predictions = torch.cat(all_predictions, dim=0)
|
||||
# all_predictions = apply_inverse_transforms(all_predictions, conf.inference.num_aug_ops)
|
||||
all_predictions += torch.tensor([conf.data.names[conf.experiment.name].nSize/2., conf.data.names[conf.experiment.name].nSize/2.]).unsqueeze(0) # adjust back to original coordinate system
|
||||
print(f'mean x = {torch.mean(all_predictions[:, 0])}, std x = {torch.std(all_predictions[:, 0])}')
|
||||
print(f'mean y = {torch.mean(all_predictions[:, 1])}, std y = {torch.std(all_predictions[:, 1])}')
|
||||
referencePoints = data_loader.dataset.referencePoint ### the lower-left corner of the cluster in absolute coordinate
|
||||
referencePoints = np.repeat(referencePoints, 2, axis=0) ### duplicate reference points for 2-photon clusters
|
||||
return all_predictions.numpy(), referencePoints
|
||||
|
||||
def accumulate_hits(predictions: np.ndarray, reference_points: np.ndarray,
|
||||
binning_factor: int):
|
||||
### ret
|
||||
ml_super_frame = np.zeros((NY*binning_factor, NX*binning_factor), dtype=np.int32)
|
||||
count_frame = np.zeros((NY, NX), dtype=np.int32)
|
||||
subpixel_dist = np.zeros((binning_factor, binning_factor), dtype=np.int32)
|
||||
|
||||
### absolute coordinate = predicted subpixel + reference point
|
||||
absolute_positions = predictions + reference_points
|
||||
|
||||
hit_x_superpixel_idx = np.floor(absolute_positions[:, 0] * binning_factor).astype(int)
|
||||
hit_x_superpixel_idx = np.clip(hit_x_superpixel_idx, 0, NX*binning_factor-1)
|
||||
hit_y_superpixel_idx = np.floor(absolute_positions[:, 1] * binning_factor).astype(int)
|
||||
hit_y_superpixel_idx = np.clip(hit_y_superpixel_idx, 0, NY*binning_factor-1)
|
||||
np.add.at(ml_super_frame, (hit_y_superpixel_idx, hit_x_superpixel_idx), 1)
|
||||
|
||||
hit_x_pixel_idx = np.floor(absolute_positions[:, 0]).astype(int)
|
||||
hit_x_pixel_idx = np.clip(hit_x_pixel_idx, 0, NX-1)
|
||||
hit_y_pixel_idx = np.floor(absolute_positions[:, 1]).astype(int)
|
||||
hit_y_pixel_idx = np.clip(hit_y_pixel_idx, 0, NY-1)
|
||||
np.add.at(count_frame, (hit_y_pixel_idx, hit_x_pixel_idx), 1)
|
||||
|
||||
subpixel_x_idx = np.floor((absolute_positions[:, 0] % 1) * binning_factor).astype(int)
|
||||
subpixel_y_idx = np.floor((absolute_positions[:, 1] % 1) * binning_factor).astype(int)
|
||||
np.add.at(subpixel_dist, (subpixel_y_idx, subpixel_x_idx), 1)
|
||||
|
||||
return ml_super_frame, count_frame, subpixel_dist
|
||||
|
||||
def save_results(ml_super_frame, count_frame, subpixel_dist,
|
||||
roi: list, binning_factor: int, output_dir: Path):
|
||||
x_st, x_ed, y_st, y_ed = roi
|
||||
|
||||
# 1. super-resolution frame
|
||||
plt.figure(figsize=(8, 8))
|
||||
plt.imshow(ml_super_frame[y_st*binning_factor:y_ed*binning_factor, x_st*binning_factor:x_ed*binning_factor], origin='lower', extent=[x_st, x_ed, y_st, y_ed])
|
||||
plt.colorbar(label='Counts')
|
||||
plt.title('ML Super-Resolution Frame')
|
||||
plt.xlabel('X (pixel)')
|
||||
plt.ylabel('Y (pixel)')
|
||||
plt.savefig(output_dir / '2Photon_ML_superFrame.png', dpi=300, bbox_inches='tight')
|
||||
plt.clf()
|
||||
np.save(output_dir / '2Photon_ML_superFrame.npy', ml_super_frame)
|
||||
|
||||
# 2. count frame
|
||||
plt.imshow(count_frame[y_st:y_ed, x_st:x_ed], origin='lower', extent=[x_st, x_ed, y_st, y_ed])
|
||||
plt.colorbar(label='Counts')
|
||||
plt.title('Photon Count Frame')
|
||||
plt.xlabel('X (pixel)')
|
||||
plt.ylabel('Y (pixel)')
|
||||
plt.savefig(output_dir / '2Photon_count_Frame.png', dpi=300, bbox_inches='tight')
|
||||
plt.clf()
|
||||
np.save(output_dir / '2Photon_count_Frame.npy', count_frame)
|
||||
|
||||
# 3. subpixel distribution
|
||||
plt.imshow(subpixel_dist, origin='lower', extent=[0, 1, 0, 1])
|
||||
plt.colorbar(label='Counts')
|
||||
plt.title('Subpixel Distribution')
|
||||
plt.xlabel('Subpixel X')
|
||||
plt.ylabel('Subpixel Y')
|
||||
plt.savefig(output_dir / '2Photon_subpixel_Distribution.png', dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
np.save(output_dir / '2Photon_subpixel_Distribution.npy', subpixel_dist)
|
||||
std, mean = np.std(subpixel_dist), np.mean(subpixel_dist)
|
||||
print(f"[Plotting]: Sub-pixel distribution: Std/Mean: {std/mean:.4f}")
|
||||
|
||||
print(f"Results saved to: {output_dir}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
### output folder preparation
|
||||
output_dir = prepare_output_folder(conf)
|
||||
|
||||
### model loading
|
||||
model = get_double_photon_model_class(conf.model.version)().cuda()
|
||||
model.load_state_dict(torch.load(f'{conf.model.base_dir}/{conf.model.experiment_name}/Models/{conf.model.name}', weights_only=True))
|
||||
# model.eval()
|
||||
|
||||
### data loading
|
||||
files_list = get_files_list(conf)
|
||||
roi = conf.data.names[conf.experiment.name].roi
|
||||
BinningFactor = conf.inference.binning_factor
|
||||
numberOfAugOps = conf.inference.num_aug_ops
|
||||
flag_normalize = conf.data.normalize
|
||||
nChunks = np.ceil(len(files_list) / 16).astype(int)
|
||||
|
||||
ml_super_frame = np.zeros((NY*BinningFactor, NX*BinningFactor), dtype=np.int32)
|
||||
count_frame = np.zeros((NY, NX), dtype=np.int32)
|
||||
subpixel_dist = np.zeros((BinningFactor, BinningFactor), dtype=np.int32)
|
||||
|
||||
for idxChunk in range(nChunks):
|
||||
start_idx = idxChunk * conf.inference.chunk_size
|
||||
end_idx = min(start_idx + conf.inference.chunk_size, len(files_list))
|
||||
chunk_files = files_list[start_idx:end_idx]
|
||||
print(f'[Inferring] Chunk {idxChunk+1}/{nChunks}: Loading files {start_idx} to {end_idx}...')
|
||||
|
||||
dataset = doublePhotonInferenceDataset(
|
||||
chunk_files,
|
||||
sampleRatio=1.0,
|
||||
datasetName=f'Inference_Chunk{idxChunk+1}',
|
||||
# numberOfAugOps=numberOfAugOps,
|
||||
nSize=conf.data.names[conf.experiment.name].nSize,
|
||||
)
|
||||
dataLoader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=8192,
|
||||
shuffle=False,
|
||||
num_workers=16,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
predictions, reference_points = run_inference(model, dataLoader, conf)
|
||||
ml_super_frame_chunk, count_frame_chunk, subpixel_dist_chunk = accumulate_hits(predictions, reference_points, BinningFactor)
|
||||
ml_super_frame += ml_super_frame_chunk
|
||||
count_frame += count_frame_chunk
|
||||
subpixel_dist += subpixel_dist_chunk
|
||||
save_results(ml_super_frame, count_frame, subpixel_dist, roi, BinningFactor, output_dir)
|
||||
@@ -1,114 +0,0 @@
|
||||
import torch
|
||||
import sys
|
||||
sys.path.append('./src')
|
||||
import models
|
||||
from datasets import *
|
||||
from tqdm import tqdm
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
torch.manual_seed(42)
|
||||
torch.cuda.manual_seed(42)
|
||||
np.random.seed(42)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
configs = {}
|
||||
configs['SiemenStarLowerLeft'] = {
|
||||
'dataFiles': [f'/home/xie_x1/MLXID/DataProcess/Samples/SiemenStarLowerLeft_old/Clusters_2Photon_CS7_chunk{i}.h5' for i in range(200)],
|
||||
'modelVersion': '251124',
|
||||
'roi': [140, 230, 120, 210], # x_min, x_max, y_min, y_max,
|
||||
'nSize': 7,
|
||||
}
|
||||
|
||||
configs['SiemenStarLowerRight'] = {
|
||||
'dataFiles': [f'/home/xie_x1/MLXID/DataProcess/Samples/SiemenStarLowerRight/2Photon_CS7_chunk{i}.h5' for i in range(320)], ### 320 files
|
||||
'modelVersion': '251124',
|
||||
'roi': [235, 345, 110, 220], # x_min, x_max, y_min, y_max,
|
||||
'nSize': 7,
|
||||
}
|
||||
|
||||
task = 'SiemenStarLowerRight'
|
||||
config = configs[task]
|
||||
|
||||
BinningFactor = 10
|
||||
Roi = configs[task]['roi']
|
||||
X_st, X_ed, Y_st, Y_ed = Roi
|
||||
mlSuperFrame = np.zeros(((Y_ed-Y_st)*BinningFactor, (X_ed-X_st)*BinningFactor))
|
||||
countFrame = np.zeros((Y_ed-Y_st, X_ed-X_st))
|
||||
subpixelDistribution = np.zeros((BinningFactor, BinningFactor))
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = models.get_double_photon_model_class(config['modelVersion'])().cuda()
|
||||
model.load_state_dict(torch.load(f'/home/xie_x1/MLXID/DeepLearning/Models/doublePhoton{config["modelVersion"]}_15keV_Noise0.13keV_E300.pth', weights_only=True))
|
||||
nChunks = np.ceil(len(config['dataFiles']) / 16).astype(int)
|
||||
for idxChunk in range(nChunks):
|
||||
predictions = []
|
||||
referencePoints = []
|
||||
stFileIdx = idxChunk * 16
|
||||
edFileIdx = min((idxChunk + 1) * 16, len(config['dataFiles']))
|
||||
sampleFiles = config['dataFiles'][stFileIdx : edFileIdx]
|
||||
print(f'Processing files {stFileIdx} to {edFileIdx}...')
|
||||
dataset = doublePhotonInferenceDataset(
|
||||
sampleFiles,
|
||||
sampleRatio=1.,
|
||||
datasetName='Inference',
|
||||
nSize=config['nSize']
|
||||
)
|
||||
dataLoader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=8192,
|
||||
shuffle=False,
|
||||
num_workers=16,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
referencePoints.append(dataset.referencePoint)
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(dataLoader):
|
||||
inputs, _ = batch
|
||||
inputs = inputs.cuda()
|
||||
outputs = model(inputs).view(-1, 2) # 2B x 2
|
||||
predictions.append(outputs.cpu())
|
||||
|
||||
predictions = torch.cat(predictions, dim=0)
|
||||
predictions += torch.tensor([config['nSize']/2., config['nSize']/2.]).unsqueeze(0) # adjust back to original coordinate system
|
||||
print(f'mean x = {torch.mean(predictions[:, 0])}, std x = {torch.std(predictions[:, 0])}')
|
||||
print(f'mean y = {torch.mean(predictions[:, 1])}, std y = {torch.std(predictions[:, 1])}')
|
||||
referencePoints = np.concatenate(referencePoints, axis=0) ### the lower-left corner of the cluster in absolute coordinate
|
||||
### duplicate reference points for 2-photon clusters
|
||||
referencePoints = np.repeat(referencePoints, 2, axis=0)
|
||||
absolutePositions = predictions.numpy() + referencePoints
|
||||
|
||||
hit_x = np.floor((absolutePositions[:, 0] - Roi[0]) * BinningFactor).astype(int)
|
||||
hit_x = np.clip(hit_x, 0, mlSuperFrame.shape[1]-1)
|
||||
hit_y = np.floor((absolutePositions[:, 1] - Roi[2]) * BinningFactor).astype(int)
|
||||
hit_y = np.clip(hit_y, 0, mlSuperFrame.shape[0]-1)
|
||||
np.add.at(mlSuperFrame, (hit_y, hit_x), 1)
|
||||
|
||||
np.add.at(countFrame, ((referencePoints[:, 1] - Roi[2]).astype(int),
|
||||
(referencePoints[:, 0] - Roi[0]).astype(int)), 1)
|
||||
|
||||
np.add.at(subpixelDistribution,
|
||||
(np.floor((absolutePositions[:, 1] % 1) * BinningFactor).astype(int),
|
||||
np.floor((absolutePositions[:, 0] % 1) * BinningFactor).astype(int)), 1)
|
||||
|
||||
import os
|
||||
os.makedirs(f'InferenceResults/{task}', exist_ok=True)
|
||||
|
||||
plt.imshow(mlSuperFrame, origin='lower', extent=[Y_st, Y_ed, X_st, X_ed])
|
||||
plt.colorbar()
|
||||
plt.savefig(f'InferenceResults/{task}/ML_2Photon_superFrame.png', dpi=300)
|
||||
np.save(f'InferenceResults/{task}/ML_2Photon_superFrame.npy', mlSuperFrame)
|
||||
plt.clf()
|
||||
|
||||
plt.imshow(countFrame, origin='lower', extent=[Y_st, Y_ed, X_st, X_ed])
|
||||
plt.colorbar()
|
||||
plt.savefig(f'InferenceResults/{task}/count_2Photon_Frame.png', dpi=300)
|
||||
np.save(f'InferenceResults/{task}/count_2Photon_Frame.npy', countFrame)
|
||||
|
||||
plt.clf()
|
||||
plt.imshow(subpixelDistribution, origin='lower')
|
||||
plt.colorbar()
|
||||
plt.savefig(f'InferenceResults/{task}/subpixel_2Photon_Distribution.png', dpi=300)
|
||||
np.save(f'InferenceResults/{task}/subpixel_2Photon_Distribution.npy', subpixelDistribution)
|
||||
Reference in New Issue
Block a user