denizaybey commited on
Commit
5f2f6f3
·
verified ·
1 Parent(s): b834543

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -70
app.py CHANGED
@@ -41,15 +41,6 @@ class TrackingAlgorithm:
41
  SORT = "SORT (2016)"
42
 
43
 
44
- TRACKERS = [None, TrackingAlgorithm.BYTETRACK, TrackingAlgorithm.DEEPSORT, TrackingAlgorithm.SORT]
45
- VIDEO_EXAMPLES = [
46
- {"path": "./examples/videos/dogs_running.mp4", "label": "Local Video", "tracker": None, "classes": "all"},
47
- {"path": "./examples/videos/traffic.mp4", "label": "Local Video", "tracker": TrackingAlgorithm.BYTETRACK,
48
- "classes": "car, truck, bus"},
49
- {"path": "./examples/videos/fast_and_furious.mp4", "label": "Local Video", "tracker": None, "classes": "all"},
50
- {"path": "./examples/videos/break_dance.mp4", "label": "Local Video", "tracker": None, "classes": "all"},
51
- ]
52
-
53
  # Create a color palette for visualization
54
  # These hex color codes define different colors for tracking different objects
55
  color = sv.ColorPalette.from_hex([
@@ -72,17 +63,16 @@ def get_model_and_processor(checkpoint: str):
72
 
73
  @spaces.GPU(duration=20)
74
  def detect_objects(
75
- checkpoint: str,
76
  images: List[np.ndarray] | np.ndarray,
77
- confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD,
78
  target_size: Optional[Tuple[int, int]] = None,
79
- batch_size: int = BATCH_SIZE,
80
- classes: Optional[List[str]] = None,
81
  ):
 
 
82
  device = "cuda" if torch.cuda.is_available() else "cpu"
83
  model, image_processor = get_model_and_processor(checkpoint)
84
  model = model.to(device)
85
-
86
  if classes is not None:
87
  wrong_classes = [cls for cls in classes if cls not in model.config.label2id]
88
  if wrong_classes:
@@ -92,7 +82,7 @@ def detect_objects(
92
  keep_ids = None
93
 
94
  if isinstance(images, np.ndarray) and images.ndim == 4:
95
- images = [x for x in images] # split video array into list of images
96
 
97
  batches = [images[i:i + batch_size] for i in range(0, len(images), batch_size)]
98
 
@@ -164,7 +154,8 @@ def read_video_k_frames(video_path: str, k: int, read_every_i_frame: int = 1):
164
  return frames
165
 
166
 
167
- def get_tracker(tracker: str, fps: float):
 
168
  if tracker == TrackingAlgorithm.SORT:
169
  return trackers.SORTTracker(frame_rate=fps)
170
  elif tracker == TrackingAlgorithm.DEEPSORT:
@@ -272,33 +263,7 @@ def create_video_inputs() -> List[gr.components.Component]:
272
  interactive=True,
273
  format="mp4", # Ensure MP4 format
274
  elem_classes="input-component",
275
- ),
276
- gr.Dropdown(
277
- choices=CHECKPOINTS,
278
- label="Select Model Checkpoint",
279
- value=DEFAULT_CHECKPOINT,
280
- elem_classes="input-component",
281
- ),
282
- gr.Dropdown(
283
- choices=TRACKERS,
284
- label="Select Tracker (Optional)",
285
- value=None,
286
- elem_classes="input-component",
287
- ),
288
- gr.TextArea(
289
- label="Specify Class Names to Detect (comma separated)",
290
- value="all",
291
- lines=1,
292
- elem_classes="input-component",
293
- ),
294
- gr.Slider(
295
- minimum=0.1,
296
- maximum=1.0,
297
- value=DEFAULT_CONFIDENCE_THRESHOLD,
298
- step=0.1,
299
- label="Confidence Threshold",
300
- elem_classes="input-component",
301
- ),
302
  ]
303
 
304
 
@@ -329,7 +294,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
329
  with gr.Row():
330
  with gr.Column(scale=1, min_width=300):
331
  with gr.Group():
332
- video_input, video_checkpoint, video_tracker, video_classes, video_confidence_threshold = create_video_inputs()
333
  video_detect_button, video_clear_button = create_button_row()
334
  with gr.Column(scale=2):
335
  video_output = gr.Video(
@@ -337,44 +302,19 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
337
  format="mp4", # Explicit MP4 format
338
  elem_classes="output-component",
339
  )
340
-
341
- gr.Examples(
342
- examples=[
343
- [example["path"], DEFAULT_CHECKPOINT, example["tracker"], example["classes"],
344
- DEFAULT_CONFIDENCE_THRESHOLD]
345
- for example in VIDEO_EXAMPLES
346
- ],
347
- inputs=[video_input, video_checkpoint, video_tracker, video_classes, video_confidence_threshold],
348
- outputs=[video_output],
349
- fn=process_video,
350
- cache_examples=False,
351
- label="Select a video example to populate inputs",
352
- )
353
-
354
- # Video clear button
355
  video_clear_button.click(
356
  fn=lambda: (
357
  None,
358
- DEFAULT_CHECKPOINT,
359
- None,
360
- "all",
361
- DEFAULT_CONFIDENCE_THRESHOLD,
362
  None,
363
  ),
364
  outputs=[
365
  video_input,
366
- video_checkpoint,
367
- video_tracker,
368
- video_classes,
369
- video_confidence_threshold,
370
  video_output,
371
  ],
372
  )
373
-
374
- # Video detect button
375
  video_detect_button.click(
376
  fn=process_video,
377
- inputs=[video_input, video_checkpoint, video_tracker, video_classes, video_confidence_threshold],
378
  outputs=[video_output],
379
  )
380
 
 
41
  SORT = "SORT (2016)"
42
 
43
 
 
 
 
 
 
 
 
 
 
44
  # Create a color palette for visualization
45
  # These hex color codes define different colors for tracking different objects
46
  color = sv.ColorPalette.from_hex([
 
63
 
64
  @spaces.GPU(duration=20)
65
  def detect_objects(
 
66
  images: List[np.ndarray] | np.ndarray,
 
67
  target_size: Optional[Tuple[int, int]] = None,
68
+ batch_size: int = BATCH_SIZE
 
69
  ):
70
+ checkpoint = "ustc-community/dfine-xlarge-obj2coco"
71
+ confidence_threshold = 0.3
72
  device = "cuda" if torch.cuda.is_available() else "cpu"
73
  model, image_processor = get_model_and_processor(checkpoint)
74
  model = model.to(device)
75
+ classes = ["Airplane", "Drone", "Helicopter", "Satellite", "Quadcopter", "Vehicle"]
76
  if classes is not None:
77
  wrong_classes = [cls for cls in classes if cls not in model.config.label2id]
78
  if wrong_classes:
 
82
  keep_ids = None
83
 
84
  if isinstance(images, np.ndarray) and images.ndim == 4:
85
+ images = [x for x in images]
86
 
87
  batches = [images[i:i + batch_size] for i in range(0, len(images), batch_size)]
88
 
 
154
  return frames
155
 
156
 
157
+ def get_tracker(fps: float):
158
+ tracker = TrackingAlgorithm.BYTETRACK
159
  if tracker == TrackingAlgorithm.SORT:
160
  return trackers.SORTTracker(frame_rate=fps)
161
  elif tracker == TrackingAlgorithm.DEEPSORT:
 
263
  interactive=True,
264
  format="mp4", # Ensure MP4 format
265
  elem_classes="input-component",
266
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  ]
268
 
269
 
 
294
  with gr.Row():
295
  with gr.Column(scale=1, min_width=300):
296
  with gr.Group():
297
+ video_input = create_video_inputs()
298
  video_detect_button, video_clear_button = create_button_row()
299
  with gr.Column(scale=2):
300
  video_output = gr.Video(
 
302
  format="mp4", # Explicit MP4 format
303
  elem_classes="output-component",
304
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
  video_clear_button.click(
306
  fn=lambda: (
307
  None,
 
 
 
 
308
  None,
309
  ),
310
  outputs=[
311
  video_input,
 
 
 
 
312
  video_output,
313
  ],
314
  )
 
 
315
  video_detect_button.click(
316
  fn=process_video,
317
+ inputs=[video_input],
318
  outputs=[video_output],
319
  )
320