PrakhAI commited on
Commit
cc63bcb
·
0 Parent(s):

Duplicate from PrakhAI/AIPlane2

Browse files
Files changed (7) hide show
  1. .gitattributes +35 -0
  2. README.md +13 -0
  3. __init__.py +0 -0
  4. app.py +52 -0
  5. generator.py +59 -0
  6. local_response_norm.py +11 -0
  7. requirements.txt +1 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: AIPlane2
3
+ emoji: 🌖
4
+ colorFrom: green
5
+ colorTo: indigo
6
+ sdk: streamlit
7
+ sdk_version: 1.25.0
8
+ app_file: app.py
9
+ pinned: false
10
+ duplicated_from: PrakhAI/AIPlane2
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
__init__.py ADDED
File without changes
app.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ import jax
4
+ import jax.numpy as jnp # JAX NumPy
5
+ import numpy as np
6
+ from huggingface_hub import HfFileSystem
7
+ from flax.serialization import msgpack_restore, from_state_dict
8
+ import time
9
+ from generator import Generator, LATENT_DIM
10
+ import math
11
+
12
+ generator = Generator()
13
+ variables = generator.init(jax.random.PRNGKey(0), jnp.zeros([1, LATENT_DIM]), training=False)
14
+
15
+ fs = HfFileSystem()
16
+ with fs.open("PrakhAI/AIPlane2/g_checkpoint.msgpack", "rb") as f:
17
+ g_state = from_state_dict(variables, msgpack_restore(f.read()))
18
+
19
+ def sample_latent(batch, key):
20
+ return jax.random.normal(key, shape=(batch, LATENT_DIM))
21
+
22
+ def to_img(normalized):
23
+ return ((normalized+1)*255./2.).astype(np.uint8)
24
+
25
+ st.write("The model and its details are at https://huggingface.co/PrakhAI/AIPlane2")
26
+ if st.button('Generate Random'):
27
+ st.session_state['generate'] = None
28
+
29
+ ROWS = 4
30
+ COLUMNS = 4
31
+
32
+ def set_latent(latent):
33
+ st.session_state['generate'] = latent
34
+
35
+ if 'generate' in st.session_state:
36
+ unique_id = int(1_000_000 * time.time())
37
+ latents = sample_latent(ROWS * COLUMNS, jax.random.PRNGKey(unique_id))
38
+ previous = st.session_state['generate']
39
+ if previous is not None:
40
+ if "similarity" not in st.session_state:
41
+ st.session_state["similarity"] = 0.5
42
+ similarity = st.number_input(label="Mutation (for \"Generate Similar\") - lower value generates more similar images", key="similarity", min_value=0.01, max_value=1.0)
43
+ latents = np.repeat([previous], repeats=16, axis=0) + similarity * latents
44
+ (g_out128, _, _, _, _, _) = generator.apply({'params': g_state['params'], 'batch_stats': g_state['batch_stats']}, latents, training=False)
45
+ img = np.array(to_img(g_out128))
46
+ for row in range(ROWS):
47
+ with st.container():
48
+ for (col_idx, col) in enumerate(st.columns(COLUMNS)):
49
+ with col:
50
+ idx = row*COLUMNS + col_idx
51
+ st.image(Image.fromarray(img[idx]))
52
+ st.button(label="Generate Similar", key="%d_%d" % (unique_id, idx), on_click=set_latent, args=(latents[idx],))
generator.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flax import linen as nn
2
+ import jax
3
+ import jax.numpy as jnp
4
+ from local_response_norm import LocalResponseNorm
5
+
6
+ LATENT_DIM = 500
7
+ EPSILON = 1e-8
8
+
9
+ class Generator(nn.Module):
10
+ @nn.compact
11
+ def __call__(self, latent, training=True):
12
+ x = nn.Dense(features=64)(latent)
13
+ # x = nn.BatchNorm(not training)(x)
14
+ x = nn.relu(x)
15
+ x = nn.Dense(features=2*2*1024)(x)
16
+ x = nn.BatchNorm(not training)(x)
17
+ x = nn.relu(x)
18
+ x = nn.Dropout(0.25, deterministic=not training)(x)
19
+ x = x.reshape((x.shape[0], 2, 2, -1))
20
+ x4 = nn.ConvTranspose(features=512, kernel_size=(3, 3), strides=(2, 2))(x)
21
+ x4 = LocalResponseNorm()(x4)
22
+ x4 = nn.relu(x4)
23
+ x4o = nn.ConvTranspose(features=3, kernel_size=(3, 3))(x4)
24
+ x4 = nn.ConvTranspose(features=512, kernel_size=(3, 3))(x4)
25
+ x4 = LocalResponseNorm()(x4)
26
+ x4 = nn.relu(x4)
27
+ x8 = nn.ConvTranspose(features=256, kernel_size=(3, 3), strides=(2, 2))(x4)
28
+ x8 = LocalResponseNorm()(x8)
29
+ x8 = nn.relu(x8)
30
+ x8o = nn.ConvTranspose(features=3, kernel_size=(3, 3))(x8)
31
+ x8 = nn.ConvTranspose(features=256, kernel_size=(3, 3))(x8)
32
+ x8 = LocalResponseNorm()(x8)
33
+ x8 = nn.relu(x8)
34
+ x16 = nn.ConvTranspose(features=128, kernel_size=(3, 3), strides=(2, 2))(x8)
35
+ x16 = LocalResponseNorm()(x16)
36
+ x16 = nn.relu(x16)
37
+ x16o = nn.ConvTranspose(features=3, kernel_size=(3, 3))(x16)
38
+ x16 = nn.ConvTranspose(features=128, kernel_size=(3, 3))(x16)
39
+ x16 = LocalResponseNorm()(x16)
40
+ x16 = nn.relu(x16)
41
+ x32 = nn.ConvTranspose(features=64, kernel_size=(3, 3), strides=(2, 2))(x16)
42
+ x32 = LocalResponseNorm()(x32)
43
+ x32 = nn.relu(x32)
44
+ x32o = nn.ConvTranspose(features=3, kernel_size=(3, 3))(x32)
45
+ x32 = nn.ConvTranspose(features=64, kernel_size=(3, 3))(x32)
46
+ x32 = LocalResponseNorm()(x32)
47
+ x32 = nn.relu(x32)
48
+ x64 = nn.ConvTranspose(features=32, kernel_size=(3, 3), strides=(2, 2))(x32)
49
+ x64 = LocalResponseNorm()(x64)
50
+ x64 = nn.relu(x64)
51
+ x64o = nn.ConvTranspose(features=3, kernel_size=(3, 3))(x64)
52
+ x64 = nn.ConvTranspose(features=32, kernel_size=(3, 3))(x64)
53
+ x64 = LocalResponseNorm()(x64)
54
+ x64 = nn.relu(x64)
55
+ x128 = nn.ConvTranspose(features=64, kernel_size=(3, 3), strides=(2, 2))(x64)
56
+ x128 = LocalResponseNorm()(x128)
57
+ x128 = nn.relu(x128)
58
+ x128o = nn.ConvTranspose(features=3, kernel_size=(3, 3))(x128)
59
+ return (nn.tanh(x128o), nn.tanh(x64o), nn.tanh(x32o), nn.tanh(x16o), nn.tanh(x8o), nn.tanh(x4o))
local_response_norm.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flax import linen as nn
2
+ import jax
3
+ import jax.numpy as jnp
4
+
5
+ class LocalResponseNorm(nn.Module):
6
+ @nn.compact
7
+ def __call__(
8
+ self,
9
+ value: jax.Array
10
+ ) -> jax.Array:
11
+ return value / jnp.repeat(jnp.expand_dims((1e-8 + (value**2).mean(axis=-1))**0.5, axis=-1), repeats=value.shape[-1], axis=-1)
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ flax