File size: 3,701 Bytes
14ce5a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from einops import rearrange, repeat
from einops.layers.torch import Rearrange


class SineLayer(nn.Module):
    """
    Paper: Implicit Neural Representation with Periodic Activ ation Function (SIREN)
    """

    def __init__(
        self, in_features, out_features, bias=True, is_first=False, omega_0=30
    ):
        super().__init__()
        self.omega_0 = omega_0
        self.is_first = is_first

        self.in_features = in_features
        self.linear = nn.Linear(in_features, out_features, bias=bias)

        self.init_weights()

    def init_weights(self):
        with torch.no_grad():
            if self.is_first:
                self.linear.weight.uniform_(-1 / self.in_features, 1 / self.in_features)
            else:
                self.linear.weight.uniform_(
                    -np.sqrt(6 / self.in_features) / self.omega_0,
                    np.sqrt(6 / self.in_features) / self.omega_0,
                )

    def forward(self, input):
        return torch.sin(self.omega_0 * self.linear(input))


class ToPixel(nn.Module):
    def __init__(
        self, to_pixel="linear", img_size=256, in_channels=3, in_dim=512, patch_size=16
    ) -> None:
        super().__init__()
        self.to_pixel_name = to_pixel
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.in_channels = in_channels
        if to_pixel == "linear":
            self.model = nn.Linear(in_dim, in_channels * patch_size * patch_size)
        elif to_pixel == "conv":
            self.model = nn.Sequential(
                Rearrange("b (h w) c -> b c h w", h=img_size // patch_size),
                nn.ConvTranspose2d(
                    in_dim, in_channels, kernel_size=patch_size, stride=patch_size
                ),
            )
        elif to_pixel == "siren":
            self.model = nn.Sequential(
                SineLayer(in_dim, in_dim * 2, is_first=True, omega_0=30.0),
                SineLayer(
                    in_dim * 2,
                    img_size // patch_size * patch_size * in_channels,
                    is_first=False,
                    omega_0=30,
                ),
            )
        elif to_pixel == "identity":
            self.model = nn.Identity()
        else:
            raise NotImplementedError

    def get_last_layer(self):
        if self.to_pixel_name == "linear":
            return self.model.weight
        elif self.to_pixel_name == "siren":
            return self.model[1].linear.weight
        elif self.to_pixel_name == "conv":
            return self.model[1].weight
        else:
            return None

    def unpatchify(self, x):
        """
        x: (N, L, patch_size**2 *3)
        imgs: (N, 3, H, W)
        """
        p = self.patch_size
        h = w = int(x.shape[1] ** 0.5)
        assert h * w == x.shape[1]
        x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
        x = torch.einsum("nhwpqc->nchpwq", x)
        imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
        return imgs

    def forward(self, x):
        if self.to_pixel_name == "linear":
            x = self.model(x)
            x = self.unpatchify(x)
        elif self.to_pixel_name == "siren":
            x = self.model(x)
            x = x.view(
                x.shape[0],
                self.in_channels,
                self.patch_size * int(self.num_patches**0.5),
                self.patch_size * int(self.num_patches**0.5),
            )
        elif self.to_pixel_name == "conv":
            x = self.model(x)
        elif self.to_pixel_name == "identity":
            pass
        return x