File size: 2,280 Bytes
edf2ce7
fd3be3e
edf2ce7
fd3be3e
 
63dc69f
fd3be3e
 
63dc69f
 
fd3be3e
 
 
 
 
 
 
 
 
 
63dc69f
edf2ce7
fd3be3e
 
 
 
63dc69f
fd3be3e
63dc69f
fd3be3e
63dc69f
 
 
 
 
 
 
 
fd3be3e
63dc69f
 
fd3be3e
63dc69f
fd3be3e
63dc69f
 
 
 
 
 
 
fd3be3e
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from .internvideo2_stage2 import InternVideo2_Stage2 as IV2S2
from transformers import PretrainedConfig, PreTrainedModel, AutoModel, AutoConfig
from .config import InternVideo2Config as config
import warnings
import torch
# from transformers.utils import logging
warnings.filterwarnings("ignore")

# logging.set_verbosity_error()

# model_config = config()
# model = IV2S2(model_config)
# print(model)

class InternVideo2Stage2VideoEncoder(PreTrainedModel):
    config_class = config

    def __init__(self, config):
        super().__init__(config)
        self.config = config
        # print(self.config.model.vision_encoder.num_frames)
        self.model = IV2S2(self.config).to('cpu').to(torch.float16)

    def forward(self, x: torch.tensor):
        """forward pass 
        Args:
            x (torch.tensor): Shape (B, N, C, H, W) or (B, C, H, W)
        Returns:
            torch.tensor: Shape (B*N, hidden_size) or (B, hidden_size)
        """
        if len(x.shape) == 5 and x.shape[1] > 8:
            ## There is no way, the weight limits the number of input frames to be less than or equal to 8. 
            ## Forgive me for dealing with input frames greater than 8 in such a stupid way. T^T
            T = x.shape[1]
            embs = torch.cat([self.forward(x[:, i:i+8, :, :, :])for i in range(0, T, 8)], dim=1)
            return embs
        
        image = False
        if len(x.shape) == 4:
            x = x.unsqueeze(1)
            image = True
        B, N, C, H, W = x.shape
        # x = x.permute(0, 2, 1, 3, 4)                 # Shape(B, N, C, H, W)
        output = self.model.encode_vision(x)   
        pooled_vision_embeds = output[1]               # Shape(B, N*256 + 1, Hidden_size)
        output = pooled_vision_embeds[:, :256*N, :]    # Shape(B, N*256, Hidden_size)
        output = output.reshape(B, N, 256, -1)         # Shape(B, N, 256, Hidden_size)
        output = output.mean(dim=2)                    # Shape(B, N, Hidden_size)
        if image:
            output = output.squeeze(1)
        return output
    
if __name__ == "__main__":
    model_config = config()
    model = InternVideo2Stage2VideoEncoder(model_config)
    x = torch.randn(2, 3, 8, 224, 224, dtype=torch.float16).to(model_config.device)
    output = model(x)