yashbyname's picture
Create app.py
bb3fdf2 verified
raw
history blame
5.72 kB
import gradio as gr
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import os
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Model parameters (must match your training)
nz = 100
ngf = 64
num_classes = 10
# Generator class (same as your training script)
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.label_embedding = nn.Embedding(num_classes, num_classes)
self.main = nn.Sequential(
nn.ConvTranspose2d(nz + num_classes, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
nn.ConvTranspose2d(ngf * 2, 1, 4, 2, 1, bias=False),
nn.Tanh()
)
self.resize = nn.AdaptiveAvgPool2d((28, 28))
def forward(self, noise, labels):
label_embedding = self.label_embedding(labels)
label_embedding = label_embedding.view(label_embedding.size(0), num_classes, 1, 1)
input_tensor = torch.cat([noise, label_embedding], dim=1)
output = self.main(input_tensor)
output = self.resize(output)
return output
# Load the trained model
@st.cache_resource
def load_model():
generator = Generator().to(device)
# Load the saved model
if os.path.exists('mnist_gan_model.pth'):
checkpoint = torch.load('mnist_gan_model.pth', map_location=device)
generator.load_state_dict(checkpoint['generator_state_dict'])
generator.eval()
print("Model loaded successfully!")
else:
print("Warning: Model file not found!")
return generator
# Initialize generator
generator = load_model()
# Generation function
def generate_digit_images(digit):
"""Generate 5 images of the specified digit"""
digit = int(digit)
num_images = 5
with torch.no_grad():
# Generate random noise
noise = torch.randn(num_images, nz, 1, 1).to(device)
labels = torch.full((num_images,), digit, dtype=torch.long).to(device)
# Generate images
generated_images = generator(noise, labels)
# Convert to numpy and denormalize
images = generated_images.cpu().numpy()
images = (images + 1) / 2.0 # Denormalize from [-1, 1] to [0, 1]
images = np.squeeze(images) # Remove channel dimension
# Convert to PIL Images for Gradio
pil_images = []
for img in images:
# Convert to 0-255 range and uint8
img_uint8 = (img * 255).astype(np.uint8)
pil_img = Image.fromarray(img_uint8, mode='L')
# Resize for better visibility
pil_img = pil_img.resize((112, 112), Image.NEAREST) # 4x upscale
pil_images.append(pil_img)
return pil_images
# Create Gradio interface
def create_app():
with gr.Blocks(
title="Handwritten Digit Generator",
theme=gr.themes.Soft(),
css=".gradio-container {max-width: 700px; margin: auto;}"
) as app:
gr.Markdown("# πŸ”’ Handwritten Digit Generator")
gr.Markdown("Generate synthetic MNIST-like digit images using a trained GAN model. Select a digit (0-9) to generate 5 unique images.")
with gr.Row():
with gr.Column(scale=1):
digit_input = gr.Dropdown(
choices=list(range(10)),
value=2,
label="Choose a digit to generate (0-9)",
interactive=True
)
generate_btn = gr.Button(
"🎨 Generate Images",
variant="primary",
size="lg"
)
with gr.Column(scale=2):
gr.Markdown("### Generated Images")
# Gallery to display 5 images
image_gallery = gr.Gallery(
label="Generated Digit Images",
show_label=False,
columns=5,
rows=1,
height=200,
object_fit="contain"
)
# Example section
gr.Markdown("---")
gr.Markdown("### How it works")
gr.Markdown("""
1. **Select** a digit from the dropdown (0-9)
2. **Click** 'Generate Images' button
3. **View** 5 unique generated images of your chosen digit
4. Each generation produces different variations of the same digit
""")
# Connect button to generation function
generate_btn.click(
fn=generate_digit_images,
inputs=[digit_input],
outputs=[image_gallery]
)
# Auto-generate on page load
app.load(
fn=generate_digit_images,
inputs=[gr.Number(value=2, visible=False)],
outputs=[image_gallery]
)
# Footer
gr.Markdown("---")
gr.Markdown("**πŸ€– Model**: Conditional GAN trained on MNIST | **⚑ Framework**: PyTorch + Gradio")
return app
# Launch the app
if __name__ == "__main__":
app = create_app()
app.launch(
server_name="0.0.0.0",
server_port=7860,
share=False # Set to False for Hugging Face deployment
)