carefully move between CPU and GPU
Browse files- app.py +6 -11
- sam2/modeling/sam2_base.py +2 -2
- sam2/sam2_video_predictor.py +20 -12
app.py
CHANGED
@@ -165,7 +165,6 @@ def clear_points(
|
|
165 |
)
|
166 |
|
167 |
|
168 |
-
@spaces.GPU(duration=10)
|
169 |
def preprocess_video_in(
|
170 |
video_path,
|
171 |
first_frame,
|
@@ -227,16 +226,12 @@ def preprocess_video_in(
|
|
227 |
input_points = []
|
228 |
input_labels = []
|
229 |
|
230 |
-
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
offload_video_to_cpu=True,
|
237 |
-
offload_state_to_cpu=True,
|
238 |
-
video_path=video_path,
|
239 |
-
)
|
240 |
|
241 |
return [
|
242 |
gr.update(open=False), # video_in_drawer
|
|
|
165 |
)
|
166 |
|
167 |
|
|
|
168 |
def preprocess_video_in(
|
169 |
video_path,
|
170 |
first_frame,
|
|
|
226 |
input_points = []
|
227 |
input_labels = []
|
228 |
|
229 |
+
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
|
230 |
+
inference_state = predictor.init_state(
|
231 |
+
offload_video_to_cpu=True,
|
232 |
+
offload_state_to_cpu=True,
|
233 |
+
video_path=video_path,
|
234 |
+
)
|
|
|
|
|
|
|
|
|
235 |
|
236 |
return [
|
237 |
gr.update(open=False), # video_in_drawer
|
sam2/modeling/sam2_base.py
CHANGED
@@ -617,7 +617,7 @@ class SAM2Base(torch.nn.Module):
|
|
617 |
if self.use_signed_tpos_enc_to_obj_ptrs
|
618 |
else abs(frame_idx - t)
|
619 |
),
|
620 |
-
out["obj_ptr"],
|
621 |
)
|
622 |
for t, out in ptr_cond_outputs.items()
|
623 |
]
|
@@ -630,7 +630,7 @@ class SAM2Base(torch.nn.Module):
|
|
630 |
t, unselected_cond_outputs.get(t, None)
|
631 |
)
|
632 |
if out is not None:
|
633 |
-
pos_and_ptrs.append((t_diff, out["obj_ptr"]))
|
634 |
# If we have at least one object pointer, add them to the across attention
|
635 |
if len(pos_and_ptrs) > 0:
|
636 |
pos_list, ptrs_list = zip(*pos_and_ptrs)
|
|
|
617 |
if self.use_signed_tpos_enc_to_obj_ptrs
|
618 |
else abs(frame_idx - t)
|
619 |
),
|
620 |
+
out["obj_ptr"].to(device),
|
621 |
)
|
622 |
for t, out in ptr_cond_outputs.items()
|
623 |
]
|
|
|
630 |
t, unselected_cond_outputs.get(t, None)
|
631 |
)
|
632 |
if out is not None:
|
633 |
+
pos_and_ptrs.append((t_diff, out["obj_ptr"].to(device)))
|
634 |
# If we have at least one object pointer, add them to the across attention
|
635 |
if len(pos_and_ptrs) > 0:
|
636 |
pos_list, ptrs_list = zip(*pos_and_ptrs)
|
sam2/sam2_video_predictor.py
CHANGED
@@ -107,7 +107,7 @@ class SAM2VideoPredictor(SAM2Base):
|
|
107 |
inference_state["tracking_has_started"] = False
|
108 |
inference_state["frames_already_tracked"] = {}
|
109 |
# Warm up the visual backbone and cache the image feature on frame 0
|
110 |
-
self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
|
111 |
return inference_state
|
112 |
|
113 |
@classmethod
|
@@ -470,7 +470,7 @@ class SAM2VideoPredictor(SAM2Base):
|
|
470 |
size=(batch_size, self.hidden_dim),
|
471 |
fill_value=NO_OBJ_SCORE,
|
472 |
dtype=torch.float32,
|
473 |
-
device=inference_state["
|
474 |
),
|
475 |
"object_score_logits": torch.full(
|
476 |
size=(batch_size, 1),
|
@@ -478,7 +478,7 @@ class SAM2VideoPredictor(SAM2Base):
|
|
478 |
# present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder`
|
479 |
fill_value=10.0,
|
480 |
dtype=torch.float32,
|
481 |
-
device=inference_state["
|
482 |
),
|
483 |
}
|
484 |
empty_mask_ptr = None
|
@@ -545,7 +545,9 @@ class SAM2VideoPredictor(SAM2Base):
|
|
545 |
frame_idx=frame_idx,
|
546 |
batch_size=batch_size,
|
547 |
high_res_masks=high_res_masks,
|
548 |
-
object_score_logits=consolidated_out["object_score_logits"]
|
|
|
|
|
549 |
is_mask_from_pts=True, # these frames are what the user interacted with
|
550 |
)
|
551 |
consolidated_out["maskmem_features"] = maskmem_features
|
@@ -879,9 +881,10 @@ class SAM2VideoPredictor(SAM2Base):
|
|
879 |
def _get_image_feature(self, inference_state, frame_idx, batch_size):
|
880 |
"""Compute the image features on a given frame."""
|
881 |
# Look up in the cache first
|
882 |
-
image, backbone_out = inference_state["cached_features"].get(
|
883 |
-
|
884 |
-
)
|
|
|
885 |
if backbone_out is None:
|
886 |
# Cache miss -- we will run inference on a single image
|
887 |
device = inference_state["device"]
|
@@ -889,7 +892,7 @@ class SAM2VideoPredictor(SAM2Base):
|
|
889 |
backbone_out = self.forward_image(image)
|
890 |
# Cache the most recent frame's feature (for repeated interactions with
|
891 |
# a frame; we can use an LRU cache for more frames in the future).
|
892 |
-
inference_state["cached_features"] = {frame_idx: (image, backbone_out)}
|
893 |
|
894 |
# expand the features to have the same dimension as the number of objects
|
895 |
expanded_image = image.expand(batch_size, -1, -1, -1)
|
@@ -964,9 +967,11 @@ class SAM2VideoPredictor(SAM2Base):
|
|
964 |
pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True)
|
965 |
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
|
966 |
maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out)
|
967 |
-
# object pointer is a small tensor, so we always keep it on GPU memory for fast access
|
968 |
-
obj_ptr = current_out["obj_ptr"]
|
969 |
-
object_score_logits = current_out["object_score_logits"]
|
|
|
|
|
970 |
# make a compact version of this frame's output to reduce the state size
|
971 |
compact_current_out = {
|
972 |
"maskmem_features": maskmem_features,
|
@@ -1018,6 +1023,7 @@ class SAM2VideoPredictor(SAM2Base):
|
|
1018 |
`maskmem_pos_enc` is the same across frames and objects, so we cache it as
|
1019 |
a constant in the inference session to reduce session storage size.
|
1020 |
"""
|
|
|
1021 |
model_constants = inference_state["constants"]
|
1022 |
# "out_maskmem_pos_enc" should be either a list of tensors or None
|
1023 |
out_maskmem_pos_enc = current_out["maskmem_pos_enc"]
|
@@ -1026,7 +1032,9 @@ class SAM2VideoPredictor(SAM2Base):
|
|
1026 |
assert isinstance(out_maskmem_pos_enc, list)
|
1027 |
# only take the slice for one object, since it's same across objects
|
1028 |
maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc]
|
1029 |
-
model_constants["maskmem_pos_enc"] = maskmem_pos_enc
|
|
|
|
|
1030 |
else:
|
1031 |
maskmem_pos_enc = model_constants["maskmem_pos_enc"]
|
1032 |
# expand the cached maskmem_pos_enc to the actual batch size
|
|
|
107 |
inference_state["tracking_has_started"] = False
|
108 |
inference_state["frames_already_tracked"] = {}
|
109 |
# Warm up the visual backbone and cache the image feature on frame 0
|
110 |
+
# self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
|
111 |
return inference_state
|
112 |
|
113 |
@classmethod
|
|
|
470 |
size=(batch_size, self.hidden_dim),
|
471 |
fill_value=NO_OBJ_SCORE,
|
472 |
dtype=torch.float32,
|
473 |
+
device=inference_state["storage_device"],
|
474 |
),
|
475 |
"object_score_logits": torch.full(
|
476 |
size=(batch_size, 1),
|
|
|
478 |
# present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder`
|
479 |
fill_value=10.0,
|
480 |
dtype=torch.float32,
|
481 |
+
device=inference_state["storage_device"],
|
482 |
),
|
483 |
}
|
484 |
empty_mask_ptr = None
|
|
|
545 |
frame_idx=frame_idx,
|
546 |
batch_size=batch_size,
|
547 |
high_res_masks=high_res_masks,
|
548 |
+
object_score_logits=consolidated_out["object_score_logits"].to(
|
549 |
+
device, non_blocking=True
|
550 |
+
),
|
551 |
is_mask_from_pts=True, # these frames are what the user interacted with
|
552 |
)
|
553 |
consolidated_out["maskmem_features"] = maskmem_features
|
|
|
881 |
def _get_image_feature(self, inference_state, frame_idx, batch_size):
|
882 |
"""Compute the image features on a given frame."""
|
883 |
# Look up in the cache first
|
884 |
+
# image, backbone_out = inference_state["cached_features"].get(
|
885 |
+
# frame_idx, (None, None)
|
886 |
+
# )
|
887 |
+
image, backbone_out = None, None
|
888 |
if backbone_out is None:
|
889 |
# Cache miss -- we will run inference on a single image
|
890 |
device = inference_state["device"]
|
|
|
892 |
backbone_out = self.forward_image(image)
|
893 |
# Cache the most recent frame's feature (for repeated interactions with
|
894 |
# a frame; we can use an LRU cache for more frames in the future).
|
895 |
+
# inference_state["cached_features"] = {frame_idx: (image, backbone_out)}
|
896 |
|
897 |
# expand the features to have the same dimension as the number of objects
|
898 |
expanded_image = image.expand(batch_size, -1, -1, -1)
|
|
|
967 |
pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True)
|
968 |
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
|
969 |
maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out)
|
970 |
+
# object pointer is a small tensor, so we always keep it on GPU memory for fast access (modified for ZeroGPU)
|
971 |
+
obj_ptr = current_out["obj_ptr"].to(storage_device, non_blocking=True)
|
972 |
+
object_score_logits = current_out["object_score_logits"].to(
|
973 |
+
storage_device, non_blocking=True
|
974 |
+
)
|
975 |
# make a compact version of this frame's output to reduce the state size
|
976 |
compact_current_out = {
|
977 |
"maskmem_features": maskmem_features,
|
|
|
1023 |
`maskmem_pos_enc` is the same across frames and objects, so we cache it as
|
1024 |
a constant in the inference session to reduce session storage size.
|
1025 |
"""
|
1026 |
+
storage_device = inference_state["storage_device"]
|
1027 |
model_constants = inference_state["constants"]
|
1028 |
# "out_maskmem_pos_enc" should be either a list of tensors or None
|
1029 |
out_maskmem_pos_enc = current_out["maskmem_pos_enc"]
|
|
|
1032 |
assert isinstance(out_maskmem_pos_enc, list)
|
1033 |
# only take the slice for one object, since it's same across objects
|
1034 |
maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc]
|
1035 |
+
model_constants["maskmem_pos_enc"] = maskmem_pos_enc.to(
|
1036 |
+
storage_device, non_blocking=True
|
1037 |
+
)
|
1038 |
else:
|
1039 |
maskmem_pos_enc = model_constants["maskmem_pos_enc"]
|
1040 |
# expand the cached maskmem_pos_enc to the actual batch size
|