yashbyname commited on
Commit
a95533b
Β·
verified Β·
1 Parent(s): de91025

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -60
app.py CHANGED
@@ -8,12 +8,12 @@ import os
8
  # Set device
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
 
11
- # Model parameters (must match your training)
12
  nz = 100
13
  ngf = 64
14
  num_classes = 10
15
 
16
- # Generator class (same as your training script)
17
  class Generator(nn.Module):
18
  def __init__(self):
19
  super(Generator, self).__init__()
@@ -47,21 +47,19 @@ class Generator(nn.Module):
47
  output = self.resize(output)
48
  return output
49
 
50
- # Load the trained model
51
- @st.cache_resource
52
  def load_model():
53
  generator = Generator().to(device)
54
 
55
- # Load the saved model
56
- if os.path.exists('mnist_gan_model.pth'):
57
  checkpoint = torch.load('mnist_gan_model.pth', map_location=device)
58
  generator.load_state_dict(checkpoint['generator_state_dict'])
59
  generator.eval()
60
- print("Model loaded successfully!")
61
- else:
62
- print("Warning: Model file not found!")
63
-
64
- return generator
65
 
66
  # Initialize generator
67
  generator = load_model()
@@ -70,107 +68,75 @@ generator = load_model()
70
  def generate_digit_images(digit):
71
  """Generate 5 images of the specified digit"""
72
 
 
 
 
73
  digit = int(digit)
74
  num_images = 5
75
 
76
  with torch.no_grad():
77
- # Generate random noise
78
  noise = torch.randn(num_images, nz, 1, 1).to(device)
79
  labels = torch.full((num_images,), digit, dtype=torch.long).to(device)
80
 
81
- # Generate images
82
  generated_images = generator(noise, labels)
83
 
84
- # Convert to numpy and denormalize
85
  images = generated_images.cpu().numpy()
86
- images = (images + 1) / 2.0 # Denormalize from [-1, 1] to [0, 1]
87
- images = np.squeeze(images) # Remove channel dimension
88
 
89
- # Convert to PIL Images for Gradio
90
  pil_images = []
91
  for img in images:
92
- # Convert to 0-255 range and uint8
93
  img_uint8 = (img * 255).astype(np.uint8)
94
  pil_img = Image.fromarray(img_uint8, mode='L')
95
- # Resize for better visibility
96
- pil_img = pil_img.resize((112, 112), Image.NEAREST) # 4x upscale
97
  pil_images.append(pil_img)
98
 
99
  return pil_images
100
 
101
- # Create Gradio interface
102
  def create_app():
103
- with gr.Blocks(
104
- title="Handwritten Digit Generator",
105
- theme=gr.themes.Soft(),
106
- css=".gradio-container {max-width: 700px; margin: auto;}"
107
- ) as app:
108
 
109
  gr.Markdown("# πŸ”’ Handwritten Digit Generator")
110
- gr.Markdown("Generate synthetic MNIST-like digit images using a trained GAN model. Select a digit (0-9) to generate 5 unique images.")
111
 
112
  with gr.Row():
113
  with gr.Column(scale=1):
114
  digit_input = gr.Dropdown(
115
  choices=list(range(10)),
116
  value=2,
117
- label="Choose a digit to generate (0-9)",
118
- interactive=True
119
- )
120
- generate_btn = gr.Button(
121
- "🎨 Generate Images",
122
- variant="primary",
123
- size="lg"
124
  )
 
125
 
126
  with gr.Column(scale=2):
127
  gr.Markdown("### Generated Images")
128
-
129
- # Gallery to display 5 images
130
  image_gallery = gr.Gallery(
131
  label="Generated Digit Images",
132
  show_label=False,
133
  columns=5,
134
  rows=1,
135
- height=200,
136
- object_fit="contain"
137
  )
138
 
139
- # Example section
140
- gr.Markdown("---")
141
- gr.Markdown("### How it works")
142
- gr.Markdown("""
143
- 1. **Select** a digit from the dropdown (0-9)
144
- 2. **Click** 'Generate Images' button
145
- 3. **View** 5 unique generated images of your chosen digit
146
- 4. Each generation produces different variations of the same digit
147
- """)
148
-
149
- # Connect button to generation function
150
  generate_btn.click(
151
  fn=generate_digit_images,
152
  inputs=[digit_input],
153
  outputs=[image_gallery]
154
  )
155
 
156
- # Auto-generate on page load
157
  app.load(
158
- fn=generate_digit_images,
159
- inputs=[gr.Number(value=2, visible=False)],
160
  outputs=[image_gallery]
161
  )
162
 
163
- # Footer
164
  gr.Markdown("---")
165
- gr.Markdown("**πŸ€– Model**: Conditional GAN trained on MNIST | **⚑ Framework**: PyTorch + Gradio")
166
 
167
  return app
168
 
169
- # Launch the app
170
  if __name__ == "__main__":
171
  app = create_app()
172
- app.launch(
173
- server_name="0.0.0.0",
174
- server_port=7860,
175
- share=False # Set to False for Hugging Face deployment
176
- )
 
8
  # Set device
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
 
11
+ # Model parameters
12
  nz = 100
13
  ngf = 64
14
  num_classes = 10
15
 
16
+ # Generator class
17
  class Generator(nn.Module):
18
  def __init__(self):
19
  super(Generator, self).__init__()
 
47
  output = self.resize(output)
48
  return output
49
 
50
+ # Load model function (NO @st.cache_resource decorator!)
 
51
  def load_model():
52
  generator = Generator().to(device)
53
 
54
+ try:
 
55
  checkpoint = torch.load('mnist_gan_model.pth', map_location=device)
56
  generator.load_state_dict(checkpoint['generator_state_dict'])
57
  generator.eval()
58
+ print("βœ… Model loaded successfully!")
59
+ return generator
60
+ except Exception as e:
61
+ print(f"❌ Error loading model: {e}")
62
+ return None
63
 
64
  # Initialize generator
65
  generator = load_model()
 
68
  def generate_digit_images(digit):
69
  """Generate 5 images of the specified digit"""
70
 
71
+ if generator is None:
72
+ return [Image.new('L', (112, 112), 128)] * 5
73
+
74
  digit = int(digit)
75
  num_images = 5
76
 
77
  with torch.no_grad():
 
78
  noise = torch.randn(num_images, nz, 1, 1).to(device)
79
  labels = torch.full((num_images,), digit, dtype=torch.long).to(device)
80
 
 
81
  generated_images = generator(noise, labels)
82
 
 
83
  images = generated_images.cpu().numpy()
84
+ images = (images + 1) / 2.0
85
+ images = np.squeeze(images)
86
 
 
87
  pil_images = []
88
  for img in images:
 
89
  img_uint8 = (img * 255).astype(np.uint8)
90
  pil_img = Image.fromarray(img_uint8, mode='L')
91
+ pil_img = pil_img.resize((112, 112), Image.NEAREST)
 
92
  pil_images.append(pil_img)
93
 
94
  return pil_images
95
 
96
+ # Gradio interface
97
  def create_app():
98
+ with gr.Blocks(title="Handwritten Digit Generator", theme=gr.themes.Soft()) as app:
 
 
 
 
99
 
100
  gr.Markdown("# πŸ”’ Handwritten Digit Generator")
101
+ gr.Markdown("Generate synthetic MNIST-like digit images using a trained GAN model.")
102
 
103
  with gr.Row():
104
  with gr.Column(scale=1):
105
  digit_input = gr.Dropdown(
106
  choices=list(range(10)),
107
  value=2,
108
+ label="Choose a digit (0-9)"
 
 
 
 
 
 
109
  )
110
+ generate_btn = gr.Button("🎨 Generate Images", variant="primary")
111
 
112
  with gr.Column(scale=2):
113
  gr.Markdown("### Generated Images")
 
 
114
  image_gallery = gr.Gallery(
115
  label="Generated Digit Images",
116
  show_label=False,
117
  columns=5,
118
  rows=1,
119
+ height=200
 
120
  )
121
 
 
 
 
 
 
 
 
 
 
 
 
122
  generate_btn.click(
123
  fn=generate_digit_images,
124
  inputs=[digit_input],
125
  outputs=[image_gallery]
126
  )
127
 
128
+ # Auto-generate on load
129
  app.load(
130
+ fn=lambda: generate_digit_images(2),
 
131
  outputs=[image_gallery]
132
  )
133
 
 
134
  gr.Markdown("---")
135
+ gr.Markdown("**πŸ€– Model**: Conditional GAN | **⚑ Framework**: PyTorch + Gradio")
136
 
137
  return app
138
 
139
+ # Launch
140
  if __name__ == "__main__":
141
  app = create_app()
142
+ app.launch()