Update app.py
Browse files
app.py
CHANGED
@@ -8,12 +8,12 @@ import os
|
|
8 |
# Set device
|
9 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
10 |
|
11 |
-
# Model parameters
|
12 |
nz = 100
|
13 |
ngf = 64
|
14 |
num_classes = 10
|
15 |
|
16 |
-
# Generator class
|
17 |
class Generator(nn.Module):
|
18 |
def __init__(self):
|
19 |
super(Generator, self).__init__()
|
@@ -47,21 +47,19 @@ class Generator(nn.Module):
|
|
47 |
output = self.resize(output)
|
48 |
return output
|
49 |
|
50 |
-
# Load
|
51 |
-
@st.cache_resource
|
52 |
def load_model():
|
53 |
generator = Generator().to(device)
|
54 |
|
55 |
-
|
56 |
-
if os.path.exists('mnist_gan_model.pth'):
|
57 |
checkpoint = torch.load('mnist_gan_model.pth', map_location=device)
|
58 |
generator.load_state_dict(checkpoint['generator_state_dict'])
|
59 |
generator.eval()
|
60 |
-
print("Model loaded successfully!")
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
|
66 |
# Initialize generator
|
67 |
generator = load_model()
|
@@ -70,107 +68,75 @@ generator = load_model()
|
|
70 |
def generate_digit_images(digit):
|
71 |
"""Generate 5 images of the specified digit"""
|
72 |
|
|
|
|
|
|
|
73 |
digit = int(digit)
|
74 |
num_images = 5
|
75 |
|
76 |
with torch.no_grad():
|
77 |
-
# Generate random noise
|
78 |
noise = torch.randn(num_images, nz, 1, 1).to(device)
|
79 |
labels = torch.full((num_images,), digit, dtype=torch.long).to(device)
|
80 |
|
81 |
-
# Generate images
|
82 |
generated_images = generator(noise, labels)
|
83 |
|
84 |
-
# Convert to numpy and denormalize
|
85 |
images = generated_images.cpu().numpy()
|
86 |
-
images = (images + 1) / 2.0
|
87 |
-
images = np.squeeze(images)
|
88 |
|
89 |
-
# Convert to PIL Images for Gradio
|
90 |
pil_images = []
|
91 |
for img in images:
|
92 |
-
# Convert to 0-255 range and uint8
|
93 |
img_uint8 = (img * 255).astype(np.uint8)
|
94 |
pil_img = Image.fromarray(img_uint8, mode='L')
|
95 |
-
|
96 |
-
pil_img = pil_img.resize((112, 112), Image.NEAREST) # 4x upscale
|
97 |
pil_images.append(pil_img)
|
98 |
|
99 |
return pil_images
|
100 |
|
101 |
-
#
|
102 |
def create_app():
|
103 |
-
with gr.Blocks(
|
104 |
-
title="Handwritten Digit Generator",
|
105 |
-
theme=gr.themes.Soft(),
|
106 |
-
css=".gradio-container {max-width: 700px; margin: auto;}"
|
107 |
-
) as app:
|
108 |
|
109 |
gr.Markdown("# π’ Handwritten Digit Generator")
|
110 |
-
gr.Markdown("Generate synthetic MNIST-like digit images using a trained GAN model.
|
111 |
|
112 |
with gr.Row():
|
113 |
with gr.Column(scale=1):
|
114 |
digit_input = gr.Dropdown(
|
115 |
choices=list(range(10)),
|
116 |
value=2,
|
117 |
-
label="Choose a digit
|
118 |
-
interactive=True
|
119 |
-
)
|
120 |
-
generate_btn = gr.Button(
|
121 |
-
"π¨ Generate Images",
|
122 |
-
variant="primary",
|
123 |
-
size="lg"
|
124 |
)
|
|
|
125 |
|
126 |
with gr.Column(scale=2):
|
127 |
gr.Markdown("### Generated Images")
|
128 |
-
|
129 |
-
# Gallery to display 5 images
|
130 |
image_gallery = gr.Gallery(
|
131 |
label="Generated Digit Images",
|
132 |
show_label=False,
|
133 |
columns=5,
|
134 |
rows=1,
|
135 |
-
height=200
|
136 |
-
object_fit="contain"
|
137 |
)
|
138 |
|
139 |
-
# Example section
|
140 |
-
gr.Markdown("---")
|
141 |
-
gr.Markdown("### How it works")
|
142 |
-
gr.Markdown("""
|
143 |
-
1. **Select** a digit from the dropdown (0-9)
|
144 |
-
2. **Click** 'Generate Images' button
|
145 |
-
3. **View** 5 unique generated images of your chosen digit
|
146 |
-
4. Each generation produces different variations of the same digit
|
147 |
-
""")
|
148 |
-
|
149 |
-
# Connect button to generation function
|
150 |
generate_btn.click(
|
151 |
fn=generate_digit_images,
|
152 |
inputs=[digit_input],
|
153 |
outputs=[image_gallery]
|
154 |
)
|
155 |
|
156 |
-
# Auto-generate on
|
157 |
app.load(
|
158 |
-
fn=generate_digit_images,
|
159 |
-
inputs=[gr.Number(value=2, visible=False)],
|
160 |
outputs=[image_gallery]
|
161 |
)
|
162 |
|
163 |
-
# Footer
|
164 |
gr.Markdown("---")
|
165 |
-
gr.Markdown("**π€ Model**: Conditional GAN
|
166 |
|
167 |
return app
|
168 |
|
169 |
-
# Launch
|
170 |
if __name__ == "__main__":
|
171 |
app = create_app()
|
172 |
-
app.launch(
|
173 |
-
server_name="0.0.0.0",
|
174 |
-
server_port=7860,
|
175 |
-
share=False # Set to False for Hugging Face deployment
|
176 |
-
)
|
|
|
8 |
# Set device
|
9 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
10 |
|
11 |
+
# Model parameters
|
12 |
nz = 100
|
13 |
ngf = 64
|
14 |
num_classes = 10
|
15 |
|
16 |
+
# Generator class
|
17 |
class Generator(nn.Module):
|
18 |
def __init__(self):
|
19 |
super(Generator, self).__init__()
|
|
|
47 |
output = self.resize(output)
|
48 |
return output
|
49 |
|
50 |
+
# Load model function (NO @st.cache_resource decorator!)
|
|
|
51 |
def load_model():
|
52 |
generator = Generator().to(device)
|
53 |
|
54 |
+
try:
|
|
|
55 |
checkpoint = torch.load('mnist_gan_model.pth', map_location=device)
|
56 |
generator.load_state_dict(checkpoint['generator_state_dict'])
|
57 |
generator.eval()
|
58 |
+
print("β
Model loaded successfully!")
|
59 |
+
return generator
|
60 |
+
except Exception as e:
|
61 |
+
print(f"β Error loading model: {e}")
|
62 |
+
return None
|
63 |
|
64 |
# Initialize generator
|
65 |
generator = load_model()
|
|
|
68 |
def generate_digit_images(digit):
|
69 |
"""Generate 5 images of the specified digit"""
|
70 |
|
71 |
+
if generator is None:
|
72 |
+
return [Image.new('L', (112, 112), 128)] * 5
|
73 |
+
|
74 |
digit = int(digit)
|
75 |
num_images = 5
|
76 |
|
77 |
with torch.no_grad():
|
|
|
78 |
noise = torch.randn(num_images, nz, 1, 1).to(device)
|
79 |
labels = torch.full((num_images,), digit, dtype=torch.long).to(device)
|
80 |
|
|
|
81 |
generated_images = generator(noise, labels)
|
82 |
|
|
|
83 |
images = generated_images.cpu().numpy()
|
84 |
+
images = (images + 1) / 2.0
|
85 |
+
images = np.squeeze(images)
|
86 |
|
|
|
87 |
pil_images = []
|
88 |
for img in images:
|
|
|
89 |
img_uint8 = (img * 255).astype(np.uint8)
|
90 |
pil_img = Image.fromarray(img_uint8, mode='L')
|
91 |
+
pil_img = pil_img.resize((112, 112), Image.NEAREST)
|
|
|
92 |
pil_images.append(pil_img)
|
93 |
|
94 |
return pil_images
|
95 |
|
96 |
+
# Gradio interface
|
97 |
def create_app():
|
98 |
+
with gr.Blocks(title="Handwritten Digit Generator", theme=gr.themes.Soft()) as app:
|
|
|
|
|
|
|
|
|
99 |
|
100 |
gr.Markdown("# π’ Handwritten Digit Generator")
|
101 |
+
gr.Markdown("Generate synthetic MNIST-like digit images using a trained GAN model.")
|
102 |
|
103 |
with gr.Row():
|
104 |
with gr.Column(scale=1):
|
105 |
digit_input = gr.Dropdown(
|
106 |
choices=list(range(10)),
|
107 |
value=2,
|
108 |
+
label="Choose a digit (0-9)"
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
)
|
110 |
+
generate_btn = gr.Button("π¨ Generate Images", variant="primary")
|
111 |
|
112 |
with gr.Column(scale=2):
|
113 |
gr.Markdown("### Generated Images")
|
|
|
|
|
114 |
image_gallery = gr.Gallery(
|
115 |
label="Generated Digit Images",
|
116 |
show_label=False,
|
117 |
columns=5,
|
118 |
rows=1,
|
119 |
+
height=200
|
|
|
120 |
)
|
121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
generate_btn.click(
|
123 |
fn=generate_digit_images,
|
124 |
inputs=[digit_input],
|
125 |
outputs=[image_gallery]
|
126 |
)
|
127 |
|
128 |
+
# Auto-generate on load
|
129 |
app.load(
|
130 |
+
fn=lambda: generate_digit_images(2),
|
|
|
131 |
outputs=[image_gallery]
|
132 |
)
|
133 |
|
|
|
134 |
gr.Markdown("---")
|
135 |
+
gr.Markdown("**π€ Model**: Conditional GAN | **β‘ Framework**: PyTorch + Gradio")
|
136 |
|
137 |
return app
|
138 |
|
139 |
+
# Launch
|
140 |
if __name__ == "__main__":
|
141 |
app = create_app()
|
142 |
+
app.launch()
|
|
|
|
|
|
|
|