chongzhou commited on
Commit
3e65e00
·
1 Parent(s): 4277d6f

offload_state_to_cpu

Browse files
Files changed (1) hide show
  1. app.py +116 -119
app.py CHANGED
@@ -237,15 +237,16 @@ def preprocess_video_in(
237
  input_points = []
238
  input_labels = []
239
 
240
- predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
241
- predictor.to("cuda")
242
- if inference_state:
243
- inference_state["device"] = "cuda"
244
  if torch.cuda.get_device_properties(0).major >= 8:
245
  torch.backends.cuda.matmul.allow_tf32 = True
246
  torch.backends.cudnn.allow_tf32 = True
247
- torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
248
- inference_state = predictor.init_state(video_path=video_path)
 
 
 
 
249
 
250
  return [
251
  gr.update(open=False), # video_in_drawer
@@ -270,72 +271,68 @@ def segment_with_points(
270
  inference_state,
271
  evt: gr.SelectData,
272
  ):
273
- predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
274
- predictor.to("cuda")
275
- if inference_state:
276
- inference_state["device"] = "cuda"
277
  if torch.cuda.get_device_properties(0).major >= 8:
278
  torch.backends.cuda.matmul.allow_tf32 = True
279
  torch.backends.cudnn.allow_tf32 = True
280
- torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
281
-
282
- input_points.append(evt.index)
283
- print(f"TRACKING INPUT POINT: {input_points}")
284
-
285
- if point_type == "include":
286
- input_labels.append(1)
287
- elif point_type == "exclude":
288
- input_labels.append(0)
289
- print(f"TRACKING INPUT LABEL: {input_labels}")
290
-
291
- # Open the image and get its dimensions
292
- transparent_background = Image.fromarray(first_frame).convert("RGBA")
293
- w, h = transparent_background.size
294
-
295
- # Define the circle radius as a fraction of the smaller dimension
296
- fraction = 0.01 # You can adjust this value as needed
297
- radius = int(fraction * min(w, h))
298
-
299
- # Create a transparent layer to draw on
300
- transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)
301
-
302
- for index, track in enumerate(input_points):
303
- if input_labels[index] == 1:
304
- cv2.circle(transparent_layer, track, radius, (0, 255, 0, 255), -1)
305
- else:
306
- cv2.circle(transparent_layer, track, radius, (255, 0, 0, 255), -1)
307
-
308
- # Convert the transparent layer back to an image
309
- transparent_layer = Image.fromarray(transparent_layer, "RGBA")
310
- selected_point_map = Image.alpha_composite(
311
- transparent_background, transparent_layer
312
- )
313
 
314
- # Let's add a positive click at (x, y) = (210, 350) to get started
315
- points = np.array(input_points, dtype=np.float32)
316
- # for labels, `1` means positive click and `0` means negative click
317
- labels = np.array(input_labels, dtype=np.int32)
318
- _, _, out_mask_logits = predictor.add_new_points(
319
- inference_state=inference_state,
320
- frame_idx=0,
321
- obj_id=OBJ_ID,
322
- points=points,
323
- labels=labels,
324
- )
325
 
326
- mask_image = show_mask((out_mask_logits[0] > 0.0).cpu().numpy())
327
- first_frame_output = Image.alpha_composite(transparent_background, mask_image)
328
 
329
- torch.cuda.empty_cache()
330
- return (
331
- selected_point_map,
332
- first_frame_output,
333
- first_frame,
334
- all_frames,
335
- input_points,
336
- input_labels,
337
- inference_state,
338
- )
339
 
340
 
341
  def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
@@ -362,64 +359,64 @@ def propagate_to_all(
362
  input_labels,
363
  inference_state,
364
  ):
365
- predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
366
- predictor.to("cuda")
367
- if inference_state:
368
- inference_state["device"] = "cuda"
369
  if torch.cuda.get_device_properties(0).major >= 8:
370
  torch.backends.cuda.matmul.allow_tf32 = True
371
  torch.backends.cudnn.allow_tf32 = True
372
- torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
373
-
374
- if len(input_points) == 0 or video_in is None or inference_state is None:
375
- return None
376
- # run propagation throughout the video and collect the results in a dict
377
- video_segments = {} # video_segments contains the per-frame segmentation results
378
- print("starting propagate_in_video")
379
- for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
380
- inference_state
381
- ):
382
- video_segments[out_frame_idx] = {
383
- out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
384
- for i, out_obj_id in enumerate(out_obj_ids)
385
- }
386
-
387
- # obtain the segmentation results every few frames
388
- vis_frame_stride = 1
389
-
390
- output_frames = []
391
- for out_frame_idx in range(0, len(video_segments), vis_frame_stride):
392
- transparent_background = Image.fromarray(all_frames[out_frame_idx]).convert(
393
- "RGBA"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
394
  )
395
- out_mask = video_segments[out_frame_idx][OBJ_ID]
396
- mask_image = show_mask(out_mask)
397
- output_frame = Image.alpha_composite(transparent_background, mask_image)
398
- output_frame = np.array(output_frame)
399
- output_frames.append(output_frame)
400
-
401
- torch.cuda.empty_cache()
402
-
403
- # Create a video clip from the image sequence
404
- original_fps = get_video_fps(video_in)
405
- fps = original_fps # Frames per second
406
- clip = ImageSequenceClip(output_frames, fps=fps)
407
- # Write the result to a file
408
- unique_id = datetime.now().strftime("%Y%m%d%H%M%S")
409
- final_vid_output_path = f"output_video_{unique_id}.mp4"
410
- final_vid_output_path = os.path.join(tempfile.gettempdir(), final_vid_output_path)
411
-
412
- # Write the result to a file
413
- clip.write_videofile(final_vid_output_path, codec="libx264")
414
 
415
- return (
416
- gr.update(value=final_vid_output_path),
417
- first_frame,
418
- all_frames,
419
- input_points,
420
- input_labels,
421
- inference_state,
422
- )
 
 
 
423
 
424
 
425
  def update_ui():
 
237
  input_points = []
238
  input_labels = []
239
 
240
+ predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cuda")
 
 
 
241
  if torch.cuda.get_device_properties(0).major >= 8:
242
  torch.backends.cuda.matmul.allow_tf32 = True
243
  torch.backends.cudnn.allow_tf32 = True
244
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
245
+ inference_state = predictor.init_state(
246
+ offload_video_to_cpu=True,
247
+ offload_state_to_cpu=True,
248
+ video_path=video_path,
249
+ )
250
 
251
  return [
252
  gr.update(open=False), # video_in_drawer
 
271
  inference_state,
272
  evt: gr.SelectData,
273
  ):
274
+ predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cuda")
 
 
 
275
  if torch.cuda.get_device_properties(0).major >= 8:
276
  torch.backends.cuda.matmul.allow_tf32 = True
277
  torch.backends.cudnn.allow_tf32 = True
278
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
279
+ input_points.append(evt.index)
280
+ print(f"TRACKING INPUT POINT: {input_points}")
281
+
282
+ if point_type == "include":
283
+ input_labels.append(1)
284
+ elif point_type == "exclude":
285
+ input_labels.append(0)
286
+ print(f"TRACKING INPUT LABEL: {input_labels}")
287
+
288
+ # Open the image and get its dimensions
289
+ transparent_background = Image.fromarray(first_frame).convert("RGBA")
290
+ w, h = transparent_background.size
291
+
292
+ # Define the circle radius as a fraction of the smaller dimension
293
+ fraction = 0.01 # You can adjust this value as needed
294
+ radius = int(fraction * min(w, h))
295
+
296
+ # Create a transparent layer to draw on
297
+ transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)
298
+
299
+ for index, track in enumerate(input_points):
300
+ if input_labels[index] == 1:
301
+ cv2.circle(transparent_layer, track, radius, (0, 255, 0, 255), -1)
302
+ else:
303
+ cv2.circle(transparent_layer, track, radius, (255, 0, 0, 255), -1)
304
+
305
+ # Convert the transparent layer back to an image
306
+ transparent_layer = Image.fromarray(transparent_layer, "RGBA")
307
+ selected_point_map = Image.alpha_composite(
308
+ transparent_background, transparent_layer
309
+ )
 
310
 
311
+ # Let's add a positive click at (x, y) = (210, 350) to get started
312
+ points = np.array(input_points, dtype=np.float32)
313
+ # for labels, `1` means positive click and `0` means negative click
314
+ labels = np.array(input_labels, dtype=np.int32)
315
+ _, _, out_mask_logits = predictor.add_new_points(
316
+ inference_state=inference_state,
317
+ frame_idx=0,
318
+ obj_id=OBJ_ID,
319
+ points=points,
320
+ labels=labels,
321
+ )
322
 
323
+ mask_image = show_mask((out_mask_logits[0] > 0.0).cpu().numpy())
324
+ first_frame_output = Image.alpha_composite(transparent_background, mask_image)
325
 
326
+ torch.cuda.empty_cache()
327
+ return (
328
+ selected_point_map,
329
+ first_frame_output,
330
+ first_frame,
331
+ all_frames,
332
+ input_points,
333
+ input_labels,
334
+ inference_state,
335
+ )
336
 
337
 
338
  def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True):
 
359
  input_labels,
360
  inference_state,
361
  ):
362
+ predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cuda")
 
 
 
363
  if torch.cuda.get_device_properties(0).major >= 8:
364
  torch.backends.cuda.matmul.allow_tf32 = True
365
  torch.backends.cudnn.allow_tf32 = True
366
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
367
+ if len(input_points) == 0 or video_in is None or inference_state is None:
368
+ return None
369
+ # run propagation throughout the video and collect the results in a dict
370
+ video_segments = (
371
+ {}
372
+ ) # video_segments contains the per-frame segmentation results
373
+ print("starting propagate_in_video")
374
+ for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
375
+ inference_state
376
+ ):
377
+ video_segments[out_frame_idx] = {
378
+ out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
379
+ for i, out_obj_id in enumerate(out_obj_ids)
380
+ }
381
+
382
+ # obtain the segmentation results every few frames
383
+ vis_frame_stride = 1
384
+
385
+ output_frames = []
386
+ for out_frame_idx in range(0, len(video_segments), vis_frame_stride):
387
+ transparent_background = Image.fromarray(all_frames[out_frame_idx]).convert(
388
+ "RGBA"
389
+ )
390
+ out_mask = video_segments[out_frame_idx][OBJ_ID]
391
+ mask_image = show_mask(out_mask)
392
+ output_frame = Image.alpha_composite(transparent_background, mask_image)
393
+ output_frame = np.array(output_frame)
394
+ output_frames.append(output_frame)
395
+
396
+ torch.cuda.empty_cache()
397
+
398
+ # Create a video clip from the image sequence
399
+ original_fps = get_video_fps(video_in)
400
+ fps = original_fps # Frames per second
401
+ clip = ImageSequenceClip(output_frames, fps=fps)
402
+ # Write the result to a file
403
+ unique_id = datetime.now().strftime("%Y%m%d%H%M%S")
404
+ final_vid_output_path = f"output_video_{unique_id}.mp4"
405
+ final_vid_output_path = os.path.join(
406
+ tempfile.gettempdir(), final_vid_output_path
407
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
408
 
409
+ # Write the result to a file
410
+ clip.write_videofile(final_vid_output_path, codec="libx264")
411
+
412
+ return (
413
+ gr.update(value=final_vid_output_path),
414
+ first_frame,
415
+ all_frames,
416
+ input_points,
417
+ input_labels,
418
+ inference_state,
419
+ )
420
 
421
 
422
  def update_ui():