PrakhAI commited on
Commit
de78f83
·
1 Parent(s): 03d5be6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -0
app.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ import jax
4
+ import numpy as np
5
+ from flax import linen as nn # Linen API
6
+ from huggingface_hub import HfFileSystem
7
+ from flax.serialization import msgpack_restore, from_state_dict
8
+ import time
9
+
10
+ LATENT_DIM = 100
11
+
12
+ class Generator(nn.Module):
13
+ @nn.compact
14
+ def __call__(self, latent, training=True):
15
+ x = latent
16
+ x = nn.Dense(features=64)(x)
17
+ x = nn.BatchNorm(not training)(x)
18
+ x = nn.relu(x)
19
+ x = nn.Dense(features=2*2*512)(x)
20
+ x = nn.relu(x)
21
+ x = x.reshape((x.shape[0], 2, 2, -1))
22
+ x = nn.ConvTranspose(features=256, kernel_size=(2, 2), strides=(2, 2))(x)
23
+ x = nn.relu(x)
24
+ x = nn.ConvTranspose(features=128, kernel_size=(2, 2), strides=(2, 2))(x)
25
+ x = nn.relu(x)
26
+ x = nn.ConvTranspose(features=64, kernel_size=(2, 2), strides=(2, 2))(x)
27
+ x = nn.relu(x)
28
+ x = nn.ConvTranspose(features=1, kernel_size=(2, 2), strides=(2, 2))(x)
29
+ x = nn.tanh(x)
30
+
31
+ generator = Generator()
32
+
33
+ fs = HfFileSystem()
34
+ with fs.open("PrakhAI/DigitGAN/g_checkpoint.msgpack", "rb") as f:
35
+ params = from_state_dict(params, msgpack_restore(f.read())["params"])
36
+
37
+ def sample_latent(key):
38
+ return jax.random.normal(key, shape=(1, LATENT_DIM))
39
+
40
+ if st.button('Generate Digit'):
41
+ latents = sample_latent(jax.random.PRNGKey(int(1_000_000 * time.time())))
42
+ g_out = Generator().apply({'params': g_state.params, 'batch_stats': g_state.batch_stats}, latents, training=False)
43
+ img = ((np.array(g_out)+1)*255./2.).astype(np.uint8)[0]
44
+ st.image(Image.fromarray(np.repeat(img, repeats=3, axis=2)))