|
import os |
|
import torch |
|
import requests |
|
from tqdm import tqdm |
|
from torchvision import transforms |
|
from .videomaev2_finetune import vit_giant_patch14_224 |
|
|
|
def to_normalized_float_tensor(vid): |
|
return vid.permute(3, 0, 1, 2).to(torch.float32) / 255 |
|
|
|
|
|
|
|
|
|
|
|
def resize(vid, size, interpolation='bilinear'): |
|
|
|
|
|
scale = None |
|
if isinstance(size, int): |
|
scale = float(size) / min(vid.shape[-2:]) |
|
size = None |
|
return torch.nn.functional.interpolate( |
|
vid, |
|
size=size, |
|
scale_factor=scale, |
|
mode=interpolation, |
|
align_corners=False) |
|
|
|
|
|
class ToFloatTensorInZeroOne(object): |
|
def __call__(self, vid): |
|
return to_normalized_float_tensor(vid) |
|
|
|
|
|
class Resize(object): |
|
def __init__(self, size): |
|
self.size = size |
|
def __call__(self, vid): |
|
return resize(vid, self.size) |
|
|
|
def preprocess_videomae(videos): |
|
transform = transforms.Compose( |
|
[ToFloatTensorInZeroOne(), |
|
Resize((224, 224))]) |
|
return torch.stack([transform(f) for f in torch.from_numpy(videos)]) |
|
|
|
|
|
def load_videomae_model(device, ckpt_path=None): |
|
if ckpt_path is None: |
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
ckpt_path = os.path.join(current_dir, 'vit_g_hybrid_pt_1200e_ssv2_ft.pth') |
|
|
|
if not os.path.exists(ckpt_path): |
|
|
|
ckpt_url = 'https://pjlab-gvm-data.oss-cn-shanghai.aliyuncs.com/internvideo/videomaev2/vit_g_hybrid_pt_1200e_ssv2_ft.pth' |
|
response = requests.get(ckpt_url, stream=True, allow_redirects=True) |
|
total_size = int(response.headers.get("content-length", 0)) |
|
block_size = 1024 |
|
|
|
with tqdm(total=total_size, unit="B", unit_scale=True) as progress_bar: |
|
with open(ckpt_path, "wb") as fw: |
|
for data in response.iter_content(block_size): |
|
progress_bar.update(len(data)) |
|
fw.write(data) |
|
|
|
model = vit_giant_patch14_224( |
|
img_size=224, |
|
pretrained=False, |
|
num_classes=174, |
|
all_frames=16, |
|
tubelet_size=2, |
|
drop_path_rate=0.3, |
|
use_mean_pooling=True) |
|
|
|
ckpt = torch.load(ckpt_path, map_location='cpu') |
|
for model_key in ['model', 'module']: |
|
if model_key in ckpt: |
|
ckpt = ckpt[model_key] |
|
break |
|
model.load_state_dict(ckpt) |
|
return model.to(device) |