remove predictor from state
Browse files
app.py
CHANGED
@@ -87,21 +87,45 @@ def get_video_fps(video_path):
|
|
87 |
return fps
|
88 |
|
89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
def reset(
|
91 |
first_frame,
|
92 |
all_frames,
|
93 |
input_points,
|
94 |
input_labels,
|
95 |
inference_state,
|
96 |
-
predictor,
|
97 |
):
|
98 |
first_frame = None
|
99 |
all_frames = None
|
100 |
input_points = []
|
101 |
input_labels = []
|
102 |
|
103 |
-
if inference_state and predictor:
|
104 |
-
predictor.reset_state(inference_state)
|
105 |
inference_state = None
|
106 |
return (
|
107 |
None,
|
@@ -114,7 +138,6 @@ def reset(
|
|
114 |
input_points,
|
115 |
input_labels,
|
116 |
inference_state,
|
117 |
-
predictor,
|
118 |
)
|
119 |
|
120 |
|
@@ -124,12 +147,11 @@ def clear_points(
|
|
124 |
input_points,
|
125 |
input_labels,
|
126 |
inference_state,
|
127 |
-
predictor,
|
128 |
):
|
129 |
input_points = []
|
130 |
input_labels = []
|
131 |
-
if inference_state and
|
132 |
-
|
133 |
return (
|
134 |
first_frame,
|
135 |
None,
|
@@ -139,7 +161,6 @@ def clear_points(
|
|
139 |
input_points,
|
140 |
input_labels,
|
141 |
inference_state,
|
142 |
-
predictor,
|
143 |
)
|
144 |
|
145 |
|
@@ -150,7 +171,6 @@ def preprocess_video_in(
|
|
150 |
input_points,
|
151 |
input_labels,
|
152 |
inference_state,
|
153 |
-
predictor,
|
154 |
):
|
155 |
if video_path is None:
|
156 |
return (
|
@@ -163,7 +183,6 @@ def preprocess_video_in(
|
|
163 |
input_points,
|
164 |
input_labels,
|
165 |
inference_state,
|
166 |
-
predictor,
|
167 |
)
|
168 |
|
169 |
# Read the first frame
|
@@ -180,12 +199,8 @@ def preprocess_video_in(
|
|
180 |
input_points,
|
181 |
input_labels,
|
182 |
inference_state,
|
183 |
-
predictor,
|
184 |
)
|
185 |
|
186 |
-
if predictor is None:
|
187 |
-
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
|
188 |
-
|
189 |
frame_number = 0
|
190 |
_first_frame = None
|
191 |
all_frames = []
|
@@ -207,10 +222,19 @@ def preprocess_video_in(
|
|
207 |
|
208 |
cap.release()
|
209 |
first_frame = copy.deepcopy(_first_frame)
|
210 |
-
inference_state = predictor.init_state(video_path=video_path)
|
211 |
input_points = []
|
212 |
input_labels = []
|
213 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
return [
|
215 |
gr.update(open=False), # video_in_drawer
|
216 |
first_frame, # points_map
|
@@ -221,7 +245,6 @@ def preprocess_video_in(
|
|
221 |
input_points,
|
222 |
input_labels,
|
223 |
inference_state,
|
224 |
-
predictor,
|
225 |
]
|
226 |
|
227 |
|
@@ -232,9 +255,9 @@ def segment_with_points(
|
|
232 |
input_points,
|
233 |
input_labels,
|
234 |
inference_state,
|
235 |
-
predictor,
|
236 |
evt: gr.SelectData,
|
237 |
):
|
|
|
238 |
if torch.cuda.is_available():
|
239 |
predictor.to("cuda")
|
240 |
inference_state["device"] = "cuda"
|
@@ -299,7 +322,6 @@ def segment_with_points(
|
|
299 |
input_points,
|
300 |
input_labels,
|
301 |
inference_state,
|
302 |
-
predictor,
|
303 |
)
|
304 |
|
305 |
|
@@ -325,8 +347,8 @@ def propagate_to_all(
|
|
325 |
input_points,
|
326 |
input_labels,
|
327 |
inference_state,
|
328 |
-
predictor,
|
329 |
):
|
|
|
330 |
if torch.cuda.is_available():
|
331 |
predictor.to("cuda")
|
332 |
inference_state["device"] = "cuda"
|
@@ -383,15 +405,15 @@ def propagate_to_all(
|
|
383 |
input_points,
|
384 |
input_labels,
|
385 |
inference_state,
|
386 |
-
predictor,
|
387 |
)
|
388 |
|
389 |
|
390 |
try:
|
391 |
from spaces import GPU
|
392 |
|
393 |
-
|
394 |
-
|
|
|
395 |
except:
|
396 |
print("spaces unavailable")
|
397 |
|
@@ -406,7 +428,6 @@ with gr.Blocks() as demo:
|
|
406 |
input_points = gr.State([])
|
407 |
input_labels = gr.State([])
|
408 |
inference_state = gr.State()
|
409 |
-
predictor = gr.State()
|
410 |
|
411 |
with gr.Column():
|
412 |
# Title
|
@@ -461,7 +482,6 @@ with gr.Blocks() as demo:
|
|
461 |
input_points,
|
462 |
input_labels,
|
463 |
inference_state,
|
464 |
-
predictor,
|
465 |
],
|
466 |
outputs=[
|
467 |
video_in_drawer, # Accordion to hide uploaded video player
|
@@ -473,7 +493,6 @@ with gr.Blocks() as demo:
|
|
473 |
input_points,
|
474 |
input_labels,
|
475 |
inference_state,
|
476 |
-
predictor,
|
477 |
],
|
478 |
queue=False,
|
479 |
)
|
@@ -487,7 +506,6 @@ with gr.Blocks() as demo:
|
|
487 |
input_points,
|
488 |
input_labels,
|
489 |
inference_state,
|
490 |
-
predictor,
|
491 |
],
|
492 |
outputs=[
|
493 |
video_in_drawer, # Accordion to hide uploaded video player
|
@@ -499,7 +517,6 @@ with gr.Blocks() as demo:
|
|
499 |
input_points,
|
500 |
input_labels,
|
501 |
inference_state,
|
502 |
-
predictor,
|
503 |
],
|
504 |
queue=False,
|
505 |
)
|
@@ -514,7 +531,6 @@ with gr.Blocks() as demo:
|
|
514 |
input_points,
|
515 |
input_labels,
|
516 |
inference_state,
|
517 |
-
predictor,
|
518 |
],
|
519 |
outputs=[
|
520 |
points_map, # updated image with points
|
@@ -524,7 +540,6 @@ with gr.Blocks() as demo:
|
|
524 |
input_points,
|
525 |
input_labels,
|
526 |
inference_state,
|
527 |
-
predictor,
|
528 |
],
|
529 |
queue=False,
|
530 |
)
|
@@ -538,7 +553,6 @@ with gr.Blocks() as demo:
|
|
538 |
input_points,
|
539 |
input_labels,
|
540 |
inference_state,
|
541 |
-
predictor,
|
542 |
],
|
543 |
outputs=[
|
544 |
points_map,
|
@@ -549,7 +563,6 @@ with gr.Blocks() as demo:
|
|
549 |
input_points,
|
550 |
input_labels,
|
551 |
inference_state,
|
552 |
-
predictor,
|
553 |
],
|
554 |
queue=False,
|
555 |
)
|
@@ -562,7 +575,6 @@ with gr.Blocks() as demo:
|
|
562 |
input_points,
|
563 |
input_labels,
|
564 |
inference_state,
|
565 |
-
predictor,
|
566 |
],
|
567 |
outputs=[
|
568 |
video_in,
|
@@ -575,7 +587,6 @@ with gr.Blocks() as demo:
|
|
575 |
input_points,
|
576 |
input_labels,
|
577 |
inference_state,
|
578 |
-
predictor,
|
579 |
],
|
580 |
queue=False,
|
581 |
)
|
@@ -594,7 +605,6 @@ with gr.Blocks() as demo:
|
|
594 |
input_points,
|
595 |
input_labels,
|
596 |
inference_state,
|
597 |
-
predictor,
|
598 |
],
|
599 |
outputs=[
|
600 |
output_video,
|
@@ -603,7 +613,6 @@ with gr.Blocks() as demo:
|
|
603 |
input_points,
|
604 |
input_labels,
|
605 |
inference_state,
|
606 |
-
predictor,
|
607 |
],
|
608 |
concurrency_limit=10,
|
609 |
queue=False,
|
|
|
87 |
return fps
|
88 |
|
89 |
|
90 |
+
def reset_state(inference_state):
|
91 |
+
for v in inference_state["point_inputs_per_obj"].values():
|
92 |
+
v.clear()
|
93 |
+
for v in inference_state["mask_inputs_per_obj"].values():
|
94 |
+
v.clear()
|
95 |
+
for v in inference_state["output_dict_per_obj"].values():
|
96 |
+
v["cond_frame_outputs"].clear()
|
97 |
+
v["non_cond_frame_outputs"].clear()
|
98 |
+
for v in inference_state["temp_output_dict_per_obj"].values():
|
99 |
+
v["cond_frame_outputs"].clear()
|
100 |
+
v["non_cond_frame_outputs"].clear()
|
101 |
+
inference_state["output_dict"]["cond_frame_outputs"].clear()
|
102 |
+
inference_state["output_dict"]["non_cond_frame_outputs"].clear()
|
103 |
+
inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear()
|
104 |
+
inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear()
|
105 |
+
inference_state["tracking_has_started"] = False
|
106 |
+
inference_state["frames_already_tracked"].clear()
|
107 |
+
inference_state["obj_id_to_idx"].clear()
|
108 |
+
inference_state["obj_idx_to_id"].clear()
|
109 |
+
inference_state["obj_ids"].clear()
|
110 |
+
inference_state["point_inputs_per_obj"].clear()
|
111 |
+
inference_state["mask_inputs_per_obj"].clear()
|
112 |
+
inference_state["output_dict_per_obj"].clear()
|
113 |
+
inference_state["temp_output_dict_per_obj"].clear()
|
114 |
+
return inference_state
|
115 |
+
|
116 |
+
|
117 |
def reset(
|
118 |
first_frame,
|
119 |
all_frames,
|
120 |
input_points,
|
121 |
input_labels,
|
122 |
inference_state,
|
|
|
123 |
):
|
124 |
first_frame = None
|
125 |
all_frames = None
|
126 |
input_points = []
|
127 |
input_labels = []
|
128 |
|
|
|
|
|
129 |
inference_state = None
|
130 |
return (
|
131 |
None,
|
|
|
138 |
input_points,
|
139 |
input_labels,
|
140 |
inference_state,
|
|
|
141 |
)
|
142 |
|
143 |
|
|
|
147 |
input_points,
|
148 |
input_labels,
|
149 |
inference_state,
|
|
|
150 |
):
|
151 |
input_points = []
|
152 |
input_labels = []
|
153 |
+
if inference_state and inference_state["tracking_has_started"]:
|
154 |
+
inference_state = reset_state(inference_state)
|
155 |
return (
|
156 |
first_frame,
|
157 |
None,
|
|
|
161 |
input_points,
|
162 |
input_labels,
|
163 |
inference_state,
|
|
|
164 |
)
|
165 |
|
166 |
|
|
|
171 |
input_points,
|
172 |
input_labels,
|
173 |
inference_state,
|
|
|
174 |
):
|
175 |
if video_path is None:
|
176 |
return (
|
|
|
183 |
input_points,
|
184 |
input_labels,
|
185 |
inference_state,
|
|
|
186 |
)
|
187 |
|
188 |
# Read the first frame
|
|
|
199 |
input_points,
|
200 |
input_labels,
|
201 |
inference_state,
|
|
|
202 |
)
|
203 |
|
|
|
|
|
|
|
204 |
frame_number = 0
|
205 |
_first_frame = None
|
206 |
all_frames = []
|
|
|
222 |
|
223 |
cap.release()
|
224 |
first_frame = copy.deepcopy(_first_frame)
|
|
|
225 |
input_points = []
|
226 |
input_labels = []
|
227 |
|
228 |
+
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
|
229 |
+
if torch.cuda.is_available():
|
230 |
+
predictor.to("cuda")
|
231 |
+
inference_state["device"] = "cuda"
|
232 |
+
if torch.cuda.get_device_properties(0).major >= 8:
|
233 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
234 |
+
torch.backends.cudnn.allow_tf32 = True
|
235 |
+
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
236 |
+
inference_state = predictor.init_state(video_path=video_path)
|
237 |
+
|
238 |
return [
|
239 |
gr.update(open=False), # video_in_drawer
|
240 |
first_frame, # points_map
|
|
|
245 |
input_points,
|
246 |
input_labels,
|
247 |
inference_state,
|
|
|
248 |
]
|
249 |
|
250 |
|
|
|
255 |
input_points,
|
256 |
input_labels,
|
257 |
inference_state,
|
|
|
258 |
evt: gr.SelectData,
|
259 |
):
|
260 |
+
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
|
261 |
if torch.cuda.is_available():
|
262 |
predictor.to("cuda")
|
263 |
inference_state["device"] = "cuda"
|
|
|
322 |
input_points,
|
323 |
input_labels,
|
324 |
inference_state,
|
|
|
325 |
)
|
326 |
|
327 |
|
|
|
347 |
input_points,
|
348 |
input_labels,
|
349 |
inference_state,
|
|
|
350 |
):
|
351 |
+
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
|
352 |
if torch.cuda.is_available():
|
353 |
predictor.to("cuda")
|
354 |
inference_state["device"] = "cuda"
|
|
|
405 |
input_points,
|
406 |
input_labels,
|
407 |
inference_state,
|
|
|
408 |
)
|
409 |
|
410 |
|
411 |
try:
|
412 |
from spaces import GPU
|
413 |
|
414 |
+
preprocess_video_in = GPU(preprocess_video_in, duration=10)
|
415 |
+
segment_with_points = GPU(segment_with_points, duration=5)
|
416 |
+
propagate_to_all = GPU(propagate_to_all, duration=30)
|
417 |
except:
|
418 |
print("spaces unavailable")
|
419 |
|
|
|
428 |
input_points = gr.State([])
|
429 |
input_labels = gr.State([])
|
430 |
inference_state = gr.State()
|
|
|
431 |
|
432 |
with gr.Column():
|
433 |
# Title
|
|
|
482 |
input_points,
|
483 |
input_labels,
|
484 |
inference_state,
|
|
|
485 |
],
|
486 |
outputs=[
|
487 |
video_in_drawer, # Accordion to hide uploaded video player
|
|
|
493 |
input_points,
|
494 |
input_labels,
|
495 |
inference_state,
|
|
|
496 |
],
|
497 |
queue=False,
|
498 |
)
|
|
|
506 |
input_points,
|
507 |
input_labels,
|
508 |
inference_state,
|
|
|
509 |
],
|
510 |
outputs=[
|
511 |
video_in_drawer, # Accordion to hide uploaded video player
|
|
|
517 |
input_points,
|
518 |
input_labels,
|
519 |
inference_state,
|
|
|
520 |
],
|
521 |
queue=False,
|
522 |
)
|
|
|
531 |
input_points,
|
532 |
input_labels,
|
533 |
inference_state,
|
|
|
534 |
],
|
535 |
outputs=[
|
536 |
points_map, # updated image with points
|
|
|
540 |
input_points,
|
541 |
input_labels,
|
542 |
inference_state,
|
|
|
543 |
],
|
544 |
queue=False,
|
545 |
)
|
|
|
553 |
input_points,
|
554 |
input_labels,
|
555 |
inference_state,
|
|
|
556 |
],
|
557 |
outputs=[
|
558 |
points_map,
|
|
|
563 |
input_points,
|
564 |
input_labels,
|
565 |
inference_state,
|
|
|
566 |
],
|
567 |
queue=False,
|
568 |
)
|
|
|
575 |
input_points,
|
576 |
input_labels,
|
577 |
inference_state,
|
|
|
578 |
],
|
579 |
outputs=[
|
580 |
video_in,
|
|
|
587 |
input_points,
|
588 |
input_labels,
|
589 |
inference_state,
|
|
|
590 |
],
|
591 |
queue=False,
|
592 |
)
|
|
|
605 |
input_points,
|
606 |
input_labels,
|
607 |
inference_state,
|
|
|
608 |
],
|
609 |
outputs=[
|
610 |
output_video,
|
|
|
613 |
input_points,
|
614 |
input_labels,
|
615 |
inference_state,
|
|
|
616 |
],
|
617 |
concurrency_limit=10,
|
618 |
queue=False,
|