srikanthp07 commited on
Commit
d36648e
·
1 Parent(s): 75110aa

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +286 -0
app.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from base64 import b64encode
2
+ import gradio as gr
3
+ import numpy as np
4
+ import torch
5
+ from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
6
+
7
+ from matplotlib import pyplot as plt
8
+ from pathlib import Path
9
+ from PIL import Image
10
+ from torch import autocast
11
+ from torchvision import transforms as tfms
12
+ from tqdm.auto import tqdm
13
+ from transformers import CLIPTextModel, CLIPTokenizer, logging
14
+ import os
15
+ import cv2
16
+ import torchvision.transforms as T
17
+
18
+ torch.manual_seed(1)
19
+ logging.set_verbosity_error()
20
+
21
+ torch_device = "cuda" if torch.cuda.is_available() else "cpu"
22
+
23
+ # Load the autoencoder
24
+ vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder='vae')
25
+
26
+ # Load tokenizer and text encoder to tokenize and encode the text
27
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
28
+ text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
29
+
30
+ # Unet model for generating latents
31
+ unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder='unet')
32
+
33
+ # Noise scheduler
34
+ scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
35
+
36
+ # Move everything to GPU
37
+ vae = vae.to(torch_device)
38
+ text_encoder = text_encoder.to(torch_device)
39
+ unet = unet.to(torch_device)
40
+
41
+ style_files = ['stable_diffusion/learned_embeddings/arcane-style-jv.bin', 'stable_diffusion/learned_embeddings/birb-style.bin',
42
+ 'stable_diffusion/learned_embeddings/dr-strange.bin', 'stable_diffusion/learned_embeddings/midjourney-style.bin',
43
+ 'stable_diffusion/learned_embeddings/oil_style.bin']
44
+
45
+ images_without_loss = []
46
+ images_with_loss = []
47
+
48
+ seed_values = [8,16,50,80,128]
49
+ height = 512 # default height of Stable Diffusion
50
+ width = 512 # default width of Stable Diffusion
51
+ num_inference_steps = 5 # Number of denoising steps
52
+ guidance_scale = 7.5 # Scale for classifier-free guidance
53
+ num_styles = len(style_files)
54
+
55
+ # Prep Scheduler
56
+ def set_timesteps(scheduler, num_inference_steps):
57
+ scheduler.set_timesteps(num_inference_steps)
58
+ scheduler.timesteps = scheduler.timesteps.to(torch.float32) # minor fix to ensure MPS compatibility, fixed in diffusers PR 3925
59
+
60
+ def get_output_embeds(input_embeddings):
61
+ # CLIP's text model uses causal mask, so we prepare it here:
62
+ bsz, seq_len = input_embeddings.shape[:2]
63
+ causal_attention_mask = text_encoder.text_model._build_causal_attention_mask(bsz, seq_len, dtype=input_embeddings.dtype)
64
+
65
+ # Getting the output embeddings involves calling the model with passing output_hidden_states=True
66
+ # so that it doesn't just return the pooled final predictions:
67
+ encoder_outputs = text_encoder.text_model.encoder(
68
+ inputs_embeds=input_embeddings,
69
+ attention_mask=None, # We aren't using an attention mask so that can be None
70
+ causal_attention_mask=causal_attention_mask.to(torch_device),
71
+ output_attentions=None,
72
+ output_hidden_states=True, # We want the output embs not the final output
73
+ return_dict=None,
74
+ )
75
+
76
+ # We're interested in the output hidden state only
77
+ output = encoder_outputs[0]
78
+
79
+ # There is a final layer norm we need to pass these through
80
+ output = text_encoder.text_model.final_layer_norm(output)
81
+
82
+ # And now they're ready!
83
+ return output
84
+
85
+ def get_style_embeddings(style_file):
86
+ style_embed = torch.load(style_file)
87
+ style_name = list(style_embed.keys())[0]
88
+ return style_embed[style_name]
89
+
90
+ import torch
91
+
92
+ def vibrance_loss(image):
93
+ # Calculate the standard deviation of color channels
94
+ std_dev = torch.std(image, dim=(2, 3)) # Compute standard deviation over height and width
95
+ # Calculate the mean standard deviation across the batch
96
+ mean_std_dev = torch.mean(std_dev)
97
+ # You can adjust a scale factor to control the strength of vibrance regularization
98
+ scale_factor = 100.0
99
+ # Calculate the vibrance loss
100
+ loss = -scale_factor * mean_std_dev
101
+ return loss
102
+
103
+
104
+ from torchvision.transforms import ToTensor
105
+
106
+ def pil_to_latent(input_im):
107
+ # Single image -> single latent in a batch (so size 1, 4, 64, 64)
108
+ with torch.no_grad():
109
+ latent = vae.encode(tfms.ToTensor()(input_im).unsqueeze(0).to(torch_device)*2-1) # Note scaling
110
+ return 0.18215 * latent.latent_dist.sample()
111
+
112
+ def latents_to_pil(latents):
113
+ # bath of latents -> list of images
114
+ latents = (1 / 0.18215) * latents
115
+ with torch.no_grad():
116
+ image = vae.decode(latents).sample
117
+ image = (image / 2 + 0.5).clamp(0, 1)
118
+ image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
119
+ images = (image * 255).round().astype("uint8")
120
+ pil_images = [Image.fromarray(image) for image in images]
121
+ return pil_images
122
+
123
+ def additional_guidance(latents, scheduler, noise_pred, t, sigma, custom_loss_fn):
124
+ #### ADDITIONAL GUIDANCE ###
125
+ # Requires grad on the latents
126
+ latents = latents.detach().requires_grad_()
127
+
128
+ # Get the predicted x0:
129
+ latents_x0 = latents - sigma * noise_pred
130
+ #print(f"latents: {latents.shape}, noise_pred:{noise_pred.shape}")
131
+ #latents_x0 = scheduler.step(noise_pred, t, latents).pred_original_sample
132
+
133
+ # Decode to image space
134
+ denoised_images = vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 # range (0, 1)
135
+
136
+ # Calculate loss
137
+ loss = custom_loss_fn(denoised_images)
138
+
139
+ # Get gradient
140
+ cond_grad = torch.autograd.grad(loss, latents, allow_unused=False)[0]
141
+
142
+ # Modify the latents based on this gradient
143
+ latents = latents.detach() - cond_grad * sigma**2
144
+ return latents, loss
145
+
146
+
147
+ def generate_with_embs(text_embeddings, max_length, random_seed, loss_fn = None):
148
+ generator = torch.manual_seed(random_seed) # Seed generator to create the inital latent noise
149
+ batch_size = 1
150
+
151
+ uncond_input = tokenizer(
152
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
153
+ )
154
+ with torch.no_grad():
155
+ uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
156
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
157
+
158
+ # Prep Scheduler
159
+ set_timesteps(scheduler, num_inference_steps)
160
+
161
+ # Prep latents
162
+ latents = torch.randn(
163
+ (batch_size, unet.in_channels, height // 8, width // 8),
164
+ generator=generator,
165
+ )
166
+ latents = latents.to(torch_device)
167
+ latents = latents * scheduler.init_noise_sigma
168
+
169
+ # Loop
170
+ for i, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)):
171
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
172
+ latent_model_input = torch.cat([latents] * 2)
173
+ sigma = scheduler.sigmas[i]
174
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
175
+
176
+ # predict the noise residual
177
+ with torch.no_grad():
178
+ noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
179
+
180
+ # perform guidance
181
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
182
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
183
+ if loss_fn is not None:
184
+ if i%2 == 0:
185
+ latents, custom_loss = additional_guidance(latents, scheduler, noise_pred, t, sigma, loss_fn)
186
+
187
+ # compute the previous noisy sample x_t -> x_t-1
188
+ latents = scheduler.step(noise_pred, t, latents).prev_sample
189
+
190
+ return latents_to_pil(latents)[0]
191
+
192
+ def generate_images(prompt, style_num=None, random_seed=41, custom_loss_fn = None):
193
+ eos_pos = len(prompt.split())+1
194
+
195
+ style_token_embedding = None
196
+ if style_num:
197
+ style_token_embedding = get_style_embeddings(style_files[style_num])
198
+
199
+ # tokenize
200
+ text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
201
+ max_length = text_input.input_ids.shape[-1]
202
+ input_ids = text_input.input_ids.to(torch_device)
203
+
204
+ # get token embeddings
205
+ token_emb_layer = text_encoder.text_model.embeddings.token_embedding
206
+ token_embeddings = token_emb_layer(input_ids)
207
+
208
+ # Append style token towards the end of the sentence embeddings
209
+ if style_token_embedding is not None:
210
+ token_embeddings[-1, eos_pos, :] = style_token_embedding
211
+
212
+ # combine with pos embs
213
+ pos_emb_layer = text_encoder.text_model.embeddings.position_embedding
214
+ position_ids = text_encoder.text_model.embeddings.position_ids[:, :77]
215
+ position_embeddings = pos_emb_layer(position_ids)
216
+ input_embeddings = token_embeddings + position_embeddings
217
+
218
+ # Feed through to get final output embs
219
+ modified_output_embeddings = get_output_embeds(input_embeddings)
220
+
221
+ # And generate an image with this:
222
+ generated_image = generate_with_embs(modified_output_embeddings, max_length, random_seed, custom_loss_fn)
223
+ return generated_image
224
+
225
+ import matplotlib.pyplot as plt
226
+
227
+ def display_images_in_rows(images_with_titles, titles):
228
+ num_images = len(images_with_titles)
229
+ rows = 5 # Display 5 rows always
230
+ columns = 1 if num_images == 5 else 2 # Use 1 column if there are 5 images, otherwise 2 columns
231
+ fig, axes = plt.subplots(rows, columns + 1, figsize=(15, 5 * rows)) # Add an extra column for titles
232
+
233
+ for r in range(rows):
234
+ # Add the title on the extreme left in the middle of each picture
235
+ axes[r, 0].text(0.5, 0.5, titles[r], ha='center', va='center')
236
+ axes[r, 0].axis('off')
237
+
238
+ # Add "Without Loss" label above the first column and "With Loss" label above the second column (if applicable)
239
+ if columns == 2:
240
+ axes[r, 1].set_title("Without Loss", pad=10)
241
+ axes[r, 2].set_title("With Loss", pad=10)
242
+
243
+ for c in range(1, columns + 1):
244
+ index = r * columns + c - 1
245
+ if index < num_images:
246
+ image, _ = images_with_titles[index]
247
+ axes[r, c].imshow(image)
248
+ axes[r, c].axis('off')
249
+
250
+ return fig
251
+ # plt.show()
252
+
253
+
254
+ def image_generator(prompt = "dog", loss_function=None):
255
+ images_without_loss = []
256
+ images_with_loss = []
257
+ if loss_function == "Yes":
258
+ loss_function = vibrance_loss
259
+ else:
260
+ loss_function = None
261
+
262
+ for i in range(num_styles):
263
+ generated_img = generate_images(prompt,style_num = i,random_seed = seed_values[i],custom_loss_fn = None)
264
+ images_without_loss.append(generated_img)
265
+ if loss_function:
266
+ generated_img = generate_images(prompt,style_num = i,random_seed = seed_values[i],custom_loss_fn = loss_function)
267
+ images_with_loss.append(generated_img)
268
+
269
+ generated_sd_images = []
270
+ titles = ["Arcane Style", "Birb Style", "Dr Strange Style", "Midjourney Style", "Oil Style"]
271
+
272
+ for i in range(len(titles)):
273
+ generated_sd_images.append((images_without_loss[i], titles[i]))
274
+ if images_with_loss != []:
275
+ generated_sd_images.append((images_with_loss[i], titles[i]))
276
+
277
+
278
+ return display_images_in_rows(generated_sd_images, titles)
279
+
280
+ description = "Generate an image with a prompt and apply vibrance loss if you wish to. Note that the app is hosted on a cpu and it takes atleast 15 minutes for generating images without loss. Please feel free to clone the space and use it with a GPU after increase the inference steps to more than 10 for better results"
281
+
282
+ demo = gr.Interface(image_generator,
283
+ inputs=[gr.Textbox(label="Enter prompt for generation", type="text", value="dog sitting on a bench"),
284
+ gr.Radio(["Yes", "No"], value="No" , label="Apply vibrance loss")],
285
+ outputs=gr.Plot(label="Generated Images"), title = "Stable Diffusion using Textual Inversion", description=description)
286
+ demo.launch()