187 lines
5.7 KiB
Python
187 lines
5.7 KiB
Python
#!/usr/bin/env python3
|
|
"""Export Depth Anything V1/V2 Hugging Face models to ONNX."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
from pathlib import Path
|
|
import shutil
|
|
import subprocess
|
|
|
|
import torch
|
|
from transformers import AutoModelForDepthEstimation
|
|
|
|
|
|
class DepthWrapper(torch.nn.Module):
|
|
"""Expose predicted_depth directly for ONNX export."""
|
|
|
|
def __init__(self, model: torch.nn.Module):
|
|
super().__init__()
|
|
self.model = model
|
|
|
|
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
|
return self.model(pixel_values=pixel_values).predicted_depth
|
|
|
|
|
|
def export_model(model_id: str, out_path: Path, opset: int, size: int) -> None:
|
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
|
model = AutoModelForDepthEstimation.from_pretrained(model_id).eval()
|
|
wrapped = DepthWrapper(model).eval()
|
|
dummy = torch.randn(1, 3, size, size, dtype=torch.float32)
|
|
|
|
torch.onnx.export(
|
|
wrapped,
|
|
(dummy,),
|
|
str(out_path),
|
|
input_names=["pixel_values"],
|
|
output_names=["predicted_depth"],
|
|
dynamic_axes={
|
|
"pixel_values": {0: "batch", 2: "height", 3: "width"},
|
|
"predicted_depth": {0: "batch", 1: "height", 2: "width"},
|
|
},
|
|
opset_version=opset,
|
|
do_constant_folding=True,
|
|
)
|
|
print(f"Exported {model_id} -> {out_path}", flush=True)
|
|
|
|
|
|
def build_trt_engine(
|
|
onnx_path: Path,
|
|
engine_path: Path,
|
|
trtexec_path: str,
|
|
min_hw: int,
|
|
opt_hw: int,
|
|
max_hw: int,
|
|
fp16: bool,
|
|
) -> None:
|
|
engine_path.parent.mkdir(parents=True, exist_ok=True)
|
|
cmd = [
|
|
trtexec_path,
|
|
f"--onnx={onnx_path}",
|
|
f"--saveEngine={engine_path}",
|
|
f"--minShapes=pixel_values:1x3x{min_hw}x{min_hw}",
|
|
f"--optShapes=pixel_values:1x3x{opt_hw}x{opt_hw}",
|
|
f"--maxShapes=pixel_values:1x3x{max_hw}x{max_hw}",
|
|
]
|
|
if fp16:
|
|
cmd.append("--fp16")
|
|
print("Running:", " ".join(cmd), flush=True)
|
|
subprocess.run(cmd, check=True)
|
|
print(f"Built TensorRT engine -> {engine_path}", flush=True)
|
|
|
|
|
|
def main() -> None:
|
|
project_root = Path(__file__).resolve().parents[2]
|
|
default_models_dir = project_root / "models"
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--v1-model-id",
|
|
default="LiheYoung/depth-anything-small-hf",
|
|
help="HF model id for Depth Anything V1.",
|
|
)
|
|
parser.add_argument(
|
|
"--v2-model-id",
|
|
default="depth-anything/Depth-Anything-V2-Small-hf",
|
|
help="HF model id for Depth Anything V2.",
|
|
)
|
|
parser.add_argument(
|
|
"--v1-out",
|
|
default=str(default_models_dir / "depth_anything_v1_small.onnx"),
|
|
help="Output ONNX path for V1.",
|
|
)
|
|
parser.add_argument(
|
|
"--v2-out",
|
|
default=str(default_models_dir / "depth_anything_v2_small.onnx"),
|
|
help="Output ONNX path for V2.",
|
|
)
|
|
parser.add_argument(
|
|
"--which",
|
|
choices=["v1", "v2", "both"],
|
|
default="both",
|
|
help="Which model(s) to export.",
|
|
)
|
|
parser.add_argument(
|
|
"--size",
|
|
type=int,
|
|
default=384,
|
|
help="Dummy export size (HxW) for tracing.",
|
|
)
|
|
parser.add_argument(
|
|
"--opset",
|
|
type=int,
|
|
default=17,
|
|
help="ONNX opset version.",
|
|
)
|
|
parser.add_argument(
|
|
"--build-trt",
|
|
action="store_true",
|
|
help="Also build TensorRT engine(s) via trtexec after ONNX export.",
|
|
)
|
|
parser.add_argument(
|
|
"--v1-engine-out",
|
|
default=str(default_models_dir / "depth_anything_v1_small.engine"),
|
|
help="Output TensorRT engine path for V1.",
|
|
)
|
|
parser.add_argument(
|
|
"--v2-engine-out",
|
|
default=str(default_models_dir / "depth_anything_v2_small.engine"),
|
|
help="Output TensorRT engine path for V2.",
|
|
)
|
|
parser.add_argument(
|
|
"--trtexec",
|
|
default="",
|
|
help="Path to trtexec binary (auto-detects /usr/src/tensorrt/bin/trtexec if empty).",
|
|
)
|
|
parser.add_argument("--min-shape", type=int, default=224, help="TRT min H/W.")
|
|
parser.add_argument("--opt-shape", type=int, default=384, help="TRT opt H/W.")
|
|
parser.add_argument("--max-shape", type=int, default=768, help="TRT max H/W.")
|
|
parser.add_argument(
|
|
"--no-fp16",
|
|
action="store_true",
|
|
help="Disable FP16 when building TensorRT engines.",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
v1_onnx = Path(args.v1_out)
|
|
v2_onnx = Path(args.v2_out)
|
|
if args.which in {"v1", "both"}:
|
|
export_model(args.v1_model_id, v1_onnx, args.opset, args.size)
|
|
if args.which in {"v2", "both"}:
|
|
export_model(args.v2_model_id, v2_onnx, args.opset, args.size)
|
|
|
|
if args.build_trt:
|
|
trtexec_path = args.trtexec.strip()
|
|
if not trtexec_path:
|
|
trtexec_path = shutil.which("trtexec") or "/usr/src/tensorrt/bin/trtexec"
|
|
if not Path(trtexec_path).exists():
|
|
raise FileNotFoundError(
|
|
f"trtexec not found at '{trtexec_path}'. "
|
|
"Provide --trtexec /path/to/trtexec."
|
|
)
|
|
fp16 = not args.no_fp16
|
|
if args.which in {"v1", "both"}:
|
|
build_trt_engine(
|
|
v1_onnx,
|
|
Path(args.v1_engine_out),
|
|
trtexec_path,
|
|
args.min_shape,
|
|
args.opt_shape,
|
|
args.max_shape,
|
|
fp16,
|
|
)
|
|
if args.which in {"v2", "both"}:
|
|
build_trt_engine(
|
|
v2_onnx,
|
|
Path(args.v2_engine_out),
|
|
trtexec_path,
|
|
args.min_shape,
|
|
args.opt_shape,
|
|
args.max_shape,
|
|
fp16,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|