|
|
|
|
|
|
|
import math |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
class RelativePositionBias(nn.Module): |
|
def __init__(self, bidirectional=True, num_buckets=32, max_distance=128, n_heads=12): |
|
super().__init__() |
|
self.bidirectional = bidirectional |
|
self.num_buckets = num_buckets |
|
self.max_distance = max_distance |
|
self.n_heads = n_heads |
|
self.relative_attention_bias = nn.Embedding(self.num_buckets, self.n_heads) |
|
|
|
@staticmethod |
|
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): |
|
ret = 0 |
|
n = -relative_position |
|
if bidirectional: |
|
num_buckets //= 2 |
|
ret += (n < 0).to(torch.long) * num_buckets |
|
n = torch.abs(n) |
|
else: |
|
n = torch.max(n, torch.zeros_like(n)) |
|
|
|
max_exact = num_buckets // 2 |
|
is_small = n < max_exact |
|
|
|
val_if_large = max_exact + ( |
|
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) |
|
).to(torch.long) |
|
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) |
|
|
|
ret += torch.where(is_small, n, val_if_large) |
|
return ret |
|
|
|
def compute_bias(self, qlen, klen, step=None): |
|
step = 0 if step is None else step |
|
context_position = torch.arange( |
|
step, |
|
step + qlen, |
|
dtype=torch.long, |
|
device=self.relative_attention_bias.weight.device, |
|
)[:, None] |
|
memory_position = torch.arange(klen, dtype=torch.long, device=self.relative_attention_bias.weight.device)[ |
|
None, : |
|
] |
|
relative_position = memory_position - context_position |
|
|
|
rp_bucket = self._relative_position_bucket( |
|
relative_position, |
|
bidirectional=self.bidirectional, |
|
num_buckets=self.num_buckets, |
|
max_distance=self.max_distance, |
|
) |
|
rp_bucket = rp_bucket.to(self.relative_attention_bias.weight.device) |
|
values = self.relative_attention_bias(rp_bucket) |
|
values = values.permute([2, 0, 1]).unsqueeze(0) |
|
return values |
|
|
|
def forward(self, batch_size, qlen, klen, step=None): |
|
|
|
return self.compute_bias(qlen, klen, step).repeat(batch_size, 1, 1, 1).view(-1, qlen, klen) |
|
|