Spaces:
Runtime error
Runtime error
Commit
·
8811dd9
1
Parent(s):
e85c7e3
Update app.py
Browse files
app.py
CHANGED
@@ -235,14 +235,6 @@ def show_images(images_list):
|
|
235 |
axs[c].imshow(images_list[c])
|
236 |
plt.show()
|
237 |
|
238 |
-
|
239 |
-
def invert_loss(gen_image):
|
240 |
-
inverter = T.RandomInvert(p=1.0)
|
241 |
-
inverted_img = inverter(gen_image)
|
242 |
-
#loss = torch.abs(gen_image - inverted_img).sum()
|
243 |
-
loss = torch.nn.functional.mse_loss(gen_image[:,0], gen_image[:,2]) + torch.nn.functional.mse_loss(gen_image[:,2], gen_image[:,1]) + torch.nn.functional.mse_loss(gen_image[:,0], gen_image[:,1])
|
244 |
-
return loss
|
245 |
-
|
246 |
def brilliance_loss(image, target_brilliance=10):
|
247 |
# Calculate the standard deviation of color channels
|
248 |
std_dev = torch.std(image, dim=(2, 3))
|
@@ -252,6 +244,42 @@ def brilliance_loss(image, target_brilliance=10):
|
|
252 |
loss = torch.abs(mean_std_dev - target_brilliance)
|
253 |
return loss
|
254 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
255 |
|
256 |
def display_images_in_rows(images_with_titles, titles):
|
257 |
num_images = len(images_with_titles)
|
@@ -280,41 +308,46 @@ def display_images_in_rows(images_with_titles, titles):
|
|
280 |
# plt.show()
|
281 |
|
282 |
|
283 |
-
def image_generator(prompt
|
284 |
-
|
285 |
-
|
286 |
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
if loss_function:
|
291 |
-
generated_img = generate_image_custom_style(prompt,style_num = i,random_seed = seed_values[i],custom_loss_fn = loss_function)
|
292 |
-
images_with_loss.append(generated_img)
|
293 |
|
294 |
-
|
295 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
296 |
|
297 |
-
|
298 |
-
|
299 |
-
if images_with_loss != []:
|
300 |
-
generated_sd_images.append((images_with_loss[i], titles[i]))
|
301 |
|
302 |
-
|
|
|
|
|
|
|
303 |
|
304 |
-
|
305 |
-
def image_generator_wrapper(prompt = "dog", loss_function=None):
|
306 |
-
if loss_function == "Yes":
|
307 |
-
loss_function = brilliance_loss
|
308 |
-
else:
|
309 |
-
loss_function = None
|
310 |
|
311 |
-
|
|
|
|
|
312 |
|
313 |
description = 'Stable Diffusion is a generative artificial intelligence (generative AI) model that produces unique photorealistic images from text and image prompts.'
|
314 |
title = 'Image Generation using Stable Diffusion'
|
315 |
|
316 |
demo = gr.Interface(image_generator_wrapper,
|
317 |
inputs=[gr.Textbox(label="Enter prompt for generation", type="text", value="A ballerina cat dancing in space"),
|
318 |
-
gr.Radio(["
|
319 |
-
outputs=gr.Plot(label="Generated Images"),
|
|
|
|
|
320 |
demo.launch()
|
|
|
235 |
axs[c].imshow(images_list[c])
|
236 |
plt.show()
|
237 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
def brilliance_loss(image, target_brilliance=10):
|
239 |
# Calculate the standard deviation of color channels
|
240 |
std_dev = torch.std(image, dim=(2, 3))
|
|
|
244 |
loss = torch.abs(mean_std_dev - target_brilliance)
|
245 |
return loss
|
246 |
|
247 |
+
import numpy as np
|
248 |
+
from PIL import Image
|
249 |
+
|
250 |
+
import torch
|
251 |
+
from scipy.stats import wasserstein_distance
|
252 |
+
|
253 |
+
def exposure_loss(image, target_exposure = 3):
|
254 |
+
# Calculate the brightness (exposure) of the image.
|
255 |
+
image_brightness = torch.mean(image)
|
256 |
+
|
257 |
+
# Calculate the loss as the absolute difference from the target exposure.
|
258 |
+
loss = torch.abs(image_brightness - target_exposure)
|
259 |
+
return loss
|
260 |
+
|
261 |
+
def color_diversity_loss(images):
|
262 |
+
# Calculate color diversity by measuring the variance of color channels (R, G, B).
|
263 |
+
color_variance = torch.var(images, dim=(2, 3), keepdim=True)
|
264 |
+
# Sum the color variances for each channel to get the total color diversity.
|
265 |
+
total_color_diversity = torch.sum(color_variance, dim=1)
|
266 |
+
return total_color_diversity
|
267 |
+
|
268 |
+
def sharpness_loss(images):
|
269 |
+
# Apply the Laplacian filter to the images to measure sharpness.
|
270 |
+
laplacian_filter = torch.Tensor([[-1, -1, -1],
|
271 |
+
[-1, 8, -1],
|
272 |
+
[-1, -1, -1]]).view(1, 1, 3, 3).to(images.device)
|
273 |
+
|
274 |
+
# Expand the filter to match the number of channels in the input image.
|
275 |
+
laplacian_filter = laplacian_filter.expand(-1, images.shape[1], -1, -1)
|
276 |
+
|
277 |
+
# Apply the convolution operation.
|
278 |
+
laplacian = torch.abs(F.conv2d(images, laplacian_filter))
|
279 |
+
|
280 |
+
# Calculate sharpness as the negative of the Laplacian variance.
|
281 |
+
sharpness = torch.var(laplacian)
|
282 |
+
return sharpness
|
283 |
|
284 |
def display_images_in_rows(images_with_titles, titles):
|
285 |
num_images = len(images_with_titles)
|
|
|
308 |
# plt.show()
|
309 |
|
310 |
|
311 |
+
def image_generator(prompt="cat", loss_function=None):
|
312 |
+
images_without_loss = []
|
313 |
+
images_with_loss = []
|
314 |
|
315 |
+
for i in range(num_styles):
|
316 |
+
generated_img = generate_image_custom_style(prompt, style_num=i, random_seed=seed_values[i], custom_loss_fn=None)
|
317 |
+
images_without_loss.append(generated_img)
|
|
|
|
|
|
|
318 |
|
319 |
+
if loss_function:
|
320 |
+
if loss_function == "exposure_loss":
|
321 |
+
generated_img = generate_image_custom_style(prompt, style_num=i, random_seed=seed_values[i], custom_loss_fn=exposure_loss)
|
322 |
+
elif loss_function == "color_diversity_loss":
|
323 |
+
generated_img = generate_image_custom_style(prompt, style_num=i, random_seed=seed_values[i], custom_loss_fn=color_diversity_loss)
|
324 |
+
elif loss_function == "sharpness_loss":
|
325 |
+
generated_img = generate_image_custom_style(prompt, style_num=i, random_seed=seed_values[i], custom_loss_fn=sharpness_loss)
|
326 |
+
elif loss_function == "brilliance_loss":
|
327 |
+
generated_img = generate_image_custom_style(prompt, style_num=i, random_seed=seed_values[i], custom_loss_fn=brilliance_loss)
|
328 |
+
images_with_loss.append(generated_img)
|
329 |
|
330 |
+
generated_sd_images = []
|
331 |
+
titles = ["animal toy", "fft style", "mid journey", "oil style", "Space style"]
|
|
|
|
|
332 |
|
333 |
+
for i in range(len(titles)):
|
334 |
+
generated_sd_images.append((images_without_loss[i], titles[i]))
|
335 |
+
if images_with_loss:
|
336 |
+
generated_sd_images.append((images_with_loss[i], titles[i]))
|
337 |
|
338 |
+
return generated_sd_images
|
|
|
|
|
|
|
|
|
|
|
339 |
|
340 |
+
# Create a wrapper function for image_generator()
|
341 |
+
def image_generator_wrapper(prompt="dog", selected_loss="None"):
|
342 |
+
return image_generator(prompt, selected_loss)
|
343 |
|
344 |
description = 'Stable Diffusion is a generative artificial intelligence (generative AI) model that produces unique photorealistic images from text and image prompts.'
|
345 |
title = 'Image Generation using Stable Diffusion'
|
346 |
|
347 |
demo = gr.Interface(image_generator_wrapper,
|
348 |
inputs=[gr.Textbox(label="Enter prompt for generation", type="text", value="A ballerina cat dancing in space"),
|
349 |
+
gr.Radio(["None", "exposure_loss", "color_diversity_loss", "sharpness_loss", "brilliance_loss"], value="None", label="Select Loss")],
|
350 |
+
outputs=gr.Plot(label="Generated Images"),
|
351 |
+
title=title,
|
352 |
+
description=description)
|
353 |
demo.launch()
|