File size: 7,569 Bytes
254b385
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import os
import io
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms
from PIL import Image, ImageTk, ImageFilter
import numpy as np
import gradio as gr
from huggingface_hub import hf_hub_download


# --- Hyperparameters ---
image_size = 64
latent_dim = 128
model_repo_id = "elapt1c/catGen"
model_filename = "model.pth"
#model_path = 'model.pth'  # Relative path within the space. Assumed it will be in the root
generated_images_folder = 'generated_images'


# --- VAE Model --- (Simplified VAE - MATCHING TRAINING CODE)
class VAE(nn.Module):
    def __init__(self, latent_dim):
        super(VAE, self).__init__()

        # Encoder - MATCHING TRAINING CODE ARCHITECTURE
        self.encoder_conv = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1),  # Increased initial channels
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),  # Increased final channels
            nn.LeakyReLU(0.2, inplace=True),
        )
        self.encoder_fc_mu = nn.Linear(512 * 2 * 2, latent_dim)
        self.encoder_fc_logvar = nn.Linear(512 * 2 * 2, latent_dim)

        # Decoder - MATCHING TRAINING CODE ARCHITECTURE
        self.decoder_fc = nn.Linear(latent_dim, 512 * 2 * 2)
        self.decoder_conv = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()
        )

    def encode(self, x):
        h = self.encoder_conv(x)
        h = h.view(h.size(0), -1)
        mu = self.encoder_fc_mu(h)
        logvar = self.encoder_fc_logvar(h)
        return mu, logvar

    def decode(self, z):
        z = self.decoder_fc(z)
        z = z.view(z.size(0), 512, 2, 2)  # Corrected view shape to 512 channels
        reconstructed_image = self.decoder_conv(z)
        return reconstructed_image

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        reconstructed_image = self.decode(z)
        return reconstructed_image, mu, logvar


# --- Helper Functions ---
def load_model(device, repo_id, filename):
    try:
        model_path = hf_hub_download(repo_id=repo_id, filename=filename)
    except Exception as e:
        print(f"Error downloading model from Hugging Face Hub: {e}")
        return None

    vae_model = VAE(latent_dim=latent_dim).to(device)  # Plain VAE model

    try:
        checkpoint = torch.load(model_path, map_location=device)  # Load checkpoint dict
    except FileNotFoundError:
        print(f"Error: Model file not found at {model_path}. This should not happen after downloading.")
        return None

    new_state_dict = {}  # Create a new dictionary for modified keys
    for key, value in checkpoint.items():
        new_key = key.replace('_orig_mod.', '')  # Remove "_orig_mod." prefix
        new_state_dict[new_key] = value  # Add to new dict with modified key

    vae_model.load_state_dict(new_state_dict)  # Load state_dict with modified keys
    print(f"====> Loaded existing model from {model_path} (handling Torch Compile state_dict)")
    return vae_model


def preprocess_image(image):
    try:
        transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
        ])
        image = transform(image).unsqueeze(0)
        return image
    except Exception as e:
        print(f"Failed to preprocess image: {e}")
        return None


def generate_single_image(model, device):
    try:
        model.eval()
        with torch.no_grad():
            sample_z = torch.randn(1, latent_dim).to(device)
            generated_image = model.decode(sample_z)  # Use simple VAE decode
            img = generated_image.cpu().detach().numpy()
            output = (img[0] * 255).transpose(1, 2, 0).astype(np.uint8)
            image = Image.fromarray(output)  # save from random image
            return image  # use the image
    except Exception as e:
        print(f"Image generation failed: {e}")
        return None


def generate_from_base_image(model, device, base_image, noise_scale=0.1):
    try:
        model.eval()
        with torch.no_grad():
            processed_image = preprocess_image(base_image)  # Process base image
            if processed_image is None:
                return None

            processed_image = processed_image.to(device)  # to device
            mu, logvar = model.encode(processed_image)  # encode
            latent_vector = model.reparameterize(mu, logvar)  # reparameterize

            noise = torch.randn_like(latent_vector) * noise_scale  # add noise
            latent_vector = latent_vector + noise  # combine

            generated_image = model.decode(latent_vector)  # Use simple VAE decode
            img = generated_image.cpu().detach().numpy()
            output = (img[0] * 255).transpose(1, 2, 0).astype(np.uint8)
            output_image = Image.fromarray(output)  # save from
            return output_image

    except Exception as e:
        print(f"Seed image generation failed: {e}")
        return None



# --- Gradio Interface ---
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    vae_model = load_model(device, model_repo_id, model_filename)
    if vae_model is None:
        return  # Exit if model loading fails

    def generate_single():
        img = generate_single_image(vae_model, device)
        if img:
            return img
        else:
            return "Image generation failed. Check console for errors."

    def generate_from_seed(seed_image):
        if seed_image is None:
            return "Please upload a seed image."

        img = generate_from_base_image(vae_model, device, seed_image)
        if img:
            return img
        else:
            return "Image generation from seed failed. Check console for errors."


    with gr.Blocks() as demo:
        gr.Markdown("# VAE Image Generator")

        with gr.Tab("Generate Single Image"):
            single_button = gr.Button("Generate Random Image")
            single_output = gr.Image()
            single_button.click(generate_single, inputs=[], outputs=single_output)

        with gr.Tab("Generate from Seed"):
            seed_input = gr.Image(label="Seed Image")
            seed_button = gr.Button("Generate from Seed")
            seed_output = gr.Image()
            seed_button.click(generate_from_seed, inputs=seed_input, outputs=seed_output)


    demo.launch()


if __name__ == "__main__":
    main()