Update app.py
Browse files
app.py
CHANGED
@@ -13,7 +13,7 @@ generator = Generator()
|
|
13 |
variables = generator.init(jax.random.PRNGKey(0), jnp.zeros([1, LATENT_DIM]), training=False)
|
14 |
|
15 |
fs = HfFileSystem()
|
16 |
-
with fs.open("PrakhAI/
|
17 |
g_state = from_state_dict(variables, msgpack_restore(f.read()))
|
18 |
|
19 |
def sample_latent(batch, key):
|
@@ -22,12 +22,12 @@ def sample_latent(batch, key):
|
|
22 |
def to_img(normalized):
|
23 |
return ((normalized+1)*255./2.).astype(np.uint8)
|
24 |
|
25 |
-
st.write("The model and its details are at https://huggingface.co/PrakhAI/
|
26 |
if st.button('Generate Random'):
|
27 |
st.session_state['generate'] = None
|
28 |
|
29 |
-
ROWS =
|
30 |
-
COLUMNS =
|
31 |
|
32 |
def set_latent(latent):
|
33 |
st.session_state['generate'] = latent
|
@@ -40,9 +40,9 @@ if 'generate' in st.session_state:
|
|
40 |
if "similarity" not in st.session_state:
|
41 |
st.session_state["similarity"] = 0.5
|
42 |
similarity = st.number_input(label="Mutation (for \"Generate Similar\") - lower value generates more similar images", key="similarity", min_value=0.01, max_value=1.0)
|
43 |
-
latents = np.repeat([previous], repeats=
|
44 |
-
|
45 |
-
img = np.array(to_img(
|
46 |
for row in range(ROWS):
|
47 |
with st.container():
|
48 |
for (col_idx, col) in enumerate(st.columns(COLUMNS)):
|
|
|
13 |
variables = generator.init(jax.random.PRNGKey(0), jnp.zeros([1, LATENT_DIM]), training=False)
|
14 |
|
15 |
fs = HfFileSystem()
|
16 |
+
with fs.open("PrakhAI/AIPlane3/g_checkpoint_200000.msgpack", "rb") as f:
|
17 |
g_state = from_state_dict(variables, msgpack_restore(f.read()))
|
18 |
|
19 |
def sample_latent(batch, key):
|
|
|
22 |
def to_img(normalized):
|
23 |
return ((normalized+1)*255./2.).astype(np.uint8)
|
24 |
|
25 |
+
st.write("The model and its details are at https://huggingface.co/PrakhAI/AIPlane3")
|
26 |
if st.button('Generate Random'):
|
27 |
st.session_state['generate'] = None
|
28 |
|
29 |
+
ROWS = 2
|
30 |
+
COLUMNS = 2
|
31 |
|
32 |
def set_latent(latent):
|
33 |
st.session_state['generate'] = latent
|
|
|
40 |
if "similarity" not in st.session_state:
|
41 |
st.session_state["similarity"] = 0.5
|
42 |
similarity = st.number_input(label="Mutation (for \"Generate Similar\") - lower value generates more similar images", key="similarity", min_value=0.01, max_value=1.0)
|
43 |
+
latents = np.repeat([previous], repeats=4, axis=0) + similarity * latents
|
44 |
+
g_out = generator.apply({'params': g_state['params']}, latents)
|
45 |
+
img = np.array(to_img(g_out))
|
46 |
for row in range(ROWS):
|
47 |
with st.container():
|
48 |
for (col_idx, col) in enumerate(st.columns(COLUMNS)):
|