PrakhAI commited on
Commit
31266d7
·
1 Parent(s): 5bf3307

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -30,11 +30,11 @@ class Generator(nn.Module):
30
  x = nn.tanh(x)
31
 
32
  generator = Generator()
33
- params = generator.init(jax.random.PRNGKey(0), jnp.zeros([1, LATENT_DIM]), training=False)['params']
34
 
35
  fs = HfFileSystem()
36
  with fs.open("PrakhAI/DigitGAN/g_checkpoint.msgpack", "rb") as f:
37
- g_state = from_state_dict(params, msgpack_restore(f.read()))
38
 
39
  def sample_latent(key):
40
  return jax.random.normal(key, shape=(1, LATENT_DIM))
 
30
  x = nn.tanh(x)
31
 
32
  generator = Generator()
33
+ variables = generator.init(jax.random.PRNGKey(0), jnp.zeros([1, LATENT_DIM]), training=False)
34
 
35
  fs = HfFileSystem()
36
  with fs.open("PrakhAI/DigitGAN/g_checkpoint.msgpack", "rb") as f:
37
+ g_state = from_state_dict(variables, msgpack_restore(f.read()))
38
 
39
  def sample_latent(key):
40
  return jax.random.normal(key, shape=(1, LATENT_DIM))