|
from flax import linen as nn |
|
import jax |
|
import jax.numpy as jnp |
|
from local_response_norm import LocalResponseNorm |
|
|
|
EPSILON = 1e-8 |
|
MAX_DISC_FEATURES = 128 |
|
MAX_GEN_FEATURES = 512 |
|
LATENT_DIM = 512 |
|
MAX_LAYERS = 7 |
|
|
|
def get_gen_layers(layer): |
|
resolution = int(4 * 2 ** layer) |
|
features = min(int(32 * 2 ** (MAX_LAYERS - 1 - layer)), MAX_GEN_FEATURES) |
|
layers = [] |
|
layers.append(lambda x: jax.image.resize(x, shape=(x.shape[0], resolution, resolution, x.shape[3]), method="linear")) |
|
layers.append(lambda x: nn.ConvTranspose(features=features, kernel_size=(3, 3), name=f"ConvTranspose_{resolution}_{features}")(x)) |
|
layers.append(lambda x: nn.relu(x)) |
|
return layers |
|
|
|
def get_initial_gen_layers(num_layers): |
|
layers = [] |
|
layers.append(lambda x: x.reshape(x.shape[0], 1, 1, -1)) |
|
return layers |
|
|
|
def get_final_gen_layers(num_layers): |
|
resolution = int(4 * 2 ** (num_layers - 1)) |
|
layers = [] |
|
layers.append(lambda x: nn.ConvTranspose(features=3, kernel_size=(3, 3), name=f"ConvTranspose_{resolution}_3")(x)) |
|
return layers |
|
|
|
class Generator(nn.Module): |
|
num_layers: int = None |
|
|
|
def setup(self): |
|
|
|
layers = [] |
|
layers.extend(get_initial_gen_layers(self.num_layers)) |
|
for layer in range(self.num_layers): |
|
layers.extend(get_gen_layers(layer)) |
|
layers.extend(get_final_gen_layers(self.num_layers)) |
|
self.layers = layers |
|
|
|
@nn.compact |
|
def __call__(self, x): |
|
result = x |
|
for layer in self.layers: |
|
result = layer(result) |
|
return result |