PrakhAI commited on
Commit
c77e6c8
·
1 Parent(s): d9e2ba0

Create generator.py

Browse files
Files changed (1) hide show
  1. generator.py +54 -0
generator.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LATENT_DIM = 500
2
+ EPSILON = 1e-8
3
+
4
+ class Generator(nn.Module):
5
+ @nn.compact
6
+ def __call__(self, latent, training=True):
7
+ x = nn.Dense(features=64)(latent)
8
+ # x = nn.BatchNorm(not training)(x)
9
+ x = nn.relu(x)
10
+ x = nn.Dense(features=2*2*1024)(x)
11
+ x = nn.BatchNorm(not training)(x)
12
+ x = nn.relu(x)
13
+ x = nn.Dropout(0.25, deterministic=not training)(x)
14
+ x = x.reshape((x.shape[0], 2, 2, -1))
15
+ x4 = nn.ConvTranspose(features=512, kernel_size=(3, 3), strides=(2, 2))(x)
16
+ x4 = LocalResponseNorm()(x4)
17
+ x4 = nn.relu(x4)
18
+ x4o = nn.ConvTranspose(features=3, kernel_size=(3, 3))(x4)
19
+ x4 = nn.ConvTranspose(features=512, kernel_size=(3, 3))(x4)
20
+ x4 = LocalResponseNorm()(x4)
21
+ x4 = nn.relu(x4)
22
+ x8 = nn.ConvTranspose(features=256, kernel_size=(3, 3), strides=(2, 2))(x4)
23
+ x8 = LocalResponseNorm()(x8)
24
+ x8 = nn.relu(x8)
25
+ x8o = nn.ConvTranspose(features=3, kernel_size=(3, 3))(x8)
26
+ x8 = nn.ConvTranspose(features=256, kernel_size=(3, 3))(x8)
27
+ x8 = LocalResponseNorm()(x8)
28
+ x8 = nn.relu(x8)
29
+ x16 = nn.ConvTranspose(features=128, kernel_size=(3, 3), strides=(2, 2))(x8)
30
+ x16 = LocalResponseNorm()(x16)
31
+ x16 = nn.relu(x16)
32
+ x16o = nn.ConvTranspose(features=3, kernel_size=(3, 3))(x16)
33
+ x16 = nn.ConvTranspose(features=128, kernel_size=(3, 3))(x16)
34
+ x16 = LocalResponseNorm()(x16)
35
+ x16 = nn.relu(x16)
36
+ x32 = nn.ConvTranspose(features=64, kernel_size=(3, 3), strides=(2, 2))(x16)
37
+ x32 = LocalResponseNorm()(x32)
38
+ x32 = nn.relu(x32)
39
+ x32o = nn.ConvTranspose(features=3, kernel_size=(3, 3))(x32)
40
+ x32 = nn.ConvTranspose(features=64, kernel_size=(3, 3))(x32)
41
+ x32 = LocalResponseNorm()(x32)
42
+ x32 = nn.relu(x32)
43
+ x64 = nn.ConvTranspose(features=32, kernel_size=(3, 3), strides=(2, 2))(x32)
44
+ x64 = LocalResponseNorm()(x64)
45
+ x64 = nn.relu(x64)
46
+ x64o = nn.ConvTranspose(features=3, kernel_size=(3, 3))(x64)
47
+ x64 = nn.ConvTranspose(features=32, kernel_size=(3, 3))(x64)
48
+ x64 = LocalResponseNorm()(x64)
49
+ x64 = nn.relu(x64)
50
+ x128 = nn.ConvTranspose(features=64, kernel_size=(3, 3), strides=(2, 2))(x64)
51
+ x128 = LocalResponseNorm()(x128)
52
+ x128 = nn.relu(x128)
53
+ x128o = nn.ConvTranspose(features=3, kernel_size=(3, 3))(x128)
54
+ return (nn.tanh(x128o), nn.tanh(x64o), nn.tanh(x32o), nn.tanh(x16o), nn.tanh(x8o), nn.tanh(x4o))