|
import gradio as gr |
|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
from PIL import Image |
|
import os |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
nz = 100 |
|
ngf = 64 |
|
num_classes = 10 |
|
|
|
|
|
class Generator(nn.Module): |
|
def __init__(self): |
|
super(Generator, self).__init__() |
|
|
|
self.label_embedding = nn.Embedding(num_classes, num_classes) |
|
|
|
self.main = nn.Sequential( |
|
nn.ConvTranspose2d(nz + num_classes, ngf * 8, 4, 1, 0, bias=False), |
|
nn.BatchNorm2d(ngf * 8), |
|
nn.ReLU(True), |
|
|
|
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), |
|
nn.BatchNorm2d(ngf * 4), |
|
nn.ReLU(True), |
|
|
|
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), |
|
nn.BatchNorm2d(ngf * 2), |
|
nn.ReLU(True), |
|
|
|
nn.ConvTranspose2d(ngf * 2, 1, 4, 2, 1, bias=False), |
|
nn.Tanh() |
|
) |
|
|
|
self.resize = nn.AdaptiveAvgPool2d((28, 28)) |
|
|
|
def forward(self, noise, labels): |
|
label_embedding = self.label_embedding(labels) |
|
label_embedding = label_embedding.view(label_embedding.size(0), num_classes, 1, 1) |
|
input_tensor = torch.cat([noise, label_embedding], dim=1) |
|
output = self.main(input_tensor) |
|
output = self.resize(output) |
|
return output |
|
|
|
|
|
def load_model(): |
|
generator = Generator().to(device) |
|
|
|
try: |
|
checkpoint = torch.load('mnist_gan_model.pth', map_location=device) |
|
generator.load_state_dict(checkpoint['generator_state_dict']) |
|
generator.eval() |
|
print("β
Model loaded successfully!") |
|
return generator |
|
except Exception as e: |
|
print(f"β Error loading model: {e}") |
|
return None |
|
|
|
|
|
generator = load_model() |
|
|
|
|
|
def generate_digit_images(digit): |
|
"""Generate 5 images of the specified digit""" |
|
|
|
if generator is None: |
|
return [Image.new('L', (112, 112), 128)] * 5 |
|
|
|
digit = int(digit) |
|
num_images = 5 |
|
|
|
with torch.no_grad(): |
|
noise = torch.randn(num_images, nz, 1, 1).to(device) |
|
labels = torch.full((num_images,), digit, dtype=torch.long).to(device) |
|
|
|
generated_images = generator(noise, labels) |
|
|
|
images = generated_images.cpu().numpy() |
|
images = (images + 1) / 2.0 |
|
images = np.squeeze(images) |
|
|
|
pil_images = [] |
|
for img in images: |
|
img_uint8 = (img * 255).astype(np.uint8) |
|
pil_img = Image.fromarray(img_uint8, mode='L') |
|
pil_img = pil_img.resize((112, 112), Image.NEAREST) |
|
pil_images.append(pil_img) |
|
|
|
return pil_images |
|
|
|
|
|
def create_app(): |
|
with gr.Blocks(title="Handwritten Digit Generator", theme=gr.themes.Soft()) as app: |
|
|
|
gr.Markdown("# π’ Handwritten Digit Generator") |
|
gr.Markdown("Generate synthetic MNIST-like digit images using a trained GAN model.") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
digit_input = gr.Dropdown( |
|
choices=list(range(10)), |
|
value=8, |
|
label="Choose a digit (0-9)" |
|
) |
|
generate_btn = gr.Button("π¨ Generate Images", variant="primary") |
|
|
|
with gr.Column(scale=2): |
|
gr.Markdown("### Generated Images") |
|
image_gallery = gr.Gallery( |
|
label="Generated Digit Images", |
|
show_label=False, |
|
columns=5, |
|
rows=1, |
|
height=200 |
|
) |
|
|
|
generate_btn.click( |
|
fn=generate_digit_images, |
|
inputs=[digit_input], |
|
outputs=[image_gallery] |
|
) |
|
|
|
|
|
app.load( |
|
fn=lambda: generate_digit_images(8), |
|
outputs=[image_gallery] |
|
) |
|
|
|
gr.Markdown("---") |
|
gr.Markdown("**π€ Model**: Conditional GAN | **β‘ Framework**: PyTorch + Gradio") |
|
|
|
return app |
|
|
|
|
|
if __name__ == "__main__": |
|
app = create_app() |
|
app.launch() |