#!/usr/bin/env python3 """Export Depth Anything V1/V2 Hugging Face models to ONNX.""" from __future__ import annotations import argparse from pathlib import Path import shutil import subprocess import torch from transformers import AutoModelForDepthEstimation class DepthWrapper(torch.nn.Module): """Expose predicted_depth directly for ONNX export.""" def __init__(self, model: torch.nn.Module): super().__init__() self.model = model def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: return self.model(pixel_values=pixel_values).predicted_depth def export_model(model_id: str, out_path: Path, opset: int, size: int) -> None: out_path.parent.mkdir(parents=True, exist_ok=True) model = AutoModelForDepthEstimation.from_pretrained(model_id).eval() wrapped = DepthWrapper(model).eval() dummy = torch.randn(1, 3, size, size, dtype=torch.float32) torch.onnx.export( wrapped, (dummy,), str(out_path), input_names=["pixel_values"], output_names=["predicted_depth"], dynamic_axes={ "pixel_values": {0: "batch", 2: "height", 3: "width"}, "predicted_depth": {0: "batch", 1: "height", 2: "width"}, }, opset_version=opset, do_constant_folding=True, ) print(f"Exported {model_id} -> {out_path}", flush=True) def build_trt_engine( onnx_path: Path, engine_path: Path, trtexec_path: str, min_hw: int, opt_hw: int, max_hw: int, fp16: bool, ) -> None: engine_path.parent.mkdir(parents=True, exist_ok=True) cmd = [ trtexec_path, f"--onnx={onnx_path}", f"--saveEngine={engine_path}", f"--minShapes=pixel_values:1x3x{min_hw}x{min_hw}", f"--optShapes=pixel_values:1x3x{opt_hw}x{opt_hw}", f"--maxShapes=pixel_values:1x3x{max_hw}x{max_hw}", ] if fp16: cmd.append("--fp16") print("Running:", " ".join(cmd), flush=True) subprocess.run(cmd, check=True) print(f"Built TensorRT engine -> {engine_path}", flush=True) def main() -> None: project_root = Path(__file__).resolve().parents[2] default_models_dir = project_root / "models" parser = argparse.ArgumentParser() parser.add_argument( "--v1-model-id", default="LiheYoung/depth-anything-small-hf", help="HF model id for Depth Anything V1.", ) parser.add_argument( "--v2-model-id", default="depth-anything/Depth-Anything-V2-Small-hf", help="HF model id for Depth Anything V2.", ) parser.add_argument( "--v1-out", default=str(default_models_dir / "depth_anything_v1_small.onnx"), help="Output ONNX path for V1.", ) parser.add_argument( "--v2-out", default=str(default_models_dir / "depth_anything_v2_small.onnx"), help="Output ONNX path for V2.", ) parser.add_argument( "--which", choices=["v1", "v2", "both"], default="both", help="Which model(s) to export.", ) parser.add_argument( "--size", type=int, default=384, help="Dummy export size (HxW) for tracing.", ) parser.add_argument( "--opset", type=int, default=17, help="ONNX opset version.", ) parser.add_argument( "--build-trt", action="store_true", help="Also build TensorRT engine(s) via trtexec after ONNX export.", ) parser.add_argument( "--v1-engine-out", default=str(default_models_dir / "depth_anything_v1_small.engine"), help="Output TensorRT engine path for V1.", ) parser.add_argument( "--v2-engine-out", default=str(default_models_dir / "depth_anything_v2_small.engine"), help="Output TensorRT engine path for V2.", ) parser.add_argument( "--trtexec", default="", help="Path to trtexec binary (auto-detects /usr/src/tensorrt/bin/trtexec if empty).", ) parser.add_argument("--min-shape", type=int, default=224, help="TRT min H/W.") parser.add_argument("--opt-shape", type=int, default=384, help="TRT opt H/W.") parser.add_argument("--max-shape", type=int, default=768, help="TRT max H/W.") parser.add_argument( "--no-fp16", action="store_true", help="Disable FP16 when building TensorRT engines.", ) args = parser.parse_args() v1_onnx = Path(args.v1_out) v2_onnx = Path(args.v2_out) if args.which in {"v1", "both"}: export_model(args.v1_model_id, v1_onnx, args.opset, args.size) if args.which in {"v2", "both"}: export_model(args.v2_model_id, v2_onnx, args.opset, args.size) if args.build_trt: trtexec_path = args.trtexec.strip() if not trtexec_path: trtexec_path = shutil.which("trtexec") or "/usr/src/tensorrt/bin/trtexec" if not Path(trtexec_path).exists(): raise FileNotFoundError( f"trtexec not found at '{trtexec_path}'. " "Provide --trtexec /path/to/trtexec." ) fp16 = not args.no_fp16 if args.which in {"v1", "both"}: build_trt_engine( v1_onnx, Path(args.v1_engine_out), trtexec_path, args.min_shape, args.opt_shape, args.max_shape, fp16, ) if args.which in {"v2", "both"}: build_trt_engine( v2_onnx, Path(args.v2_engine_out), trtexec_path, args.min_shape, args.opt_shape, args.max_shape, fp16, ) if __name__ == "__main__": main()