elapt1c commited on
Commit
254b385
·
verified ·
1 Parent(s): 8bcfdd1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +209 -0
app.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torch.optim as optim
7
+ from torchvision import transforms
8
+ from PIL import Image, ImageTk, ImageFilter
9
+ import numpy as np
10
+ import gradio as gr
11
+ from huggingface_hub import hf_hub_download
12
+
13
+
14
+ # --- Hyperparameters ---
15
+ image_size = 64
16
+ latent_dim = 128
17
+ model_repo_id = "elapt1c/catGen"
18
+ model_filename = "model.pth"
19
+ #model_path = 'model.pth' # Relative path within the space. Assumed it will be in the root
20
+ generated_images_folder = 'generated_images'
21
+
22
+
23
+ # --- VAE Model --- (Simplified VAE - MATCHING TRAINING CODE)
24
+ class VAE(nn.Module):
25
+ def __init__(self, latent_dim):
26
+ super(VAE, self).__init__()
27
+
28
+ # Encoder - MATCHING TRAINING CODE ARCHITECTURE
29
+ self.encoder_conv = nn.Sequential(
30
+ nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1), # Increased initial channels
31
+ nn.LeakyReLU(0.2, inplace=True),
32
+ nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
33
+ nn.LeakyReLU(0.2, inplace=True),
34
+ nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
35
+ nn.LeakyReLU(0.2, inplace=True),
36
+ nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
37
+ nn.LeakyReLU(0.2, inplace=True),
38
+ nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1), # Increased final channels
39
+ nn.LeakyReLU(0.2, inplace=True),
40
+ )
41
+ self.encoder_fc_mu = nn.Linear(512 * 2 * 2, latent_dim)
42
+ self.encoder_fc_logvar = nn.Linear(512 * 2 * 2, latent_dim)
43
+
44
+ # Decoder - MATCHING TRAINING CODE ARCHITECTURE
45
+ self.decoder_fc = nn.Linear(latent_dim, 512 * 2 * 2)
46
+ self.decoder_conv = nn.Sequential(
47
+ nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
48
+ nn.LeakyReLU(0.2, inplace=True),
49
+ nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
50
+ nn.LeakyReLU(0.2, inplace=True),
51
+ nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
52
+ nn.LeakyReLU(0.2, inplace=True),
53
+ nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
54
+ nn.LeakyReLU(0.2, inplace=True),
55
+ nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),
56
+ nn.Sigmoid()
57
+ )
58
+
59
+ def encode(self, x):
60
+ h = self.encoder_conv(x)
61
+ h = h.view(h.size(0), -1)
62
+ mu = self.encoder_fc_mu(h)
63
+ logvar = self.encoder_fc_logvar(h)
64
+ return mu, logvar
65
+
66
+ def decode(self, z):
67
+ z = self.decoder_fc(z)
68
+ z = z.view(z.size(0), 512, 2, 2) # Corrected view shape to 512 channels
69
+ reconstructed_image = self.decoder_conv(z)
70
+ return reconstructed_image
71
+
72
+ def reparameterize(self, mu, logvar):
73
+ std = torch.exp(0.5 * logvar)
74
+ eps = torch.randn_like(std)
75
+ return mu + eps * std
76
+
77
+ def forward(self, x):
78
+ mu, logvar = self.encode(x)
79
+ z = self.reparameterize(mu, logvar)
80
+ reconstructed_image = self.decode(z)
81
+ return reconstructed_image, mu, logvar
82
+
83
+
84
+ # --- Helper Functions ---
85
+ def load_model(device, repo_id, filename):
86
+ try:
87
+ model_path = hf_hub_download(repo_id=repo_id, filename=filename)
88
+ except Exception as e:
89
+ print(f"Error downloading model from Hugging Face Hub: {e}")
90
+ return None
91
+
92
+ vae_model = VAE(latent_dim=latent_dim).to(device) # Plain VAE model
93
+
94
+ try:
95
+ checkpoint = torch.load(model_path, map_location=device) # Load checkpoint dict
96
+ except FileNotFoundError:
97
+ print(f"Error: Model file not found at {model_path}. This should not happen after downloading.")
98
+ return None
99
+
100
+ new_state_dict = {} # Create a new dictionary for modified keys
101
+ for key, value in checkpoint.items():
102
+ new_key = key.replace('_orig_mod.', '') # Remove "_orig_mod." prefix
103
+ new_state_dict[new_key] = value # Add to new dict with modified key
104
+
105
+ vae_model.load_state_dict(new_state_dict) # Load state_dict with modified keys
106
+ print(f"====> Loaded existing model from {model_path} (handling Torch Compile state_dict)")
107
+ return vae_model
108
+
109
+
110
+ def preprocess_image(image):
111
+ try:
112
+ transform = transforms.Compose([
113
+ transforms.Resize((image_size, image_size)),
114
+ transforms.ToTensor(),
115
+ ])
116
+ image = transform(image).unsqueeze(0)
117
+ return image
118
+ except Exception as e:
119
+ print(f"Failed to preprocess image: {e}")
120
+ return None
121
+
122
+
123
+ def generate_single_image(model, device):
124
+ try:
125
+ model.eval()
126
+ with torch.no_grad():
127
+ sample_z = torch.randn(1, latent_dim).to(device)
128
+ generated_image = model.decode(sample_z) # Use simple VAE decode
129
+ img = generated_image.cpu().detach().numpy()
130
+ output = (img[0] * 255).transpose(1, 2, 0).astype(np.uint8)
131
+ image = Image.fromarray(output) # save from random image
132
+ return image # use the image
133
+ except Exception as e:
134
+ print(f"Image generation failed: {e}")
135
+ return None
136
+
137
+
138
+ def generate_from_base_image(model, device, base_image, noise_scale=0.1):
139
+ try:
140
+ model.eval()
141
+ with torch.no_grad():
142
+ processed_image = preprocess_image(base_image) # Process base image
143
+ if processed_image is None:
144
+ return None
145
+
146
+ processed_image = processed_image.to(device) # to device
147
+ mu, logvar = model.encode(processed_image) # encode
148
+ latent_vector = model.reparameterize(mu, logvar) # reparameterize
149
+
150
+ noise = torch.randn_like(latent_vector) * noise_scale # add noise
151
+ latent_vector = latent_vector + noise # combine
152
+
153
+ generated_image = model.decode(latent_vector) # Use simple VAE decode
154
+ img = generated_image.cpu().detach().numpy()
155
+ output = (img[0] * 255).transpose(1, 2, 0).astype(np.uint8)
156
+ output_image = Image.fromarray(output) # save from
157
+ return output_image
158
+
159
+ except Exception as e:
160
+ print(f"Seed image generation failed: {e}")
161
+ return None
162
+
163
+
164
+
165
+ # --- Gradio Interface ---
166
+ def main():
167
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
168
+ vae_model = load_model(device, model_repo_id, model_filename)
169
+ if vae_model is None:
170
+ return # Exit if model loading fails
171
+
172
+ def generate_single():
173
+ img = generate_single_image(vae_model, device)
174
+ if img:
175
+ return img
176
+ else:
177
+ return "Image generation failed. Check console for errors."
178
+
179
+ def generate_from_seed(seed_image):
180
+ if seed_image is None:
181
+ return "Please upload a seed image."
182
+
183
+ img = generate_from_base_image(vae_model, device, seed_image)
184
+ if img:
185
+ return img
186
+ else:
187
+ return "Image generation from seed failed. Check console for errors."
188
+
189
+
190
+ with gr.Blocks() as demo:
191
+ gr.Markdown("# VAE Image Generator")
192
+
193
+ with gr.Tab("Generate Single Image"):
194
+ single_button = gr.Button("Generate Random Image")
195
+ single_output = gr.Image()
196
+ single_button.click(generate_single, inputs=[], outputs=single_output)
197
+
198
+ with gr.Tab("Generate from Seed"):
199
+ seed_input = gr.Image(label="Seed Image")
200
+ seed_button = gr.Button("Generate from Seed")
201
+ seed_output = gr.Image()
202
+ seed_button.click(generate_from_seed, inputs=seed_input, outputs=seed_output)
203
+
204
+
205
+ demo.launch()
206
+
207
+
208
+ if __name__ == "__main__":
209
+ main()