|
import torch |
|
from torch.utils.data import DataLoader, Dataset |
|
from torchvision import transforms |
|
from PIL import Image |
|
from diffusers import StableDiffusionPipeline |
|
from transformers import CLIPTokenizer |
|
import os |
|
import zipfile |
|
import gradio as gr |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
class CustomImageDataset(Dataset): |
|
def __init__(self, images, prompts, transform=None): |
|
self.images = images |
|
self.prompts = prompts |
|
self.transform = transform |
|
|
|
def __len__(self): |
|
return len(self.images) |
|
|
|
def __getitem__(self, idx): |
|
image = self.images[idx] |
|
if self.transform: |
|
image = self.transform(image) |
|
prompt = self.prompts[idx] |
|
return image, prompt |
|
|
|
|
|
def fine_tune_model(images, prompts, model_save_path, num_epochs=3): |
|
transform = transforms.Compose([ |
|
transforms.Resize((512, 512)), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), |
|
]) |
|
dataset = CustomImageDataset(images, prompts, transform) |
|
dataloader = DataLoader(dataset, batch_size=4, shuffle=True) |
|
|
|
|
|
pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2").to(device) |
|
|
|
|
|
vae = pipeline.vae.to(device) |
|
unet = pipeline.unet.to(device) |
|
text_encoder = pipeline.text_encoder.to(device) |
|
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") |
|
optimizer = torch.optim.AdamW(unet.parameters(), lr=5e-6) |
|
|
|
|
|
timesteps = torch.linspace(0, 1, steps=5).to(device) |
|
|
|
|
|
for epoch in range(num_epochs): |
|
for i, (images, prompts) in enumerate(dataloader): |
|
images = images.to(device) |
|
|
|
|
|
inputs = tokenizer(list(prompts), padding=True, return_tensors="pt", truncation=True).to(device) |
|
|
|
latents = vae.encode(images).latent_dist.sample() * 0.18215 |
|
text_embeddings = text_encoder(inputs.input_ids).last_hidden_state |
|
|
|
noise = torch.randn_like(latents).to(device) |
|
noisy_latents = latents + noise |
|
|
|
|
|
timestep = torch.randint(0, len(timesteps), (latents.size(0),), device=device).float() |
|
pred_noise = unet(noisy_latents, timestep=timestep, encoder_hidden_states=text_embeddings).sample |
|
|
|
loss = torch.nn.functional.mse_loss(pred_noise, noise) |
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
|
|
|
|
pipeline.save_pretrained(model_save_path) |
|
|
|
|
|
def tensor_to_pil(tensor): |
|
tensor = tensor.squeeze().cpu().clamp(0, 1) |
|
tensor = transforms.ToPILImage()(tensor) |
|
return tensor |
|
|
|
|
|
def generate_images(pipeline, prompt): |
|
with torch.no_grad(): |
|
|
|
output = pipeline(prompt) |
|
|
|
|
|
image = output.images[0] |
|
return image |
|
|
|
|
|
def zip_model(model_path): |
|
zip_path = f"{model_path}.zip" |
|
with zipfile.ZipFile(zip_path, "w") as zipf: |
|
for root, _, files in os.walk(model_path): |
|
for file in files: |
|
zipf.write(os.path.join(root, file), os.path.relpath(os.path.join(root, file), model_path)) |
|
return zip_path |
|
|
|
|
|
def save_uploaded_file(uploaded_file, save_path): |
|
|
|
with open(save_path, 'wb') as f: |
|
f.write(uploaded_file.data) |
|
return f"File saved at {save_path}" |
|
|
|
|
|
def start_fine_tuning(uploaded_files, prompts, num_epochs): |
|
images = [Image.open(file).convert("RGB") for file in uploaded_files] |
|
model_save_path = "fine_tuned_model" |
|
fine_tune_model(images, prompts, model_save_path, num_epochs=int(num_epochs)) |
|
return "Fine-tuning completed! Model is ready for download." |
|
|
|
def download_model(): |
|
model_save_path = "fine_tuned_model" |
|
if os.path.exists(model_save_path): |
|
return zip_model(model_save_path) |
|
else: |
|
return None |
|
|
|
def generate_new_image(prompt): |
|
model_save_path = "fine_tuned_model" |
|
if os.path.exists(model_save_path): |
|
pipeline = StableDiffusionPipeline.from_pretrained(model_save_path).to(device) |
|
else: |
|
pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2").to(device) |
|
image = generate_images(pipeline, prompt) |
|
image_path = "generated_image.png" |
|
image.save(image_path) |
|
return image_path |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Fine-Tune Stable Diffusion and Generate Images") |
|
|
|
with gr.Tab("Fine-Tune Model"): |
|
with gr.Row(): |
|
uploaded_files = gr.File(label="Upload Images", file_types=[".png", ".jpg", ".jpeg"], file_count="multiple") |
|
with gr.Row(): |
|
prompts = gr.Textbox(label="Enter Prompts (comma-separated)") |
|
num_epochs = gr.Number(label="Number of Epochs", value=3) |
|
with gr.Row(): |
|
fine_tune_button = gr.Button("Start Fine-Tuning") |
|
fine_tune_output = gr.Textbox(label="Output") |
|
|
|
fine_tune_button.click(start_fine_tuning, [uploaded_files, prompts, num_epochs], fine_tune_output) |
|
|
|
with gr.Tab("Download Fine-Tuned Model"): |
|
download_button = gr.Button("Download Fine-Tuned Model") |
|
download_output = gr.File() |
|
|
|
download_button.click(download_model, [], download_output) |
|
|
|
with gr.Tab("Generate New Images"): |
|
prompt_input = gr.Textbox(label="Enter a Prompt") |
|
generate_button = gr.Button("Generate Image") |
|
generated_image = gr.Image(label="Generated Image") |
|
|
|
generate_button.click(generate_new_image, [prompt_input], generated_image) |
|
|
|
demo.launch() |
|
|