Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
import os
|
7 |
+
|
8 |
+
# Set device
|
9 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
10 |
+
|
11 |
+
# Model parameters (must match your training)
|
12 |
+
nz = 100
|
13 |
+
ngf = 64
|
14 |
+
num_classes = 10
|
15 |
+
|
16 |
+
# Generator class (same as your training script)
|
17 |
+
class Generator(nn.Module):
|
18 |
+
def __init__(self):
|
19 |
+
super(Generator, self).__init__()
|
20 |
+
|
21 |
+
self.label_embedding = nn.Embedding(num_classes, num_classes)
|
22 |
+
|
23 |
+
self.main = nn.Sequential(
|
24 |
+
nn.ConvTranspose2d(nz + num_classes, ngf * 8, 4, 1, 0, bias=False),
|
25 |
+
nn.BatchNorm2d(ngf * 8),
|
26 |
+
nn.ReLU(True),
|
27 |
+
|
28 |
+
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
|
29 |
+
nn.BatchNorm2d(ngf * 4),
|
30 |
+
nn.ReLU(True),
|
31 |
+
|
32 |
+
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
|
33 |
+
nn.BatchNorm2d(ngf * 2),
|
34 |
+
nn.ReLU(True),
|
35 |
+
|
36 |
+
nn.ConvTranspose2d(ngf * 2, 1, 4, 2, 1, bias=False),
|
37 |
+
nn.Tanh()
|
38 |
+
)
|
39 |
+
|
40 |
+
self.resize = nn.AdaptiveAvgPool2d((28, 28))
|
41 |
+
|
42 |
+
def forward(self, noise, labels):
|
43 |
+
label_embedding = self.label_embedding(labels)
|
44 |
+
label_embedding = label_embedding.view(label_embedding.size(0), num_classes, 1, 1)
|
45 |
+
input_tensor = torch.cat([noise, label_embedding], dim=1)
|
46 |
+
output = self.main(input_tensor)
|
47 |
+
output = self.resize(output)
|
48 |
+
return output
|
49 |
+
|
50 |
+
# Load the trained model
|
51 |
+
@st.cache_resource
|
52 |
+
def load_model():
|
53 |
+
generator = Generator().to(device)
|
54 |
+
|
55 |
+
# Load the saved model
|
56 |
+
if os.path.exists('mnist_gan_model.pth'):
|
57 |
+
checkpoint = torch.load('mnist_gan_model.pth', map_location=device)
|
58 |
+
generator.load_state_dict(checkpoint['generator_state_dict'])
|
59 |
+
generator.eval()
|
60 |
+
print("Model loaded successfully!")
|
61 |
+
else:
|
62 |
+
print("Warning: Model file not found!")
|
63 |
+
|
64 |
+
return generator
|
65 |
+
|
66 |
+
# Initialize generator
|
67 |
+
generator = load_model()
|
68 |
+
|
69 |
+
# Generation function
|
70 |
+
def generate_digit_images(digit):
|
71 |
+
"""Generate 5 images of the specified digit"""
|
72 |
+
|
73 |
+
digit = int(digit)
|
74 |
+
num_images = 5
|
75 |
+
|
76 |
+
with torch.no_grad():
|
77 |
+
# Generate random noise
|
78 |
+
noise = torch.randn(num_images, nz, 1, 1).to(device)
|
79 |
+
labels = torch.full((num_images,), digit, dtype=torch.long).to(device)
|
80 |
+
|
81 |
+
# Generate images
|
82 |
+
generated_images = generator(noise, labels)
|
83 |
+
|
84 |
+
# Convert to numpy and denormalize
|
85 |
+
images = generated_images.cpu().numpy()
|
86 |
+
images = (images + 1) / 2.0 # Denormalize from [-1, 1] to [0, 1]
|
87 |
+
images = np.squeeze(images) # Remove channel dimension
|
88 |
+
|
89 |
+
# Convert to PIL Images for Gradio
|
90 |
+
pil_images = []
|
91 |
+
for img in images:
|
92 |
+
# Convert to 0-255 range and uint8
|
93 |
+
img_uint8 = (img * 255).astype(np.uint8)
|
94 |
+
pil_img = Image.fromarray(img_uint8, mode='L')
|
95 |
+
# Resize for better visibility
|
96 |
+
pil_img = pil_img.resize((112, 112), Image.NEAREST) # 4x upscale
|
97 |
+
pil_images.append(pil_img)
|
98 |
+
|
99 |
+
return pil_images
|
100 |
+
|
101 |
+
# Create Gradio interface
|
102 |
+
def create_app():
|
103 |
+
with gr.Blocks(
|
104 |
+
title="Handwritten Digit Generator",
|
105 |
+
theme=gr.themes.Soft(),
|
106 |
+
css=".gradio-container {max-width: 700px; margin: auto;}"
|
107 |
+
) as app:
|
108 |
+
|
109 |
+
gr.Markdown("# 🔢 Handwritten Digit Generator")
|
110 |
+
gr.Markdown("Generate synthetic MNIST-like digit images using a trained GAN model. Select a digit (0-9) to generate 5 unique images.")
|
111 |
+
|
112 |
+
with gr.Row():
|
113 |
+
with gr.Column(scale=1):
|
114 |
+
digit_input = gr.Dropdown(
|
115 |
+
choices=list(range(10)),
|
116 |
+
value=2,
|
117 |
+
label="Choose a digit to generate (0-9)",
|
118 |
+
interactive=True
|
119 |
+
)
|
120 |
+
generate_btn = gr.Button(
|
121 |
+
"🎨 Generate Images",
|
122 |
+
variant="primary",
|
123 |
+
size="lg"
|
124 |
+
)
|
125 |
+
|
126 |
+
with gr.Column(scale=2):
|
127 |
+
gr.Markdown("### Generated Images")
|
128 |
+
|
129 |
+
# Gallery to display 5 images
|
130 |
+
image_gallery = gr.Gallery(
|
131 |
+
label="Generated Digit Images",
|
132 |
+
show_label=False,
|
133 |
+
columns=5,
|
134 |
+
rows=1,
|
135 |
+
height=200,
|
136 |
+
object_fit="contain"
|
137 |
+
)
|
138 |
+
|
139 |
+
# Example section
|
140 |
+
gr.Markdown("---")
|
141 |
+
gr.Markdown("### How it works")
|
142 |
+
gr.Markdown("""
|
143 |
+
1. **Select** a digit from the dropdown (0-9)
|
144 |
+
2. **Click** 'Generate Images' button
|
145 |
+
3. **View** 5 unique generated images of your chosen digit
|
146 |
+
4. Each generation produces different variations of the same digit
|
147 |
+
""")
|
148 |
+
|
149 |
+
# Connect button to generation function
|
150 |
+
generate_btn.click(
|
151 |
+
fn=generate_digit_images,
|
152 |
+
inputs=[digit_input],
|
153 |
+
outputs=[image_gallery]
|
154 |
+
)
|
155 |
+
|
156 |
+
# Auto-generate on page load
|
157 |
+
app.load(
|
158 |
+
fn=generate_digit_images,
|
159 |
+
inputs=[gr.Number(value=2, visible=False)],
|
160 |
+
outputs=[image_gallery]
|
161 |
+
)
|
162 |
+
|
163 |
+
# Footer
|
164 |
+
gr.Markdown("---")
|
165 |
+
gr.Markdown("**🤖 Model**: Conditional GAN trained on MNIST | **⚡ Framework**: PyTorch + Gradio")
|
166 |
+
|
167 |
+
return app
|
168 |
+
|
169 |
+
# Launch the app
|
170 |
+
if __name__ == "__main__":
|
171 |
+
app = create_app()
|
172 |
+
app.launch(
|
173 |
+
server_name="0.0.0.0",
|
174 |
+
server_port=7860,
|
175 |
+
share=False # Set to False for Hugging Face deployment
|
176 |
+
)
|