PrakhAI commited on
Commit
5bf3307
·
1 Parent(s): 48e419b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -3
app.py CHANGED
@@ -34,14 +34,13 @@ params = generator.init(jax.random.PRNGKey(0), jnp.zeros([1, LATENT_DIM]), train
34
 
35
  fs = HfFileSystem()
36
  with fs.open("PrakhAI/DigitGAN/g_checkpoint.msgpack", "rb") as f:
37
- params = from_state_dict(params, msgpack_restore(f.read())["params"])
38
- batch_stats = from_state_dict(params, msgpack_restore(f.read())["batch_stats"])
39
 
40
  def sample_latent(key):
41
  return jax.random.normal(key, shape=(1, LATENT_DIM))
42
 
43
  if st.button('Generate Digit'):
44
  latents = sample_latent(jax.random.PRNGKey(int(1_000_000 * time.time())))
45
- g_out = Generator().apply({'params': params, 'batch_stats': batch_stats}, latents, training=False)
46
  img = ((np.array(g_out)+1)*255./2.).astype(np.uint8)[0]
47
  st.image(Image.fromarray(np.repeat(img, repeats=3, axis=2)))
 
34
 
35
  fs = HfFileSystem()
36
  with fs.open("PrakhAI/DigitGAN/g_checkpoint.msgpack", "rb") as f:
37
+ g_state = from_state_dict(params, msgpack_restore(f.read()))
 
38
 
39
  def sample_latent(key):
40
  return jax.random.normal(key, shape=(1, LATENT_DIM))
41
 
42
  if st.button('Generate Digit'):
43
  latents = sample_latent(jax.random.PRNGKey(int(1_000_000 * time.time())))
44
+ g_out = Generator().apply({'params': g_state.params, 'batch_stats': g_state.batch_stats}, latents, training=False)
45
  img = ((np.array(g_out)+1)*255./2.).astype(np.uint8)[0]
46
  st.image(Image.fromarray(np.repeat(img, repeats=3, axis=2)))