205 lines
8.3 KiB
Python
205 lines
8.3 KiB
Python
import sys
|
|
sys.path.append('./src')
|
|
from pathlib import Path
|
|
from omegaconf import OmegaConf
|
|
import torch
|
|
from tqdm import tqdm
|
|
from matplotlib import pyplot as plt
|
|
import numpy as np
|
|
import h5py
|
|
|
|
from models import get_model_class
|
|
from datasets import singlePhotonDataset
|
|
|
|
torch.manual_seed(0)
|
|
torch.cuda.manual_seed(0)
|
|
np.random.seed(0)
|
|
torch.backends.cudnn.deterministic = True
|
|
torch.backends.cudnn.benchmark = False
|
|
|
|
conf = OmegaConf.load("Configs/infer_1photon.yaml")
|
|
NX, NY = conf.data.names[conf.experiment.name].NX, conf.data.names[conf.experiment.name].NY
|
|
|
|
def inv0(p): return p
|
|
def inv1(p): return torch.stack([-p[..., 0], p[..., 1]], dim=-1)
|
|
def inv2(p): return torch.stack([p[..., 0], -p[..., 1]], dim=-1)
|
|
def inv3(p): return -p
|
|
def inv4(p): return torch.stack([p[..., 1], p[..., 0]], dim=-1)
|
|
def inv5(p): return torch.stack([p[..., 1], -p[..., 0]], dim=-1)
|
|
def inv6(p): return torch.stack([-p[..., 1], p[..., 0]], dim=-1)
|
|
def inv7(p): return torch.stack([-p[..., 1], -p[..., 0]], dim=-1)
|
|
|
|
INVERSE_TRANSFORMS = {
|
|
0: inv0, 1: inv1, 2: inv2, 3: inv3,
|
|
4: inv4, 5: inv5, 6: inv6, 7: inv7,
|
|
}
|
|
|
|
def apply_inverse_transforms(predictions: torch.Tensor, numberOfAugOps: int) -> torch.Tensor:
|
|
N = predictions.shape[0] // numberOfAugOps
|
|
preds = predictions.view(N, numberOfAugOps, 2)
|
|
corrected = torch.zeros_like(preds)
|
|
for idx in range(numberOfAugOps):
|
|
corrected[:, idx, :] = INVERSE_TRANSFORMS[idx](preds[:, idx, :])
|
|
return corrected.mean(dim=1)
|
|
|
|
def prepare_output_folder(conf):
|
|
if conf.data.normalize:
|
|
normalize_suffix = '_normalized'
|
|
else:
|
|
normalize_suffix = ''
|
|
output_base = Path(conf.experiment.output_base) / conf.experiment.name / conf.model.experiment_name / f'augX{conf.inference.num_aug_ops}{normalize_suffix}'
|
|
output_base.mkdir(parents=True, exist_ok=True)
|
|
OmegaConf.save(conf, output_base / 'config.yaml')
|
|
return output_base
|
|
|
|
def get_files_list(conf):
|
|
task_conf = conf.data.names[conf.experiment.name]
|
|
file_pattern = task_conf.file_pattern
|
|
start, end = task_conf.file_range
|
|
files = [str(Path(conf.data.sample_folder) / file_pattern.format(i)) for i in range(start, end)]
|
|
return files
|
|
|
|
def run_inference(model, data_loader, conf):
|
|
model.eval()
|
|
all_predictions = []
|
|
all_reference_points = data_loader.dataset.referencePoint
|
|
with torch.no_grad():
|
|
for batch in tqdm(data_loader, desc="Inferring"):
|
|
inputs, _ = batch
|
|
inputs = inputs.to('cuda')
|
|
outputs = model(inputs)[:, :2].cpu() # 只取 x, y
|
|
all_predictions.append(outputs)
|
|
|
|
all_predictions = torch.cat(all_predictions, dim=0)
|
|
all_predictions = apply_inverse_transforms(all_predictions, conf.inference.num_aug_ops)
|
|
offset = [inputs.shape[-2] / 2., inputs.shape[-1] / 2.]
|
|
offset = torch.tensor(offset).unsqueeze(0) # (1, 2)
|
|
all_predictions += offset
|
|
print(f'mean x = {torch.mean(all_predictions[:, 0]):.4f}, std x = {torch.std(all_predictions[:, 0]):.4f}')
|
|
print(f'mean y = {torch.mean(all_predictions[:, 1]):.4f}, std y = {torch.std(all_predictions[:, 1]):.4f}')
|
|
return all_predictions.numpy(), all_reference_points
|
|
|
|
def accumulate_hits(predictions: np.ndarray, reference_points: np.ndarray,
|
|
binning_factor: int):
|
|
ml_super_frame = np.zeros((NY * binning_factor, NX * binning_factor))
|
|
count_frame = np.zeros((NY, NX))
|
|
subpixel_dist = np.zeros((binning_factor, binning_factor))
|
|
|
|
# 绝对坐标 = 预测亚像素 + 参考点(像素左下角)
|
|
absolute_positions = predictions + reference_points[:, :2]
|
|
|
|
# 超分辨帧 (binning)
|
|
hit_x = np.floor(absolute_positions[:, 0] * binning_factor).astype(int)
|
|
hit_y = np.floor(absolute_positions[:, 1] * binning_factor).astype(int)
|
|
np.add.at(ml_super_frame, (hit_y, hit_x), 1)
|
|
|
|
# 计数帧 (按参考点像素索引)
|
|
ref_x = (reference_points[:, 0] + 1).astype(int) # 参考点是左下角,+1 得像素索引
|
|
ref_y = (reference_points[:, 1] + 1).astype(int)
|
|
np.add.at(count_frame, (ref_y, ref_x), 1)
|
|
|
|
# 亚像素分布
|
|
sub_x = np.floor((absolute_positions[:, 0] % 1) * binning_factor).astype(int)
|
|
sub_y = np.floor((absolute_positions[:, 1] % 1) * binning_factor).astype(int)
|
|
np.add.at(subpixel_dist, (sub_y, sub_x), 1)
|
|
|
|
return ml_super_frame, count_frame, subpixel_dist
|
|
|
|
def save_results(ml_super_frame, count_frame, subpixel_dist,
|
|
roi: list, binning_factor: int, output_dir: Path):
|
|
x_min, x_max, y_min, y_max = roi
|
|
|
|
# 1. super-resolution frame
|
|
plt.figure(figsize=(8, 8))
|
|
plt.imshow(ml_super_frame[y_min*binning_factor:y_max*binning_factor, x_min*binning_factor:x_max*binning_factor], origin='lower', extent=[x_min, x_max, y_min, y_max])
|
|
plt.colorbar(label='Counts')
|
|
plt.title('ML Super-Resolution Frame')
|
|
plt.xlabel('X (pixel)')
|
|
plt.ylabel('Y (pixel)')
|
|
plt.savefig(output_dir / '1Photon_ML_superFrame.png', dpi=300, bbox_inches='tight')
|
|
plt.close()
|
|
np.save(output_dir / '1Photon_ML_superFrame.npy', ml_super_frame)
|
|
|
|
# 2. count frame
|
|
plt.figure(figsize=(8, 8))
|
|
plt.imshow(count_frame[y_min:y_max, x_min:x_max], origin='lower', extent=[x_min, x_max, y_min, y_max])
|
|
plt.colorbar(label='Counts')
|
|
plt.title('Photon Count Frame')
|
|
plt.xlabel('X (pixel)')
|
|
plt.ylabel('Y (pixel)')
|
|
plt.savefig(output_dir / '1Photon_count_Frame.png', dpi=300, bbox_inches='tight')
|
|
plt.close()
|
|
np.save(output_dir / '1Photon_count_Frame.npy', count_frame)
|
|
|
|
# 3. sub-pixel distribution
|
|
plt.figure(figsize=(8, 8))
|
|
plt.imshow(subpixel_dist, origin='lower',
|
|
extent=[0, binning_factor, 0, binning_factor],
|
|
cmap='viridis')
|
|
plt.colorbar(label='Counts')
|
|
plt.title('Sub-pixel Distribution')
|
|
plt.xlabel(f'Sub-pixel X (1/{binning_factor} pixel)')
|
|
plt.ylabel(f'Sub-pixel Y (1/{binning_factor} pixel)')
|
|
plt.savefig(output_dir / '1Photon_subpixel_Distribution.png', dpi=300, bbox_inches='tight')
|
|
plt.close()
|
|
np.save(output_dir / '1Photon_subpixel_Distribution.npy', subpixel_dist)
|
|
rms, mean = np.std(subpixel_dist), np.mean(subpixel_dist)
|
|
print(f"[Plotting]: Sub-pixel distribution: RMS/Mean: {rms/mean:.4f}")
|
|
|
|
print(f"Results saved to: {output_dir}")
|
|
|
|
if __name__ == "__main__":
|
|
### output preparation
|
|
output_dir = prepare_output_folder(conf)
|
|
|
|
### model loading
|
|
model = get_model_class(conf.model.version)().cuda()
|
|
model.load_state_dict(torch.load(f'{conf.model.base_dir}/{conf.model.experiment_name}/Models/{conf.model.name}', weights_only=True))
|
|
model.eval()
|
|
|
|
### data loading
|
|
files_list = get_files_list(conf)
|
|
roi = conf.data.names[conf.experiment.name].roi
|
|
BinningFactor = conf.inference.binning_factor
|
|
numberOfAugOps = conf.inference.num_aug_ops
|
|
flag_normalize = conf.data.normalize
|
|
nChunks = int(np.ceil(len(files_list) / conf.inference.chunk_size))
|
|
|
|
ml_super_frame = np.zeros((NY * BinningFactor, NX * BinningFactor))
|
|
count_frame = np.zeros((NY, NX))
|
|
subpixel_dist = np.zeros((BinningFactor, BinningFactor))
|
|
|
|
### Inference loop
|
|
for idxChunk in range(nChunks):
|
|
start_idx = idxChunk * conf.inference.chunk_size
|
|
end_idx = min(start_idx + conf.inference.chunk_size, len(files_list))
|
|
chunk_files = files_list[start_idx:end_idx]
|
|
print(f'[Inferring] Chunk {idxChunk+1}/{nChunks}: Loading files {start_idx} to {end_idx}...')
|
|
|
|
dataset = singlePhotonDataset(
|
|
chunk_files,
|
|
sampleRatio=1.0,
|
|
datasetName='Inference',
|
|
numberOfAugOps=conf.inference.num_aug_ops,
|
|
normalize=conf.data.normalize
|
|
)
|
|
|
|
data_loader = torch.utils.data.DataLoader(
|
|
dataset,
|
|
batch_size=conf.data.batch_size,
|
|
shuffle=False,
|
|
num_workers=conf.data.num_workers,
|
|
pin_memory=True
|
|
)
|
|
|
|
predictions, ref_points = run_inference(model, data_loader, conf)
|
|
|
|
chunk_super, chunk_count, chunk_subpixel = accumulate_hits(
|
|
predictions, ref_points, binning_factor=BinningFactor
|
|
)
|
|
ml_super_frame += chunk_super
|
|
count_frame += chunk_count
|
|
subpixel_dist += chunk_subpixel
|
|
save_results(ml_super_frame, count_frame, subpixel_dist,
|
|
roi=roi, binning_factor=BinningFactor,
|
|
output_dir=output_dir) |