acai66's picture
Upload folder using huggingface_hub
3440f83 verified
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
import torch.nn as nn
from vlmo.torchscale.component.multihead_attention import MultiheadAttention
from vlmo.torchscale.component.multiway_network import MultiwayNetwork
def init_bert_params(module):
def normal_(data):
data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
if isinstance(module, nn.Linear):
normal_(module.weight.data)
if module.bias is not None:
module.bias.data.zero_()
if isinstance(module, nn.Embedding):
normal_(module.weight.data)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
if isinstance(module, MultiheadAttention):
if isinstance(module.q_proj, MultiwayNetwork):
normal_(module.q_proj.A.weight.data)
normal_(module.q_proj.B.weight.data)
normal_(module.k_proj.A.weight.data)
normal_(module.k_proj.B.weight.data)
normal_(module.v_proj.A.weight.data)
normal_(module.v_proj.B.weight.data)
else:
normal_(module.q_proj.weight.data)
normal_(module.k_proj.weight.data)
normal_(module.v_proj.weight.data)