Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,930 Bytes
d90acf0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import math
import torch
from torch import nn, einsum
from einops import rearrange, repeat
from .utils import exist
class Identity(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
@staticmethod
def forward(x, *args, **kwargs):
return x
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=x.device, dtype=x.dtype) * -emb)
emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
return torch.cat((emb.sin(), emb.cos()), dim=-1)
class ConditionalGroupNorm(nn.Module):
def __init__(self, groups, normalized_shape, context_dim):
super().__init__()
self.norm = nn.GroupNorm(groups, normalized_shape, affine=False)
self.context_mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(context_dim, 2 * normalized_shape)
)
self.context_mlp[1].weight.data.zero_()
self.context_mlp[1].bias.data.zero_()
def forward(self, x, context):
context = self.context_mlp(context)
ndims = ' 1' * len(x.shape[2:])
context = rearrange(context, f'b c -> b c{ndims}')
scale, shift = context.chunk(2, dim=1)
x = self.norm(x) * (scale + 1.) + shift
return x
class Attention(nn.Module):
def __init__(self, in_channels, out_channels, context_dim, head_dim=64):
super().__init__()
assert out_channels % head_dim == 0
self.num_heads = out_channels // head_dim
self.scale = head_dim ** -0.5
self.to_query = nn.Linear(in_channels, out_channels, bias=False)
self.to_key = nn.Linear(context_dim, out_channels, bias=False)
self.to_value = nn.Linear(context_dim, out_channels, bias=False)
self.output_layer = nn.Linear(out_channels, out_channels, bias=False)
def forward(self, x, context, context_mask=None):
query = rearrange(self.to_query(x), 'b n (h d) -> b h n d', h=self.num_heads)
key = rearrange(self.to_key(context), 'b n (h d) -> b h n d', h=self.num_heads)
value = rearrange(self.to_value(context), 'b n (h d) -> b h n d', h=self.num_heads)
attention_matrix = einsum('b h i d, b h j d -> b h i j', query, key) * self.scale
if exist(context_mask):
max_neg_value = -torch.finfo(attention_matrix.dtype).max
context_mask = rearrange(context_mask, 'b j -> b 1 1 j')
attention_matrix = attention_matrix.masked_fill(~context_mask, max_neg_value)
attention_matrix = attention_matrix.softmax(dim=-1)
out = einsum('b h i j, b h j d -> b h i d', attention_matrix, value)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.output_layer(out)
return out
|