|
|
|
|
|
|
|
import torch.nn as nn |
|
|
|
from vlmo.torchscale.component.multihead_attention import MultiheadAttention |
|
from vlmo.torchscale.component.multiway_network import MultiwayNetwork |
|
|
|
|
|
def init_bert_params(module): |
|
def normal_(data): |
|
data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device)) |
|
|
|
if isinstance(module, nn.Linear): |
|
normal_(module.weight.data) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
if isinstance(module, nn.Embedding): |
|
normal_(module.weight.data) |
|
if module.padding_idx is not None: |
|
module.weight.data[module.padding_idx].zero_() |
|
if isinstance(module, MultiheadAttention): |
|
if isinstance(module.q_proj, MultiwayNetwork): |
|
normal_(module.q_proj.A.weight.data) |
|
normal_(module.q_proj.B.weight.data) |
|
normal_(module.k_proj.A.weight.data) |
|
normal_(module.k_proj.B.weight.data) |
|
normal_(module.v_proj.A.weight.data) |
|
normal_(module.v_proj.B.weight.data) |
|
else: |
|
normal_(module.q_proj.weight.data) |
|
normal_(module.k_proj.weight.data) |
|
normal_(module.v_proj.weight.data) |
|
|