import gradio as gr import torch import torch.nn as nn import numpy as np from PIL import Image import os # Set device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Model parameters nz = 100 ngf = 64 num_classes = 10 # Generator class class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() self.label_embedding = nn.Embedding(num_classes, num_classes) self.main = nn.Sequential( nn.ConvTranspose2d(nz + num_classes, ngf * 8, 4, 1, 0, bias=False), nn.BatchNorm2d(ngf * 8), nn.ReLU(True), nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), nn.BatchNorm2d(ngf * 4), nn.ReLU(True), nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), nn.BatchNorm2d(ngf * 2), nn.ReLU(True), nn.ConvTranspose2d(ngf * 2, 1, 4, 2, 1, bias=False), nn.Tanh() ) self.resize = nn.AdaptiveAvgPool2d((28, 28)) def forward(self, noise, labels): label_embedding = self.label_embedding(labels) label_embedding = label_embedding.view(label_embedding.size(0), num_classes, 1, 1) input_tensor = torch.cat([noise, label_embedding], dim=1) output = self.main(input_tensor) output = self.resize(output) return output # Load model function (NO @st.cache_resource decorator!) def load_model(): generator = Generator().to(device) try: checkpoint = torch.load('mnist_gan_model.pth', map_location=device) generator.load_state_dict(checkpoint['generator_state_dict']) generator.eval() print("✅ Model loaded successfully!") return generator except Exception as e: print(f"❌ Error loading model: {e}") return None # Initialize generator generator = load_model() # Generation function def generate_digit_images(digit): """Generate 5 images of the specified digit""" if generator is None: return [Image.new('L', (112, 112), 128)] * 5 digit = int(digit) num_images = 5 with torch.no_grad(): noise = torch.randn(num_images, nz, 1, 1).to(device) labels = torch.full((num_images,), digit, dtype=torch.long).to(device) generated_images = generator(noise, labels) images = generated_images.cpu().numpy() images = (images + 1) / 2.0 images = np.squeeze(images) pil_images = [] for img in images: img_uint8 = (img * 255).astype(np.uint8) pil_img = Image.fromarray(img_uint8, mode='L') pil_img = pil_img.resize((112, 112), Image.NEAREST) pil_images.append(pil_img) return pil_images # Gradio interface def create_app(): with gr.Blocks(title="Handwritten Digit Generator", theme=gr.themes.Soft()) as app: gr.Markdown("# 🔢 Handwritten Digit Generator") gr.Markdown("Generate synthetic MNIST-like digit images using a trained GAN model.") with gr.Row(): with gr.Column(scale=1): digit_input = gr.Dropdown( choices=list(range(10)), value=8, label="Choose a digit (0-9)" ) generate_btn = gr.Button("🎨 Generate Images", variant="primary") with gr.Column(scale=2): gr.Markdown("### Generated Images") image_gallery = gr.Gallery( label="Generated Digit Images", show_label=False, columns=5, rows=1, height=200 ) generate_btn.click( fn=generate_digit_images, inputs=[digit_input], outputs=[image_gallery] ) # Auto-generate on load app.load( fn=lambda: generate_digit_images(8), outputs=[image_gallery] ) gr.Markdown("---") gr.Markdown("**🤖 Model**: Conditional GAN | **⚡ Framework**: PyTorch + Gradio") return app # Launch if __name__ == "__main__": app = create_app() app.launch()