autocast
Browse files
app.py
CHANGED
@@ -337,57 +337,54 @@ def propagate_to_all(
|
|
337 |
input_points,
|
338 |
inference_state,
|
339 |
):
|
|
|
340 |
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cuda")
|
341 |
if torch.cuda.get_device_properties(0).major >= 8:
|
342 |
torch.backends.cuda.matmul.allow_tf32 = True
|
343 |
torch.backends.cudnn.allow_tf32 = True
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
transparent_background = Image.fromarray(all_frames[out_frame_idx]).convert(
|
366 |
-
"RGBA"
|
367 |
-
)
|
368 |
-
out_mask = video_segments[out_frame_idx][OBJ_ID]
|
369 |
-
mask_image = show_mask(out_mask)
|
370 |
-
output_frame = Image.alpha_composite(transparent_background, mask_image)
|
371 |
-
output_frame = np.array(output_frame)
|
372 |
-
output_frames.append(output_frame)
|
373 |
-
|
374 |
-
torch.cuda.empty_cache()
|
375 |
-
|
376 |
-
# Create a video clip from the image sequence
|
377 |
-
original_fps = get_video_fps(video_in)
|
378 |
-
fps = original_fps # Frames per second
|
379 |
-
clip = ImageSequenceClip(output_frames, fps=fps)
|
380 |
-
# Write the result to a file
|
381 |
-
unique_id = datetime.now().strftime("%Y%m%d%H%M%S")
|
382 |
-
final_vid_output_path = f"output_video_{unique_id}.mp4"
|
383 |
-
final_vid_output_path = os.path.join(
|
384 |
-
tempfile.gettempdir(), final_vid_output_path
|
385 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
386 |
|
387 |
-
|
388 |
-
|
389 |
|
390 |
-
|
391 |
|
392 |
|
393 |
def update_ui():
|
|
|
337 |
input_points,
|
338 |
inference_state,
|
339 |
):
|
340 |
+
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
341 |
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cuda")
|
342 |
if torch.cuda.get_device_properties(0).major >= 8:
|
343 |
torch.backends.cuda.matmul.allow_tf32 = True
|
344 |
torch.backends.cudnn.allow_tf32 = True
|
345 |
+
|
346 |
+
if len(input_points) == 0 or video_in is None or inference_state is None:
|
347 |
+
return None
|
348 |
+
# run propagation throughout the video and collect the results in a dict
|
349 |
+
video_segments = {} # video_segments contains the per-frame segmentation results
|
350 |
+
print("starting propagate_in_video")
|
351 |
+
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
|
352 |
+
inference_state
|
353 |
+
):
|
354 |
+
video_segments[out_frame_idx] = {
|
355 |
+
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
|
356 |
+
for i, out_obj_id in enumerate(out_obj_ids)
|
357 |
+
}
|
358 |
+
|
359 |
+
# obtain the segmentation results every few frames
|
360 |
+
vis_frame_stride = 1
|
361 |
+
|
362 |
+
output_frames = []
|
363 |
+
for out_frame_idx in range(0, len(video_segments), vis_frame_stride):
|
364 |
+
transparent_background = Image.fromarray(all_frames[out_frame_idx]).convert(
|
365 |
+
"RGBA"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
366 |
)
|
367 |
+
out_mask = video_segments[out_frame_idx][OBJ_ID]
|
368 |
+
mask_image = show_mask(out_mask)
|
369 |
+
output_frame = Image.alpha_composite(transparent_background, mask_image)
|
370 |
+
output_frame = np.array(output_frame)
|
371 |
+
output_frames.append(output_frame)
|
372 |
+
|
373 |
+
torch.cuda.empty_cache()
|
374 |
+
|
375 |
+
# Create a video clip from the image sequence
|
376 |
+
original_fps = get_video_fps(video_in)
|
377 |
+
fps = original_fps # Frames per second
|
378 |
+
clip = ImageSequenceClip(output_frames, fps=fps)
|
379 |
+
# Write the result to a file
|
380 |
+
unique_id = datetime.now().strftime("%Y%m%d%H%M%S")
|
381 |
+
final_vid_output_path = f"output_video_{unique_id}.mp4"
|
382 |
+
final_vid_output_path = os.path.join(tempfile.gettempdir(), final_vid_output_path)
|
383 |
|
384 |
+
# Write the result to a file
|
385 |
+
clip.write_videofile(final_vid_output_path, codec="libx264")
|
386 |
|
387 |
+
return gr.update(value=final_vid_output_path)
|
388 |
|
389 |
|
390 |
def update_ui():
|