from flax import linen as nn import jax import jax.numpy as jnp from local_response_norm import LocalResponseNorm EPSILON = 1e-8 MAX_DISC_FEATURES = 128 MAX_GEN_FEATURES = 512 LATENT_DIM = 512 MAX_LAYERS = 7 # 256x256 def get_gen_layers(layer): resolution = int(4 * 2 ** layer) features = min(int(32 * 2 ** (MAX_LAYERS - 1 - layer)), MAX_GEN_FEATURES) layers = [] layers.append(lambda x: jax.image.resize(x, shape=(x.shape[0], resolution, resolution, x.shape[3]), method="linear")) layers.append(lambda x: nn.ConvTranspose(features=features, kernel_size=(3, 3), name=f"ConvTranspose_{resolution}_{features}")(x)) layers.append(lambda x: nn.relu(x)) return layers def get_initial_gen_layers(num_layers): layers = [] layers.append(lambda x: x.reshape(x.shape[0], 1, 1, -1)) return layers def get_final_gen_layers(num_layers): resolution = int(4 * 2 ** (num_layers - 1)) layers = [] layers.append(lambda x: nn.ConvTranspose(features=3, kernel_size=(3, 3), name=f"ConvTranspose_{resolution}_3")(x)) return layers class Generator(nn.Module): num_layers: int = None def setup(self): # 512 => 1x1x512 => 4x4x512 => 8x8x512 => 16x16x512 => 32x32x256 => 64x64x128 => 128x128x64 => 256x256x32 => 256x256x3 layers = [] layers.extend(get_initial_gen_layers(self.num_layers)) for layer in range(self.num_layers): layers.extend(get_gen_layers(layer)) layers.extend(get_final_gen_layers(self.num_layers)) self.layers = layers @nn.compact def __call__(self, x): result = x for layer in self.layers: result = layer(result) return result