PrakhAI commited on
Commit
9c9ee7c
·
1 Parent(s): dad1c8d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -3
app.py CHANGED
@@ -28,6 +28,7 @@ class Generator(nn.Module):
28
  x = nn.relu(x)
29
  x = nn.ConvTranspose(features=1, kernel_size=(2, 2), strides=(2, 2))(x)
30
  x = nn.tanh(x)
 
31
 
32
  generator = Generator()
33
  variables = generator.init(jax.random.PRNGKey(0), jnp.zeros([1, LATENT_DIM]), training=False)
@@ -42,8 +43,5 @@ def sample_latent(key):
42
  if st.button('Generate Digit'):
43
  latents = sample_latent(jax.random.PRNGKey(int(1_000_000 * time.time())))
44
  g_out = generator.apply({'params': g_state['params'], 'batch_stats': g_state['batch_stats']}, latents, training=False)
45
- st.write(g_state['params'])
46
- st.write(g_state['batch_stats'])
47
- st.write(g_out)
48
  img = ((np.array(g_out)+1)*255./2.).astype(np.uint8)[0]
49
  st.image(Image.fromarray(np.repeat(img, repeats=3, axis=2)))
 
28
  x = nn.relu(x)
29
  x = nn.ConvTranspose(features=1, kernel_size=(2, 2), strides=(2, 2))(x)
30
  x = nn.tanh(x)
31
+ return x
32
 
33
  generator = Generator()
34
  variables = generator.init(jax.random.PRNGKey(0), jnp.zeros([1, LATENT_DIM]), training=False)
 
43
  if st.button('Generate Digit'):
44
  latents = sample_latent(jax.random.PRNGKey(int(1_000_000 * time.time())))
45
  g_out = generator.apply({'params': g_state['params'], 'batch_stats': g_state['batch_stats']}, latents, training=False)
 
 
 
46
  img = ((np.array(g_out)+1)*255./2.).astype(np.uint8)[0]
47
  st.image(Image.fromarray(np.repeat(img, repeats=3, axis=2)))