PrakhAI commited on
Commit
b0990de
·
1 Parent(s): c571337

Update generator.py

Browse files
Files changed (1) hide show
  1. generator.py +40 -50
generator.py CHANGED
@@ -3,57 +3,47 @@ import jax
3
  import jax.numpy as jnp
4
  from local_response_norm import LocalResponseNorm
5
 
6
- LATENT_DIM = 500
7
  EPSILON = 1e-8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  class Generator(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
10
  @nn.compact
11
- def __call__(self, latent, training=True):
12
- x = nn.Dense(features=64)(latent)
13
- # x = nn.BatchNorm(not training)(x)
14
- x = nn.relu(x)
15
- x = nn.Dense(features=2*2*1024)(x)
16
- x = nn.BatchNorm(not training)(x)
17
- x = nn.relu(x)
18
- x = nn.Dropout(0.25, deterministic=not training)(x)
19
- x = x.reshape((x.shape[0], 2, 2, -1))
20
- x4 = nn.ConvTranspose(features=512, kernel_size=(3, 3), strides=(2, 2))(x)
21
- x4 = LocalResponseNorm()(x4)
22
- x4 = nn.relu(x4)
23
- x4o = nn.ConvTranspose(features=3, kernel_size=(3, 3))(x4)
24
- x4 = nn.ConvTranspose(features=512, kernel_size=(3, 3))(x4)
25
- x4 = LocalResponseNorm()(x4)
26
- x4 = nn.relu(x4)
27
- x8 = nn.ConvTranspose(features=256, kernel_size=(3, 3), strides=(2, 2))(x4)
28
- x8 = LocalResponseNorm()(x8)
29
- x8 = nn.relu(x8)
30
- x8o = nn.ConvTranspose(features=3, kernel_size=(3, 3))(x8)
31
- x8 = nn.ConvTranspose(features=256, kernel_size=(3, 3))(x8)
32
- x8 = LocalResponseNorm()(x8)
33
- x8 = nn.relu(x8)
34
- x16 = nn.ConvTranspose(features=128, kernel_size=(3, 3), strides=(2, 2))(x8)
35
- x16 = LocalResponseNorm()(x16)
36
- x16 = nn.relu(x16)
37
- x16o = nn.ConvTranspose(features=3, kernel_size=(3, 3))(x16)
38
- x16 = nn.ConvTranspose(features=128, kernel_size=(3, 3))(x16)
39
- x16 = LocalResponseNorm()(x16)
40
- x16 = nn.relu(x16)
41
- x32 = nn.ConvTranspose(features=64, kernel_size=(3, 3), strides=(2, 2))(x16)
42
- x32 = LocalResponseNorm()(x32)
43
- x32 = nn.relu(x32)
44
- x32o = nn.ConvTranspose(features=3, kernel_size=(3, 3))(x32)
45
- x32 = nn.ConvTranspose(features=64, kernel_size=(3, 3))(x32)
46
- x32 = LocalResponseNorm()(x32)
47
- x32 = nn.relu(x32)
48
- x64 = nn.ConvTranspose(features=32, kernel_size=(3, 3), strides=(2, 2))(x32)
49
- x64 = LocalResponseNorm()(x64)
50
- x64 = nn.relu(x64)
51
- x64o = nn.ConvTranspose(features=3, kernel_size=(3, 3))(x64)
52
- x64 = nn.ConvTranspose(features=32, kernel_size=(3, 3))(x64)
53
- x64 = LocalResponseNorm()(x64)
54
- x64 = nn.relu(x64)
55
- x128 = nn.ConvTranspose(features=64, kernel_size=(3, 3), strides=(2, 2))(x64)
56
- x128 = LocalResponseNorm()(x128)
57
- x128 = nn.relu(x128)
58
- x128o = nn.ConvTranspose(features=3, kernel_size=(3, 3))(x128)
59
- return (nn.tanh(x128o), nn.tanh(x64o), nn.tanh(x32o), nn.tanh(x16o), nn.tanh(x8o), nn.tanh(x4o))
 
3
  import jax.numpy as jnp
4
  from local_response_norm import LocalResponseNorm
5
 
 
6
  EPSILON = 1e-8
7
+ MAX_DISC_FEATURES = 128
8
+ MAX_GEN_FEATURES = 512
9
+ LATENT_DIM = 512
10
+ MAX_LAYERS = 7 # 256x256
11
+
12
+ def get_gen_layers(layer):
13
+ resolution = int(4 * 2 ** layer)
14
+ features = min(int(32 * 2 ** (MAX_LAYERS - 1 - layer)), MAX_GEN_FEATURES)
15
+ layers = []
16
+ layers.append(lambda x: jax.image.resize(x, shape=(x.shape[0], resolution, resolution, x.shape[3]), method="linear"))
17
+ layers.append(lambda x: nn.ConvTranspose(features=features, kernel_size=(3, 3), name=f"ConvTranspose_{resolution}_{features}")(x))
18
+ layers.append(lambda x: nn.relu(x))
19
+ return layers
20
+
21
+ def get_initial_gen_layers(num_layers):
22
+ layers = []
23
+ layers.append(lambda x: x.reshape(x.shape[0], 1, 1, -1))
24
+ return layers
25
+
26
+ def get_final_gen_layers(num_layers):
27
+ resolution = int(4 * 2 ** (num_layers - 1))
28
+ layers = []
29
+ layers.append(lambda x: nn.ConvTranspose(features=3, kernel_size=(3, 3), name=f"ConvTranspose_{resolution}_3")(x))
30
+ return layers
31
 
32
  class Generator(nn.Module):
33
+ num_layers: int = None
34
+
35
+ def setup(self):
36
+ # 512 => 1x1x512 => 4x4x512 => 8x8x512 => 16x16x512 => 32x32x256 => 64x64x128 => 128x128x64 => 256x256x32 => 256x256x3
37
+ layers = []
38
+ layers.extend(get_initial_gen_layers(self.num_layers))
39
+ for layer in range(self.num_layers):
40
+ layers.extend(get_gen_layers(layer))
41
+ layers.extend(get_final_gen_layers(self.num_layers))
42
+ self.layers = layers
43
+
44
  @nn.compact
45
+ def __call__(self, x):
46
+ result = x
47
+ for layer in self.layers:
48
+ result = layer(result)
49
+ return result