chongzhou commited on
Commit
db1d5c6
·
1 Parent(s): 3ccde9c
Files changed (1) hide show
  1. app.py +41 -44
app.py CHANGED
@@ -337,57 +337,54 @@ def propagate_to_all(
337
  input_points,
338
  inference_state,
339
  ):
 
340
  predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cuda")
341
  if torch.cuda.get_device_properties(0).major >= 8:
342
  torch.backends.cuda.matmul.allow_tf32 = True
343
  torch.backends.cudnn.allow_tf32 = True
344
- with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
345
- if len(input_points) == 0 or video_in is None or inference_state is None:
346
- return None
347
- # run propagation throughout the video and collect the results in a dict
348
- video_segments = (
349
- {}
350
- ) # video_segments contains the per-frame segmentation results
351
- print("starting propagate_in_video")
352
- for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
353
- inference_state
354
- ):
355
- video_segments[out_frame_idx] = {
356
- out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
357
- for i, out_obj_id in enumerate(out_obj_ids)
358
- }
359
-
360
- # obtain the segmentation results every few frames
361
- vis_frame_stride = 1
362
-
363
- output_frames = []
364
- for out_frame_idx in range(0, len(video_segments), vis_frame_stride):
365
- transparent_background = Image.fromarray(all_frames[out_frame_idx]).convert(
366
- "RGBA"
367
- )
368
- out_mask = video_segments[out_frame_idx][OBJ_ID]
369
- mask_image = show_mask(out_mask)
370
- output_frame = Image.alpha_composite(transparent_background, mask_image)
371
- output_frame = np.array(output_frame)
372
- output_frames.append(output_frame)
373
-
374
- torch.cuda.empty_cache()
375
-
376
- # Create a video clip from the image sequence
377
- original_fps = get_video_fps(video_in)
378
- fps = original_fps # Frames per second
379
- clip = ImageSequenceClip(output_frames, fps=fps)
380
- # Write the result to a file
381
- unique_id = datetime.now().strftime("%Y%m%d%H%M%S")
382
- final_vid_output_path = f"output_video_{unique_id}.mp4"
383
- final_vid_output_path = os.path.join(
384
- tempfile.gettempdir(), final_vid_output_path
385
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
 
387
- # Write the result to a file
388
- clip.write_videofile(final_vid_output_path, codec="libx264")
389
 
390
- return gr.update(value=final_vid_output_path)
391
 
392
 
393
  def update_ui():
 
337
  input_points,
338
  inference_state,
339
  ):
340
+ torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
341
  predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cuda")
342
  if torch.cuda.get_device_properties(0).major >= 8:
343
  torch.backends.cuda.matmul.allow_tf32 = True
344
  torch.backends.cudnn.allow_tf32 = True
345
+
346
+ if len(input_points) == 0 or video_in is None or inference_state is None:
347
+ return None
348
+ # run propagation throughout the video and collect the results in a dict
349
+ video_segments = {} # video_segments contains the per-frame segmentation results
350
+ print("starting propagate_in_video")
351
+ for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
352
+ inference_state
353
+ ):
354
+ video_segments[out_frame_idx] = {
355
+ out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
356
+ for i, out_obj_id in enumerate(out_obj_ids)
357
+ }
358
+
359
+ # obtain the segmentation results every few frames
360
+ vis_frame_stride = 1
361
+
362
+ output_frames = []
363
+ for out_frame_idx in range(0, len(video_segments), vis_frame_stride):
364
+ transparent_background = Image.fromarray(all_frames[out_frame_idx]).convert(
365
+ "RGBA"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
  )
367
+ out_mask = video_segments[out_frame_idx][OBJ_ID]
368
+ mask_image = show_mask(out_mask)
369
+ output_frame = Image.alpha_composite(transparent_background, mask_image)
370
+ output_frame = np.array(output_frame)
371
+ output_frames.append(output_frame)
372
+
373
+ torch.cuda.empty_cache()
374
+
375
+ # Create a video clip from the image sequence
376
+ original_fps = get_video_fps(video_in)
377
+ fps = original_fps # Frames per second
378
+ clip = ImageSequenceClip(output_frames, fps=fps)
379
+ # Write the result to a file
380
+ unique_id = datetime.now().strftime("%Y%m%d%H%M%S")
381
+ final_vid_output_path = f"output_video_{unique_id}.mp4"
382
+ final_vid_output_path = os.path.join(tempfile.gettempdir(), final_vid_output_path)
383
 
384
+ # Write the result to a file
385
+ clip.write_videofile(final_vid_output_path, codec="libx264")
386
 
387
+ return gr.update(value=final_vid_output_path)
388
 
389
 
390
  def update_ui():