|
|
|
|
|
|
|
import math |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
|
|
try: |
|
from apex.normalization import FusedLayerNorm as LayerNorm |
|
except ModuleNotFoundError: |
|
from torch.nn import LayerNorm |
|
|
|
from .multiway_network import MultiwayWrapper |
|
from .xpos_relative_position import XPOS |
|
|
|
|
|
class MultiheadAttention(nn.Module): |
|
def __init__( |
|
self, |
|
args, |
|
embed_dim, |
|
num_heads, |
|
dropout=0.0, |
|
self_attention=False, |
|
encoder_decoder_attention=False, |
|
subln=False, |
|
one_attn=False, |
|
): |
|
super().__init__() |
|
self.args = args |
|
self.embed_dim = embed_dim |
|
self.num_heads = num_heads |
|
self.head_dim = embed_dim // num_heads |
|
self.scaling = self.head_dim ** (-0.5) |
|
self.self_attention = self_attention |
|
self.encoder_decoder_attention = encoder_decoder_attention |
|
assert self.self_attention ^ self.encoder_decoder_attention |
|
if one_attn: |
|
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True) |
|
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True) |
|
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True) |
|
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) |
|
else: |
|
self.k_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True)) |
|
self.v_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True)) |
|
self.q_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True)) |
|
|
|
self.out_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True)) |
|
self.inner_attn_ln = ( |
|
MultiwayWrapper(args, LayerNorm(self.embed_dim, eps=args.layernorm_eps)) |
|
if subln and self.self_attention |
|
else None |
|
) |
|
self.dropout_module = torch.nn.Dropout(dropout) |
|
self.xpos = XPOS(self.head_dim, args.xpos_scale_base) if args.xpos_rel_pos and self.self_attention else None |
|
|
|
def reset_parameters(self): |
|
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) |
|
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) |
|
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) |
|
nn.init.xavier_uniform_(self.out_proj.weight) |
|
nn.init.constant_(self.out_proj.bias, 0.0) |
|
|
|
def forward( |
|
self, |
|
query, |
|
key, |
|
value, |
|
incremental_state=None, |
|
key_padding_mask=None, |
|
attn_mask=None, |
|
rel_pos=None, |
|
): |
|
bsz, tgt_len, embed_dim = query.size() |
|
src_len = tgt_len |
|
assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}" |
|
|
|
key_bsz, src_len, _ = key.size() |
|
assert key_bsz == bsz, f"{query.size(), key.size()}" |
|
assert value is not None |
|
assert bsz, src_len == value.shape[:2] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
q = self.q_proj(query) |
|
k = self.k_proj(key) |
|
v = self.v_proj(value) |
|
|
|
q = (q * self.scaling).view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
k = k.view(bsz, src_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
v = v.view(bsz, src_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
q = q.reshape(bsz * self.num_heads, tgt_len, self.head_dim) |
|
k = k.reshape(bsz * self.num_heads, src_len, self.head_dim) |
|
v = v.reshape(bsz * self.num_heads, src_len, self.head_dim) |
|
|
|
if incremental_state is not None: |
|
if "prev_key" in incremental_state: |
|
prev_key = incremental_state["prev_key"].view(bsz * self.num_heads, -1, self.head_dim) |
|
prev_value = incremental_state["prev_value"].view(bsz * self.num_heads, -1, self.head_dim) |
|
k = torch.cat([prev_key, k], dim=1) |
|
v = torch.cat([prev_value, v], dim=1) |
|
incremental_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim) |
|
incremental_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim) |
|
src_len = k.size(1) |
|
|
|
if self.xpos is not None: |
|
if incremental_state is not None: |
|
offset = src_len - 1 |
|
else: |
|
offset = 0 |
|
k = self.xpos(k, offset=0, downscale=True) |
|
q = self.xpos(q, offset=offset, downscale=False) |
|
|
|
attn_weights = torch.bmm(q, k.transpose(1, 2)) |
|
|
|
if attn_mask is not None: |
|
attn_weights = torch.nan_to_num(attn_weights) |
|
if len(attn_mask.shape) != len(attn_weights.shape): |
|
attn_mask = attn_mask.unsqueeze(0) |
|
else: |
|
attn_mask = attn_mask.repeat_interleave(self.num_heads, dim=0) |
|
attn_weights += attn_mask |
|
|
|
if key_padding_mask is not None: |
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) |
|
attn_weights = attn_weights.masked_fill( |
|
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), |
|
float("-inf"), |
|
) |
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) |
|
|
|
if rel_pos is not None: |
|
rel_pos = rel_pos.view(attn_weights.size()) |
|
attn_weights = attn_weights + rel_pos |
|
|
|
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(attn_weights) |
|
attn_probs = self.dropout_module(attn_weights) |
|
|
|
attn = torch.bmm(attn_probs, v) |
|
attn = attn.transpose(0, 1).reshape(tgt_len, bsz, embed_dim).transpose(0, 1) |
|
|
|
if self.inner_attn_ln is not None: |
|
attn = self.inner_attn_ln(attn) |
|
|
|
attn = self.out_proj(attn) |
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0) |
|
|
|
return attn, attn_weights |
|
|