File size: 2,280 Bytes
edf2ce7 fd3be3e edf2ce7 fd3be3e 63dc69f fd3be3e 63dc69f fd3be3e 63dc69f edf2ce7 fd3be3e 63dc69f fd3be3e 63dc69f fd3be3e 63dc69f fd3be3e 63dc69f fd3be3e 63dc69f fd3be3e 63dc69f fd3be3e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
from .internvideo2_stage2 import InternVideo2_Stage2 as IV2S2
from transformers import PretrainedConfig, PreTrainedModel, AutoModel, AutoConfig
from .config import InternVideo2Config as config
import warnings
import torch
# from transformers.utils import logging
warnings.filterwarnings("ignore")
# logging.set_verbosity_error()
# model_config = config()
# model = IV2S2(model_config)
# print(model)
class InternVideo2Stage2VideoEncoder(PreTrainedModel):
config_class = config
def __init__(self, config):
super().__init__(config)
self.config = config
# print(self.config.model.vision_encoder.num_frames)
self.model = IV2S2(self.config).to('cpu').to(torch.float16)
def forward(self, x: torch.tensor):
"""forward pass
Args:
x (torch.tensor): Shape (B, N, C, H, W) or (B, C, H, W)
Returns:
torch.tensor: Shape (B*N, hidden_size) or (B, hidden_size)
"""
if len(x.shape) == 5 and x.shape[1] > 8:
## There is no way, the weight limits the number of input frames to be less than or equal to 8.
## Forgive me for dealing with input frames greater than 8 in such a stupid way. T^T
T = x.shape[1]
embs = torch.cat([self.forward(x[:, i:i+8, :, :, :])for i in range(0, T, 8)], dim=1)
return embs
image = False
if len(x.shape) == 4:
x = x.unsqueeze(1)
image = True
B, N, C, H, W = x.shape
# x = x.permute(0, 2, 1, 3, 4) # Shape(B, N, C, H, W)
output = self.model.encode_vision(x)
pooled_vision_embeds = output[1] # Shape(B, N*256 + 1, Hidden_size)
output = pooled_vision_embeds[:, :256*N, :] # Shape(B, N*256, Hidden_size)
output = output.reshape(B, N, 256, -1) # Shape(B, N, 256, Hidden_size)
output = output.mean(dim=2) # Shape(B, N, Hidden_size)
if image:
output = output.squeeze(1)
return output
if __name__ == "__main__":
model_config = config()
model = InternVideo2Stage2VideoEncoder(model_config)
x = torch.randn(2, 3, 8, 224, 224, dtype=torch.float16).to(model_config.device)
output = model(x) |