VinitT's picture
Update app.py
d2190eb verified
raw
history blame
6.49 kB
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
from diffusers import StableDiffusionPipeline
from transformers import CLIPTokenizer
import os
import zipfile
import tempfile
import shutil
import gradio as gr
# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Define your custom dataset
class CustomImageDataset(Dataset):
def __init__(self, images, prompts, transform=None):
self.images = images
self.prompts = prompts
self.transform = transform
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
image = self.images[idx]
if self.transform:
image = self.transform(image)
prompt = self.prompts[idx]
return image, prompt
# Function to fine-tune the model
def fine_tune_model(images, prompts, model_save_path, num_epochs=3):
transform = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])
dataset = CustomImageDataset(images, prompts, transform)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
# Load Stable Diffusion model
pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2").to(device)
# Load model components
vae = pipeline.vae.to(device)
unet = pipeline.unet.to(device)
text_encoder = pipeline.text_encoder.to(device)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") # Ensure correct tokenizer is used
optimizer = torch.optim.AdamW(unet.parameters(), lr=5e-6) # Define the optimizer
# Define timestep range for training
timesteps = torch.linspace(0, 1, steps=5).to(device)
# Fine-tuning loop
for epoch in range(num_epochs):
for i, (images, prompts) in enumerate(dataloader):
images = images.to(device) # Move images to GPU if available
# Tokenize the prompts
inputs = tokenizer(list(prompts), padding=True, return_tensors="pt", truncation=True).to(device)
latents = vae.encode(images).latent_dist.sample() * 0.18215
text_embeddings = text_encoder(inputs.input_ids).last_hidden_state
noise = torch.randn_like(latents).to(device)
noisy_latents = latents + noise
# Pass text embeddings and timestep to UNet
timestep = torch.randint(0, len(timesteps), (latents.size(0),), device=device).float()
pred_noise = unet(noisy_latents, timestep=timestep, encoder_hidden_states=text_embeddings).sample
loss = torch.nn.functional.mse_loss(pred_noise, noise)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Save the fine-tuned model
pipeline.save_pretrained(model_save_path)
# Function to convert tensor to PIL Image
def tensor_to_pil(tensor):
tensor = tensor.squeeze().cpu().clamp(0, 1) # Remove batch dimension if necessary
tensor = transforms.ToPILImage()(tensor)
return tensor
# Function to generate images
def generate_images(pipeline, prompt):
with torch.no_grad():
# Generate image from the prompt
output = pipeline(prompt)
# Convert the output to PIL Image
image = output.images[0] # Get the first generated image
return image
# Function to zip the fine-tuned model
def zip_model(model_path):
zip_path = f"{model_path}.zip"
with zipfile.ZipFile(zip_path, "w") as zipf:
for root, _, files in os.walk(model_path):
for file in files:
zipf.write(os.path.join(root, file), os.path.relpath(os.path.join(root, file), model_path))
return zip_path
# Gradio interface functions
def start_fine_tuning(uploaded_files, prompts, num_epochs):
# Create a temporary directory for storing files
temp_dir = tempfile.mkdtemp()
print("Temporary directory:", temp_dir)
images = []
for file in uploaded_files:
# Store the uploaded file in the temp directory
image_path = os.path.join(temp_dir, file.name)
with open(image_path, 'wb') as f:
f.write(file.read()) # Save file content
images.append(Image.open(image_path).convert("RGB"))
model_save_path = "fine_tuned_model"
fine_tune_model(images, prompts, model_save_path, num_epochs=int(num_epochs))
# Clean up the temporary directory after fine-tuning
shutil.rmtree(temp_dir)
return "Fine-tuning completed! Model is ready for download."
def download_model():
model_save_path = "fine_tuned_model"
if os.path.exists(model_save_path):
return zip_model(model_save_path)
else:
return None
def generate_new_image(prompt):
model_save_path = "fine_tuned_model"
if os.path.exists(model_save_path):
pipeline = StableDiffusionPipeline.from_pretrained(model_save_path).to(device)
else:
pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2").to(device)
image = generate_images(pipeline, prompt)
image_path = "generated_image.png"
image.save(image_path)
return image_path
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Fine-Tune Stable Diffusion and Generate Images")
with gr.Tab("Fine-Tune Model"):
with gr.Row():
uploaded_files = gr.File(label="Upload Images", file_types=[".png", ".jpg", ".jpeg"], file_count="multiple")
with gr.Row():
prompts = gr.Textbox(label="Enter Prompts (comma-separated)")
num_epochs = gr.Number(label="Number of Epochs", value=3)
with gr.Row():
fine_tune_button = gr.Button("Start Fine-Tuning")
fine_tune_output = gr.Textbox(label="Output")
fine_tune_button.click(start_fine_tuning, [uploaded_files, prompts, num_epochs], fine_tune_output)
with gr.Tab("Download Fine-Tuned Model"):
download_button = gr.Button("Download Fine-Tuned Model")
download_output = gr.File()
download_button.click(download_model, [], download_output)
with gr.Tab("Generate New Images"):
prompt_input = gr.Textbox(label="Enter a Prompt")
generate_button = gr.Button("Generate Image")
generated_image = gr.Image(label="Generated Image")
generate_button.click(generate_new_image, [prompt_input], generated_image)
demo.launch()