yashbyname commited on
Commit
bb3fdf2
·
verified ·
1 Parent(s): 4456bcf

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +176 -0
app.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ from PIL import Image
6
+ import os
7
+
8
+ # Set device
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+
11
+ # Model parameters (must match your training)
12
+ nz = 100
13
+ ngf = 64
14
+ num_classes = 10
15
+
16
+ # Generator class (same as your training script)
17
+ class Generator(nn.Module):
18
+ def __init__(self):
19
+ super(Generator, self).__init__()
20
+
21
+ self.label_embedding = nn.Embedding(num_classes, num_classes)
22
+
23
+ self.main = nn.Sequential(
24
+ nn.ConvTranspose2d(nz + num_classes, ngf * 8, 4, 1, 0, bias=False),
25
+ nn.BatchNorm2d(ngf * 8),
26
+ nn.ReLU(True),
27
+
28
+ nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
29
+ nn.BatchNorm2d(ngf * 4),
30
+ nn.ReLU(True),
31
+
32
+ nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
33
+ nn.BatchNorm2d(ngf * 2),
34
+ nn.ReLU(True),
35
+
36
+ nn.ConvTranspose2d(ngf * 2, 1, 4, 2, 1, bias=False),
37
+ nn.Tanh()
38
+ )
39
+
40
+ self.resize = nn.AdaptiveAvgPool2d((28, 28))
41
+
42
+ def forward(self, noise, labels):
43
+ label_embedding = self.label_embedding(labels)
44
+ label_embedding = label_embedding.view(label_embedding.size(0), num_classes, 1, 1)
45
+ input_tensor = torch.cat([noise, label_embedding], dim=1)
46
+ output = self.main(input_tensor)
47
+ output = self.resize(output)
48
+ return output
49
+
50
+ # Load the trained model
51
+ @st.cache_resource
52
+ def load_model():
53
+ generator = Generator().to(device)
54
+
55
+ # Load the saved model
56
+ if os.path.exists('mnist_gan_model.pth'):
57
+ checkpoint = torch.load('mnist_gan_model.pth', map_location=device)
58
+ generator.load_state_dict(checkpoint['generator_state_dict'])
59
+ generator.eval()
60
+ print("Model loaded successfully!")
61
+ else:
62
+ print("Warning: Model file not found!")
63
+
64
+ return generator
65
+
66
+ # Initialize generator
67
+ generator = load_model()
68
+
69
+ # Generation function
70
+ def generate_digit_images(digit):
71
+ """Generate 5 images of the specified digit"""
72
+
73
+ digit = int(digit)
74
+ num_images = 5
75
+
76
+ with torch.no_grad():
77
+ # Generate random noise
78
+ noise = torch.randn(num_images, nz, 1, 1).to(device)
79
+ labels = torch.full((num_images,), digit, dtype=torch.long).to(device)
80
+
81
+ # Generate images
82
+ generated_images = generator(noise, labels)
83
+
84
+ # Convert to numpy and denormalize
85
+ images = generated_images.cpu().numpy()
86
+ images = (images + 1) / 2.0 # Denormalize from [-1, 1] to [0, 1]
87
+ images = np.squeeze(images) # Remove channel dimension
88
+
89
+ # Convert to PIL Images for Gradio
90
+ pil_images = []
91
+ for img in images:
92
+ # Convert to 0-255 range and uint8
93
+ img_uint8 = (img * 255).astype(np.uint8)
94
+ pil_img = Image.fromarray(img_uint8, mode='L')
95
+ # Resize for better visibility
96
+ pil_img = pil_img.resize((112, 112), Image.NEAREST) # 4x upscale
97
+ pil_images.append(pil_img)
98
+
99
+ return pil_images
100
+
101
+ # Create Gradio interface
102
+ def create_app():
103
+ with gr.Blocks(
104
+ title="Handwritten Digit Generator",
105
+ theme=gr.themes.Soft(),
106
+ css=".gradio-container {max-width: 700px; margin: auto;}"
107
+ ) as app:
108
+
109
+ gr.Markdown("# 🔢 Handwritten Digit Generator")
110
+ gr.Markdown("Generate synthetic MNIST-like digit images using a trained GAN model. Select a digit (0-9) to generate 5 unique images.")
111
+
112
+ with gr.Row():
113
+ with gr.Column(scale=1):
114
+ digit_input = gr.Dropdown(
115
+ choices=list(range(10)),
116
+ value=2,
117
+ label="Choose a digit to generate (0-9)",
118
+ interactive=True
119
+ )
120
+ generate_btn = gr.Button(
121
+ "🎨 Generate Images",
122
+ variant="primary",
123
+ size="lg"
124
+ )
125
+
126
+ with gr.Column(scale=2):
127
+ gr.Markdown("### Generated Images")
128
+
129
+ # Gallery to display 5 images
130
+ image_gallery = gr.Gallery(
131
+ label="Generated Digit Images",
132
+ show_label=False,
133
+ columns=5,
134
+ rows=1,
135
+ height=200,
136
+ object_fit="contain"
137
+ )
138
+
139
+ # Example section
140
+ gr.Markdown("---")
141
+ gr.Markdown("### How it works")
142
+ gr.Markdown("""
143
+ 1. **Select** a digit from the dropdown (0-9)
144
+ 2. **Click** 'Generate Images' button
145
+ 3. **View** 5 unique generated images of your chosen digit
146
+ 4. Each generation produces different variations of the same digit
147
+ """)
148
+
149
+ # Connect button to generation function
150
+ generate_btn.click(
151
+ fn=generate_digit_images,
152
+ inputs=[digit_input],
153
+ outputs=[image_gallery]
154
+ )
155
+
156
+ # Auto-generate on page load
157
+ app.load(
158
+ fn=generate_digit_images,
159
+ inputs=[gr.Number(value=2, visible=False)],
160
+ outputs=[image_gallery]
161
+ )
162
+
163
+ # Footer
164
+ gr.Markdown("---")
165
+ gr.Markdown("**🤖 Model**: Conditional GAN trained on MNIST | **⚡ Framework**: PyTorch + Gradio")
166
+
167
+ return app
168
+
169
+ # Launch the app
170
+ if __name__ == "__main__":
171
+ app = create_app()
172
+ app.launch(
173
+ server_name="0.0.0.0",
174
+ server_port=7860,
175
+ share=False # Set to False for Hugging Face deployment
176
+ )