M2_Encoder_Large / vlmo /modules /vlmo_module.py
acai66's picture
Upload folder using huggingface_hub
3440f83 verified
import math
import os
import time
import numpy as np
import pytorch_lightning as pl
import torch
import torch.distributed as dist
import torch.nn as nn
from pytorch_lightning.utilities.rank_zero import rank_zero_info
from timm.models import create_model
from transformers import AutoTokenizer, BertTokenizer, XLMRobertaTokenizer # noqa
from vlmo.modules import heads, objectives, vlmo_utils
from vlmo.tokenizer.tokenization_glm import GLMChineseTokenizer # noqa
from vlmo.torchscale.architecture.encoder import Encoder
from vlmo.torchscale.model.BEiT3 import BEiT3 as ts_backbone
from vlmo.transforms.utils import inception_normalize as img_norm
from .modeling_utils import _get_base_config, _get_large_config, _get_huge_config, trunc_normal_ # noqa
def convert_pl_ckpt(state_dict, num_visual_token=197):
print("start convert_pl_ckpt!!!")
new_state_dict = {}
for key in state_dict:
value = state_dict[key]
if "visual_tokenizer" in key:
continue
elif "backbone.encoder.embed_positions.A.weight" in key:
if value.shape[0] < num_visual_token + 2:
N = value.shape[0] - 3
dim = value.shape[-1]
class_pos_embed = value[:3, ]
patch_pos_embed = value[3:, ]
w0, h0 = int(math.sqrt(num_visual_token - 1)), int(math.sqrt(num_visual_token - 1))
patch_pos_embed = patch_pos_embed.float()
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
size=(w0, h0),
mode="area",
)
patch_pos_embed = patch_pos_embed.to(class_pos_embed.dtype)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(-1, dim)
new_value = torch.cat((class_pos_embed, patch_pos_embed), dim=0)
new_state_dict[key] = new_value
print("reshape ", key, "raw shape: ", value.shape, "new shape: ", new_value.shape, num_visual_token)
elif value.shape[0] > num_visual_token + 2:
new_state_dict[key] = value[: num_visual_token + 2, :]
print("first ", key, "raw shape: ", value.shape, new_state_dict[key].shape, num_visual_token)
else:
new_state_dict[key] = value
print("raw shape")
else:
new_state_dict[key] = state_dict[key]
return new_state_dict
def convert_deepspeed_ckpt(state_dict, num_visual_token=197):
new_state_dict = {}
for key in state_dict:
if key.startswith("_forward_module."):
new_key = key[len("_forward_module."):]
value = state_dict[key]
new_state_dict[new_key] = value
if "visual_tokenizer.encoder.pos_embed" in new_key or "visual_tokenizer.decoder.pos_embed" in new_key:
if value.shape[1] != num_visual_token:
N = value.shape[1] - 1
dim = value.shape[-1]
class_pos_embed = value[:, 0]
patch_pos_embed = value[:, 1:]
w0, h0 = int(math.sqrt(num_visual_token - 1)), int(math.sqrt(num_visual_token - 1))
patch_pos_embed = patch_pos_embed.float()
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
size=(w0, h0),
mode="area",
)
patch_pos_embed = patch_pos_embed.to(class_pos_embed.dtype)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
new_value = torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
new_state_dict[new_key] = new_value
print("reshape ", new_key, "raw shape: ", value.shape, "new_shape: ", new_value.shape)
if "backbone.encoder.embed_positions.A.weight" in new_key:
if value.shape[1] != num_visual_token + 2:
N = value.shape[0] - 3
dim = value.shape[-1]
class_pos_embed = value[:3, ]
patch_pos_embed = value[3:, ]
w0, h0 = int(math.sqrt(num_visual_token - 1)), int(math.sqrt(num_visual_token - 1))
patch_pos_embed = patch_pos_embed.float()
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
size=(w0, h0),
mode="area",
)
patch_pos_embed = patch_pos_embed.to(class_pos_embed.dtype)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(-1, dim)
new_value = torch.cat((class_pos_embed, patch_pos_embed), dim=0)
new_state_dict[new_key] = new_value
print("reshape ", new_key, "raw shape: ", value.shape, "new_shape: ", new_value.shape)
else:
new_state_dict[key] = state_dict[key]
return new_state_dict
def get_visual_tokenizer(config):
tokenizer_name = config["tokenizer_model"]
print(f"Creating visual tokenizer: {tokenizer_name}")
model = create_model(
config["tokenizer_model"],
img_size=config["image_size"],
n_code=config["codebook_size"],
code_dim=config["codebook_dim"],
).eval()
return model
def get_pretrained_tokenizer(tokenizer_type, from_pretrained):
_Tokenizer = eval(f"{tokenizer_type}")
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
_Tokenizer.from_pretrained(from_pretrained)
torch.distributed.barrier()
return _Tokenizer.from_pretrained(from_pretrained)
class VLMo(pl.LightningModule):
def __init__(self, config):
super().__init__()
self.save_hyperparameters()
s_t = time.time()
# tokenizer & backbone
self.img_size = config["image_size"]
if not config["test_only"]:
self.visual_tokenizer = get_visual_tokenizer(config)
kwargs = {}
if "encoder_attention_heads" in config:
kwargs["encoder_attention_heads"] = config["encoder_attention_heads"]
if "atorch_config" in config and config["atorch_config"]:
checkpoint_activations = False # ?
else:
checkpoint_activations = config["checkpoint_activations"]
args = eval(f'_get_{config["beit_version"]}_config')(
img_size=config["image_size"],
patch_size=config["patch_size"],
vocab_size=config["vocab_size"],
encoder_layers=config["encoder_layers"],
encoder_embed_dim=config["encoder_embed_dim"],
checkpoint_activations=checkpoint_activations,
share_layer=config["share_layer"],
share_attn=config["share_attn"],
deepnorm=config["deepnorm"],
mask_ratio=config["mask_ratio"],
max_text_len=config["max_text_len"],
one_attn=config["one_attn"],
**kwargs,
)
self.num_features = args.encoder_embed_dim
self.out_features = config["out_embed_dim"]
self.cap_onlytext = config["cap_onlytext"]
self.lang = config["lang"]
self.num_frames = config["num_frames"]
self.tokenizer_type = config["tokenizer_type"]
self.text_tokenizer = get_pretrained_tokenizer(self.tokenizer_type, from_pretrained=config["tokenizer"]) # noqa
print("BEiT args", args.__dict__)
self.backbone = ts_backbone(args)
self.use_vl = config["beit3_vl_layers"] > 0
if self.use_vl:
args.encoder_layers = config["beit3_vl_layers"]
self.backbone_vl = Encoder(args)
self.norm = nn.LayerNorm(self.num_features, eps=1e-6)
# task layers
self.pooler = heads.Pooler(self.num_features)
self.pooler.apply(objectives.init_weights)
# contrastive loss (or sampling for global hard negative)
if config["loss_names"]["itc"] > 0:
self.itc_text_proj = heads.ITCHead(self.num_features, self.out_features)
self.itc_image_proj = heads.ITCHead(self.num_features, self.out_features)
self.itc_text_proj.apply(objectives.init_weights)
self.itc_image_proj.apply(objectives.init_weights)
self.itc_vl_text_proj = heads.ITCHead(self.num_features, self.out_features)
self.itc_vl_image_proj = heads.ITCHead(self.num_features, self.out_features)
self.itc_vl_text_proj.apply(objectives.init_weights)
self.itc_vl_image_proj.apply(objectives.init_weights)
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.logit_vl_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
lp_s_t = time.time()
self.load_pretrained_weight()
load_pretrain_time = time.time() - lp_s_t
self.current_tasks = list()
# ===================== load downstream (test_only) ======================
if self.hparams.config["load_path"] != "" and self.hparams.config["test_only"]:
rank_zero_info("Load ckpt from: {}".format(self.hparams.config["load_path"]))
ckpt = torch.load(self.hparams.config["load_path"], map_location="cpu")
state_dict = None
for state_dict_key in ("state_dict", "module", "model"):
if state_dict_key in ckpt:
rank_zero_info("Read state dict from ckpt[%s]. " % state_dict_key)
state_dict = ckpt[state_dict_key]
break
if state_dict_key == "module":
state_dict = convert_deepspeed_ckpt(state_dict, self.backbone.vision_embed.num_position_embeddings())
if state_dict_key == "state_dict":
state_dict = convert_pl_ckpt(state_dict, self.backbone.vision_embed.num_position_embeddings())
if state_dict is None:
if list(ckpt.keys())[0].startswith('_forward_module.'):
rank_zero_info("Read state dict from ckpt with _forward_module prefix. ")
state_dict = convert_deepspeed_ckpt(ckpt, self.backbone.vision_embed.num_position_embeddings())
else:
rank_zero_info("Read state dict from ckpt. ")
state_dict = ckpt
missing_keys, unexpected_keys = self.load_state_dict(state_dict, strict=False)
rank_zero_info("missing_keys: {}".format(missing_keys))
rank_zero_info("unexpected_keys: {}".format(unexpected_keys))
construct_time = time.time() - s_t
print(
f"Process {os.getpid()}. VLMo Constructor time: {construct_time}s;",
f"load_pretrain_time: {load_pretrain_time}s",
flush=True,
)
# coalesce backbone calls
self._coalesce_backbone = config["coalesce_backbone"]
self._mask_data = config["mask_data"]
self._backbone_inputs = {}
self._backbone_inputs_current_size = 0
self._backbone_inputs_keys = {}
self._backbone_outputs = None
self._default_attn_masks = {}
self._itc_group = None
self._itc_aggregate_dict = None
self._itc_mask = config["itc_mask"]
self._local_loss = config["local_loss"]
self._aggregate_nodes = config["aggregate_nodes"]
self.accumulated_batches_reached = False
vlmo_utils.set_task(self)
self._only_itc_single_machine = (
self._aggregate_nodes > 0 and len(self.current_tasks) == 1 and "itc" in self.current_tasks
)
self._split_data_for_imagemlm = config["split_data_for_imagemlm"]
self.log_metric_steps = config["log_metric_steps"]
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def fix_init_weight(self):
def rescale(param, layer_id):
param.div_(math.sqrt(2.0 * layer_id))
for layer_id, layer in enumerate(self.backbone.encoder.layers):
rescale(layer.self_attn.v_proj.A.weight.data, layer_id + 1)
rescale(layer.self_attn.v_proj.B.weight.data, layer_id + 1)
rescale(layer.self_attn.out_proj.A.weight.data, layer_id + 1)
rescale(layer.self_attn.out_proj.B.weight.data, layer_id + 1)
rescale(layer.ffn.A.fc2.weight.data, layer_id + 1)
rescale(layer.ffn.B.fc2.weight.data, layer_id + 1)
if self.use_vl:
pre_layers = len(self.backbone.encoder.layers) + 1
for layer_id, layer in enumerate(self.backbone_vl.layers):
rescale(layer.self_attn.v_proj.A.weight.data, layer_id + pre_layers)
rescale(layer.self_attn.v_proj.B.weight.data, layer_id + pre_layers)
rescale(layer.self_attn.out_proj.A.weight.data, layer_id + pre_layers)
rescale(layer.self_attn.out_proj.B.weight.data, layer_id + pre_layers)
rescale(layer.ffn.A.fc2.weight.data, layer_id + pre_layers)
rescale(layer.ffn.B.fc2.weight.data, layer_id + pre_layers)
def load_pretrained_weight(self):
if self.hparams.config["load_path"] != "" and not self.hparams.config["test_only"]:
config = self.hparams.config
ckpt = torch.load(self.hparams.config["load_path"], map_location="cpu")
rank_zero_info("Load ckpt from: {}".format(self.hparams.config["load_path"]))
state_dict = None
for state_dict_key in ("state_dict", "module", "model"):
if state_dict_key in ckpt:
rank_zero_info("Read state dict from ckpt[%s]. " % state_dict_key)
state_dict = ckpt[state_dict_key]
break
if state_dict_key == "module":
state_dict = convert_deepspeed_ckpt(state_dict, self.backbone.vision_embed.num_position_embeddings())
if state_dict_key == "state_dict":
state_dict = convert_pl_ckpt(state_dict, self.backbone.vision_embed.num_position_embeddings())
if state_dict is None:
if list(ckpt.keys())[0].startswith('_forward_module.'):
rank_zero_info("Read state dict from ckpt with _forward_module prefix. ")
state_dict = convert_deepspeed_ckpt(ckpt,
self.backbone.vision_embed.num_position_embeddings())
else:
rank_zero_info("Read state dict from ckpt. ")
state_dict = ckpt
missing_keys, unexpected_keys = self.load_state_dict(state_dict, strict=False)
missing_keys = [k for k in missing_keys if "itc_teacher" not in k]
rank_zero_info("missing_keys: {}".format(missing_keys))
rank_zero_info("unexpected_keys: {}".format(unexpected_keys))
def infer_text(
self,
batch,
mask_text=False,
):
do_mlm = "_mlm" if mask_text else ""
text_ids = batch[f"text_ids{do_mlm}"]
text_labels = batch[f"text_labels{do_mlm}"]
text_masks = batch[f"text_masks"]
text_embed = self.backbone.text_embed(text_ids)
text_padding_position = 1 - text_masks
lffn_hiddens = self.backbone(
textual_tokens=text_ids,
text_padding_position=text_padding_position,
)["encoder_out"]
vlffn_hiddens = self.backbone_vl(
src_tokens=None,
token_embeddings=lffn_hiddens,
encoder_padding_mask=text_padding_position,
multiway_split_position=-1,
)["encoder_out"]
cls_feats = self.itc_text_proj(lffn_hiddens[:, 0])
cls_feats = cls_feats / cls_feats.norm(dim=-1, keepdim=True)
cls_vlffn_feats = self.itc_vl_text_proj(vlffn_hiddens[:, 0])
cls_vlffn_feats = cls_vlffn_feats / cls_vlffn_feats.norm(dim=-1, keepdim=True)
ret = {
"cls_feats": cls_feats,
"cls_vlffn_feats": cls_vlffn_feats,
"text_embed": text_embed,
}
return ret
def infer_image(
self,
batch,
mask_image=False,
image_token_type_idx=1,
image_embeds=None,
image_masks=None,
):
if f"image_{image_token_type_idx - 1}" in batch:
imgkey = f"image_{image_token_type_idx - 1}"
else:
imgkey = "image"
img = batch[imgkey][0]
if mask_image:
image_masks = batch[f"{imgkey}_masks"][0].flatten(1)
with torch.no_grad():
img = self.visual_tokenizer.pre_process(img)
quantize, embed_ind, _ = self.visual_tokenizer.encode(img)
image_ids = embed_ind.view(img.shape[0], -1)
image_labels = torch.full_like(image_ids, -100)
bool_masked_pos = image_masks.to(torch.bool)
image_labels[bool_masked_pos] = image_ids[bool_masked_pos]
img_tensor = img_norm(img)
vffn_hiddens = self.backbone(visual_tokens=img_tensor)["encoder_out"]
vlffn_hiddens = self.backbone_vl(
src_tokens=None,
token_embeddings=vffn_hiddens,
multiway_split_position=-1,
)["encoder_out"]
cls_feats = self.itc_image_proj(vffn_hiddens[:, 0])
cls_feats = cls_feats / cls_feats.norm(dim=-1, keepdim=True)
cls_vlffn_feats = self.itc_vl_image_proj(vlffn_hiddens[:, 0])
cls_vlffn_feats = cls_vlffn_feats / cls_vlffn_feats.norm(dim=-1, keepdim=True)
ret = {
"image_feats": vffn_hiddens,
"cls_feats": cls_feats,
"cls_vlffn_feats": cls_vlffn_feats,
}
return ret