Spaces:
Running
on
Zero
Running
on
Zero
Upload app.py
Browse files
app.py
CHANGED
@@ -41,15 +41,6 @@ class TrackingAlgorithm:
|
|
41 |
SORT = "SORT (2016)"
|
42 |
|
43 |
|
44 |
-
TRACKERS = [None, TrackingAlgorithm.BYTETRACK, TrackingAlgorithm.DEEPSORT, TrackingAlgorithm.SORT]
|
45 |
-
VIDEO_EXAMPLES = [
|
46 |
-
{"path": "./examples/videos/dogs_running.mp4", "label": "Local Video", "tracker": None, "classes": "all"},
|
47 |
-
{"path": "./examples/videos/traffic.mp4", "label": "Local Video", "tracker": TrackingAlgorithm.BYTETRACK,
|
48 |
-
"classes": "car, truck, bus"},
|
49 |
-
{"path": "./examples/videos/fast_and_furious.mp4", "label": "Local Video", "tracker": None, "classes": "all"},
|
50 |
-
{"path": "./examples/videos/break_dance.mp4", "label": "Local Video", "tracker": None, "classes": "all"},
|
51 |
-
]
|
52 |
-
|
53 |
# Create a color palette for visualization
|
54 |
# These hex color codes define different colors for tracking different objects
|
55 |
color = sv.ColorPalette.from_hex([
|
@@ -72,17 +63,16 @@ def get_model_and_processor(checkpoint: str):
|
|
72 |
|
73 |
@spaces.GPU(duration=20)
|
74 |
def detect_objects(
|
75 |
-
checkpoint: str,
|
76 |
images: List[np.ndarray] | np.ndarray,
|
77 |
-
confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD,
|
78 |
target_size: Optional[Tuple[int, int]] = None,
|
79 |
-
batch_size: int = BATCH_SIZE
|
80 |
-
classes: Optional[List[str]] = None,
|
81 |
):
|
|
|
|
|
82 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
83 |
model, image_processor = get_model_and_processor(checkpoint)
|
84 |
model = model.to(device)
|
85 |
-
|
86 |
if classes is not None:
|
87 |
wrong_classes = [cls for cls in classes if cls not in model.config.label2id]
|
88 |
if wrong_classes:
|
@@ -92,7 +82,7 @@ def detect_objects(
|
|
92 |
keep_ids = None
|
93 |
|
94 |
if isinstance(images, np.ndarray) and images.ndim == 4:
|
95 |
-
images = [x for x in images]
|
96 |
|
97 |
batches = [images[i:i + batch_size] for i in range(0, len(images), batch_size)]
|
98 |
|
@@ -164,7 +154,8 @@ def read_video_k_frames(video_path: str, k: int, read_every_i_frame: int = 1):
|
|
164 |
return frames
|
165 |
|
166 |
|
167 |
-
def get_tracker(
|
|
|
168 |
if tracker == TrackingAlgorithm.SORT:
|
169 |
return trackers.SORTTracker(frame_rate=fps)
|
170 |
elif tracker == TrackingAlgorithm.DEEPSORT:
|
@@ -272,33 +263,7 @@ def create_video_inputs() -> List[gr.components.Component]:
|
|
272 |
interactive=True,
|
273 |
format="mp4", # Ensure MP4 format
|
274 |
elem_classes="input-component",
|
275 |
-
)
|
276 |
-
gr.Dropdown(
|
277 |
-
choices=CHECKPOINTS,
|
278 |
-
label="Select Model Checkpoint",
|
279 |
-
value=DEFAULT_CHECKPOINT,
|
280 |
-
elem_classes="input-component",
|
281 |
-
),
|
282 |
-
gr.Dropdown(
|
283 |
-
choices=TRACKERS,
|
284 |
-
label="Select Tracker (Optional)",
|
285 |
-
value=None,
|
286 |
-
elem_classes="input-component",
|
287 |
-
),
|
288 |
-
gr.TextArea(
|
289 |
-
label="Specify Class Names to Detect (comma separated)",
|
290 |
-
value="all",
|
291 |
-
lines=1,
|
292 |
-
elem_classes="input-component",
|
293 |
-
),
|
294 |
-
gr.Slider(
|
295 |
-
minimum=0.1,
|
296 |
-
maximum=1.0,
|
297 |
-
value=DEFAULT_CONFIDENCE_THRESHOLD,
|
298 |
-
step=0.1,
|
299 |
-
label="Confidence Threshold",
|
300 |
-
elem_classes="input-component",
|
301 |
-
),
|
302 |
]
|
303 |
|
304 |
|
@@ -329,7 +294,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
329 |
with gr.Row():
|
330 |
with gr.Column(scale=1, min_width=300):
|
331 |
with gr.Group():
|
332 |
-
video_input
|
333 |
video_detect_button, video_clear_button = create_button_row()
|
334 |
with gr.Column(scale=2):
|
335 |
video_output = gr.Video(
|
@@ -337,44 +302,19 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
337 |
format="mp4", # Explicit MP4 format
|
338 |
elem_classes="output-component",
|
339 |
)
|
340 |
-
|
341 |
-
gr.Examples(
|
342 |
-
examples=[
|
343 |
-
[example["path"], DEFAULT_CHECKPOINT, example["tracker"], example["classes"],
|
344 |
-
DEFAULT_CONFIDENCE_THRESHOLD]
|
345 |
-
for example in VIDEO_EXAMPLES
|
346 |
-
],
|
347 |
-
inputs=[video_input, video_checkpoint, video_tracker, video_classes, video_confidence_threshold],
|
348 |
-
outputs=[video_output],
|
349 |
-
fn=process_video,
|
350 |
-
cache_examples=False,
|
351 |
-
label="Select a video example to populate inputs",
|
352 |
-
)
|
353 |
-
|
354 |
-
# Video clear button
|
355 |
video_clear_button.click(
|
356 |
fn=lambda: (
|
357 |
None,
|
358 |
-
DEFAULT_CHECKPOINT,
|
359 |
-
None,
|
360 |
-
"all",
|
361 |
-
DEFAULT_CONFIDENCE_THRESHOLD,
|
362 |
None,
|
363 |
),
|
364 |
outputs=[
|
365 |
video_input,
|
366 |
-
video_checkpoint,
|
367 |
-
video_tracker,
|
368 |
-
video_classes,
|
369 |
-
video_confidence_threshold,
|
370 |
video_output,
|
371 |
],
|
372 |
)
|
373 |
-
|
374 |
-
# Video detect button
|
375 |
video_detect_button.click(
|
376 |
fn=process_video,
|
377 |
-
inputs=[video_input
|
378 |
outputs=[video_output],
|
379 |
)
|
380 |
|
|
|
41 |
SORT = "SORT (2016)"
|
42 |
|
43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
# Create a color palette for visualization
|
45 |
# These hex color codes define different colors for tracking different objects
|
46 |
color = sv.ColorPalette.from_hex([
|
|
|
63 |
|
64 |
@spaces.GPU(duration=20)
|
65 |
def detect_objects(
|
|
|
66 |
images: List[np.ndarray] | np.ndarray,
|
|
|
67 |
target_size: Optional[Tuple[int, int]] = None,
|
68 |
+
batch_size: int = BATCH_SIZE
|
|
|
69 |
):
|
70 |
+
checkpoint = "ustc-community/dfine-xlarge-obj2coco"
|
71 |
+
confidence_threshold = 0.3
|
72 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
73 |
model, image_processor = get_model_and_processor(checkpoint)
|
74 |
model = model.to(device)
|
75 |
+
classes = ["Airplane", "Drone", "Helicopter", "Satellite", "Quadcopter", "Vehicle"]
|
76 |
if classes is not None:
|
77 |
wrong_classes = [cls for cls in classes if cls not in model.config.label2id]
|
78 |
if wrong_classes:
|
|
|
82 |
keep_ids = None
|
83 |
|
84 |
if isinstance(images, np.ndarray) and images.ndim == 4:
|
85 |
+
images = [x for x in images]
|
86 |
|
87 |
batches = [images[i:i + batch_size] for i in range(0, len(images), batch_size)]
|
88 |
|
|
|
154 |
return frames
|
155 |
|
156 |
|
157 |
+
def get_tracker(fps: float):
|
158 |
+
tracker = TrackingAlgorithm.BYTETRACK
|
159 |
if tracker == TrackingAlgorithm.SORT:
|
160 |
return trackers.SORTTracker(frame_rate=fps)
|
161 |
elif tracker == TrackingAlgorithm.DEEPSORT:
|
|
|
263 |
interactive=True,
|
264 |
format="mp4", # Ensure MP4 format
|
265 |
elem_classes="input-component",
|
266 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
267 |
]
|
268 |
|
269 |
|
|
|
294 |
with gr.Row():
|
295 |
with gr.Column(scale=1, min_width=300):
|
296 |
with gr.Group():
|
297 |
+
video_input = create_video_inputs()
|
298 |
video_detect_button, video_clear_button = create_button_row()
|
299 |
with gr.Column(scale=2):
|
300 |
video_output = gr.Video(
|
|
|
302 |
format="mp4", # Explicit MP4 format
|
303 |
elem_classes="output-component",
|
304 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
305 |
video_clear_button.click(
|
306 |
fn=lambda: (
|
307 |
None,
|
|
|
|
|
|
|
|
|
308 |
None,
|
309 |
),
|
310 |
outputs=[
|
311 |
video_input,
|
|
|
|
|
|
|
|
|
312 |
video_output,
|
313 |
],
|
314 |
)
|
|
|
|
|
315 |
video_detect_button.click(
|
316 |
fn=process_video,
|
317 |
+
inputs=[video_input],
|
318 |
outputs=[video_output],
|
319 |
)
|
320 |
|