sagar007 commited on
Commit
dbedcec
·
verified ·
1 Parent(s): 0e59359

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -135
app.py CHANGED
@@ -42,174 +42,82 @@ styles_mapping = {
42
 
43
  # Define seeds for all the styles
44
  seed_list = [11, 56, 110, 65, 5, 29, 47]
45
-
46
- # Loss Function based on Edge Detection
47
  def edge_detection(image):
48
  channels = image.shape[1]
49
-
50
- # Define the kernels for Edge Detection
51
- ed_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).unsqueeze(0).unsqueeze(0)
52
- ed_y = torch.tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]], dtype=torch.float32).unsqueeze(0).unsqueeze(0)
53
-
54
- # Replicate the Edge detection kernels for each channel
55
- ed_x = ed_x.repeat(channels, 1, 1, 1).to(image.device)
56
- ed_y = ed_y.repeat(channels, 1, 1, 1).to(image.device)
57
-
58
- # ed_x = ed_x.to(torch.float16)
59
- # ed_y = ed_y.to(torch.float16)
60
-
61
- # Convolve the image with the Edge detection kernels
62
- conv_ed_x = F.conv2d(image, ed_x, padding=1, groups=channels)
63
- conv_ed_y = F.conv2d(image, ed_y, padding=1, groups=channels)
64
-
65
- # Combine the x and y gradients after convolution
66
- ed_value = torch.sqrt(conv_ed_x**2 + conv_ed_y**2)
67
-
68
- return ed_value
69
-
70
- def edge_loss(image):
71
- ed_value = edge_detection(image)
72
- ed_capped = (ed_value > 0.5).to(torch.float32)
73
- return F.mse_loss(ed_value, ed_capped)
74
-
75
- def compute_loss(original_image, loss_type):
76
-
77
  if loss_type == 'blue':
78
- # blue loss
79
- # [:,2] -> all images in batch, only the blue channel
80
- error = torch.abs(original_image[:,2] - 0.9).mean()
81
  elif loss_type == 'edge':
82
- # edge loss
83
- error = edge_loss(original_image)
84
  elif loss_type == 'contrast':
85
- # RGB to Gray loss
86
- transformed_image = T.functional.adjust_contrast(original_image, contrast_factor = 2)
87
- error = torch.abs(transformed_image - original_image).mean()
88
  elif loss_type == 'brightness':
89
- # brightnesss loss
90
- transformed_image = T.functional.adjust_brightness(original_image, brightness_factor = 2)
91
- error = torch.abs(transformed_image - original_image).mean()
92
  elif loss_type == 'sharpness':
93
- # sharpness loss
94
- transformed_image = T.functional.adjust_sharpness(original_image, sharpness_factor = 2)
95
- error = torch.abs(transformed_image - original_image).mean()
96
  elif loss_type == 'saturation':
97
- # saturation loss
98
- transformed_image = T.functional.adjust_saturation(original_image, saturation_factor = 10)
99
- error = torch.abs(transformed_image - original_image).mean()
100
  else:
101
- print("error. Loss not defined")
102
-
103
- return error
104
-
105
-
106
-
107
- def get_examples():
108
- examples = [
109
- ['A bird sitting on a tree', 'Midjourney', 'edge'],
110
- ['Cats fighting on the road', 'Marc Allante', 'brightness'],
111
- ['A mouse with the head of a puppy', 'Hitokomoru Style', 'contrast'],
112
- ['A woman with a smiling face in front of an Italian Pizza', 'Hanfu Anime', 'brightness'],
113
- ['A campfire (oil on canvas)', 'Birb Style', 'blue'],
114
- ]
115
- return examples
116
-
117
- # Existing functions (latents_to_pil, show_image, generate_image)
118
- # ... (Copy all the existing functions here)
119
- def latents_to_pil(latents):
120
- # bath of latents -> list of images
121
- latents = (1 / 0.18215) * latents
122
- with torch.no_grad():
123
- image = sd_pipeline.vae.decode(latents).sample
124
- image = (image / 2 + 0.5).clamp(0, 1) # 0 to 1
125
- image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
126
- image = (image * 255).round().astype("uint8")
127
- return Image.fromarray(image[0])
128
-
129
-
130
- def show_image(prompt, concept, guidance_type):
131
-
132
- for idx, sd in enumerate(styles_mapping.keys()):
133
- if(sd == concept):
134
- break
135
- seed = seed_list[idx]
136
- prompt = f"{prompt} in the style of {styles_mapping[sd]}"
137
- styled_image_without_loss = latents_to_pil(generate_image(seed, prompt, guidance_type, loss_flag=False))
138
- styled_image_with_loss = latents_to_pil(generate_image(seed, prompt, guidance_type, loss_flag=True))
139
- return([styled_image_without_loss, styled_image_with_loss])
140
-
141
 
 
 
142
  def generate_image(seed, prompt, loss_type, loss_flag=False):
 
 
143
 
144
- generator = torch.manual_seed(seed)
145
- batch_size = 1
146
-
147
- # scheduler
148
- scheduler = LMSDiscreteScheduler(beta_start = 0.00085, beta_end = 0.012, beta_schedule = "scaled_linear", num_train_timesteps = 1000)
149
- scheduler.set_timesteps(num_inference_steps)
150
- scheduler.timesteps = scheduler.timesteps.to(torch.float32)
151
-
152
- # text embeddings of the prompt
153
- text_input = sd_pipeline.tokenizer(prompt, padding='max_length', max_length = sd_pipeline.tokenizer.model_max_length, truncation= True, return_tensors="pt")
154
- input_ids = text_input.input_ids.to(torch_device)
155
-
156
- with torch.no_grad():
157
- text_embeddings = sd_pipeline.text_encoder(text_input.input_ids.to(torch_device))[0]
158
 
159
- max_length = text_input.input_ids.shape[-1]
160
- uncond_input = sd_pipeline.tokenizer(
161
- [""] * batch_size, padding="max_length", max_length= max_length, return_tensors="pt"
162
- )
163
-
164
- with torch.no_grad():
165
- uncond_embeddings = sd_pipeline.text_encoder(uncond_input.input_ids.to(torch_device))[0]
166
-
167
- text_embeddings = torch.cat([uncond_embeddings,text_embeddings]) # shape: 2,77,768
168
-
169
- # random latent
170
  latents = torch.randn(
171
- (batch_size, sd_pipeline.unet.config.in_channels, height// 8, width //8),
172
- generator = generator,
173
- ) .to(torch.float32)
174
-
175
 
176
- latents = latents.to(torch_device)
177
- latents = latents * scheduler.init_noise_sigma
178
 
179
- for i, t in tqdm(enumerate(scheduler.timesteps), total = len(scheduler.timesteps)):
180
 
 
181
  latent_model_input = torch.cat([latents] * 2)
182
- sigma = scheduler.sigmas[i]
183
- latent_model_input = scheduler.scale_model_input(latent_model_input, t)
184
 
185
  with torch.no_grad():
186
- noise_pred = sd_pipeline.unet(latent_model_input.to(torch.float32), t, encoder_hidden_states=text_embeddings)["sample"]
187
 
188
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
189
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
190
 
191
- if loss_flag and i%5 == 0:
192
-
193
  latents = latents.detach().requires_grad_()
194
- # the following line alone does not work, it requires change to reduce step only once
195
- # hence commenting it out
196
- #latents_x0 = scheduler.step(noise_pred,t, latents).pred_original_sample
197
- latents_x0 = latents - sigma * noise_pred
198
-
199
- # use vae to decode the image
200
- denoised_images = sd_pipeline.vae.decode((1/ 0.18215) * latents_x0).sample / 2 + 0.5 # range(0,1)
201
 
202
  loss = compute_loss(denoised_images, loss_type) * loss_scale
203
- #loss = loss.to(torch.float16)
204
- print(f"{i} loss {loss}")
205
 
206
  cond_grad = torch.autograd.grad(loss, latents)[0]
207
- latents = latents.detach() - cond_grad * sigma**2
208
 
209
- latents = scheduler.step(noise_pred,t, latents).prev_sample
210
 
211
  return latents
212
 
 
213
  # Gradio interface function
214
  def generate_images(prompt, style, guidance_type):
215
  images = show_image(prompt, style, guidance_type)
 
42
 
43
  # Define seeds for all the styles
44
  seed_list = [11, 56, 110, 65, 5, 29, 47]
45
+ # Optimized loss computation functions
 
46
  def edge_detection(image):
47
  channels = image.shape[1]
48
+ kernels = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1],
49
+ [-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], device=image.device).float()
50
+ kernels = kernels.view(2, 1, 3, 3).repeat(channels, 1, 1, 1)
51
+ padded_image = F.pad(image, (1, 1, 1, 1), mode='replicate')
52
+ edge = F.conv2d(padded_image, kernels, groups=channels)
53
+ return torch.sqrt(edge[:, :channels]**2 + edge[:, channels:]**2)
54
+
55
+ @torch.jit.script
56
+ def compute_loss(original_image, loss_type: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  if loss_type == 'blue':
58
+ return torch.abs(original_image[:,2] - 0.9).mean()
 
 
59
  elif loss_type == 'edge':
60
+ ed_value = edge_detection(original_image)
61
+ return F.mse_loss(ed_value, (ed_value > 0.5).float())
62
  elif loss_type == 'contrast':
63
+ transformed_image = T.functional.adjust_contrast(original_image, contrast_factor=2)
64
+ return torch.abs(transformed_image - original_image).mean()
 
65
  elif loss_type == 'brightness':
66
+ transformed_image = T.functional.adjust_brightness(original_image, brightness_factor=2)
67
+ return torch.abs(transformed_image - original_image).mean()
 
68
  elif loss_type == 'sharpness':
69
+ transformed_image = T.functional.adjust_sharpness(original_image, sharpness_factor=2)
70
+ return torch.abs(transformed_image - original_image).mean()
 
71
  elif loss_type == 'saturation':
72
+ transformed_image = T.functional.adjust_saturation(original_image, saturation_factor=10)
73
+ return torch.abs(transformed_image - original_image).mean()
 
74
  else:
75
+ return torch.tensor(0.0, device=original_image.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
+ # Optimized generate_image function
78
+ @torch.cuda.amp.autocast()
79
  def generate_image(seed, prompt, loss_type, loss_flag=False):
80
+ generator = torch.manual_seed(seed)
81
+ batch_size = 1
82
 
83
+ text_embeddings = sd_pipeline._encode_prompt(prompt, sd_pipeline.device, 1, True)
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
 
 
 
 
 
 
 
 
 
 
 
85
  latents = torch.randn(
86
+ (batch_size, sd_pipeline.unet.config.in_channels, height // 8, width // 8),
87
+ generator=generator,
88
+ ).to(sd_pipeline.device)
 
89
 
90
+ latents = latents * sd_pipeline.scheduler.init_noise_sigma
 
91
 
92
+ sd_pipeline.scheduler.set_timesteps(num_inference_steps)
93
 
94
+ for i, t in enumerate(tqdm(sd_pipeline.scheduler.timesteps)):
95
  latent_model_input = torch.cat([latents] * 2)
96
+ latent_model_input = sd_pipeline.scheduler.scale_model_input(latent_model_input, t)
 
97
 
98
  with torch.no_grad():
99
+ noise_pred = sd_pipeline.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
100
 
101
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
102
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
103
 
104
+ if loss_flag and i % 5 == 0:
 
105
  latents = latents.detach().requires_grad_()
106
+ latents_x0 = sd_pipeline.scheduler.step(noise_pred, t, latents).pred_original_sample
107
+ with torch.no_grad():
108
+ denoised_images = sd_pipeline.vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5
 
 
 
 
109
 
110
  loss = compute_loss(denoised_images, loss_type) * loss_scale
111
+ print(f"Step {i}, Loss: {loss.item():.4f}")
 
112
 
113
  cond_grad = torch.autograd.grad(loss, latents)[0]
114
+ latents = latents.detach() - cond_grad * sd_pipeline.scheduler.sigmas[i] ** 2
115
 
116
+ latents = sd_pipeline.scheduler.step(noise_pred, t, latents).prev_sample
117
 
118
  return latents
119
 
120
+
121
  # Gradio interface function
122
  def generate_images(prompt, style, guidance_type):
123
  images = show_image(prompt, style, guidance_type)