File size: 6,266 Bytes
22ec225
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c2a1bc
 
 
 
 
 
 
22ec225
 
2c2a1bc
22ec225
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c2a1bc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
from diffusers import StableDiffusionPipeline
from transformers import CLIPTokenizer
import os
import zipfile
import gradio as gr

# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define your custom dataset
class CustomImageDataset(Dataset):
    def __init__(self, images, prompts, transform=None):
        self.images = images
        self.prompts = prompts
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        if self.transform:
            image = self.transform(image)
        prompt = self.prompts[idx]
        return image, prompt

# Function to fine-tune the model
def fine_tune_model(images, prompts, model_save_path, num_epochs=3):
    transform = transforms.Compose([
        transforms.Resize((512, 512)),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])
    dataset = CustomImageDataset(images, prompts, transform)
    dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

    # Load Stable Diffusion model
    pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2").to(device)

    # Load model components
    vae = pipeline.vae.to(device)
    unet = pipeline.unet.to(device)
    text_encoder = pipeline.text_encoder.to(device)
    tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")  # Ensure correct tokenizer is used
    optimizer = torch.optim.AdamW(unet.parameters(), lr=5e-6)  # Define the optimizer

    # Define timestep range for training
    timesteps = torch.linspace(0, 1, steps=5).to(device)

    # Fine-tuning loop
    for epoch in range(num_epochs):
        for i, (images, prompts) in enumerate(dataloader):
            images = images.to(device)  # Move images to GPU if available

            # Tokenize the prompts
            inputs = tokenizer(list(prompts), padding=True, return_tensors="pt", truncation=True).to(device)

            latents = vae.encode(images).latent_dist.sample() * 0.18215
            text_embeddings = text_encoder(inputs.input_ids).last_hidden_state

            noise = torch.randn_like(latents).to(device)
            noisy_latents = latents + noise

            # Pass text embeddings and timestep to UNet
            timestep = torch.randint(0, len(timesteps), (latents.size(0),), device=device).float()
            pred_noise = unet(noisy_latents, timestep=timestep, encoder_hidden_states=text_embeddings).sample

            loss = torch.nn.functional.mse_loss(pred_noise, noise)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    # Save the fine-tuned model
    pipeline.save_pretrained(model_save_path)

# Function to convert tensor to PIL Image
def tensor_to_pil(tensor):
    tensor = tensor.squeeze().cpu().clamp(0, 1)  # Remove batch dimension if necessary
    tensor = transforms.ToPILImage()(tensor)
    return tensor

# Function to generate images
def generate_images(pipeline, prompt):
    with torch.no_grad():
        # Generate image from the prompt
        output = pipeline(prompt)

        # Convert the output to PIL Image
        image = output.images[0]  # Get the first generated image
    return image

# Function to zip the fine-tuned model
def zip_model(model_path):
    zip_path = f"{model_path}.zip"
    with zipfile.ZipFile(zip_path, "w") as zipf:
        for root, _, files in os.walk(model_path):
            for file in files:
                zipf.write(os.path.join(root, file), os.path.relpath(os.path.join(root, file), model_path))
    return zip_path

# Function to save uploaded files
def save_uploaded_file(uploaded_file, save_path):
    # Open the file in binary write mode
    with open(save_path, 'wb') as f:
        f.write(uploaded_file.data)  # Use .data for the file content
    return f"File saved at {save_path}"

# Gradio interface functions
def start_fine_tuning(uploaded_files, prompts, num_epochs):
    images = [Image.open(file).convert("RGB") for file in uploaded_files]
    model_save_path = "fine_tuned_model"
    fine_tune_model(images, prompts, model_save_path, num_epochs=int(num_epochs))
    return "Fine-tuning completed! Model is ready for download."

def download_model():
    model_save_path = "fine_tuned_model"
    if os.path.exists(model_save_path):
        return zip_model(model_save_path)
    else:
        return None

def generate_new_image(prompt):
    model_save_path = "fine_tuned_model"
    if os.path.exists(model_save_path):
        pipeline = StableDiffusionPipeline.from_pretrained(model_save_path).to(device)
    else:
        pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2").to(device)
    image = generate_images(pipeline, prompt)
    image_path = "generated_image.png"
    image.save(image_path)
    return image_path

# Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Fine-Tune Stable Diffusion and Generate Images")

    with gr.Tab("Fine-Tune Model"):
        with gr.Row():
            uploaded_files = gr.File(label="Upload Images", file_types=[".png", ".jpg", ".jpeg"], file_count="multiple")
        with gr.Row():
            prompts = gr.Textbox(label="Enter Prompts (comma-separated)")
            num_epochs = gr.Number(label="Number of Epochs", value=3)
        with gr.Row():
            fine_tune_button = gr.Button("Start Fine-Tuning")
        fine_tune_output = gr.Textbox(label="Output")

        fine_tune_button.click(start_fine_tuning, [uploaded_files, prompts, num_epochs], fine_tune_output)

    with gr.Tab("Download Fine-Tuned Model"):
        download_button = gr.Button("Download Fine-Tuned Model")
        download_output = gr.File()

        download_button.click(download_model, [], download_output)

    with gr.Tab("Generate New Images"):
        prompt_input = gr.Textbox(label="Enter a Prompt")
        generate_button = gr.Button("Generate Image")
        generated_image = gr.Image(label="Generated Image")

        generate_button.click(generate_new_image, [prompt_input], generated_image)

demo.launch()