PrakhAI commited on
Commit
36527b7
·
1 Parent(s): cc63bcb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -7
app.py CHANGED
@@ -13,7 +13,7 @@ generator = Generator()
13
  variables = generator.init(jax.random.PRNGKey(0), jnp.zeros([1, LATENT_DIM]), training=False)
14
 
15
  fs = HfFileSystem()
16
- with fs.open("PrakhAI/AIPlane2/g_checkpoint.msgpack", "rb") as f:
17
  g_state = from_state_dict(variables, msgpack_restore(f.read()))
18
 
19
  def sample_latent(batch, key):
@@ -22,12 +22,12 @@ def sample_latent(batch, key):
22
  def to_img(normalized):
23
  return ((normalized+1)*255./2.).astype(np.uint8)
24
 
25
- st.write("The model and its details are at https://huggingface.co/PrakhAI/AIPlane2")
26
  if st.button('Generate Random'):
27
  st.session_state['generate'] = None
28
 
29
- ROWS = 4
30
- COLUMNS = 4
31
 
32
  def set_latent(latent):
33
  st.session_state['generate'] = latent
@@ -40,9 +40,9 @@ if 'generate' in st.session_state:
40
  if "similarity" not in st.session_state:
41
  st.session_state["similarity"] = 0.5
42
  similarity = st.number_input(label="Mutation (for \"Generate Similar\") - lower value generates more similar images", key="similarity", min_value=0.01, max_value=1.0)
43
- latents = np.repeat([previous], repeats=16, axis=0) + similarity * latents
44
- (g_out128, _, _, _, _, _) = generator.apply({'params': g_state['params'], 'batch_stats': g_state['batch_stats']}, latents, training=False)
45
- img = np.array(to_img(g_out128))
46
  for row in range(ROWS):
47
  with st.container():
48
  for (col_idx, col) in enumerate(st.columns(COLUMNS)):
 
13
  variables = generator.init(jax.random.PRNGKey(0), jnp.zeros([1, LATENT_DIM]), training=False)
14
 
15
  fs = HfFileSystem()
16
+ with fs.open("PrakhAI/AIPlane3/g_checkpoint_200000.msgpack", "rb") as f:
17
  g_state = from_state_dict(variables, msgpack_restore(f.read()))
18
 
19
  def sample_latent(batch, key):
 
22
  def to_img(normalized):
23
  return ((normalized+1)*255./2.).astype(np.uint8)
24
 
25
+ st.write("The model and its details are at https://huggingface.co/PrakhAI/AIPlane3")
26
  if st.button('Generate Random'):
27
  st.session_state['generate'] = None
28
 
29
+ ROWS = 2
30
+ COLUMNS = 2
31
 
32
  def set_latent(latent):
33
  st.session_state['generate'] = latent
 
40
  if "similarity" not in st.session_state:
41
  st.session_state["similarity"] = 0.5
42
  similarity = st.number_input(label="Mutation (for \"Generate Similar\") - lower value generates more similar images", key="similarity", min_value=0.01, max_value=1.0)
43
+ latents = np.repeat([previous], repeats=4, axis=0) + similarity * latents
44
+ g_out = generator.apply({'params': g_state['params']}, latents)
45
+ img = np.array(to_img(g_out))
46
  for row in range(ROWS):
47
  with st.container():
48
  for (col_idx, col) in enumerate(st.columns(COLUMNS)):