Files
AareLC_train/experiments/export_depth_anything_onnx.py
2026-04-14 16:07:31 +02:00

187 lines
5.7 KiB
Python

#!/usr/bin/env python3
"""Export Depth Anything V1/V2 Hugging Face models to ONNX."""
from __future__ import annotations
import argparse
from pathlib import Path
import shutil
import subprocess
import torch
from transformers import AutoModelForDepthEstimation
class DepthWrapper(torch.nn.Module):
"""Expose predicted_depth directly for ONNX export."""
def __init__(self, model: torch.nn.Module):
super().__init__()
self.model = model
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
return self.model(pixel_values=pixel_values).predicted_depth
def export_model(model_id: str, out_path: Path, opset: int, size: int) -> None:
out_path.parent.mkdir(parents=True, exist_ok=True)
model = AutoModelForDepthEstimation.from_pretrained(model_id).eval()
wrapped = DepthWrapper(model).eval()
dummy = torch.randn(1, 3, size, size, dtype=torch.float32)
torch.onnx.export(
wrapped,
(dummy,),
str(out_path),
input_names=["pixel_values"],
output_names=["predicted_depth"],
dynamic_axes={
"pixel_values": {0: "batch", 2: "height", 3: "width"},
"predicted_depth": {0: "batch", 1: "height", 2: "width"},
},
opset_version=opset,
do_constant_folding=True,
)
print(f"Exported {model_id} -> {out_path}", flush=True)
def build_trt_engine(
onnx_path: Path,
engine_path: Path,
trtexec_path: str,
min_hw: int,
opt_hw: int,
max_hw: int,
fp16: bool,
) -> None:
engine_path.parent.mkdir(parents=True, exist_ok=True)
cmd = [
trtexec_path,
f"--onnx={onnx_path}",
f"--saveEngine={engine_path}",
f"--minShapes=pixel_values:1x3x{min_hw}x{min_hw}",
f"--optShapes=pixel_values:1x3x{opt_hw}x{opt_hw}",
f"--maxShapes=pixel_values:1x3x{max_hw}x{max_hw}",
]
if fp16:
cmd.append("--fp16")
print("Running:", " ".join(cmd), flush=True)
subprocess.run(cmd, check=True)
print(f"Built TensorRT engine -> {engine_path}", flush=True)
def main() -> None:
project_root = Path(__file__).resolve().parents[2]
default_models_dir = project_root / "models"
parser = argparse.ArgumentParser()
parser.add_argument(
"--v1-model-id",
default="LiheYoung/depth-anything-small-hf",
help="HF model id for Depth Anything V1.",
)
parser.add_argument(
"--v2-model-id",
default="depth-anything/Depth-Anything-V2-Small-hf",
help="HF model id for Depth Anything V2.",
)
parser.add_argument(
"--v1-out",
default=str(default_models_dir / "depth_anything_v1_small.onnx"),
help="Output ONNX path for V1.",
)
parser.add_argument(
"--v2-out",
default=str(default_models_dir / "depth_anything_v2_small.onnx"),
help="Output ONNX path for V2.",
)
parser.add_argument(
"--which",
choices=["v1", "v2", "both"],
default="both",
help="Which model(s) to export.",
)
parser.add_argument(
"--size",
type=int,
default=384,
help="Dummy export size (HxW) for tracing.",
)
parser.add_argument(
"--opset",
type=int,
default=17,
help="ONNX opset version.",
)
parser.add_argument(
"--build-trt",
action="store_true",
help="Also build TensorRT engine(s) via trtexec after ONNX export.",
)
parser.add_argument(
"--v1-engine-out",
default=str(default_models_dir / "depth_anything_v1_small.engine"),
help="Output TensorRT engine path for V1.",
)
parser.add_argument(
"--v2-engine-out",
default=str(default_models_dir / "depth_anything_v2_small.engine"),
help="Output TensorRT engine path for V2.",
)
parser.add_argument(
"--trtexec",
default="",
help="Path to trtexec binary (auto-detects /usr/src/tensorrt/bin/trtexec if empty).",
)
parser.add_argument("--min-shape", type=int, default=224, help="TRT min H/W.")
parser.add_argument("--opt-shape", type=int, default=384, help="TRT opt H/W.")
parser.add_argument("--max-shape", type=int, default=768, help="TRT max H/W.")
parser.add_argument(
"--no-fp16",
action="store_true",
help="Disable FP16 when building TensorRT engines.",
)
args = parser.parse_args()
v1_onnx = Path(args.v1_out)
v2_onnx = Path(args.v2_out)
if args.which in {"v1", "both"}:
export_model(args.v1_model_id, v1_onnx, args.opset, args.size)
if args.which in {"v2", "both"}:
export_model(args.v2_model_id, v2_onnx, args.opset, args.size)
if args.build_trt:
trtexec_path = args.trtexec.strip()
if not trtexec_path:
trtexec_path = shutil.which("trtexec") or "/usr/src/tensorrt/bin/trtexec"
if not Path(trtexec_path).exists():
raise FileNotFoundError(
f"trtexec not found at '{trtexec_path}'. "
"Provide --trtexec /path/to/trtexec."
)
fp16 = not args.no_fp16
if args.which in {"v1", "both"}:
build_trt_engine(
v1_onnx,
Path(args.v1_engine_out),
trtexec_path,
args.min_shape,
args.opt_shape,
args.max_shape,
fp16,
)
if args.which in {"v2", "both"}:
build_trt_engine(
v2_onnx,
Path(args.v2_engine_out),
trtexec_path,
args.min_shape,
args.opt_shape,
args.max_shape,
fp16,
)
if __name__ == "__main__":
main()