yashbyname's picture
Update app.py
676fa4e verified
import gradio as gr
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import os
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Model parameters
nz = 100
ngf = 64
num_classes = 10
# Generator class
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.label_embedding = nn.Embedding(num_classes, num_classes)
self.main = nn.Sequential(
nn.ConvTranspose2d(nz + num_classes, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
nn.ConvTranspose2d(ngf * 2, 1, 4, 2, 1, bias=False),
nn.Tanh()
)
self.resize = nn.AdaptiveAvgPool2d((28, 28))
def forward(self, noise, labels):
label_embedding = self.label_embedding(labels)
label_embedding = label_embedding.view(label_embedding.size(0), num_classes, 1, 1)
input_tensor = torch.cat([noise, label_embedding], dim=1)
output = self.main(input_tensor)
output = self.resize(output)
return output
# Load model function (NO @st.cache_resource decorator!)
def load_model():
generator = Generator().to(device)
try:
checkpoint = torch.load('mnist_gan_model.pth', map_location=device)
generator.load_state_dict(checkpoint['generator_state_dict'])
generator.eval()
print("βœ… Model loaded successfully!")
return generator
except Exception as e:
print(f"❌ Error loading model: {e}")
return None
# Initialize generator
generator = load_model()
# Generation function
def generate_digit_images(digit):
"""Generate 5 images of the specified digit"""
if generator is None:
return [Image.new('L', (112, 112), 128)] * 5
digit = int(digit)
num_images = 5
with torch.no_grad():
noise = torch.randn(num_images, nz, 1, 1).to(device)
labels = torch.full((num_images,), digit, dtype=torch.long).to(device)
generated_images = generator(noise, labels)
images = generated_images.cpu().numpy()
images = (images + 1) / 2.0
images = np.squeeze(images)
pil_images = []
for img in images:
img_uint8 = (img * 255).astype(np.uint8)
pil_img = Image.fromarray(img_uint8, mode='L')
pil_img = pil_img.resize((112, 112), Image.NEAREST)
pil_images.append(pil_img)
return pil_images
# Gradio interface
def create_app():
with gr.Blocks(title="Handwritten Digit Generator", theme=gr.themes.Soft()) as app:
gr.Markdown("# πŸ”’ Handwritten Digit Generator")
gr.Markdown("Generate synthetic MNIST-like digit images using a trained GAN model.")
with gr.Row():
with gr.Column(scale=1):
digit_input = gr.Dropdown(
choices=list(range(10)),
value=8,
label="Choose a digit (0-9)"
)
generate_btn = gr.Button("🎨 Generate Images", variant="primary")
with gr.Column(scale=2):
gr.Markdown("### Generated Images")
image_gallery = gr.Gallery(
label="Generated Digit Images",
show_label=False,
columns=5,
rows=1,
height=200
)
generate_btn.click(
fn=generate_digit_images,
inputs=[digit_input],
outputs=[image_gallery]
)
# Auto-generate on load
app.load(
fn=lambda: generate_digit_images(8),
outputs=[image_gallery]
)
gr.Markdown("---")
gr.Markdown("**πŸ€– Model**: Conditional GAN | **⚑ Framework**: PyTorch + Gradio")
return app
# Launch
if __name__ == "__main__":
app = create_app()
app.launch()