Update app.py
Browse files
app.py
CHANGED
@@ -34,14 +34,13 @@ params = generator.init(jax.random.PRNGKey(0), jnp.zeros([1, LATENT_DIM]), train
|
|
34 |
|
35 |
fs = HfFileSystem()
|
36 |
with fs.open("PrakhAI/DigitGAN/g_checkpoint.msgpack", "rb") as f:
|
37 |
-
|
38 |
-
batch_stats = from_state_dict(params, msgpack_restore(f.read())["batch_stats"])
|
39 |
|
40 |
def sample_latent(key):
|
41 |
return jax.random.normal(key, shape=(1, LATENT_DIM))
|
42 |
|
43 |
if st.button('Generate Digit'):
|
44 |
latents = sample_latent(jax.random.PRNGKey(int(1_000_000 * time.time())))
|
45 |
-
g_out = Generator().apply({'params': params, 'batch_stats': batch_stats}, latents, training=False)
|
46 |
img = ((np.array(g_out)+1)*255./2.).astype(np.uint8)[0]
|
47 |
st.image(Image.fromarray(np.repeat(img, repeats=3, axis=2)))
|
|
|
34 |
|
35 |
fs = HfFileSystem()
|
36 |
with fs.open("PrakhAI/DigitGAN/g_checkpoint.msgpack", "rb") as f:
|
37 |
+
g_state = from_state_dict(params, msgpack_restore(f.read()))
|
|
|
38 |
|
39 |
def sample_latent(key):
|
40 |
return jax.random.normal(key, shape=(1, LATENT_DIM))
|
41 |
|
42 |
if st.button('Generate Digit'):
|
43 |
latents = sample_latent(jax.random.PRNGKey(int(1_000_000 * time.time())))
|
44 |
+
g_out = Generator().apply({'params': g_state.params, 'batch_stats': g_state.batch_stats}, latents, training=False)
|
45 |
img = ((np.array(g_out)+1)*255./2.).astype(np.uint8)[0]
|
46 |
st.image(Image.fromarray(np.repeat(img, repeats=3, axis=2)))
|