import logging import json import torch from torch import nn from .config import InternVideo2Config, EasyDict from .internvideo2 import pretrain_internvideo2_1b_patch14_224, pretrain_internvideo2_6b_patch14_224 from transformers.utils import logging import warnings warnings.filterwarnings("ignore") class InternVideo2_Stage2(nn.Module): """docstring for InternVideo2_Stage2""" def __init__(self, config, is_pretrain=True): super(InternVideo2_Stage2, self).__init__() # if isinstance(config, InternVideo2Config): # config_str = str(config) # config_str = config_str.replace('InternVideo2Config ', '') # config_json = json.loads(config_str) # config = EasyDict(config_json) # self.config = config self.config = config self.is_pretrain = is_pretrain self.vision_width = config.model.vision_encoder.clip_embed_dim # self.text_width = config.model.text_encoder.d_model self.embed_dim = config.model.embed_dim # create modules. self.vision_encoder = self.build_vision_encoder() if config.model.get("freeze_vision", False): self.freeze_vision() self.vision_proj = nn.Linear(self.vision_width, self.embed_dim) self.temp = nn.parameter.Parameter(torch.ones([]) * config.model.temp) self.uta_image_only = config.criterion.get('uta_image_only', False) # logger.info(f"uta_image_only={self.uta_image_only}") def freeze_vision(self): """freeze vision encoder""" for p in self.vision_encoder.parameters(): p.requires_grad = False def no_weight_decay(self): ret = {"temp"} ret.update( {"vision_encoder." + k for k in self.vision_encoder.no_weight_decay()} ) # ret.update( # {"text_encoder." + k for k in self.text_encoder.no_weight_decay()} # ) return ret @property def dtype(self): return self.vision_encoder.patch_embed.proj.weight.dtype def encode_vision(self, image): """encode image / videos as features. Args: image (torch.Tensor): The input images. Shape(B, N, C, H, W) test (bool): Whether testing. Returns: tuple. - vision_embeds (torch.Tensor): The output features. Shape: [B,N,C]. - pooled_vision_embeds (torch.Tensor): The pooled output features. Shape: [B,1,C]. - student_output (torch.Tensor): The features of alignment. Shape: [K,B,N,C]. - clip_output (torch.Tensor): The features of clip. Shape: [K,B,N,C]. """ T = image.shape[1] use_image = True if T == 1 else False image = image.permute(0, 2, 1, 3, 4) # [B,N,C,H,W] -> [B,C,N,H,W] # whether save temporal dimension # keep_temporal=self.config.model.vision_encoder.keep_temporal vision_embeds, pooled_vision_embeds, _, _ = self.vision_encoder( image, None, use_image) return vision_embeds, pooled_vision_embeds def build_vision_encoder(self): """build vision encoder Returns: (vision_encoder, clip_teacher). Each is a `nn.Module`. """ encoder_name = self.config.model.vision_encoder.name # logger.info(f"Build vision_encoder: {encoder_name}") if encoder_name == 'pretrain_internvideo2_1b_patch14_224': vision_encoder = pretrain_internvideo2_1b_patch14_224(self.config.model) elif encoder_name == 'pretrain_internvideo2_6b_patch14_224': vision_encoder = pretrain_internvideo2_6b_patch14_224(self.config.model) else: raise ValueError(f"Not implemented: {encoder_name}") return vision_encoder