|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import time |
|
|
|
import numpy as np |
|
import torch |
|
from tqdm import tqdm |
|
|
|
from sam2.build_sam import build_sam2_video_predictor |
|
|
|
|
|
assert torch.cuda.is_available() |
|
device = torch.device("cuda") |
|
|
|
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() |
|
if torch.cuda.get_device_properties(0).major >= 8: |
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
torch.backends.cudnn.allow_tf32 = True |
|
|
|
|
|
sam2_checkpoint = "checkpoints/sam2.1_hiera_base_plus.pt" |
|
model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml" |
|
|
|
|
|
predictor = build_sam2_video_predictor( |
|
model_cfg, sam2_checkpoint, device=device, vos_optimized=True |
|
) |
|
|
|
|
|
|
|
video_dir = "notebooks/videos/bedroom" |
|
|
|
frame_names = [ |
|
p |
|
for p in os.listdir(video_dir) |
|
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] |
|
] |
|
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) |
|
inference_state = predictor.init_state(video_path=video_dir) |
|
|
|
|
|
|
|
warm_up, runs = 5, 25 |
|
verbose = True |
|
num_frames = len(frame_names) |
|
total, count = 0, 0 |
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
ann_frame_idx, ann_obj_id = 0, 1 |
|
|
|
|
|
points = np.array([[210, 350]], dtype=np.float32) |
|
labels = np.array([1], np.int32) |
|
|
|
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box( |
|
inference_state=inference_state, |
|
frame_idx=ann_frame_idx, |
|
obj_id=ann_obj_id, |
|
points=points, |
|
labels=labels, |
|
) |
|
|
|
|
|
with torch.autocast("cuda", torch.bfloat16): |
|
with torch.inference_mode(): |
|
for i in tqdm(range(runs), disable=not verbose, desc="Benchmarking"): |
|
start = time.time() |
|
|
|
for ( |
|
out_frame_idx, |
|
out_obj_ids, |
|
out_mask_logits, |
|
) in predictor.propagate_in_video(inference_state): |
|
pass |
|
|
|
end = time.time() |
|
total += end - start |
|
count += 1 |
|
if i == warm_up - 1: |
|
print("Warmup FPS: ", count * num_frames / total) |
|
total = 0 |
|
count = 0 |
|
|
|
print("FPS: ", count * num_frames / total) |
|
|