VinitT commited on
Commit
22ec225
·
verified ·
1 Parent(s): 2a8791f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +159 -0
app.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import DataLoader, Dataset
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+ from diffusers import StableDiffusionPipeline
6
+ from transformers import CLIPTokenizer
7
+ import os
8
+ import zipfile
9
+ import gradio as gr
10
+
11
+ # Define the device
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+
14
+ # Define your custom dataset
15
+ class CustomImageDataset(Dataset):
16
+ def __init__(self, images, prompts, transform=None):
17
+ self.images = images
18
+ self.prompts = prompts
19
+ self.transform = transform
20
+
21
+ def __len__(self):
22
+ return len(self.images)
23
+
24
+ def __getitem__(self, idx):
25
+ image = self.images[idx]
26
+ if self.transform:
27
+ image = self.transform(image)
28
+ prompt = self.prompts[idx]
29
+ return image, prompt
30
+
31
+ # Function to fine-tune the model
32
+ def fine_tune_model(images, prompts, model_save_path, num_epochs=3):
33
+ transform = transforms.Compose([
34
+ transforms.Resize((512, 512)),
35
+ transforms.ToTensor(),
36
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
37
+ ])
38
+ dataset = CustomImageDataset(images, prompts, transform)
39
+ dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
40
+
41
+ # Load Stable Diffusion model
42
+ pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2").to(device)
43
+
44
+ # Load model components
45
+ vae = pipeline.vae.to(device)
46
+ unet = pipeline.unet.to(device)
47
+ text_encoder = pipeline.text_encoder.to(device)
48
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") # Ensure correct tokenizer is used
49
+ optimizer = torch.optim.AdamW(unet.parameters(), lr=5e-6) # Define the optimizer
50
+
51
+ # Define timestep range for training
52
+ timesteps = torch.linspace(0, 1, steps=5).to(device)
53
+
54
+ # Fine-tuning loop
55
+ for epoch in range(num_epochs):
56
+ for i, (images, prompts) in enumerate(dataloader):
57
+ images = images.to(device) # Move images to GPU if available
58
+
59
+ # Tokenize the prompts
60
+ inputs = tokenizer(list(prompts), padding=True, return_tensors="pt", truncation=True).to(device)
61
+
62
+ latents = vae.encode(images).latent_dist.sample() * 0.18215
63
+ text_embeddings = text_encoder(inputs.input_ids).last_hidden_state
64
+
65
+ noise = torch.randn_like(latents).to(device)
66
+ noisy_latents = latents + noise
67
+
68
+ # Pass text embeddings and timestep to UNet
69
+ timestep = torch.randint(0, len(timesteps), (latents.size(0),), device=device).float()
70
+ pred_noise = unet(noisy_latents, timestep=timestep, encoder_hidden_states=text_embeddings).sample
71
+
72
+ loss = torch.nn.functional.mse_loss(pred_noise, noise)
73
+ optimizer.zero_grad()
74
+ loss.backward()
75
+ optimizer.step()
76
+
77
+ # Save the fine-tuned model
78
+ pipeline.save_pretrained(model_save_path)
79
+
80
+ # Function to convert tensor to PIL Image
81
+ def tensor_to_pil(tensor):
82
+ tensor = tensor.squeeze().cpu().clamp(0, 1) # Remove batch dimension if necessary
83
+ tensor = transforms.ToPILImage()(tensor)
84
+ return tensor
85
+
86
+ # Function to generate images
87
+ def generate_images(pipeline, prompt):
88
+ with torch.no_grad():
89
+ # Generate image from the prompt
90
+ output = pipeline(prompt)
91
+
92
+ # Convert the output to PIL Image
93
+ image = output.images[0] # Get the first generated image
94
+ return image
95
+
96
+ # Function to zip the fine-tuned model
97
+ def zip_model(model_path):
98
+ zip_path = f"{model_path}.zip"
99
+ with zipfile.ZipFile(zip_path, "w") as zipf:
100
+ for root, _, files in os.walk(model_path):
101
+ for file in files:
102
+ zipf.write(os.path.join(root, file), os.path.relpath(os.path.join(root, file), model_path))
103
+ return zip_path
104
+
105
+ # Gradio interface functions
106
+ def start_fine_tuning(uploaded_files, prompts, num_epochs):
107
+ images = [Image.open(file).convert("RGB") for file in uploaded_files]
108
+ model_save_path = "fine_tuned_model"
109
+ fine_tune_model(images, prompts, model_save_path, num_epochs=int(num_epochs))
110
+ return "Fine-tuning completed! Model is ready for download."
111
+
112
+ def download_model():
113
+ model_save_path = "fine_tuned_model"
114
+ if os.path.exists(model_save_path):
115
+ return zip_model(model_save_path)
116
+ else:
117
+ return None
118
+
119
+ def generate_new_image(prompt):
120
+ model_save_path = "fine_tuned_model"
121
+ if os.path.exists(model_save_path):
122
+ pipeline = StableDiffusionPipeline.from_pretrained(model_save_path).to(device)
123
+ else:
124
+ pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2").to(device)
125
+ image = generate_images(pipeline, prompt)
126
+ image_path = "generated_image.png"
127
+ image.save(image_path)
128
+ return image_path
129
+
130
+ # Gradio interface
131
+ with gr.Blocks() as demo:
132
+ gr.Markdown("# Fine-Tune Stable Diffusion and Generate Images")
133
+
134
+ with gr.Tab("Fine-Tune Model"):
135
+ with gr.Row():
136
+ uploaded_files = gr.File(label="Upload Images", file_types=[".png", ".jpg", ".jpeg"], file_count="multiple")
137
+ with gr.Row():
138
+ prompts = gr.Textbox(label="Enter Prompts (comma-separated)")
139
+ num_epochs = gr.Number(label="Number of Epochs", value=3)
140
+ with gr.Row():
141
+ fine_tune_button = gr.Button("Start Fine-Tuning")
142
+ fine_tune_output = gr.Textbox(label="Output")
143
+
144
+ fine_tune_button.click(start_fine_tuning, [uploaded_files, prompts, num_epochs], fine_tune_output)
145
+
146
+ with gr.Tab("Download Fine-Tuned Model"):
147
+ download_button = gr.Button("Download Fine-Tuned Model")
148
+ download_output = gr.File()
149
+
150
+ download_button.click(download_model, [], download_output)
151
+
152
+ with gr.Tab("Generate New Images"):
153
+ prompt_input = gr.Textbox(label="Enter a Prompt")
154
+ generate_button = gr.Button("Generate Image")
155
+ generated_image = gr.Image(label="Generated Image")
156
+
157
+ generate_button.click(generate_new_image, [prompt_input], generated_image)
158
+
159
+ demo.launch()