File size: 1,610 Bytes
cc63bcb
 
 
 
 
 
b0990de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc63bcb
 
b0990de
 
 
 
 
 
 
 
 
 
 
cc63bcb
b0990de
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from flax import linen as nn
import jax
import jax.numpy as jnp
from local_response_norm import LocalResponseNorm

EPSILON = 1e-8
MAX_DISC_FEATURES = 128
MAX_GEN_FEATURES = 512
LATENT_DIM = 512
MAX_LAYERS = 7 # 256x256

def get_gen_layers(layer):
  resolution = int(4 * 2 ** layer)
  features = min(int(32 * 2 ** (MAX_LAYERS - 1 - layer)), MAX_GEN_FEATURES)
  layers = []
  layers.append(lambda x: jax.image.resize(x, shape=(x.shape[0], resolution, resolution, x.shape[3]), method="linear"))
  layers.append(lambda x: nn.ConvTranspose(features=features, kernel_size=(3, 3), name=f"ConvTranspose_{resolution}_{features}")(x))
  layers.append(lambda x: nn.relu(x))
  return layers

def get_initial_gen_layers(num_layers):
  layers = []
  layers.append(lambda x: x.reshape(x.shape[0], 1, 1, -1))
  return layers

def get_final_gen_layers(num_layers):
  resolution = int(4 * 2 ** (num_layers - 1))
  layers = []
  layers.append(lambda x: nn.ConvTranspose(features=3, kernel_size=(3, 3), name=f"ConvTranspose_{resolution}_3")(x))
  return layers

class Generator(nn.Module):
  num_layers: int = None

  def setup(self):
    # 512 => 1x1x512 => 4x4x512 => 8x8x512 => 16x16x512 => 32x32x256 => 64x64x128 => 128x128x64 => 256x256x32 => 256x256x3
    layers = []
    layers.extend(get_initial_gen_layers(self.num_layers))
    for layer in range(self.num_layers):
      layers.extend(get_gen_layers(layer))
    layers.extend(get_final_gen_layers(self.num_layers))
    self.layers = layers

  @nn.compact
  def __call__(self, x):
    result = x
    for layer in self.layers:
      result = layer(result)
    return result