# Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] import math import torch import torch.nn.functional as F from torch import nn try: from apex.normalization import FusedLayerNorm as LayerNorm except ModuleNotFoundError: from torch.nn import LayerNorm from .multiway_network import MultiwayWrapper from .xpos_relative_position import XPOS class MultiheadAttention(nn.Module): def __init__( self, args, embed_dim, num_heads, dropout=0.0, self_attention=False, encoder_decoder_attention=False, subln=False, one_attn=False, ): super().__init__() self.args = args self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = embed_dim // num_heads self.scaling = self.head_dim ** (-0.5) self.self_attention = self_attention self.encoder_decoder_attention = encoder_decoder_attention assert self.self_attention ^ self.encoder_decoder_attention if one_attn: self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) else: self.k_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True)) self.v_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True)) self.q_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True)) # self.qkv_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim*3, bias=True)) self.out_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True)) self.inner_attn_ln = ( MultiwayWrapper(args, LayerNorm(self.embed_dim, eps=args.layernorm_eps)) if subln and self.self_attention else None ) self.dropout_module = torch.nn.Dropout(dropout) self.xpos = XPOS(self.head_dim, args.xpos_scale_base) if args.xpos_rel_pos and self.self_attention else None def reset_parameters(self): nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) nn.init.xavier_uniform_(self.out_proj.weight) nn.init.constant_(self.out_proj.bias, 0.0) def forward( self, query, key, value, incremental_state=None, key_padding_mask=None, attn_mask=None, rel_pos=None, ): bsz, tgt_len, embed_dim = query.size() src_len = tgt_len assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}" key_bsz, src_len, _ = key.size() assert key_bsz == bsz, f"{query.size(), key.size()}" assert value is not None assert bsz, src_len == value.shape[:2] # if query is key and key is value: # qkv = self.qkv_proj(query) # else: # # W*(q+k+v) = W(q) + W(k) + W(v) # qkv = self.qkv_proj(query+key+value) # q,k,v = qkv.split(self.embed_dim, dim=-1) q = self.q_proj(query) k = self.k_proj(key) v = self.v_proj(value) q = (q * self.scaling).view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) k = k.view(bsz, src_len, self.num_heads, self.head_dim).transpose(1, 2) v = v.view(bsz, src_len, self.num_heads, self.head_dim).transpose(1, 2) q = q.reshape(bsz * self.num_heads, tgt_len, self.head_dim) k = k.reshape(bsz * self.num_heads, src_len, self.head_dim) v = v.reshape(bsz * self.num_heads, src_len, self.head_dim) if incremental_state is not None: if "prev_key" in incremental_state: prev_key = incremental_state["prev_key"].view(bsz * self.num_heads, -1, self.head_dim) prev_value = incremental_state["prev_value"].view(bsz * self.num_heads, -1, self.head_dim) k = torch.cat([prev_key, k], dim=1) v = torch.cat([prev_value, v], dim=1) incremental_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim) incremental_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim) src_len = k.size(1) if self.xpos is not None: if incremental_state is not None: offset = src_len - 1 else: offset = 0 k = self.xpos(k, offset=0, downscale=True) q = self.xpos(q, offset=offset, downscale=False) attn_weights = torch.bmm(q, k.transpose(1, 2)) if attn_mask is not None: attn_weights = torch.nan_to_num(attn_weights) if len(attn_mask.shape) != len(attn_weights.shape): attn_mask = attn_mask.unsqueeze(0) else: attn_mask = attn_mask.repeat_interleave(self.num_heads, dim=0) attn_weights += attn_mask if key_padding_mask is not None: attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf"), ) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) if rel_pos is not None: rel_pos = rel_pos.view(attn_weights.size()) attn_weights = attn_weights + rel_pos attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(attn_weights) attn_probs = self.dropout_module(attn_weights) attn = torch.bmm(attn_probs, v) attn = attn.transpose(0, 1).reshape(tgt_len, bsz, embed_dim).transpose(0, 1) if self.inner_attn_ln is not None: attn = self.inner_attn_ln(attn) attn = self.out_proj(attn) attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0) return attn, attn_weights