chongzhou commited on
Commit
defc201
·
1 Parent(s): 93472c9

remove predictor from state

Browse files
Files changed (1) hide show
  1. app.py +44 -35
app.py CHANGED
@@ -87,21 +87,45 @@ def get_video_fps(video_path):
87
  return fps
88
 
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  def reset(
91
  first_frame,
92
  all_frames,
93
  input_points,
94
  input_labels,
95
  inference_state,
96
- predictor,
97
  ):
98
  first_frame = None
99
  all_frames = None
100
  input_points = []
101
  input_labels = []
102
 
103
- if inference_state and predictor:
104
- predictor.reset_state(inference_state)
105
  inference_state = None
106
  return (
107
  None,
@@ -114,7 +138,6 @@ def reset(
114
  input_points,
115
  input_labels,
116
  inference_state,
117
- predictor,
118
  )
119
 
120
 
@@ -124,12 +147,11 @@ def clear_points(
124
  input_points,
125
  input_labels,
126
  inference_state,
127
- predictor,
128
  ):
129
  input_points = []
130
  input_labels = []
131
- if inference_state and predictor and inference_state["tracking_has_started"]:
132
- predictor.reset_state(inference_state)
133
  return (
134
  first_frame,
135
  None,
@@ -139,7 +161,6 @@ def clear_points(
139
  input_points,
140
  input_labels,
141
  inference_state,
142
- predictor,
143
  )
144
 
145
 
@@ -150,7 +171,6 @@ def preprocess_video_in(
150
  input_points,
151
  input_labels,
152
  inference_state,
153
- predictor,
154
  ):
155
  if video_path is None:
156
  return (
@@ -163,7 +183,6 @@ def preprocess_video_in(
163
  input_points,
164
  input_labels,
165
  inference_state,
166
- predictor,
167
  )
168
 
169
  # Read the first frame
@@ -180,12 +199,8 @@ def preprocess_video_in(
180
  input_points,
181
  input_labels,
182
  inference_state,
183
- predictor,
184
  )
185
 
186
- if predictor is None:
187
- predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
188
-
189
  frame_number = 0
190
  _first_frame = None
191
  all_frames = []
@@ -207,10 +222,19 @@ def preprocess_video_in(
207
 
208
  cap.release()
209
  first_frame = copy.deepcopy(_first_frame)
210
- inference_state = predictor.init_state(video_path=video_path)
211
  input_points = []
212
  input_labels = []
213
 
 
 
 
 
 
 
 
 
 
 
214
  return [
215
  gr.update(open=False), # video_in_drawer
216
  first_frame, # points_map
@@ -221,7 +245,6 @@ def preprocess_video_in(
221
  input_points,
222
  input_labels,
223
  inference_state,
224
- predictor,
225
  ]
226
 
227
 
@@ -232,9 +255,9 @@ def segment_with_points(
232
  input_points,
233
  input_labels,
234
  inference_state,
235
- predictor,
236
  evt: gr.SelectData,
237
  ):
 
238
  if torch.cuda.is_available():
239
  predictor.to("cuda")
240
  inference_state["device"] = "cuda"
@@ -299,7 +322,6 @@ def segment_with_points(
299
  input_points,
300
  input_labels,
301
  inference_state,
302
- predictor,
303
  )
304
 
305
 
@@ -325,8 +347,8 @@ def propagate_to_all(
325
  input_points,
326
  input_labels,
327
  inference_state,
328
- predictor,
329
  ):
 
330
  if torch.cuda.is_available():
331
  predictor.to("cuda")
332
  inference_state["device"] = "cuda"
@@ -383,15 +405,15 @@ def propagate_to_all(
383
  input_points,
384
  input_labels,
385
  inference_state,
386
- predictor,
387
  )
388
 
389
 
390
  try:
391
  from spaces import GPU
392
 
393
- segment_with_points = GPU(segment_with_points)
394
- propagate_to_all = GPU(propagate_to_all)
 
395
  except:
396
  print("spaces unavailable")
397
 
@@ -406,7 +428,6 @@ with gr.Blocks() as demo:
406
  input_points = gr.State([])
407
  input_labels = gr.State([])
408
  inference_state = gr.State()
409
- predictor = gr.State()
410
 
411
  with gr.Column():
412
  # Title
@@ -461,7 +482,6 @@ with gr.Blocks() as demo:
461
  input_points,
462
  input_labels,
463
  inference_state,
464
- predictor,
465
  ],
466
  outputs=[
467
  video_in_drawer, # Accordion to hide uploaded video player
@@ -473,7 +493,6 @@ with gr.Blocks() as demo:
473
  input_points,
474
  input_labels,
475
  inference_state,
476
- predictor,
477
  ],
478
  queue=False,
479
  )
@@ -487,7 +506,6 @@ with gr.Blocks() as demo:
487
  input_points,
488
  input_labels,
489
  inference_state,
490
- predictor,
491
  ],
492
  outputs=[
493
  video_in_drawer, # Accordion to hide uploaded video player
@@ -499,7 +517,6 @@ with gr.Blocks() as demo:
499
  input_points,
500
  input_labels,
501
  inference_state,
502
- predictor,
503
  ],
504
  queue=False,
505
  )
@@ -514,7 +531,6 @@ with gr.Blocks() as demo:
514
  input_points,
515
  input_labels,
516
  inference_state,
517
- predictor,
518
  ],
519
  outputs=[
520
  points_map, # updated image with points
@@ -524,7 +540,6 @@ with gr.Blocks() as demo:
524
  input_points,
525
  input_labels,
526
  inference_state,
527
- predictor,
528
  ],
529
  queue=False,
530
  )
@@ -538,7 +553,6 @@ with gr.Blocks() as demo:
538
  input_points,
539
  input_labels,
540
  inference_state,
541
- predictor,
542
  ],
543
  outputs=[
544
  points_map,
@@ -549,7 +563,6 @@ with gr.Blocks() as demo:
549
  input_points,
550
  input_labels,
551
  inference_state,
552
- predictor,
553
  ],
554
  queue=False,
555
  )
@@ -562,7 +575,6 @@ with gr.Blocks() as demo:
562
  input_points,
563
  input_labels,
564
  inference_state,
565
- predictor,
566
  ],
567
  outputs=[
568
  video_in,
@@ -575,7 +587,6 @@ with gr.Blocks() as demo:
575
  input_points,
576
  input_labels,
577
  inference_state,
578
- predictor,
579
  ],
580
  queue=False,
581
  )
@@ -594,7 +605,6 @@ with gr.Blocks() as demo:
594
  input_points,
595
  input_labels,
596
  inference_state,
597
- predictor,
598
  ],
599
  outputs=[
600
  output_video,
@@ -603,7 +613,6 @@ with gr.Blocks() as demo:
603
  input_points,
604
  input_labels,
605
  inference_state,
606
- predictor,
607
  ],
608
  concurrency_limit=10,
609
  queue=False,
 
87
  return fps
88
 
89
 
90
+ def reset_state(inference_state):
91
+ for v in inference_state["point_inputs_per_obj"].values():
92
+ v.clear()
93
+ for v in inference_state["mask_inputs_per_obj"].values():
94
+ v.clear()
95
+ for v in inference_state["output_dict_per_obj"].values():
96
+ v["cond_frame_outputs"].clear()
97
+ v["non_cond_frame_outputs"].clear()
98
+ for v in inference_state["temp_output_dict_per_obj"].values():
99
+ v["cond_frame_outputs"].clear()
100
+ v["non_cond_frame_outputs"].clear()
101
+ inference_state["output_dict"]["cond_frame_outputs"].clear()
102
+ inference_state["output_dict"]["non_cond_frame_outputs"].clear()
103
+ inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear()
104
+ inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear()
105
+ inference_state["tracking_has_started"] = False
106
+ inference_state["frames_already_tracked"].clear()
107
+ inference_state["obj_id_to_idx"].clear()
108
+ inference_state["obj_idx_to_id"].clear()
109
+ inference_state["obj_ids"].clear()
110
+ inference_state["point_inputs_per_obj"].clear()
111
+ inference_state["mask_inputs_per_obj"].clear()
112
+ inference_state["output_dict_per_obj"].clear()
113
+ inference_state["temp_output_dict_per_obj"].clear()
114
+ return inference_state
115
+
116
+
117
  def reset(
118
  first_frame,
119
  all_frames,
120
  input_points,
121
  input_labels,
122
  inference_state,
 
123
  ):
124
  first_frame = None
125
  all_frames = None
126
  input_points = []
127
  input_labels = []
128
 
 
 
129
  inference_state = None
130
  return (
131
  None,
 
138
  input_points,
139
  input_labels,
140
  inference_state,
 
141
  )
142
 
143
 
 
147
  input_points,
148
  input_labels,
149
  inference_state,
 
150
  ):
151
  input_points = []
152
  input_labels = []
153
+ if inference_state and inference_state["tracking_has_started"]:
154
+ inference_state = reset_state(inference_state)
155
  return (
156
  first_frame,
157
  None,
 
161
  input_points,
162
  input_labels,
163
  inference_state,
 
164
  )
165
 
166
 
 
171
  input_points,
172
  input_labels,
173
  inference_state,
 
174
  ):
175
  if video_path is None:
176
  return (
 
183
  input_points,
184
  input_labels,
185
  inference_state,
 
186
  )
187
 
188
  # Read the first frame
 
199
  input_points,
200
  input_labels,
201
  inference_state,
 
202
  )
203
 
 
 
 
204
  frame_number = 0
205
  _first_frame = None
206
  all_frames = []
 
222
 
223
  cap.release()
224
  first_frame = copy.deepcopy(_first_frame)
 
225
  input_points = []
226
  input_labels = []
227
 
228
+ predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
229
+ if torch.cuda.is_available():
230
+ predictor.to("cuda")
231
+ inference_state["device"] = "cuda"
232
+ if torch.cuda.get_device_properties(0).major >= 8:
233
+ torch.backends.cuda.matmul.allow_tf32 = True
234
+ torch.backends.cudnn.allow_tf32 = True
235
+ torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
236
+ inference_state = predictor.init_state(video_path=video_path)
237
+
238
  return [
239
  gr.update(open=False), # video_in_drawer
240
  first_frame, # points_map
 
245
  input_points,
246
  input_labels,
247
  inference_state,
 
248
  ]
249
 
250
 
 
255
  input_points,
256
  input_labels,
257
  inference_state,
 
258
  evt: gr.SelectData,
259
  ):
260
+ predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
261
  if torch.cuda.is_available():
262
  predictor.to("cuda")
263
  inference_state["device"] = "cuda"
 
322
  input_points,
323
  input_labels,
324
  inference_state,
 
325
  )
326
 
327
 
 
347
  input_points,
348
  input_labels,
349
  inference_state,
 
350
  ):
351
+ predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
352
  if torch.cuda.is_available():
353
  predictor.to("cuda")
354
  inference_state["device"] = "cuda"
 
405
  input_points,
406
  input_labels,
407
  inference_state,
 
408
  )
409
 
410
 
411
  try:
412
  from spaces import GPU
413
 
414
+ preprocess_video_in = GPU(preprocess_video_in, duration=10)
415
+ segment_with_points = GPU(segment_with_points, duration=5)
416
+ propagate_to_all = GPU(propagate_to_all, duration=30)
417
  except:
418
  print("spaces unavailable")
419
 
 
428
  input_points = gr.State([])
429
  input_labels = gr.State([])
430
  inference_state = gr.State()
 
431
 
432
  with gr.Column():
433
  # Title
 
482
  input_points,
483
  input_labels,
484
  inference_state,
 
485
  ],
486
  outputs=[
487
  video_in_drawer, # Accordion to hide uploaded video player
 
493
  input_points,
494
  input_labels,
495
  inference_state,
 
496
  ],
497
  queue=False,
498
  )
 
506
  input_points,
507
  input_labels,
508
  inference_state,
 
509
  ],
510
  outputs=[
511
  video_in_drawer, # Accordion to hide uploaded video player
 
517
  input_points,
518
  input_labels,
519
  inference_state,
 
520
  ],
521
  queue=False,
522
  )
 
531
  input_points,
532
  input_labels,
533
  inference_state,
 
534
  ],
535
  outputs=[
536
  points_map, # updated image with points
 
540
  input_points,
541
  input_labels,
542
  inference_state,
 
543
  ],
544
  queue=False,
545
  )
 
553
  input_points,
554
  input_labels,
555
  inference_state,
 
556
  ],
557
  outputs=[
558
  points_map,
 
563
  input_points,
564
  input_labels,
565
  inference_state,
 
566
  ],
567
  queue=False,
568
  )
 
575
  input_points,
576
  input_labels,
577
  inference_state,
 
578
  ],
579
  outputs=[
580
  video_in,
 
587
  input_points,
588
  input_labels,
589
  inference_state,
 
590
  ],
591
  queue=False,
592
  )
 
605
  input_points,
606
  input_labels,
607
  inference_state,
 
608
  ],
609
  outputs=[
610
  output_video,
 
613
  input_points,
614
  input_labels,
615
  inference_state,
 
616
  ],
617
  concurrency_limit=10,
618
  queue=False,