chongzhou commited on
Commit
238c545
·
1 Parent(s): 129e201

carefully move between CPU and GPU

Browse files
app.py CHANGED
@@ -165,7 +165,6 @@ def clear_points(
165
  )
166
 
167
 
168
- @spaces.GPU(duration=10)
169
  def preprocess_video_in(
170
  video_path,
171
  first_frame,
@@ -227,16 +226,12 @@ def preprocess_video_in(
227
  input_points = []
228
  input_labels = []
229
 
230
- predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cuda")
231
- if torch.cuda.get_device_properties(0).major >= 8:
232
- torch.backends.cuda.matmul.allow_tf32 = True
233
- torch.backends.cudnn.allow_tf32 = True
234
- with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
235
- inference_state = predictor.init_state(
236
- offload_video_to_cpu=True,
237
- offload_state_to_cpu=True,
238
- video_path=video_path,
239
- )
240
 
241
  return [
242
  gr.update(open=False), # video_in_drawer
 
165
  )
166
 
167
 
 
168
  def preprocess_video_in(
169
  video_path,
170
  first_frame,
 
226
  input_points = []
227
  input_labels = []
228
 
229
+ predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
230
+ inference_state = predictor.init_state(
231
+ offload_video_to_cpu=True,
232
+ offload_state_to_cpu=True,
233
+ video_path=video_path,
234
+ )
 
 
 
 
235
 
236
  return [
237
  gr.update(open=False), # video_in_drawer
sam2/modeling/sam2_base.py CHANGED
@@ -617,7 +617,7 @@ class SAM2Base(torch.nn.Module):
617
  if self.use_signed_tpos_enc_to_obj_ptrs
618
  else abs(frame_idx - t)
619
  ),
620
- out["obj_ptr"],
621
  )
622
  for t, out in ptr_cond_outputs.items()
623
  ]
@@ -630,7 +630,7 @@ class SAM2Base(torch.nn.Module):
630
  t, unselected_cond_outputs.get(t, None)
631
  )
632
  if out is not None:
633
- pos_and_ptrs.append((t_diff, out["obj_ptr"]))
634
  # If we have at least one object pointer, add them to the across attention
635
  if len(pos_and_ptrs) > 0:
636
  pos_list, ptrs_list = zip(*pos_and_ptrs)
 
617
  if self.use_signed_tpos_enc_to_obj_ptrs
618
  else abs(frame_idx - t)
619
  ),
620
+ out["obj_ptr"].to(device),
621
  )
622
  for t, out in ptr_cond_outputs.items()
623
  ]
 
630
  t, unselected_cond_outputs.get(t, None)
631
  )
632
  if out is not None:
633
+ pos_and_ptrs.append((t_diff, out["obj_ptr"].to(device)))
634
  # If we have at least one object pointer, add them to the across attention
635
  if len(pos_and_ptrs) > 0:
636
  pos_list, ptrs_list = zip(*pos_and_ptrs)
sam2/sam2_video_predictor.py CHANGED
@@ -107,7 +107,7 @@ class SAM2VideoPredictor(SAM2Base):
107
  inference_state["tracking_has_started"] = False
108
  inference_state["frames_already_tracked"] = {}
109
  # Warm up the visual backbone and cache the image feature on frame 0
110
- self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
111
  return inference_state
112
 
113
  @classmethod
@@ -470,7 +470,7 @@ class SAM2VideoPredictor(SAM2Base):
470
  size=(batch_size, self.hidden_dim),
471
  fill_value=NO_OBJ_SCORE,
472
  dtype=torch.float32,
473
- device=inference_state["device"],
474
  ),
475
  "object_score_logits": torch.full(
476
  size=(batch_size, 1),
@@ -478,7 +478,7 @@ class SAM2VideoPredictor(SAM2Base):
478
  # present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder`
479
  fill_value=10.0,
480
  dtype=torch.float32,
481
- device=inference_state["device"],
482
  ),
483
  }
484
  empty_mask_ptr = None
@@ -545,7 +545,9 @@ class SAM2VideoPredictor(SAM2Base):
545
  frame_idx=frame_idx,
546
  batch_size=batch_size,
547
  high_res_masks=high_res_masks,
548
- object_score_logits=consolidated_out["object_score_logits"],
 
 
549
  is_mask_from_pts=True, # these frames are what the user interacted with
550
  )
551
  consolidated_out["maskmem_features"] = maskmem_features
@@ -879,9 +881,10 @@ class SAM2VideoPredictor(SAM2Base):
879
  def _get_image_feature(self, inference_state, frame_idx, batch_size):
880
  """Compute the image features on a given frame."""
881
  # Look up in the cache first
882
- image, backbone_out = inference_state["cached_features"].get(
883
- frame_idx, (None, None)
884
- )
 
885
  if backbone_out is None:
886
  # Cache miss -- we will run inference on a single image
887
  device = inference_state["device"]
@@ -889,7 +892,7 @@ class SAM2VideoPredictor(SAM2Base):
889
  backbone_out = self.forward_image(image)
890
  # Cache the most recent frame's feature (for repeated interactions with
891
  # a frame; we can use an LRU cache for more frames in the future).
892
- inference_state["cached_features"] = {frame_idx: (image, backbone_out)}
893
 
894
  # expand the features to have the same dimension as the number of objects
895
  expanded_image = image.expand(batch_size, -1, -1, -1)
@@ -964,9 +967,11 @@ class SAM2VideoPredictor(SAM2Base):
964
  pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True)
965
  # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
966
  maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out)
967
- # object pointer is a small tensor, so we always keep it on GPU memory for fast access
968
- obj_ptr = current_out["obj_ptr"]
969
- object_score_logits = current_out["object_score_logits"]
 
 
970
  # make a compact version of this frame's output to reduce the state size
971
  compact_current_out = {
972
  "maskmem_features": maskmem_features,
@@ -1018,6 +1023,7 @@ class SAM2VideoPredictor(SAM2Base):
1018
  `maskmem_pos_enc` is the same across frames and objects, so we cache it as
1019
  a constant in the inference session to reduce session storage size.
1020
  """
 
1021
  model_constants = inference_state["constants"]
1022
  # "out_maskmem_pos_enc" should be either a list of tensors or None
1023
  out_maskmem_pos_enc = current_out["maskmem_pos_enc"]
@@ -1026,7 +1032,9 @@ class SAM2VideoPredictor(SAM2Base):
1026
  assert isinstance(out_maskmem_pos_enc, list)
1027
  # only take the slice for one object, since it's same across objects
1028
  maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc]
1029
- model_constants["maskmem_pos_enc"] = maskmem_pos_enc
 
 
1030
  else:
1031
  maskmem_pos_enc = model_constants["maskmem_pos_enc"]
1032
  # expand the cached maskmem_pos_enc to the actual batch size
 
107
  inference_state["tracking_has_started"] = False
108
  inference_state["frames_already_tracked"] = {}
109
  # Warm up the visual backbone and cache the image feature on frame 0
110
+ # self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
111
  return inference_state
112
 
113
  @classmethod
 
470
  size=(batch_size, self.hidden_dim),
471
  fill_value=NO_OBJ_SCORE,
472
  dtype=torch.float32,
473
+ device=inference_state["storage_device"],
474
  ),
475
  "object_score_logits": torch.full(
476
  size=(batch_size, 1),
 
478
  # present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder`
479
  fill_value=10.0,
480
  dtype=torch.float32,
481
+ device=inference_state["storage_device"],
482
  ),
483
  }
484
  empty_mask_ptr = None
 
545
  frame_idx=frame_idx,
546
  batch_size=batch_size,
547
  high_res_masks=high_res_masks,
548
+ object_score_logits=consolidated_out["object_score_logits"].to(
549
+ device, non_blocking=True
550
+ ),
551
  is_mask_from_pts=True, # these frames are what the user interacted with
552
  )
553
  consolidated_out["maskmem_features"] = maskmem_features
 
881
  def _get_image_feature(self, inference_state, frame_idx, batch_size):
882
  """Compute the image features on a given frame."""
883
  # Look up in the cache first
884
+ # image, backbone_out = inference_state["cached_features"].get(
885
+ # frame_idx, (None, None)
886
+ # )
887
+ image, backbone_out = None, None
888
  if backbone_out is None:
889
  # Cache miss -- we will run inference on a single image
890
  device = inference_state["device"]
 
892
  backbone_out = self.forward_image(image)
893
  # Cache the most recent frame's feature (for repeated interactions with
894
  # a frame; we can use an LRU cache for more frames in the future).
895
+ # inference_state["cached_features"] = {frame_idx: (image, backbone_out)}
896
 
897
  # expand the features to have the same dimension as the number of objects
898
  expanded_image = image.expand(batch_size, -1, -1, -1)
 
967
  pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True)
968
  # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
969
  maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out)
970
+ # object pointer is a small tensor, so we always keep it on GPU memory for fast access (modified for ZeroGPU)
971
+ obj_ptr = current_out["obj_ptr"].to(storage_device, non_blocking=True)
972
+ object_score_logits = current_out["object_score_logits"].to(
973
+ storage_device, non_blocking=True
974
+ )
975
  # make a compact version of this frame's output to reduce the state size
976
  compact_current_out = {
977
  "maskmem_features": maskmem_features,
 
1023
  `maskmem_pos_enc` is the same across frames and objects, so we cache it as
1024
  a constant in the inference session to reduce session storage size.
1025
  """
1026
+ storage_device = inference_state["storage_device"]
1027
  model_constants = inference_state["constants"]
1028
  # "out_maskmem_pos_enc" should be either a list of tensors or None
1029
  out_maskmem_pos_enc = current_out["maskmem_pos_enc"]
 
1032
  assert isinstance(out_maskmem_pos_enc, list)
1033
  # only take the slice for one object, since it's same across objects
1034
  maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc]
1035
+ model_constants["maskmem_pos_enc"] = maskmem_pos_enc.to(
1036
+ storage_device, non_blocking=True
1037
+ )
1038
  else:
1039
  maskmem_pos_enc = model_constants["maskmem_pos_enc"]
1040
  # expand the cached maskmem_pos_enc to the actual batch size