File size: 1,270 Bytes
3440f83 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
import torch.nn as nn
from vlmo.torchscale.component.multihead_attention import MultiheadAttention
from vlmo.torchscale.component.multiway_network import MultiwayNetwork
def init_bert_params(module):
def normal_(data):
data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
if isinstance(module, nn.Linear):
normal_(module.weight.data)
if module.bias is not None:
module.bias.data.zero_()
if isinstance(module, nn.Embedding):
normal_(module.weight.data)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
if isinstance(module, MultiheadAttention):
if isinstance(module.q_proj, MultiwayNetwork):
normal_(module.q_proj.A.weight.data)
normal_(module.q_proj.B.weight.data)
normal_(module.k_proj.A.weight.data)
normal_(module.k_proj.B.weight.data)
normal_(module.v_proj.A.weight.data)
normal_(module.v_proj.B.weight.data)
else:
normal_(module.q_proj.weight.data)
normal_(module.k_proj.weight.data)
normal_(module.v_proj.weight.data)
|