|
import math |
|
import os |
|
import time |
|
|
|
import numpy as np |
|
import pytorch_lightning as pl |
|
import torch |
|
import torch.distributed as dist |
|
import torch.nn as nn |
|
from pytorch_lightning.utilities.rank_zero import rank_zero_info |
|
from timm.models import create_model |
|
from transformers import AutoTokenizer, BertTokenizer, XLMRobertaTokenizer |
|
from vlmo.modules import heads, objectives, vlmo_utils |
|
from vlmo.tokenizer.tokenization_glm import GLMChineseTokenizer |
|
from vlmo.torchscale.architecture.encoder import Encoder |
|
from vlmo.torchscale.model.BEiT3 import BEiT3 as ts_backbone |
|
from vlmo.transforms.utils import inception_normalize as img_norm |
|
|
|
from .modeling_utils import _get_base_config, _get_large_config, _get_huge_config, trunc_normal_ |
|
|
|
|
|
def convert_pl_ckpt(state_dict, num_visual_token=197): |
|
print("start convert_pl_ckpt!!!") |
|
new_state_dict = {} |
|
for key in state_dict: |
|
value = state_dict[key] |
|
if "visual_tokenizer" in key: |
|
continue |
|
elif "backbone.encoder.embed_positions.A.weight" in key: |
|
if value.shape[0] < num_visual_token + 2: |
|
N = value.shape[0] - 3 |
|
dim = value.shape[-1] |
|
class_pos_embed = value[:3, ] |
|
patch_pos_embed = value[3:, ] |
|
w0, h0 = int(math.sqrt(num_visual_token - 1)), int(math.sqrt(num_visual_token - 1)) |
|
patch_pos_embed = patch_pos_embed.float() |
|
patch_pos_embed = nn.functional.interpolate( |
|
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), |
|
size=(w0, h0), |
|
mode="area", |
|
) |
|
patch_pos_embed = patch_pos_embed.to(class_pos_embed.dtype) |
|
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(-1, dim) |
|
new_value = torch.cat((class_pos_embed, patch_pos_embed), dim=0) |
|
new_state_dict[key] = new_value |
|
print("reshape ", key, "raw shape: ", value.shape, "new shape: ", new_value.shape, num_visual_token) |
|
elif value.shape[0] > num_visual_token + 2: |
|
new_state_dict[key] = value[: num_visual_token + 2, :] |
|
print("first ", key, "raw shape: ", value.shape, new_state_dict[key].shape, num_visual_token) |
|
else: |
|
new_state_dict[key] = value |
|
print("raw shape") |
|
else: |
|
new_state_dict[key] = state_dict[key] |
|
|
|
return new_state_dict |
|
|
|
|
|
def convert_deepspeed_ckpt(state_dict, num_visual_token=197): |
|
new_state_dict = {} |
|
for key in state_dict: |
|
if key.startswith("_forward_module."): |
|
new_key = key[len("_forward_module."):] |
|
value = state_dict[key] |
|
new_state_dict[new_key] = value |
|
if "visual_tokenizer.encoder.pos_embed" in new_key or "visual_tokenizer.decoder.pos_embed" in new_key: |
|
if value.shape[1] != num_visual_token: |
|
N = value.shape[1] - 1 |
|
dim = value.shape[-1] |
|
class_pos_embed = value[:, 0] |
|
patch_pos_embed = value[:, 1:] |
|
w0, h0 = int(math.sqrt(num_visual_token - 1)), int(math.sqrt(num_visual_token - 1)) |
|
patch_pos_embed = patch_pos_embed.float() |
|
patch_pos_embed = nn.functional.interpolate( |
|
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), |
|
size=(w0, h0), |
|
mode="area", |
|
) |
|
patch_pos_embed = patch_pos_embed.to(class_pos_embed.dtype) |
|
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) |
|
new_value = torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) |
|
new_state_dict[new_key] = new_value |
|
print("reshape ", new_key, "raw shape: ", value.shape, "new_shape: ", new_value.shape) |
|
if "backbone.encoder.embed_positions.A.weight" in new_key: |
|
if value.shape[1] != num_visual_token + 2: |
|
N = value.shape[0] - 3 |
|
dim = value.shape[-1] |
|
class_pos_embed = value[:3, ] |
|
patch_pos_embed = value[3:, ] |
|
w0, h0 = int(math.sqrt(num_visual_token - 1)), int(math.sqrt(num_visual_token - 1)) |
|
patch_pos_embed = patch_pos_embed.float() |
|
patch_pos_embed = nn.functional.interpolate( |
|
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), |
|
size=(w0, h0), |
|
mode="area", |
|
) |
|
patch_pos_embed = patch_pos_embed.to(class_pos_embed.dtype) |
|
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(-1, dim) |
|
new_value = torch.cat((class_pos_embed, patch_pos_embed), dim=0) |
|
new_state_dict[new_key] = new_value |
|
print("reshape ", new_key, "raw shape: ", value.shape, "new_shape: ", new_value.shape) |
|
|
|
else: |
|
new_state_dict[key] = state_dict[key] |
|
|
|
return new_state_dict |
|
|
|
|
|
def get_visual_tokenizer(config): |
|
tokenizer_name = config["tokenizer_model"] |
|
print(f"Creating visual tokenizer: {tokenizer_name}") |
|
model = create_model( |
|
config["tokenizer_model"], |
|
img_size=config["image_size"], |
|
n_code=config["codebook_size"], |
|
code_dim=config["codebook_dim"], |
|
).eval() |
|
return model |
|
|
|
|
|
def get_pretrained_tokenizer(tokenizer_type, from_pretrained): |
|
_Tokenizer = eval(f"{tokenizer_type}") |
|
if torch.distributed.is_initialized(): |
|
if torch.distributed.get_rank() == 0: |
|
_Tokenizer.from_pretrained(from_pretrained) |
|
torch.distributed.barrier() |
|
return _Tokenizer.from_pretrained(from_pretrained) |
|
|
|
|
|
class VLMo(pl.LightningModule): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.save_hyperparameters() |
|
s_t = time.time() |
|
|
|
|
|
self.img_size = config["image_size"] |
|
if not config["test_only"]: |
|
self.visual_tokenizer = get_visual_tokenizer(config) |
|
kwargs = {} |
|
if "encoder_attention_heads" in config: |
|
kwargs["encoder_attention_heads"] = config["encoder_attention_heads"] |
|
if "atorch_config" in config and config["atorch_config"]: |
|
checkpoint_activations = False |
|
else: |
|
checkpoint_activations = config["checkpoint_activations"] |
|
args = eval(f'_get_{config["beit_version"]}_config')( |
|
img_size=config["image_size"], |
|
patch_size=config["patch_size"], |
|
vocab_size=config["vocab_size"], |
|
encoder_layers=config["encoder_layers"], |
|
encoder_embed_dim=config["encoder_embed_dim"], |
|
checkpoint_activations=checkpoint_activations, |
|
share_layer=config["share_layer"], |
|
share_attn=config["share_attn"], |
|
deepnorm=config["deepnorm"], |
|
mask_ratio=config["mask_ratio"], |
|
max_text_len=config["max_text_len"], |
|
one_attn=config["one_attn"], |
|
**kwargs, |
|
) |
|
self.num_features = args.encoder_embed_dim |
|
self.out_features = config["out_embed_dim"] |
|
self.cap_onlytext = config["cap_onlytext"] |
|
self.lang = config["lang"] |
|
self.num_frames = config["num_frames"] |
|
self.tokenizer_type = config["tokenizer_type"] |
|
self.text_tokenizer = get_pretrained_tokenizer(self.tokenizer_type, from_pretrained=config["tokenizer"]) |
|
print("BEiT args", args.__dict__) |
|
self.backbone = ts_backbone(args) |
|
|
|
self.use_vl = config["beit3_vl_layers"] > 0 |
|
if self.use_vl: |
|
args.encoder_layers = config["beit3_vl_layers"] |
|
self.backbone_vl = Encoder(args) |
|
|
|
self.norm = nn.LayerNorm(self.num_features, eps=1e-6) |
|
|
|
|
|
self.pooler = heads.Pooler(self.num_features) |
|
self.pooler.apply(objectives.init_weights) |
|
|
|
|
|
if config["loss_names"]["itc"] > 0: |
|
self.itc_text_proj = heads.ITCHead(self.num_features, self.out_features) |
|
self.itc_image_proj = heads.ITCHead(self.num_features, self.out_features) |
|
self.itc_text_proj.apply(objectives.init_weights) |
|
self.itc_image_proj.apply(objectives.init_weights) |
|
|
|
self.itc_vl_text_proj = heads.ITCHead(self.num_features, self.out_features) |
|
self.itc_vl_image_proj = heads.ITCHead(self.num_features, self.out_features) |
|
self.itc_vl_text_proj.apply(objectives.init_weights) |
|
self.itc_vl_image_proj.apply(objectives.init_weights) |
|
|
|
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) |
|
self.logit_vl_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) |
|
|
|
lp_s_t = time.time() |
|
|
|
self.load_pretrained_weight() |
|
load_pretrain_time = time.time() - lp_s_t |
|
|
|
self.current_tasks = list() |
|
|
|
|
|
|
|
if self.hparams.config["load_path"] != "" and self.hparams.config["test_only"]: |
|
rank_zero_info("Load ckpt from: {}".format(self.hparams.config["load_path"])) |
|
ckpt = torch.load(self.hparams.config["load_path"], map_location="cpu") |
|
|
|
state_dict = None |
|
|
|
for state_dict_key in ("state_dict", "module", "model"): |
|
if state_dict_key in ckpt: |
|
rank_zero_info("Read state dict from ckpt[%s]. " % state_dict_key) |
|
state_dict = ckpt[state_dict_key] |
|
break |
|
if state_dict_key == "module": |
|
state_dict = convert_deepspeed_ckpt(state_dict, self.backbone.vision_embed.num_position_embeddings()) |
|
if state_dict_key == "state_dict": |
|
state_dict = convert_pl_ckpt(state_dict, self.backbone.vision_embed.num_position_embeddings()) |
|
if state_dict is None: |
|
if list(ckpt.keys())[0].startswith('_forward_module.'): |
|
rank_zero_info("Read state dict from ckpt with _forward_module prefix. ") |
|
state_dict = convert_deepspeed_ckpt(ckpt, self.backbone.vision_embed.num_position_embeddings()) |
|
else: |
|
rank_zero_info("Read state dict from ckpt. ") |
|
state_dict = ckpt |
|
|
|
missing_keys, unexpected_keys = self.load_state_dict(state_dict, strict=False) |
|
rank_zero_info("missing_keys: {}".format(missing_keys)) |
|
rank_zero_info("unexpected_keys: {}".format(unexpected_keys)) |
|
|
|
construct_time = time.time() - s_t |
|
print( |
|
f"Process {os.getpid()}. VLMo Constructor time: {construct_time}s;", |
|
f"load_pretrain_time: {load_pretrain_time}s", |
|
flush=True, |
|
) |
|
|
|
self._coalesce_backbone = config["coalesce_backbone"] |
|
self._mask_data = config["mask_data"] |
|
self._backbone_inputs = {} |
|
self._backbone_inputs_current_size = 0 |
|
self._backbone_inputs_keys = {} |
|
self._backbone_outputs = None |
|
self._default_attn_masks = {} |
|
self._itc_group = None |
|
self._itc_aggregate_dict = None |
|
self._itc_mask = config["itc_mask"] |
|
self._local_loss = config["local_loss"] |
|
self._aggregate_nodes = config["aggregate_nodes"] |
|
self.accumulated_batches_reached = False |
|
vlmo_utils.set_task(self) |
|
self._only_itc_single_machine = ( |
|
self._aggregate_nodes > 0 and len(self.current_tasks) == 1 and "itc" in self.current_tasks |
|
) |
|
self._split_data_for_imagemlm = config["split_data_for_imagemlm"] |
|
self.log_metric_steps = config["log_metric_steps"] |
|
|
|
def _init_weights(self, m): |
|
if isinstance(m, nn.Linear): |
|
trunc_normal_(m.weight, std=0.02) |
|
if isinstance(m, nn.Linear) and m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
elif isinstance(m, nn.LayerNorm): |
|
nn.init.constant_(m.bias, 0) |
|
nn.init.constant_(m.weight, 1.0) |
|
|
|
def fix_init_weight(self): |
|
def rescale(param, layer_id): |
|
param.div_(math.sqrt(2.0 * layer_id)) |
|
|
|
for layer_id, layer in enumerate(self.backbone.encoder.layers): |
|
rescale(layer.self_attn.v_proj.A.weight.data, layer_id + 1) |
|
rescale(layer.self_attn.v_proj.B.weight.data, layer_id + 1) |
|
rescale(layer.self_attn.out_proj.A.weight.data, layer_id + 1) |
|
rescale(layer.self_attn.out_proj.B.weight.data, layer_id + 1) |
|
rescale(layer.ffn.A.fc2.weight.data, layer_id + 1) |
|
rescale(layer.ffn.B.fc2.weight.data, layer_id + 1) |
|
|
|
if self.use_vl: |
|
pre_layers = len(self.backbone.encoder.layers) + 1 |
|
for layer_id, layer in enumerate(self.backbone_vl.layers): |
|
rescale(layer.self_attn.v_proj.A.weight.data, layer_id + pre_layers) |
|
rescale(layer.self_attn.v_proj.B.weight.data, layer_id + pre_layers) |
|
rescale(layer.self_attn.out_proj.A.weight.data, layer_id + pre_layers) |
|
rescale(layer.self_attn.out_proj.B.weight.data, layer_id + pre_layers) |
|
rescale(layer.ffn.A.fc2.weight.data, layer_id + pre_layers) |
|
rescale(layer.ffn.B.fc2.weight.data, layer_id + pre_layers) |
|
|
|
def load_pretrained_weight(self): |
|
if self.hparams.config["load_path"] != "" and not self.hparams.config["test_only"]: |
|
config = self.hparams.config |
|
ckpt = torch.load(self.hparams.config["load_path"], map_location="cpu") |
|
rank_zero_info("Load ckpt from: {}".format(self.hparams.config["load_path"])) |
|
|
|
state_dict = None |
|
|
|
for state_dict_key in ("state_dict", "module", "model"): |
|
if state_dict_key in ckpt: |
|
rank_zero_info("Read state dict from ckpt[%s]. " % state_dict_key) |
|
state_dict = ckpt[state_dict_key] |
|
break |
|
if state_dict_key == "module": |
|
state_dict = convert_deepspeed_ckpt(state_dict, self.backbone.vision_embed.num_position_embeddings()) |
|
if state_dict_key == "state_dict": |
|
state_dict = convert_pl_ckpt(state_dict, self.backbone.vision_embed.num_position_embeddings()) |
|
if state_dict is None: |
|
if list(ckpt.keys())[0].startswith('_forward_module.'): |
|
rank_zero_info("Read state dict from ckpt with _forward_module prefix. ") |
|
state_dict = convert_deepspeed_ckpt(ckpt, |
|
self.backbone.vision_embed.num_position_embeddings()) |
|
else: |
|
rank_zero_info("Read state dict from ckpt. ") |
|
state_dict = ckpt |
|
|
|
missing_keys, unexpected_keys = self.load_state_dict(state_dict, strict=False) |
|
missing_keys = [k for k in missing_keys if "itc_teacher" not in k] |
|
rank_zero_info("missing_keys: {}".format(missing_keys)) |
|
rank_zero_info("unexpected_keys: {}".format(unexpected_keys)) |
|
|
|
def infer_text( |
|
self, |
|
batch, |
|
mask_text=False, |
|
): |
|
do_mlm = "_mlm" if mask_text else "" |
|
text_ids = batch[f"text_ids{do_mlm}"] |
|
text_labels = batch[f"text_labels{do_mlm}"] |
|
text_masks = batch[f"text_masks"] |
|
text_embed = self.backbone.text_embed(text_ids) |
|
text_padding_position = 1 - text_masks |
|
lffn_hiddens = self.backbone( |
|
textual_tokens=text_ids, |
|
text_padding_position=text_padding_position, |
|
)["encoder_out"] |
|
vlffn_hiddens = self.backbone_vl( |
|
src_tokens=None, |
|
token_embeddings=lffn_hiddens, |
|
encoder_padding_mask=text_padding_position, |
|
multiway_split_position=-1, |
|
)["encoder_out"] |
|
|
|
cls_feats = self.itc_text_proj(lffn_hiddens[:, 0]) |
|
cls_feats = cls_feats / cls_feats.norm(dim=-1, keepdim=True) |
|
|
|
cls_vlffn_feats = self.itc_vl_text_proj(vlffn_hiddens[:, 0]) |
|
cls_vlffn_feats = cls_vlffn_feats / cls_vlffn_feats.norm(dim=-1, keepdim=True) |
|
|
|
ret = { |
|
"cls_feats": cls_feats, |
|
"cls_vlffn_feats": cls_vlffn_feats, |
|
"text_embed": text_embed, |
|
} |
|
|
|
return ret |
|
|
|
def infer_image( |
|
self, |
|
batch, |
|
mask_image=False, |
|
image_token_type_idx=1, |
|
image_embeds=None, |
|
image_masks=None, |
|
): |
|
if f"image_{image_token_type_idx - 1}" in batch: |
|
imgkey = f"image_{image_token_type_idx - 1}" |
|
else: |
|
imgkey = "image" |
|
|
|
img = batch[imgkey][0] |
|
if mask_image: |
|
image_masks = batch[f"{imgkey}_masks"][0].flatten(1) |
|
|
|
with torch.no_grad(): |
|
img = self.visual_tokenizer.pre_process(img) |
|
quantize, embed_ind, _ = self.visual_tokenizer.encode(img) |
|
image_ids = embed_ind.view(img.shape[0], -1) |
|
|
|
image_labels = torch.full_like(image_ids, -100) |
|
bool_masked_pos = image_masks.to(torch.bool) |
|
image_labels[bool_masked_pos] = image_ids[bool_masked_pos] |
|
|
|
img_tensor = img_norm(img) |
|
vffn_hiddens = self.backbone(visual_tokens=img_tensor)["encoder_out"] |
|
vlffn_hiddens = self.backbone_vl( |
|
src_tokens=None, |
|
token_embeddings=vffn_hiddens, |
|
multiway_split_position=-1, |
|
)["encoder_out"] |
|
|
|
cls_feats = self.itc_image_proj(vffn_hiddens[:, 0]) |
|
cls_feats = cls_feats / cls_feats.norm(dim=-1, keepdim=True) |
|
|
|
cls_vlffn_feats = self.itc_vl_image_proj(vlffn_hiddens[:, 0]) |
|
cls_vlffn_feats = cls_vlffn_feats / cls_vlffn_feats.norm(dim=-1, keepdim=True) |
|
|
|
ret = { |
|
"image_feats": vffn_hiddens, |
|
"cls_feats": cls_feats, |
|
"cls_vlffn_feats": cls_vlffn_feats, |
|
} |
|
|
|
return ret |
|
|