PrarthanaTS commited on
Commit
1a4dc5b
·
1 Parent(s): af36cc8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +320 -0
app.py CHANGED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import sys
4
+ from base64 import b64encode
5
+
6
+ import numpy as np
7
+ import torch
8
+ from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
9
+
10
+ from matplotlib import pyplot as plt
11
+ from pathlib import Path
12
+ from PIL import Image
13
+ from torch import autocast
14
+ from torchvision import transforms as tfms
15
+ from tqdm.auto import tqdm
16
+ from transformers import CLIPTextModel, CLIPTokenizer, logging
17
+ import os
18
+ import cv2
19
+ import torchvision.transforms as T
20
+
21
+ torch.manual_seed(1)
22
+ logging.set_verbosity_error()
23
+ torch_device = "cuda" if torch.cuda.is_available() else "cpu"
24
+
25
+
26
+ # Load the autoencoder
27
+ vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder='vae')
28
+
29
+ # Load tokenizer and text encoder to tokenize and encode the text
30
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
31
+ text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
32
+
33
+ # Unet model for generating latents
34
+ unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder='unet')
35
+
36
+ # Noise scheduler
37
+ scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
38
+
39
+ # Move everything to GPU
40
+ vae = vae.to(torch_device)
41
+ text_encoder = text_encoder.to(torch_device)
42
+ unet = unet.to(torch_device)
43
+
44
+ def get_output_embeds(input_embeddings):
45
+ # CLIP's text model uses causal mask, so we prepare it here:
46
+ bsz, seq_len = input_embeddings.shape[:2]
47
+ causal_attention_mask = text_encoder.text_model._build_causal_attention_mask(bsz, seq_len, dtype=input_embeddings.dtype)
48
+
49
+ # Getting the output embeddings involves calling the model with passing output_hidden_states=True
50
+ # so that it doesn't just return the pooled final predictions:
51
+ encoder_outputs = text_encoder.text_model.encoder(
52
+ inputs_embeds=input_embeddings,
53
+ attention_mask=None, # We aren't using an attention mask so that can be None
54
+ causal_attention_mask=causal_attention_mask.to(torch_device),
55
+ output_attentions=None,
56
+ output_hidden_states=True, # We want the output embs not the final output
57
+ return_dict=None,
58
+ )
59
+
60
+ # We're interested in the output hidden state only
61
+ output = encoder_outputs[0]
62
+
63
+ # There is a final layer norm we need to pass these through
64
+ output = text_encoder.text_model.final_layer_norm(output)
65
+
66
+ # And now they're ready!
67
+ return output
68
+
69
+ # Prep Scheduler
70
+ def set_timesteps(scheduler, num_inference_steps):
71
+ scheduler.set_timesteps(num_inference_steps)
72
+ scheduler.timesteps = scheduler.timesteps.to(torch.float32) # minor fix to ensure MPS compatibility, fixed in diffusers PR 3925
73
+
74
+
75
+ style_files = ['learned_embeds_animal_toys.bin','learned_embeds_fftstyle.bin',
76
+ 'learned_embeds_midjourney_style.bin','learned_embeds_oil_style.bin','learned_embeds_space-style.bin']
77
+
78
+ seed_values = [8,16,50,80,128]
79
+ height = 512 # default height of Stable Diffusion
80
+ width = 512 # default width of Stable Diffusion
81
+ num_inference_steps = 5 # Number of denoising steps
82
+ guidance_scale = 7.5 # Scale for classifier-free guidance
83
+ num_styles = len(style_files)
84
+
85
+ def get_style_embeddings(style_file):
86
+ style_embed = torch.load(style_file)
87
+ style_name = list(style_embed.keys())[0]
88
+ return style_embed[style_name]
89
+
90
+ def get_EOS_pos_in_prompt(prompt):
91
+ return len(prompt.split())+1
92
+
93
+
94
+ import torch.nn.functional as F
95
+ """
96
+ def gradient_loss(images):
97
+ # Compute gradient magnitude using Sobel filters.
98
+ gradient_x = F.conv2d(images, torch.Tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).view(1, 1, 3, 3).to(images.device))
99
+ gradient_y = F.conv2d(images, torch.Tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]]).view(1, 1, 3, 3).to(images.device))
100
+ gradient_magnitude = torch.sqrt(gradient_x**2 + gradient_y**2)
101
+ return gradient_magnitude.mean()
102
+ """
103
+
104
+ from torchvision.transforms import ToTensor
105
+ def pil_to_latent(input_im):
106
+ # Single image -> single latent in a batch (so size 1, 4, 64, 64)
107
+ with torch.no_grad():
108
+ latent = vae.encode(tfms.ToTensor()(input_im).unsqueeze(0).to(torch_device)*2-1) # Note scaling
109
+ return 0.18215 * latent.latent_dist.sample()
110
+
111
+ def latents_to_pil(latents):
112
+ # bath of latents -> list of images
113
+ latents = (1 / 0.18215) * latents
114
+ with torch.no_grad():
115
+ image = vae.decode(latents).sample
116
+ image = (image / 2 + 0.5).clamp(0, 1)
117
+ image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
118
+ images = (image * 255).round().astype("uint8")
119
+ pil_images = [Image.fromarray(image) for image in images]
120
+ return pil_images
121
+
122
+
123
+ def additional_guidance(latents, scheduler, noise_pred, t, sigma, custom_loss_fn, custom_loss_scale):
124
+ #### ADDITIONAL GUIDANCE ###
125
+ # Requires grad on the latents
126
+ latents = latents.detach().requires_grad_()
127
+
128
+ # Get the predicted x0:
129
+ latents_x0 = latents - sigma * noise_pred
130
+
131
+ # Decode to image space
132
+ denoised_images = vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 # range (0, 1)
133
+
134
+ # Calculate loss
135
+ loss = custom_loss_fn(denoised_images) * custom_loss_scale
136
+
137
+ # Get gradient
138
+ cond_grad = torch.autograd.grad(loss, latents, allow_unused=False)[0]
139
+
140
+ # Modify the latents based on this gradient
141
+ latents = latents.detach() - cond_grad * sigma**2
142
+ return latents, loss
143
+
144
+
145
+ def generate_with_embs(text_embeddings, max_length, random_seed, loss_fn = None, custom_loss_scale=1.0):
146
+
147
+ height = 512 # default height of Stable Diffusion
148
+ width = 512 # default width of Stable Diffusion
149
+ num_inference_steps = 5 # Number of denoising steps
150
+ guidance_scale = 7.5 # Scale for classifier-free guidance
151
+
152
+ generator = torch.manual_seed(random_seed) # Seed generator to create the inital latent noise
153
+ batch_size = 1
154
+
155
+ uncond_input = tokenizer(
156
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
157
+ )
158
+ with torch.no_grad():
159
+ uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
160
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
161
+
162
+ # Prep Scheduler
163
+ set_timesteps(scheduler, num_inference_steps)
164
+
165
+ # Prep latents
166
+ latents = torch.randn(
167
+ (batch_size, unet.in_channels, height // 8, width // 8),
168
+ generator=generator,
169
+ )
170
+ latents = latents.to(torch_device)
171
+ latents = latents * scheduler.init_noise_sigma
172
+
173
+ # Loop
174
+ for i, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)):
175
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
176
+ latent_model_input = torch.cat([latents] * 2)
177
+ sigma = scheduler.sigmas[i]
178
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
179
+
180
+ # predict the noise residual
181
+ with torch.no_grad():
182
+ noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
183
+
184
+ # perform guidance
185
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
186
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
187
+ if loss_fn is not None:
188
+ if i%2 == 0:
189
+ latents, custom_loss = additional_guidance(latents, scheduler, noise_pred, t, sigma, loss_fn, custom_loss_scale)
190
+ print(i, 'loss:', custom_loss.item())
191
+
192
+ # compute the previous noisy sample x_t -> x_t-1
193
+ latents = scheduler.step(noise_pred, t, latents).prev_sample
194
+
195
+ return latents_to_pil(latents)[0]
196
+
197
+ def generate_image_custom_style(prompt, style_num=None, random_seed=41, custom_loss_fn = None, custom_loss_scale=1.0):
198
+ eos_pos = get_EOS_pos_in_prompt(prompt)
199
+
200
+ style_token_embedding = None
201
+ if style_num:
202
+ style_token_embedding = get_style_embeddings(style_files[style_num])
203
+
204
+ # tokenize
205
+ text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
206
+ max_length = text_input.input_ids.shape[-1]
207
+ input_ids = text_input.input_ids.to(torch_device)
208
+
209
+ # get token embeddings
210
+ token_emb_layer = text_encoder.text_model.embeddings.token_embedding
211
+ token_embeddings = token_emb_layer(input_ids)
212
+
213
+ # Append style token towards the end of the sentence embeddings
214
+ if style_token_embedding is not None:
215
+ token_embeddings[-1, eos_pos, :] = style_token_embedding
216
+
217
+ # combine with pos embs
218
+ pos_emb_layer = text_encoder.text_model.embeddings.position_embedding
219
+ position_ids = text_encoder.text_model.embeddings.position_ids[:, :77]
220
+ position_embeddings = pos_emb_layer(position_ids)
221
+ input_embeddings = token_embeddings + position_embeddings
222
+
223
+ # Feed through to get final output embs
224
+ modified_output_embeddings = get_output_embeds(input_embeddings)
225
+
226
+ # And generate an image with this:
227
+ generated_image = generate_with_embs(modified_output_embeddings, max_length, random_seed, custom_loss_fn, custom_loss_scale)
228
+ return generated_image
229
+
230
+
231
+ def show_images(images_list):
232
+ # Let's visualize the four channels of this latent representation:
233
+ fig, axs = plt.subplots(1, len(images_list), figsize=(16, 4))
234
+ for c in range(len(images_list)):
235
+ axs[c].imshow(images_list[c])
236
+ plt.show()
237
+
238
+
239
+ def invert_loss(gen_image):
240
+ inverter = T.RandomInvert(p=1.0)
241
+ inverted_img = inverter(gen_image)
242
+ #loss = torch.abs(gen_image - inverted_img).sum()
243
+ loss = torch.nn.functional.mse_loss(gen_image[:,0], gen_image[:,2]) + torch.nn.functional.mse_loss(gen_image[:,2], gen_image[:,1]) + torch.nn.functional.mse_loss(gen_image[:,0], gen_image[:,1])
244
+ return loss
245
+
246
+ def brilliance_loss(image, target_brilliance=10):
247
+ # Calculate the standard deviation of color channels
248
+ std_dev = torch.std(image, dim=(2, 3))
249
+ # Calculate the mean standard deviation across the batch
250
+ mean_std_dev = torch.mean(std_dev)
251
+ # Calculate the loss as the absolute difference from the target brilliance.
252
+ loss = torch.abs(mean_std_dev - target_brilliance)
253
+ return loss
254
+
255
+
256
+ def display_images_in_rows(images_with_titles, titles):
257
+ num_images = len(images_with_titles)
258
+ rows = 5 # Display 5 rows always
259
+ columns = 1 if num_images == 5 else 2 # Use 1 column if there are 5 images, otherwise 2 columns
260
+ fig, axes = plt.subplots(rows, columns + 1, figsize=(15, 5 * rows)) # Add an extra column for titles
261
+
262
+ for r in range(rows):
263
+ # Add the title on the extreme left in the middle of each picture
264
+ axes[r, 0].text(0.5, 0.5, titles[r], ha='center', va='center')
265
+ axes[r, 0].axis('off')
266
+
267
+ # Add "Without Loss" label above the first column and "With Loss" label above the second column (if applicable)
268
+ if columns == 2:
269
+ axes[r, 1].set_title("Without Loss", pad=10)
270
+ axes[r, 2].set_title("With Loss", pad=10)
271
+
272
+ for c in range(1, columns + 1):
273
+ index = r * columns + c - 1
274
+ if index < num_images:
275
+ image, _ = images_with_titles[index]
276
+ axes[r, c].imshow(image)
277
+ axes[r, c].axis('off')
278
+
279
+ return fig
280
+ # plt.show()
281
+
282
+
283
+ def image_generator(prompt = "dog", loss_function=None):
284
+ images_without_loss = []
285
+ images_with_loss = []
286
+
287
+ for i in range(num_styles):
288
+ generated_img = generate_image_custom_style(prompt,style_num = i,random_seed = seed_values[i],custom_loss_fn = None)
289
+ images_without_loss.append(generated_img)
290
+ if loss_function:
291
+ generated_img = generate_image_custom_style(prompt,style_num = i,random_seed = seed_values[i],custom_loss_fn = loss_function)
292
+ images_with_loss.append(generated_img)
293
+
294
+ generated_sd_images = []
295
+ titles = ["animal toy","fft style","mid journey","oil style","Space style"]
296
+
297
+ for i in range(len(titles)):
298
+ generated_sd_images.append((images_without_loss[i], titles[i]))
299
+ if images_with_loss != []:
300
+ generated_sd_images.append((images_with_loss[i], titles[i]))
301
+
302
+ return display_images_in_rows(generated_sd_images, titles)
303
+
304
+ # Create a wrapper function for show_misclassified_images()
305
+ def image_generator_wrapper(prompt = "dog", loss_function=None):
306
+ if loss_function == "Yes":
307
+ loss_function = brilliance_loss
308
+ else:
309
+ loss_function = None
310
+
311
+ return image_generator(prompt, loss_function)
312
+
313
+ description = 'Stable Diffusion is a generative artificial intelligence (generative AI) model that produces unique photorealistic images from text and image prompts.'
314
+ title = 'Image Generation using Stable Diffusion'
315
+
316
+ demo = gr.Interface(image_generator_wrapper,
317
+ inputs=[gr.Textbox(label="Enter prompt for generation", type="text", value="astronaut riding a cycle"),
318
+ gr.Radio(["Yes", "No"], value="No" , label="Apply Contrast Loss")],
319
+ outputs=gr.Plot(label="Generated Images"), title = "Stable Diffusion", description=description)
320
+ demo.launch()