AIPlane3 / generator.py
PrakhAI's picture
Update generator.py
b0990de
from flax import linen as nn
import jax
import jax.numpy as jnp
from local_response_norm import LocalResponseNorm
EPSILON = 1e-8
MAX_DISC_FEATURES = 128
MAX_GEN_FEATURES = 512
LATENT_DIM = 512
MAX_LAYERS = 7 # 256x256
def get_gen_layers(layer):
resolution = int(4 * 2 ** layer)
features = min(int(32 * 2 ** (MAX_LAYERS - 1 - layer)), MAX_GEN_FEATURES)
layers = []
layers.append(lambda x: jax.image.resize(x, shape=(x.shape[0], resolution, resolution, x.shape[3]), method="linear"))
layers.append(lambda x: nn.ConvTranspose(features=features, kernel_size=(3, 3), name=f"ConvTranspose_{resolution}_{features}")(x))
layers.append(lambda x: nn.relu(x))
return layers
def get_initial_gen_layers(num_layers):
layers = []
layers.append(lambda x: x.reshape(x.shape[0], 1, 1, -1))
return layers
def get_final_gen_layers(num_layers):
resolution = int(4 * 2 ** (num_layers - 1))
layers = []
layers.append(lambda x: nn.ConvTranspose(features=3, kernel_size=(3, 3), name=f"ConvTranspose_{resolution}_3")(x))
return layers
class Generator(nn.Module):
num_layers: int = None
def setup(self):
# 512 => 1x1x512 => 4x4x512 => 8x8x512 => 16x16x512 => 32x32x256 => 64x64x128 => 128x128x64 => 256x256x32 => 256x256x3
layers = []
layers.extend(get_initial_gen_layers(self.num_layers))
for layer in range(self.num_layers):
layers.extend(get_gen_layers(layer))
layers.extend(get_final_gen_layers(self.num_layers))
self.layers = layers
@nn.compact
def __call__(self, x):
result = x
for layer in self.layers:
result = layer(result)
return result