Prathamesh1420 commited on
Commit
33eab27
·
verified ·
1 Parent(s): 83af1dd

Delete inference.py

Browse files
Files changed (1) hide show
  1. inference.py +0 -153
inference.py DELETED
@@ -1,153 +0,0 @@
1
- import time
2
-
3
- import cv2
4
- import numpy as np
5
- import onnxruntime
6
-
7
- try:
8
- from demo.object_detection.utils import draw_detections
9
- except (ImportError, ModuleNotFoundError):
10
- from utils import draw_detections
11
-
12
-
13
- class YOLOv10:
14
- def __init__(self, path):
15
- # Initialize model
16
- self.initialize_model(path)
17
-
18
- def __call__(self, image):
19
- return self.detect_objects(image)
20
-
21
- def initialize_model(self, path):
22
- self.session = onnxruntime.InferenceSession(
23
- path, providers=onnxruntime.get_available_providers()
24
- )
25
- # Get model info
26
- self.get_input_details()
27
- self.get_output_details()
28
-
29
- def detect_objects(self, image, conf_threshold=0.3):
30
- input_tensor = self.prepare_input(image)
31
-
32
- # Perform inference on the image
33
- new_image = self.inference(image, input_tensor, conf_threshold)
34
-
35
- return new_image
36
-
37
- def prepare_input(self, image):
38
- self.img_height, self.img_width = image.shape[:2]
39
-
40
- input_img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
41
-
42
- # Resize input image
43
- input_img = cv2.resize(input_img, (self.input_width, self.input_height))
44
-
45
- # Scale input pixel values to 0 to 1
46
- input_img = input_img / 255.0
47
- input_img = input_img.transpose(2, 0, 1)
48
- input_tensor = input_img[np.newaxis, :, :, :].astype(np.float32)
49
-
50
- return input_tensor
51
-
52
- def inference(self, image, input_tensor, conf_threshold=0.3):
53
- start = time.perf_counter()
54
- outputs = self.session.run(
55
- self.output_names, {self.input_names[0]: input_tensor}
56
- )
57
-
58
- print(f"Inference time: {(time.perf_counter() - start) * 1000:.2f} ms")
59
- (
60
- boxes,
61
- scores,
62
- class_ids,
63
- ) = self.process_output(outputs, conf_threshold)
64
- return self.draw_detections(image, boxes, scores, class_ids)
65
-
66
- def process_output(self, output, conf_threshold=0.3):
67
- predictions = np.squeeze(output[0])
68
-
69
- # Filter out object confidence scores below threshold
70
- scores = predictions[:, 4]
71
- predictions = predictions[scores > conf_threshold, :]
72
- scores = scores[scores > conf_threshold]
73
-
74
- if len(scores) == 0:
75
- return [], [], []
76
-
77
- # Get the class with the highest confidence
78
- class_ids = predictions[:, 5].astype(int)
79
-
80
- # Get bounding boxes for each object
81
- boxes = self.extract_boxes(predictions)
82
-
83
- return boxes, scores, class_ids
84
-
85
- def extract_boxes(self, predictions):
86
- # Extract boxes from predictions
87
- boxes = predictions[:, :4]
88
-
89
- # Scale boxes to original image dimensions
90
- boxes = self.rescale_boxes(boxes)
91
-
92
- # Convert boxes to xyxy format
93
- # boxes = xywh2xyxy(boxes)
94
-
95
- return boxes
96
-
97
- def rescale_boxes(self, boxes):
98
- # Rescale boxes to original image dimensions
99
- input_shape = np.array(
100
- [self.input_width, self.input_height, self.input_width, self.input_height]
101
- )
102
- boxes = np.divide(boxes, input_shape, dtype=np.float32)
103
- boxes *= np.array(
104
- [self.img_width, self.img_height, self.img_width, self.img_height]
105
- )
106
- return boxes
107
-
108
- def draw_detections(
109
- self, image, boxes, scores, class_ids, draw_scores=True, mask_alpha=0.4
110
- ):
111
- return draw_detections(image, boxes, scores, class_ids, mask_alpha)
112
-
113
- def get_input_details(self):
114
- model_inputs = self.session.get_inputs()
115
- self.input_names = [model_inputs[i].name for i in range(len(model_inputs))]
116
-
117
- self.input_shape = model_inputs[0].shape
118
- self.input_height = self.input_shape[2]
119
- self.input_width = self.input_shape[3]
120
-
121
- def get_output_details(self):
122
- model_outputs = self.session.get_outputs()
123
- self.output_names = [model_outputs[i].name for i in range(len(model_outputs))]
124
-
125
-
126
- if __name__ == "__main__":
127
- import tempfile
128
-
129
- import requests
130
- from huggingface_hub import hf_hub_download
131
-
132
- model_file = hf_hub_download(
133
- repo_id="onnx-community/yolov10s", filename="onnx/model.onnx"
134
- )
135
-
136
- yolov8_detector = YOLOv10(model_file)
137
-
138
- with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f:
139
- f.write(
140
- requests.get(
141
- "https://live.staticflickr.com/13/19041780_d6fd803de0_3k.jpg"
142
- ).content
143
- )
144
- f.seek(0)
145
- img = cv2.imread(f.name)
146
-
147
- # # Detect Objects
148
- combined_image = yolov8_detector.detect_objects(img)
149
-
150
- # Draw detections
151
- cv2.namedWindow("Output", cv2.WINDOW_NORMAL)
152
- cv2.imshow("Output", combined_image)
153
- cv2.waitKey(0)