denizaybey commited on
Commit
46556af
·
verified ·
1 Parent(s): e05a450

Upload 3 files

Browse files
Files changed (1) hide show
  1. app.py +377 -51
app.py CHANGED
@@ -1,56 +1,382 @@
1
- import io
2
- import numpy
3
- import gradio
 
 
 
 
4
  import spaces
5
- import moviepy
6
- import supervision
7
- from PIL import Image
8
- from ultralytics import YOLOE
9
-
10
-
11
- @spaces.GPU
12
- def inference(video):
13
- model = YOLOE("./model.pt").to("cuda")
14
- names = ["person", "vehicle"]
15
- model.set_classes(names, model.get_text_pe(names))
16
- clip = moviepy.VideoFileClip(video)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  results = []
18
- for i, frame in enumerate(clip.iter_frames(fps=1)):
19
- image = Image.fromarray(numpy.uint8(frame))
20
- result = model.predict(frame, imgsz=640, conf=0.25, iou=0.7)
21
- detections = supervision.Detections.from_ultralytics(result[0])
22
- resolution_wh = image.size
23
- thickness = supervision.calculate_optimal_line_thickness(resolution_wh=resolution_wh)
24
- text_scale = supervision.calculate_optimal_text_scale(resolution_wh=resolution_wh)
25
- labels = [
26
- f"{class_name} {confidence:.2f}"
27
- for class_name, confidence
28
- in zip(detections['class_name'], detections.confidence)
29
- ]
30
- annotated_image = image.copy()
31
- annotated_image = supervision.MaskAnnotator(color_lookup=supervision.ColorLookup.INDEX, opacity=0.4).annotate(
32
- scene=annotated_image, detections=detections)
33
- annotated_image = supervision.BoxAnnotator(color_lookup=supervision.ColorLookup.INDEX,
34
- thickness=thickness).annotate(
35
- scene=annotated_image, detections=detections)
36
- annotated_image = supervision.LabelAnnotator(color_lookup=supervision.ColorLookup.INDEX, text_scale=text_scale,
37
- smart_position=True).annotate(
38
- scene=annotated_image, detections=detections, labels=labels)
39
- results.append(annotated_image)
40
- frames = [numpy.array(img) for img in results]
41
- output_clip = moviepy.ImageSequenceClip(frames, fps=1)
42
- buf = io.BytesIO()
43
- output_clip.write_videofile(buf, codec="libx264", audio=False)
44
- clip.close()
45
- buf.seek(0)
46
- return buf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
 
 
 
 
 
 
48
 
49
  if __name__ == "__main__":
50
- gradio.Interface(
51
- fn=inference,
52
- inputs=gradio.Video(),
53
- outputs=gradio.Video(),
54
- title="Video Object Detection",
55
- description="Upload a video to run object detection using YOLOE.",
56
- ).launch()
 
1
+ import os
2
+ import cv2
3
+ import tqdm
4
+ import uuid
5
+ import logging
6
+
7
+ import torch
8
  import spaces
9
+ import trackers
10
+ import numpy as np
11
+ import gradio as gr
12
+ import imageio.v3 as iio
13
+ import supervision as sv
14
+
15
+ from pathlib import Path
16
+ from functools import lru_cache
17
+ from typing import List, Optional, Tuple
18
+
19
+ from transformers import AutoModelForObjectDetection, AutoImageProcessor
20
+
21
+ # Configuration constants
22
+ CHECKPOINTS = [
23
+ "ustc-community/dfine-xlarge-obj2coco"
24
+ ]
25
+ DEFAULT_CHECKPOINT = CHECKPOINTS[0]
26
+ DEFAULT_CONFIDENCE_THRESHOLD = 0.3
27
+
28
+ TORCH_DTYPE = torch.float32
29
+
30
+ # Video
31
+ MAX_NUM_FRAMES = 250
32
+ BATCH_SIZE = 4
33
+ ALLOWED_VIDEO_EXTENSIONS = {".mp4", ".avi", ".mov"}
34
+ VIDEO_OUTPUT_DIR = Path("static/videos")
35
+ VIDEO_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
36
+
37
+
38
+ class TrackingAlgorithm:
39
+ BYTETRACK = "ByteTrack (2021)"
40
+ DEEPSORT = "DeepSORT (2017)"
41
+ SORT = "SORT (2016)"
42
+
43
+
44
+ TRACKERS = [None, TrackingAlgorithm.BYTETRACK, TrackingAlgorithm.DEEPSORT, TrackingAlgorithm.SORT]
45
+ VIDEO_EXAMPLES = [
46
+ {"path": "./examples/videos/dogs_running.mp4", "label": "Local Video", "tracker": None, "classes": "all"},
47
+ {"path": "./examples/videos/traffic.mp4", "label": "Local Video", "tracker": TrackingAlgorithm.BYTETRACK,
48
+ "classes": "car, truck, bus"},
49
+ {"path": "./examples/videos/fast_and_furious.mp4", "label": "Local Video", "tracker": None, "classes": "all"},
50
+ {"path": "./examples/videos/break_dance.mp4", "label": "Local Video", "tracker": None, "classes": "all"},
51
+ ]
52
+
53
+ # Create a color palette for visualization
54
+ # These hex color codes define different colors for tracking different objects
55
+ color = sv.ColorPalette.from_hex([
56
+ "#ffff00", "#ff9b00", "#ff8080", "#ff66b2", "#ff66ff", "#b266ff",
57
+ "#9999ff", "#3399ff", "#66ffff", "#33ff99", "#66ff66", "#99ff00"
58
+ ])
59
+
60
+ logging.basicConfig(
61
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
62
+ )
63
+ logger = logging.getLogger(__name__)
64
+
65
+
66
+ @lru_cache(maxsize=3)
67
+ def get_model_and_processor(checkpoint: str):
68
+ model = AutoModelForObjectDetection.from_pretrained(checkpoint, torch_dtype=TORCH_DTYPE)
69
+ image_processor = AutoImageProcessor.from_pretrained(checkpoint)
70
+ return model, image_processor
71
+
72
+
73
+ @spaces.GPU(duration=20)
74
+ def detect_objects(
75
+ checkpoint: str,
76
+ images: List[np.ndarray] | np.ndarray,
77
+ confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD,
78
+ target_size: Optional[Tuple[int, int]] = None,
79
+ batch_size: int = BATCH_SIZE,
80
+ classes: Optional[List[str]] = None,
81
+ ):
82
+ device = "cuda" if torch.cuda.is_available() else "cpu"
83
+ model, image_processor = get_model_and_processor(checkpoint)
84
+ model = model.to(device)
85
+
86
+ if classes is not None:
87
+ wrong_classes = [cls for cls in classes if cls not in model.config.label2id]
88
+ if wrong_classes:
89
+ gr.Warning(f"Classes not found in model config: {wrong_classes}")
90
+ keep_ids = [model.config.label2id[cls] for cls in classes if cls in model.config.label2id]
91
+ else:
92
+ keep_ids = None
93
+
94
+ if isinstance(images, np.ndarray) and images.ndim == 4:
95
+ images = [x for x in images] # split video array into list of images
96
+
97
+ batches = [images[i:i + batch_size] for i in range(0, len(images), batch_size)]
98
+
99
  results = []
100
+ for batch in tqdm.tqdm(batches, desc="Processing frames"):
101
+
102
+ # preprocess images
103
+ inputs = image_processor(images=batch, return_tensors="pt")
104
+ inputs = inputs.to(device).to(TORCH_DTYPE)
105
+
106
+ # forward pass
107
+ with torch.no_grad():
108
+ outputs = model(**inputs)
109
+
110
+ # postprocess outputs
111
+ if target_size:
112
+ target_sizes = [target_size] * len(batch)
113
+ else:
114
+ target_sizes = [(image.shape[0], image.shape[1]) for image in batch]
115
+
116
+ batch_results = image_processor.post_process_object_detection(
117
+ outputs, target_sizes=target_sizes, threshold=confidence_threshold
118
+ )
119
+
120
+ results.extend(batch_results)
121
+
122
+ # move results to cpu
123
+ for i, result in enumerate(results):
124
+ results[i] = {k: v.cpu() for k, v in result.items()}
125
+ if keep_ids is not None:
126
+ keep = torch.isin(results[i]["labels"], torch.tensor(keep_ids))
127
+ results[i] = {k: v[keep] for k, v in results[i].items()}
128
+
129
+ return results, model.config.id2label
130
+
131
+
132
+ def get_target_size(image_height, image_width, max_size: int):
133
+ if image_height < max_size and image_width < max_size:
134
+ new_height, new_width = image_height, image_width
135
+ elif image_height > image_width:
136
+ new_height = max_size
137
+ new_width = int(image_width * max_size / image_height)
138
+ else:
139
+ new_width = max_size
140
+ new_height = int(image_height * max_size / image_width)
141
+
142
+ # make even (for video codec compatibility)
143
+ new_height = new_height // 2 * 2
144
+ new_width = new_width // 2 * 2
145
+
146
+ return new_width, new_height
147
+
148
+
149
+ def read_video_k_frames(video_path: str, k: int, read_every_i_frame: int = 1):
150
+ cap = cv2.VideoCapture(video_path)
151
+ frames = []
152
+ i = 0
153
+ progress_bar = tqdm.tqdm(total=k, desc="Reading frames")
154
+ while cap.isOpened() and len(frames) < k:
155
+ ret, frame = cap.read()
156
+ if not ret:
157
+ break
158
+ if i % read_every_i_frame == 0:
159
+ frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
160
+ progress_bar.update(1)
161
+ i += 1
162
+ cap.release()
163
+ progress_bar.close()
164
+ return frames
165
+
166
+
167
+ def get_tracker(tracker: str, fps: float):
168
+ if tracker == TrackingAlgorithm.SORT:
169
+ return trackers.SORTTracker(frame_rate=fps)
170
+ elif tracker == TrackingAlgorithm.DEEPSORT:
171
+ feature_extractor = trackers.DeepSORTFeatureExtractor.from_timm("mobilenetv4_conv_small.e1200_r224_in1k",
172
+ device="cpu")
173
+ return trackers.DeepSORTTracker(feature_extractor, frame_rate=fps)
174
+ elif tracker == TrackingAlgorithm.BYTETRACK:
175
+ return sv.ByteTrack(frame_rate=int(fps))
176
+ else:
177
+ raise ValueError(f"Invalid tracker: {tracker}")
178
+
179
+
180
+ def update_tracker(tracker, detections, frame):
181
+ tracker_name = tracker.__class__.__name__
182
+ if tracker_name == "SORTTracker":
183
+ return tracker.update(detections)
184
+ elif tracker_name == "DeepSORTTracker":
185
+ return tracker.update(detections, frame)
186
+ elif tracker_name == "ByteTrack":
187
+ return tracker.update_with_detections(detections)
188
+ else:
189
+ raise ValueError(f"Invalid tracker: {tracker}")
190
+
191
+
192
+ def process_video(
193
+ video_path: str,
194
+ checkpoint: str,
195
+ tracker_algorithm: Optional[str] = None,
196
+ classes: str = "all",
197
+ confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD,
198
+ progress: gr.Progress = gr.Progress(track_tqdm=True),
199
+ ) -> str:
200
+ if not video_path or not os.path.isfile(video_path):
201
+ raise ValueError(f"Invalid video path: {video_path}")
202
+
203
+ ext = os.path.splitext(video_path)[1].lower()
204
+ if ext not in ALLOWED_VIDEO_EXTENSIONS:
205
+ raise ValueError(f"Unsupported video format: {ext}, supported formats: {ALLOWED_VIDEO_EXTENSIONS}")
206
+
207
+ video_info = sv.VideoInfo.from_video_path(video_path)
208
+ read_each_i_frame = max(1, video_info.fps // 25)
209
+ target_fps = video_info.fps / read_each_i_frame
210
+ target_width, target_height = get_target_size(video_info.height, video_info.width, 1080)
211
+
212
+ n_frames_to_read = min(MAX_NUM_FRAMES, video_info.total_frames // read_each_i_frame)
213
+ frames = read_video_k_frames(video_path, n_frames_to_read, read_each_i_frame)
214
+ frames = [cv2.resize(frame, (target_width, target_height), interpolation=cv2.INTER_CUBIC) for frame in frames]
215
+
216
+ # Set the color lookup mode to assign colors by track ID
217
+ # This mean objects with the same track ID will be annotated by the same color
218
+ color_lookup = sv.ColorLookup.TRACK if tracker_algorithm else sv.ColorLookup.CLASS
219
+
220
+ box_annotator = sv.BoxAnnotator(color, color_lookup=color_lookup, thickness=1)
221
+ label_annotator = sv.LabelAnnotator(color, color_lookup=color_lookup, text_scale=0.5)
222
+ trace_annotator = sv.TraceAnnotator(color, color_lookup=color_lookup, thickness=1, trace_length=100)
223
+
224
+ # preprocess classes
225
+ if classes != "all":
226
+ classes_list = [cls.strip().lower() for cls in classes.split(",")]
227
+ else:
228
+ classes_list = None
229
+
230
+ results, id2label = detect_objects(
231
+ images=np.array(frames),
232
+ checkpoint=checkpoint,
233
+ confidence_threshold=confidence_threshold,
234
+ target_size=(target_height, target_width),
235
+ classes=classes_list,
236
+ )
237
+
238
+ annotated_frames = []
239
+
240
+ # detections
241
+ if tracker_algorithm:
242
+ tracker = get_tracker(tracker_algorithm, target_fps)
243
+ for frame, result in progress.tqdm(zip(frames, results), desc="Tracking objects", total=len(frames)):
244
+ detections = sv.Detections.from_transformers(result, id2label=id2label)
245
+ detections = detections.with_nms(threshold=0.95, class_agnostic=True)
246
+ detections = update_tracker(tracker, detections, frame)
247
+ labels = [f"#{tracker_id} {id2label[class_id]}" for class_id, tracker_id in
248
+ zip(detections.class_id, detections.tracker_id)]
249
+ annotated_frame = box_annotator.annotate(scene=frame, detections=detections)
250
+ annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
251
+ annotated_frame = trace_annotator.annotate(scene=annotated_frame, detections=detections)
252
+ annotated_frames.append(annotated_frame)
253
+
254
+ else:
255
+ for frame, result in tqdm.tqdm(zip(frames, results), desc="Annotating frames", total=len(frames)):
256
+ detections = sv.Detections.from_transformers(result, id2label=id2label)
257
+ detections = detections.with_nms(threshold=0.95, class_agnostic=True)
258
+ annotated_frame = box_annotator.annotate(scene=frame, detections=detections)
259
+ annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections)
260
+ annotated_frames.append(annotated_frame)
261
+
262
+ output_filename = os.path.join(VIDEO_OUTPUT_DIR, f"output_{uuid.uuid4()}.mp4")
263
+ iio.imwrite(output_filename, annotated_frames, fps=target_fps, codec="h264")
264
+ return output_filename
265
+
266
+
267
+ def create_video_inputs() -> List[gr.components.Component]:
268
+ return [
269
+ gr.Video(
270
+ label="Upload Video",
271
+ sources=["upload"],
272
+ interactive=True,
273
+ format="mp4", # Ensure MP4 format
274
+ elem_classes="input-component",
275
+ ),
276
+ gr.Dropdown(
277
+ choices=CHECKPOINTS,
278
+ label="Select Model Checkpoint",
279
+ value=DEFAULT_CHECKPOINT,
280
+ elem_classes="input-component",
281
+ ),
282
+ gr.Dropdown(
283
+ choices=TRACKERS,
284
+ label="Select Tracker (Optional)",
285
+ value=None,
286
+ elem_classes="input-component",
287
+ ),
288
+ gr.TextArea(
289
+ label="Specify Class Names to Detect (comma separated)",
290
+ value="all",
291
+ lines=1,
292
+ elem_classes="input-component",
293
+ ),
294
+ gr.Slider(
295
+ minimum=0.1,
296
+ maximum=1.0,
297
+ value=DEFAULT_CONFIDENCE_THRESHOLD,
298
+ step=0.1,
299
+ label="Confidence Threshold",
300
+ elem_classes="input-component",
301
+ ),
302
+ ]
303
+
304
+
305
+ def create_button_row() -> List[gr.Button]:
306
+ return [
307
+ gr.Button(
308
+ f"Detect Objects", variant="primary", elem_classes="action-button"
309
+ ),
310
+ gr.Button(f"Clear", variant="secondary", elem_classes="action-button"),
311
+ ]
312
+
313
+
314
+ # Gradio interface
315
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
316
+ gr.Markdown(
317
+ """
318
+ # Aircraft Detection Demo
319
+ ## Input your video and see the detected objects!
320
+ """,
321
+ elem_classes="header-text",
322
+ )
323
+
324
+ with gr.Tabs():
325
+ with gr.Tab("Video"):
326
+ gr.Markdown(
327
+ f"The input video will be processed in ~25 FPS (up to {MAX_NUM_FRAMES} frames in result)."
328
+ )
329
+ with gr.Row():
330
+ with gr.Column(scale=1, min_width=300):
331
+ with gr.Group():
332
+ video_input, video_checkpoint, video_tracker, video_classes, video_confidence_threshold = create_video_inputs()
333
+ video_detect_button, video_clear_button = create_button_row()
334
+ with gr.Column(scale=2):
335
+ video_output = gr.Video(
336
+ label="Detection Results",
337
+ format="mp4", # Explicit MP4 format
338
+ elem_classes="output-component",
339
+ )
340
+
341
+ gr.Examples(
342
+ examples=[
343
+ [example["path"], DEFAULT_CHECKPOINT, example["tracker"], example["classes"],
344
+ DEFAULT_CONFIDENCE_THRESHOLD]
345
+ for example in VIDEO_EXAMPLES
346
+ ],
347
+ inputs=[video_input, video_checkpoint, video_tracker, video_classes, video_confidence_threshold],
348
+ outputs=[video_output],
349
+ fn=process_video,
350
+ cache_examples=False,
351
+ label="Select a video example to populate inputs",
352
+ )
353
+
354
+ # Video clear button
355
+ video_clear_button.click(
356
+ fn=lambda: (
357
+ None,
358
+ DEFAULT_CHECKPOINT,
359
+ None,
360
+ "all",
361
+ DEFAULT_CONFIDENCE_THRESHOLD,
362
+ None,
363
+ ),
364
+ outputs=[
365
+ video_input,
366
+ video_checkpoint,
367
+ video_tracker,
368
+ video_classes,
369
+ video_confidence_threshold,
370
+ video_output,
371
+ ],
372
+ )
373
 
374
+ # Video detect button
375
+ video_detect_button.click(
376
+ fn=process_video,
377
+ inputs=[video_input, video_checkpoint, video_tracker, video_classes, video_confidence_threshold],
378
+ outputs=[video_output],
379
+ )
380
 
381
  if __name__ == "__main__":
382
+ demo.queue(max_size=20).launch()