Update generator.py
Browse files- generator.py +40 -50
generator.py
CHANGED
@@ -3,57 +3,47 @@ import jax
|
|
3 |
import jax.numpy as jnp
|
4 |
from local_response_norm import LocalResponseNorm
|
5 |
|
6 |
-
LATENT_DIM = 500
|
7 |
EPSILON = 1e-8
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
class Generator(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
@nn.compact
|
11 |
-
def __call__(self,
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
x = nn.BatchNorm(not training)(x)
|
17 |
-
x = nn.relu(x)
|
18 |
-
x = nn.Dropout(0.25, deterministic=not training)(x)
|
19 |
-
x = x.reshape((x.shape[0], 2, 2, -1))
|
20 |
-
x4 = nn.ConvTranspose(features=512, kernel_size=(3, 3), strides=(2, 2))(x)
|
21 |
-
x4 = LocalResponseNorm()(x4)
|
22 |
-
x4 = nn.relu(x4)
|
23 |
-
x4o = nn.ConvTranspose(features=3, kernel_size=(3, 3))(x4)
|
24 |
-
x4 = nn.ConvTranspose(features=512, kernel_size=(3, 3))(x4)
|
25 |
-
x4 = LocalResponseNorm()(x4)
|
26 |
-
x4 = nn.relu(x4)
|
27 |
-
x8 = nn.ConvTranspose(features=256, kernel_size=(3, 3), strides=(2, 2))(x4)
|
28 |
-
x8 = LocalResponseNorm()(x8)
|
29 |
-
x8 = nn.relu(x8)
|
30 |
-
x8o = nn.ConvTranspose(features=3, kernel_size=(3, 3))(x8)
|
31 |
-
x8 = nn.ConvTranspose(features=256, kernel_size=(3, 3))(x8)
|
32 |
-
x8 = LocalResponseNorm()(x8)
|
33 |
-
x8 = nn.relu(x8)
|
34 |
-
x16 = nn.ConvTranspose(features=128, kernel_size=(3, 3), strides=(2, 2))(x8)
|
35 |
-
x16 = LocalResponseNorm()(x16)
|
36 |
-
x16 = nn.relu(x16)
|
37 |
-
x16o = nn.ConvTranspose(features=3, kernel_size=(3, 3))(x16)
|
38 |
-
x16 = nn.ConvTranspose(features=128, kernel_size=(3, 3))(x16)
|
39 |
-
x16 = LocalResponseNorm()(x16)
|
40 |
-
x16 = nn.relu(x16)
|
41 |
-
x32 = nn.ConvTranspose(features=64, kernel_size=(3, 3), strides=(2, 2))(x16)
|
42 |
-
x32 = LocalResponseNorm()(x32)
|
43 |
-
x32 = nn.relu(x32)
|
44 |
-
x32o = nn.ConvTranspose(features=3, kernel_size=(3, 3))(x32)
|
45 |
-
x32 = nn.ConvTranspose(features=64, kernel_size=(3, 3))(x32)
|
46 |
-
x32 = LocalResponseNorm()(x32)
|
47 |
-
x32 = nn.relu(x32)
|
48 |
-
x64 = nn.ConvTranspose(features=32, kernel_size=(3, 3), strides=(2, 2))(x32)
|
49 |
-
x64 = LocalResponseNorm()(x64)
|
50 |
-
x64 = nn.relu(x64)
|
51 |
-
x64o = nn.ConvTranspose(features=3, kernel_size=(3, 3))(x64)
|
52 |
-
x64 = nn.ConvTranspose(features=32, kernel_size=(3, 3))(x64)
|
53 |
-
x64 = LocalResponseNorm()(x64)
|
54 |
-
x64 = nn.relu(x64)
|
55 |
-
x128 = nn.ConvTranspose(features=64, kernel_size=(3, 3), strides=(2, 2))(x64)
|
56 |
-
x128 = LocalResponseNorm()(x128)
|
57 |
-
x128 = nn.relu(x128)
|
58 |
-
x128o = nn.ConvTranspose(features=3, kernel_size=(3, 3))(x128)
|
59 |
-
return (nn.tanh(x128o), nn.tanh(x64o), nn.tanh(x32o), nn.tanh(x16o), nn.tanh(x8o), nn.tanh(x4o))
|
|
|
3 |
import jax.numpy as jnp
|
4 |
from local_response_norm import LocalResponseNorm
|
5 |
|
|
|
6 |
EPSILON = 1e-8
|
7 |
+
MAX_DISC_FEATURES = 128
|
8 |
+
MAX_GEN_FEATURES = 512
|
9 |
+
LATENT_DIM = 512
|
10 |
+
MAX_LAYERS = 7 # 256x256
|
11 |
+
|
12 |
+
def get_gen_layers(layer):
|
13 |
+
resolution = int(4 * 2 ** layer)
|
14 |
+
features = min(int(32 * 2 ** (MAX_LAYERS - 1 - layer)), MAX_GEN_FEATURES)
|
15 |
+
layers = []
|
16 |
+
layers.append(lambda x: jax.image.resize(x, shape=(x.shape[0], resolution, resolution, x.shape[3]), method="linear"))
|
17 |
+
layers.append(lambda x: nn.ConvTranspose(features=features, kernel_size=(3, 3), name=f"ConvTranspose_{resolution}_{features}")(x))
|
18 |
+
layers.append(lambda x: nn.relu(x))
|
19 |
+
return layers
|
20 |
+
|
21 |
+
def get_initial_gen_layers(num_layers):
|
22 |
+
layers = []
|
23 |
+
layers.append(lambda x: x.reshape(x.shape[0], 1, 1, -1))
|
24 |
+
return layers
|
25 |
+
|
26 |
+
def get_final_gen_layers(num_layers):
|
27 |
+
resolution = int(4 * 2 ** (num_layers - 1))
|
28 |
+
layers = []
|
29 |
+
layers.append(lambda x: nn.ConvTranspose(features=3, kernel_size=(3, 3), name=f"ConvTranspose_{resolution}_3")(x))
|
30 |
+
return layers
|
31 |
|
32 |
class Generator(nn.Module):
|
33 |
+
num_layers: int = None
|
34 |
+
|
35 |
+
def setup(self):
|
36 |
+
# 512 => 1x1x512 => 4x4x512 => 8x8x512 => 16x16x512 => 32x32x256 => 64x64x128 => 128x128x64 => 256x256x32 => 256x256x3
|
37 |
+
layers = []
|
38 |
+
layers.extend(get_initial_gen_layers(self.num_layers))
|
39 |
+
for layer in range(self.num_layers):
|
40 |
+
layers.extend(get_gen_layers(layer))
|
41 |
+
layers.extend(get_final_gen_layers(self.num_layers))
|
42 |
+
self.layers = layers
|
43 |
+
|
44 |
@nn.compact
|
45 |
+
def __call__(self, x):
|
46 |
+
result = x
|
47 |
+
for layer in self.layers:
|
48 |
+
result = layer(result)
|
49 |
+
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|