brendenc commited on
Commit
4c4da3f
·
1 Parent(s): 5d4ac8b

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -0
app.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from load_model import load_model
3
+ import matplotlib.pyplot as plt
4
+ from tensorflow.keras import layers
5
+ from sklearn.datasets import make_moons
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+
9
+ model = load_model()
10
+
11
+ # Load the Data
12
+ data = make_moons(3000, noise=0.05)[0].astype("float32")
13
+ norm = layers.experimental.preprocessing.Normalization()
14
+ norm.adapt(data)
15
+ normalized_data = norm(data)
16
+ z, _ = model(normalized_data)
17
+
18
+ demo = gr.Blocks()
19
+
20
+ with demo:
21
+ gr.Markdown("""# Density estimation using Real NVP <br>
22
+ This demo shows a toy example of using Real NVP (real-valued non-volume preserving transformations)
23
+ from this [example](https://keras.io/examples/generative/real_nvp/). Below we have two tabs. The first, Inference, shows
24
+ our mapping from a data distribution (moons) to a latent space with a known distribution (Gaussian). Click the button to see how a data point from our distribution maps
25
+ to our latent space. Our second tab allows you to generate a sample from our latent space, and view the generated data space that is associated with it.
26
+
27
+ Full credits for this model & example
28
+ go to <br>[Mandolini Giorgio Maria](https://www.linkedin.com/in/giorgio-maria-mandolini-a2a1b71b4/),
29
+ [Sanna Daniele](https://www.linkedin.com/in/daniele-sanna-338629bb/),
30
+ and [Zannini Quirini Giorgio](https://www.linkedin.com/in/giorgio-zannini-quirini-16ab181a0/).<br>
31
+ Demo by [Brenden Connors](https://www.linkedin.com/in/brenden-connors-6a0512195).""")
32
+
33
+ with gr.Tabs():
34
+ with gr.TabItem('Inference'):
35
+ button = gr.Button(value='Infer Sample Point')
36
+
37
+ with gr.Row():
38
+ fig = plt.figure()
39
+ plt.scatter(normalized_data[:, 0], normalized_data[:, 1], color="r")
40
+ plt.xlim([-2, 2])
41
+ plt.ylim([-2, 2])
42
+ plt.title('Inference Data Space')
43
+ fig2 = plt.figure()
44
+ plt.scatter(z[:, 0], z[:, 1], color="r")
45
+ plt.xlim([-3.5, 4])
46
+ plt.ylim([-3.5, 4])
47
+ plt.title('Inference Latent Space')
48
+ data_space = gr.Plot(value = fig)
49
+ latent_space = gr.Plot(value = fig2)
50
+ with gr.TabItem('Generation'):
51
+ button_generate = gr.Button('Generate')
52
+
53
+ with gr.Row():
54
+ fig3 = plt.figure()
55
+
56
+ fig4 = plt.figure()
57
+ generated_lspace = gr.Plot(fig3)
58
+ generated_dspace = gr.Plot(fig4)
59
+
60
+ def inference_sample():
61
+ idx = np.random.choice(normalized_data.shape[0])
62
+ new_fig1 = plt.figure()
63
+ plt.scatter(normalized_data[:, 0], normalized_data[:, 1], color="r")
64
+ plt.scatter(normalized_data[idx, 0], normalized_data[idx, 1], color="b")
65
+ plt.title('Inference Data Space')
66
+ plt.xlim([-2, 2])
67
+ plt.ylim([-2, 2])
68
+ output, _ = model(np.array(normalized_data[idx, :]).reshape((1, 2)))
69
+
70
+ new_fig2 = plt.figure()
71
+ plt.scatter(z[:, 0], z[:, 1], color="r")
72
+ plt.scatter(output[0,0] , output[0,1], color="b")
73
+ plt.xlim([-3.5, 4])
74
+ plt.ylim([-3.5, 4])
75
+ plt.title('Inference Latent Space')
76
+ return new_fig1, new_fig2
77
+
78
+ def generate():
79
+ samples = model.distribution.sample(3000)
80
+ x, _ = model.predict(samples)
81
+
82
+ new_fig1=plt.figure()
83
+ plt.scatter(samples[:,0], samples[:,1])
84
+ plt.title('Generated Latent Space')
85
+ plt.xlim([-3.5, 4])
86
+ plt.ylim([-3.5, 4])
87
+
88
+ new_fig2=plt.figure()
89
+ plt.scatter(x[:,0], x[:,1])
90
+ plt.title('Generated Data Space')
91
+ plt.xlim([-2, 2])
92
+ plt.ylim([-2, 2])
93
+ return new_fig1, new_fig2
94
+ button.click(inference_sample, inputs=[], outputs=[data_space, latent_space])
95
+ button_generate.click(generate, inputs=[], outputs=[generated_lspace, generated_dspace])
96
+
97
+ demo.launch()