PrarthanaTS commited on
Commit
8811dd9
·
1 Parent(s): e85c7e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -33
app.py CHANGED
@@ -235,14 +235,6 @@ def show_images(images_list):
235
  axs[c].imshow(images_list[c])
236
  plt.show()
237
 
238
-
239
- def invert_loss(gen_image):
240
- inverter = T.RandomInvert(p=1.0)
241
- inverted_img = inverter(gen_image)
242
- #loss = torch.abs(gen_image - inverted_img).sum()
243
- loss = torch.nn.functional.mse_loss(gen_image[:,0], gen_image[:,2]) + torch.nn.functional.mse_loss(gen_image[:,2], gen_image[:,1]) + torch.nn.functional.mse_loss(gen_image[:,0], gen_image[:,1])
244
- return loss
245
-
246
  def brilliance_loss(image, target_brilliance=10):
247
  # Calculate the standard deviation of color channels
248
  std_dev = torch.std(image, dim=(2, 3))
@@ -252,6 +244,42 @@ def brilliance_loss(image, target_brilliance=10):
252
  loss = torch.abs(mean_std_dev - target_brilliance)
253
  return loss
254
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
 
256
  def display_images_in_rows(images_with_titles, titles):
257
  num_images = len(images_with_titles)
@@ -280,41 +308,46 @@ def display_images_in_rows(images_with_titles, titles):
280
  # plt.show()
281
 
282
 
283
- def image_generator(prompt = "dog", loss_function=None):
284
- images_without_loss = []
285
- images_with_loss = []
286
 
287
- for i in range(num_styles):
288
- generated_img = generate_image_custom_style(prompt,style_num = i,random_seed = seed_values[i],custom_loss_fn = None)
289
- images_without_loss.append(generated_img)
290
- if loss_function:
291
- generated_img = generate_image_custom_style(prompt,style_num = i,random_seed = seed_values[i],custom_loss_fn = loss_function)
292
- images_with_loss.append(generated_img)
293
 
294
- generated_sd_images = []
295
- titles = ["animal toy","fft style","mid journey","oil style","Space style"]
 
 
 
 
 
 
 
 
296
 
297
- for i in range(len(titles)):
298
- generated_sd_images.append((images_without_loss[i], titles[i]))
299
- if images_with_loss != []:
300
- generated_sd_images.append((images_with_loss[i], titles[i]))
301
 
302
- return display_images_in_rows(generated_sd_images, titles)
 
 
 
303
 
304
- # Create a wrapper function for show_misclassified_images()
305
- def image_generator_wrapper(prompt = "dog", loss_function=None):
306
- if loss_function == "Yes":
307
- loss_function = brilliance_loss
308
- else:
309
- loss_function = None
310
 
311
- return image_generator(prompt, loss_function)
 
 
312
 
313
  description = 'Stable Diffusion is a generative artificial intelligence (generative AI) model that produces unique photorealistic images from text and image prompts.'
314
  title = 'Image Generation using Stable Diffusion'
315
 
316
  demo = gr.Interface(image_generator_wrapper,
317
  inputs=[gr.Textbox(label="Enter prompt for generation", type="text", value="A ballerina cat dancing in space"),
318
- gr.Radio(["Yes", "No"], value="No" , label="Apply Brilliance Loss")],
319
- outputs=gr.Plot(label="Generated Images"), title = "Stable Diffusion", description=description)
 
 
320
  demo.launch()
 
235
  axs[c].imshow(images_list[c])
236
  plt.show()
237
 
 
 
 
 
 
 
 
 
238
  def brilliance_loss(image, target_brilliance=10):
239
  # Calculate the standard deviation of color channels
240
  std_dev = torch.std(image, dim=(2, 3))
 
244
  loss = torch.abs(mean_std_dev - target_brilliance)
245
  return loss
246
 
247
+ import numpy as np
248
+ from PIL import Image
249
+
250
+ import torch
251
+ from scipy.stats import wasserstein_distance
252
+
253
+ def exposure_loss(image, target_exposure = 3):
254
+ # Calculate the brightness (exposure) of the image.
255
+ image_brightness = torch.mean(image)
256
+
257
+ # Calculate the loss as the absolute difference from the target exposure.
258
+ loss = torch.abs(image_brightness - target_exposure)
259
+ return loss
260
+
261
+ def color_diversity_loss(images):
262
+ # Calculate color diversity by measuring the variance of color channels (R, G, B).
263
+ color_variance = torch.var(images, dim=(2, 3), keepdim=True)
264
+ # Sum the color variances for each channel to get the total color diversity.
265
+ total_color_diversity = torch.sum(color_variance, dim=1)
266
+ return total_color_diversity
267
+
268
+ def sharpness_loss(images):
269
+ # Apply the Laplacian filter to the images to measure sharpness.
270
+ laplacian_filter = torch.Tensor([[-1, -1, -1],
271
+ [-1, 8, -1],
272
+ [-1, -1, -1]]).view(1, 1, 3, 3).to(images.device)
273
+
274
+ # Expand the filter to match the number of channels in the input image.
275
+ laplacian_filter = laplacian_filter.expand(-1, images.shape[1], -1, -1)
276
+
277
+ # Apply the convolution operation.
278
+ laplacian = torch.abs(F.conv2d(images, laplacian_filter))
279
+
280
+ # Calculate sharpness as the negative of the Laplacian variance.
281
+ sharpness = torch.var(laplacian)
282
+ return sharpness
283
 
284
  def display_images_in_rows(images_with_titles, titles):
285
  num_images = len(images_with_titles)
 
308
  # plt.show()
309
 
310
 
311
+ def image_generator(prompt="cat", loss_function=None):
312
+ images_without_loss = []
313
+ images_with_loss = []
314
 
315
+ for i in range(num_styles):
316
+ generated_img = generate_image_custom_style(prompt, style_num=i, random_seed=seed_values[i], custom_loss_fn=None)
317
+ images_without_loss.append(generated_img)
 
 
 
318
 
319
+ if loss_function:
320
+ if loss_function == "exposure_loss":
321
+ generated_img = generate_image_custom_style(prompt, style_num=i, random_seed=seed_values[i], custom_loss_fn=exposure_loss)
322
+ elif loss_function == "color_diversity_loss":
323
+ generated_img = generate_image_custom_style(prompt, style_num=i, random_seed=seed_values[i], custom_loss_fn=color_diversity_loss)
324
+ elif loss_function == "sharpness_loss":
325
+ generated_img = generate_image_custom_style(prompt, style_num=i, random_seed=seed_values[i], custom_loss_fn=sharpness_loss)
326
+ elif loss_function == "brilliance_loss":
327
+ generated_img = generate_image_custom_style(prompt, style_num=i, random_seed=seed_values[i], custom_loss_fn=brilliance_loss)
328
+ images_with_loss.append(generated_img)
329
 
330
+ generated_sd_images = []
331
+ titles = ["animal toy", "fft style", "mid journey", "oil style", "Space style"]
 
 
332
 
333
+ for i in range(len(titles)):
334
+ generated_sd_images.append((images_without_loss[i], titles[i]))
335
+ if images_with_loss:
336
+ generated_sd_images.append((images_with_loss[i], titles[i]))
337
 
338
+ return generated_sd_images
 
 
 
 
 
339
 
340
+ # Create a wrapper function for image_generator()
341
+ def image_generator_wrapper(prompt="dog", selected_loss="None"):
342
+ return image_generator(prompt, selected_loss)
343
 
344
  description = 'Stable Diffusion is a generative artificial intelligence (generative AI) model that produces unique photorealistic images from text and image prompts.'
345
  title = 'Image Generation using Stable Diffusion'
346
 
347
  demo = gr.Interface(image_generator_wrapper,
348
  inputs=[gr.Textbox(label="Enter prompt for generation", type="text", value="A ballerina cat dancing in space"),
349
+ gr.Radio(["None", "exposure_loss", "color_diversity_loss", "sharpness_loss", "brilliance_loss"], value="None", label="Select Loss")],
350
+ outputs=gr.Plot(label="Generated Images"),
351
+ title=title,
352
+ description=description)
353
  demo.launch()