forked from nianticlabs/monodepth2
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinfer_original.py
More file actions
63 lines (59 loc) · 3.18 KB
/
Copy pathinfer_original.py
File metadata and controls
63 lines (59 loc) · 3.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#!/usr/bin/env python3
"""
Run inference on all images in a folder using original ONNX models (encoder.onnx, depth.onnx).
Usage:
python infer_original.py --model_path ./models/mono+stereo_640x192 --image_dir ./my_images --width 640 --height 192
"""
import os
import argparse
import numpy as np
import cv2
from pathlib import Path
import onnxruntime as ort
def parse_args():
parser = argparse.ArgumentParser(description="Run inference with original ONNX models on a folder of images.")
parser.add_argument('--model_path', type=str, required=True, help='Path to ONNX models directory')
parser.add_argument('--data_dir', type=str, required=True, help='Top-level KITTI data directory (e.g. ./kitti_data)')
parser.add_argument('--width', type=int, default=640, help='Input image width')
parser.add_argument('--height', type=int, default=192, help='Input image height')
parser.add_argument('--use_cuda', action='store_true', help='Use CUDA provider if available')
return parser.parse_args()
def preprocess_image(image_path, width, height):
img = cv2.imread(str(image_path))
if img is None:
raise ValueError(f"Failed to load image: {image_path}")
img = cv2.resize(img, (width, height))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = img.astype(np.float32) / 255.0
img = np.transpose(img, (2, 0, 1))
img = np.expand_dims(img, axis=0)
return img
def main():
args = parse_args()
encoder_path = os.path.join(args.model_path, "encoder.onnx")
decoder_path = os.path.join(args.model_path, "depth.onnx")
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if args.use_cuda else ['CPUExecutionProvider']
encoder = ort.InferenceSession(encoder_path, providers=providers)
decoder = ort.InferenceSession(decoder_path, providers=providers)
encoder_input_name = encoder.get_inputs()[0].name
encoder_output_names = [output.name for output in encoder.get_outputs()]
decoder_input_names = [inp.name for inp in decoder.get_inputs()]
decoder_output_names = [output.name for output in decoder.get_outputs()]
# Recursively find all images in kitti_data
image_files = sorted([p for p in Path(args.data_dir).rglob('*') if p.suffix.lower() in ['.jpg', '.png', '.jpeg']])
print(f"Found {len(image_files)} images in {args.data_dir}.")
for idx, img_path in enumerate(image_files):
input_tensor = preprocess_image(img_path, args.width, args.height)
encoder_outputs = encoder.run(encoder_output_names, {encoder_input_name: input_tensor})
encoder_output_dict = {name: output for name, output in zip(encoder_output_names, encoder_outputs)}
decoder_inputs = {}
for decoder_input_name in decoder_input_names:
expected_shape = tuple(decoder.get_inputs()[decoder_input_names.index(decoder_input_name)].shape)
for name, arr in encoder_output_dict.items():
if tuple(arr.shape) == expected_shape:
decoder_inputs[decoder_input_name] = arr
break
decoder_outputs = decoder.run(decoder_output_names, decoder_inputs)
print(f"[{idx+1}/{len(image_files)}] Inference done for {img_path}")
if __name__ == "__main__":
main()