File size: 7,569 Bytes
254b385 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 |
import os
import io
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms
from PIL import Image, ImageTk, ImageFilter
import numpy as np
import gradio as gr
from huggingface_hub import hf_hub_download
# --- Hyperparameters ---
image_size = 64
latent_dim = 128
model_repo_id = "elapt1c/catGen"
model_filename = "model.pth"
#model_path = 'model.pth' # Relative path within the space. Assumed it will be in the root
generated_images_folder = 'generated_images'
# --- VAE Model --- (Simplified VAE - MATCHING TRAINING CODE)
class VAE(nn.Module):
def __init__(self, latent_dim):
super(VAE, self).__init__()
# Encoder - MATCHING TRAINING CODE ARCHITECTURE
self.encoder_conv = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1), # Increased initial channels
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1), # Increased final channels
nn.LeakyReLU(0.2, inplace=True),
)
self.encoder_fc_mu = nn.Linear(512 * 2 * 2, latent_dim)
self.encoder_fc_logvar = nn.Linear(512 * 2 * 2, latent_dim)
# Decoder - MATCHING TRAINING CODE ARCHITECTURE
self.decoder_fc = nn.Linear(latent_dim, 512 * 2 * 2)
self.decoder_conv = nn.Sequential(
nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),
nn.Sigmoid()
)
def encode(self, x):
h = self.encoder_conv(x)
h = h.view(h.size(0), -1)
mu = self.encoder_fc_mu(h)
logvar = self.encoder_fc_logvar(h)
return mu, logvar
def decode(self, z):
z = self.decoder_fc(z)
z = z.view(z.size(0), 512, 2, 2) # Corrected view shape to 512 channels
reconstructed_image = self.decoder_conv(z)
return reconstructed_image
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
reconstructed_image = self.decode(z)
return reconstructed_image, mu, logvar
# --- Helper Functions ---
def load_model(device, repo_id, filename):
try:
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
except Exception as e:
print(f"Error downloading model from Hugging Face Hub: {e}")
return None
vae_model = VAE(latent_dim=latent_dim).to(device) # Plain VAE model
try:
checkpoint = torch.load(model_path, map_location=device) # Load checkpoint dict
except FileNotFoundError:
print(f"Error: Model file not found at {model_path}. This should not happen after downloading.")
return None
new_state_dict = {} # Create a new dictionary for modified keys
for key, value in checkpoint.items():
new_key = key.replace('_orig_mod.', '') # Remove "_orig_mod." prefix
new_state_dict[new_key] = value # Add to new dict with modified key
vae_model.load_state_dict(new_state_dict) # Load state_dict with modified keys
print(f"====> Loaded existing model from {model_path} (handling Torch Compile state_dict)")
return vae_model
def preprocess_image(image):
try:
transform = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
])
image = transform(image).unsqueeze(0)
return image
except Exception as e:
print(f"Failed to preprocess image: {e}")
return None
def generate_single_image(model, device):
try:
model.eval()
with torch.no_grad():
sample_z = torch.randn(1, latent_dim).to(device)
generated_image = model.decode(sample_z) # Use simple VAE decode
img = generated_image.cpu().detach().numpy()
output = (img[0] * 255).transpose(1, 2, 0).astype(np.uint8)
image = Image.fromarray(output) # save from random image
return image # use the image
except Exception as e:
print(f"Image generation failed: {e}")
return None
def generate_from_base_image(model, device, base_image, noise_scale=0.1):
try:
model.eval()
with torch.no_grad():
processed_image = preprocess_image(base_image) # Process base image
if processed_image is None:
return None
processed_image = processed_image.to(device) # to device
mu, logvar = model.encode(processed_image) # encode
latent_vector = model.reparameterize(mu, logvar) # reparameterize
noise = torch.randn_like(latent_vector) * noise_scale # add noise
latent_vector = latent_vector + noise # combine
generated_image = model.decode(latent_vector) # Use simple VAE decode
img = generated_image.cpu().detach().numpy()
output = (img[0] * 255).transpose(1, 2, 0).astype(np.uint8)
output_image = Image.fromarray(output) # save from
return output_image
except Exception as e:
print(f"Seed image generation failed: {e}")
return None
# --- Gradio Interface ---
def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vae_model = load_model(device, model_repo_id, model_filename)
if vae_model is None:
return # Exit if model loading fails
def generate_single():
img = generate_single_image(vae_model, device)
if img:
return img
else:
return "Image generation failed. Check console for errors."
def generate_from_seed(seed_image):
if seed_image is None:
return "Please upload a seed image."
img = generate_from_base_image(vae_model, device, seed_image)
if img:
return img
else:
return "Image generation from seed failed. Check console for errors."
with gr.Blocks() as demo:
gr.Markdown("# VAE Image Generator")
with gr.Tab("Generate Single Image"):
single_button = gr.Button("Generate Random Image")
single_output = gr.Image()
single_button.click(generate_single, inputs=[], outputs=single_output)
with gr.Tab("Generate from Seed"):
seed_input = gr.Image(label="Seed Image")
seed_button = gr.Button("Generate from Seed")
seed_output = gr.Image()
seed_button.click(generate_from_seed, inputs=seed_input, outputs=seed_output)
demo.launch()
if __name__ == "__main__":
main() |