# Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] import torch.nn as nn from vlmo.torchscale.component.multihead_attention import MultiheadAttention from vlmo.torchscale.component.multiway_network import MultiwayNetwork def init_bert_params(module): def normal_(data): data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device)) if isinstance(module, nn.Linear): normal_(module.weight.data) if module.bias is not None: module.bias.data.zero_() if isinstance(module, nn.Embedding): normal_(module.weight.data) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() if isinstance(module, MultiheadAttention): if isinstance(module.q_proj, MultiwayNetwork): normal_(module.q_proj.A.weight.data) normal_(module.q_proj.B.weight.data) normal_(module.k_proj.A.weight.data) normal_(module.k_proj.B.weight.data) normal_(module.v_proj.A.weight.data) normal_(module.v_proj.B.weight.data) else: normal_(module.q_proj.weight.data) normal_(module.k_proj.weight.data) normal_(module.v_proj.weight.data)