|
import os |
|
import io |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.optim as optim |
|
from torchvision import transforms |
|
from PIL import Image, ImageTk, ImageFilter |
|
import numpy as np |
|
import gradio as gr |
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
image_size = 64 |
|
latent_dim = 128 |
|
model_repo_id = "elapt1c/catGen" |
|
model_filename = "model.pth" |
|
|
|
generated_images_folder = 'generated_images' |
|
|
|
|
|
|
|
class VAE(nn.Module): |
|
def __init__(self, latent_dim): |
|
super(VAE, self).__init__() |
|
|
|
|
|
self.encoder_conv = nn.Sequential( |
|
nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1), |
|
nn.LeakyReLU(0.2, inplace=True), |
|
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1), |
|
nn.LeakyReLU(0.2, inplace=True), |
|
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), |
|
nn.LeakyReLU(0.2, inplace=True), |
|
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), |
|
nn.LeakyReLU(0.2, inplace=True), |
|
nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1), |
|
nn.LeakyReLU(0.2, inplace=True), |
|
) |
|
self.encoder_fc_mu = nn.Linear(512 * 2 * 2, latent_dim) |
|
self.encoder_fc_logvar = nn.Linear(512 * 2 * 2, latent_dim) |
|
|
|
|
|
self.decoder_fc = nn.Linear(latent_dim, 512 * 2 * 2) |
|
self.decoder_conv = nn.Sequential( |
|
nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1), |
|
nn.LeakyReLU(0.2, inplace=True), |
|
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), |
|
nn.LeakyReLU(0.2, inplace=True), |
|
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), |
|
nn.LeakyReLU(0.2, inplace=True), |
|
nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1), |
|
nn.LeakyReLU(0.2, inplace=True), |
|
nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1), |
|
nn.Sigmoid() |
|
) |
|
|
|
def encode(self, x): |
|
h = self.encoder_conv(x) |
|
h = h.view(h.size(0), -1) |
|
mu = self.encoder_fc_mu(h) |
|
logvar = self.encoder_fc_logvar(h) |
|
return mu, logvar |
|
|
|
def decode(self, z): |
|
z = self.decoder_fc(z) |
|
z = z.view(z.size(0), 512, 2, 2) |
|
reconstructed_image = self.decoder_conv(z) |
|
return reconstructed_image |
|
|
|
def reparameterize(self, mu, logvar): |
|
std = torch.exp(0.5 * logvar) |
|
eps = torch.randn_like(std) |
|
return mu + eps * std |
|
|
|
def forward(self, x): |
|
mu, logvar = self.encode(x) |
|
z = self.reparameterize(mu, logvar) |
|
reconstructed_image = self.decode(z) |
|
return reconstructed_image, mu, logvar |
|
|
|
|
|
|
|
def load_model(device, repo_id, filename): |
|
try: |
|
model_path = hf_hub_download(repo_id=repo_id, filename=filename) |
|
except Exception as e: |
|
print(f"Error downloading model from Hugging Face Hub: {e}") |
|
return None |
|
|
|
vae_model = VAE(latent_dim=latent_dim).to(device) |
|
|
|
try: |
|
checkpoint = torch.load(model_path, map_location=device) |
|
except FileNotFoundError: |
|
print(f"Error: Model file not found at {model_path}. This should not happen after downloading.") |
|
return None |
|
|
|
new_state_dict = {} |
|
for key, value in checkpoint.items(): |
|
new_key = key.replace('_orig_mod.', '') |
|
new_state_dict[new_key] = value |
|
|
|
vae_model.load_state_dict(new_state_dict) |
|
print(f"====> Loaded existing model from {model_path} (handling Torch Compile state_dict)") |
|
return vae_model |
|
|
|
|
|
def preprocess_image(image): |
|
try: |
|
transform = transforms.Compose([ |
|
transforms.Resize((image_size, image_size)), |
|
transforms.ToTensor(), |
|
]) |
|
image = transform(image).unsqueeze(0) |
|
return image |
|
except Exception as e: |
|
print(f"Failed to preprocess image: {e}") |
|
return None |
|
|
|
|
|
def generate_single_image(model, device): |
|
try: |
|
model.eval() |
|
with torch.no_grad(): |
|
sample_z = torch.randn(1, latent_dim).to(device) |
|
generated_image = model.decode(sample_z) |
|
img = generated_image.cpu().detach().numpy() |
|
output = (img[0] * 255).transpose(1, 2, 0).astype(np.uint8) |
|
image = Image.fromarray(output) |
|
return image |
|
except Exception as e: |
|
print(f"Image generation failed: {e}") |
|
return None |
|
|
|
|
|
def generate_from_base_image(model, device, base_image, noise_scale=0.1): |
|
try: |
|
model.eval() |
|
with torch.no_grad(): |
|
processed_image = preprocess_image(base_image) |
|
if processed_image is None: |
|
return None |
|
|
|
processed_image = processed_image.to(device) |
|
mu, logvar = model.encode(processed_image) |
|
latent_vector = model.reparameterize(mu, logvar) |
|
|
|
noise = torch.randn_like(latent_vector) * noise_scale |
|
latent_vector = latent_vector + noise |
|
|
|
generated_image = model.decode(latent_vector) |
|
img = generated_image.cpu().detach().numpy() |
|
output = (img[0] * 255).transpose(1, 2, 0).astype(np.uint8) |
|
output_image = Image.fromarray(output) |
|
return output_image |
|
|
|
except Exception as e: |
|
print(f"Seed image generation failed: {e}") |
|
return None |
|
|
|
|
|
|
|
|
|
def main(): |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
vae_model = load_model(device, model_repo_id, model_filename) |
|
if vae_model is None: |
|
return |
|
|
|
def generate_single(): |
|
img = generate_single_image(vae_model, device) |
|
if img: |
|
return img |
|
else: |
|
return "Image generation failed. Check console for errors." |
|
|
|
def generate_from_seed(seed_image): |
|
if seed_image is None: |
|
return "Please upload a seed image." |
|
|
|
img = generate_from_base_image(vae_model, device, seed_image) |
|
if img: |
|
return img |
|
else: |
|
return "Image generation from seed failed. Check console for errors." |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# VAE Image Generator") |
|
|
|
with gr.Tab("Generate Single Image"): |
|
single_button = gr.Button("Generate Random Image") |
|
single_output = gr.Image() |
|
single_button.click(generate_single, inputs=[], outputs=single_output) |
|
|
|
with gr.Tab("Generate from Seed"): |
|
seed_input = gr.Image(label="Seed Image") |
|
seed_button = gr.Button("Generate from Seed") |
|
seed_output = gr.Image() |
|
seed_button.click(generate_from_seed, inputs=seed_input, outputs=seed_output) |
|
|
|
|
|
demo.launch() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |