File size: 2,930 Bytes
d90acf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import math

import torch
from torch import nn, einsum
from einops import rearrange, repeat

from .utils import exist


class Identity(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()

    @staticmethod
    def forward(x, *args, **kwargs):
        return x


class SinusoidalPosEmb(nn.Module):

    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=x.device, dtype=x.dtype) * -emb)
        emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
        return torch.cat((emb.sin(), emb.cos()), dim=-1)


class ConditionalGroupNorm(nn.Module):

    def __init__(self, groups, normalized_shape, context_dim):
        super().__init__()
        self.norm = nn.GroupNorm(groups, normalized_shape, affine=False)
        self.context_mlp = nn.Sequential(
            nn.SiLU(),
            nn.Linear(context_dim, 2 * normalized_shape)
        )
        self.context_mlp[1].weight.data.zero_()
        self.context_mlp[1].bias.data.zero_()

    def forward(self, x, context):
        context = self.context_mlp(context)
        ndims = ' 1' * len(x.shape[2:])
        context = rearrange(context, f'b c -> b c{ndims}')

        scale, shift = context.chunk(2, dim=1)
        x = self.norm(x) * (scale + 1.) + shift
        return x


class Attention(nn.Module):

    def __init__(self, in_channels, out_channels, context_dim, head_dim=64):
        super().__init__()
        assert out_channels % head_dim == 0
        self.num_heads = out_channels // head_dim
        self.scale = head_dim ** -0.5

        self.to_query = nn.Linear(in_channels, out_channels, bias=False)
        self.to_key = nn.Linear(context_dim, out_channels, bias=False)
        self.to_value = nn.Linear(context_dim, out_channels, bias=False)

        self.output_layer = nn.Linear(out_channels, out_channels, bias=False)

    def forward(self, x, context, context_mask=None):
        query = rearrange(self.to_query(x), 'b n (h d) -> b h n d', h=self.num_heads)
        key = rearrange(self.to_key(context), 'b n (h d) -> b h n d', h=self.num_heads)
        value = rearrange(self.to_value(context), 'b n (h d) -> b h n d', h=self.num_heads)

        attention_matrix = einsum('b h i d, b h j d -> b h i j', query, key) * self.scale
        if exist(context_mask):
            max_neg_value = -torch.finfo(attention_matrix.dtype).max
            context_mask = rearrange(context_mask, 'b j -> b 1 1 j')
            attention_matrix = attention_matrix.masked_fill(~context_mask, max_neg_value)
        attention_matrix = attention_matrix.softmax(dim=-1)

        out = einsum('b h i j, b h j d -> b h i d', attention_matrix, value)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out = self.output_layer(out)
        return out