File size: 2,584 Bytes
3440f83 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
import math
import torch
import torch.nn as nn
class RelativePositionBias(nn.Module):
def __init__(self, bidirectional=True, num_buckets=32, max_distance=128, n_heads=12):
super().__init__()
self.bidirectional = bidirectional
self.num_buckets = num_buckets
self.max_distance = max_distance
self.n_heads = n_heads
self.relative_attention_bias = nn.Embedding(self.num_buckets, self.n_heads)
@staticmethod
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
ret = 0
n = -relative_position
if bidirectional:
num_buckets //= 2
ret += (n < 0).to(torch.long) * num_buckets
n = torch.abs(n)
else:
n = torch.max(n, torch.zeros_like(n))
max_exact = num_buckets // 2
is_small = n < max_exact
val_if_large = max_exact + (
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
).to(torch.long)
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
ret += torch.where(is_small, n, val_if_large)
return ret
def compute_bias(self, qlen, klen, step=None):
step = 0 if step is None else step
context_position = torch.arange(
step,
step + qlen,
dtype=torch.long,
device=self.relative_attention_bias.weight.device,
)[:, None]
memory_position = torch.arange(klen, dtype=torch.long, device=self.relative_attention_bias.weight.device)[
None, :
]
relative_position = memory_position - context_position # shape (qlen, klen)
rp_bucket = self._relative_position_bucket(
relative_position, # shape (qlen, klen)
bidirectional=self.bidirectional,
num_buckets=self.num_buckets,
max_distance=self.max_distance,
)
rp_bucket = rp_bucket.to(self.relative_attention_bias.weight.device)
values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads)
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, qlen, klen)
return values
def forward(self, batch_size, qlen, klen, step=None):
# shape (batch * num_heads, qlen, klen)
return self.compute_bias(qlen, klen, step).repeat(batch_size, 1, 1, 1).view(-1, qlen, klen)
|