# Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] import math import torch import torch.nn as nn class RelativePositionBias(nn.Module): def __init__(self, bidirectional=True, num_buckets=32, max_distance=128, n_heads=12): super().__init__() self.bidirectional = bidirectional self.num_buckets = num_buckets self.max_distance = max_distance self.n_heads = n_heads self.relative_attention_bias = nn.Embedding(self.num_buckets, self.n_heads) @staticmethod def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): ret = 0 n = -relative_position if bidirectional: num_buckets //= 2 ret += (n < 0).to(torch.long) * num_buckets n = torch.abs(n) else: n = torch.max(n, torch.zeros_like(n)) max_exact = num_buckets // 2 is_small = n < max_exact val_if_large = max_exact + ( torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) ).to(torch.long) val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) ret += torch.where(is_small, n, val_if_large) return ret def compute_bias(self, qlen, klen, step=None): step = 0 if step is None else step context_position = torch.arange( step, step + qlen, dtype=torch.long, device=self.relative_attention_bias.weight.device, )[:, None] memory_position = torch.arange(klen, dtype=torch.long, device=self.relative_attention_bias.weight.device)[ None, : ] relative_position = memory_position - context_position # shape (qlen, klen) rp_bucket = self._relative_position_bucket( relative_position, # shape (qlen, klen) bidirectional=self.bidirectional, num_buckets=self.num_buckets, max_distance=self.max_distance, ) rp_bucket = rp_bucket.to(self.relative_attention_bias.weight.device) values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads) values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, qlen, klen) return values def forward(self, batch_size, qlen, klen, step=None): # shape (batch * num_heads, qlen, klen) return self.compute_bias(qlen, klen, step).repeat(batch_size, 1, 1, 1).view(-1, qlen, klen)