Shriti09 commited on
Commit
1920232
·
verified ·
1 Parent(s): 2d9bf5d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -0
app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ from PIL import Image
7
+ import torch.nn.functional as F
8
+ from torchvision import transforms
9
+ from diffusers import StableDiffusionPipeline
10
+
11
+ # Define Loss Functions (same as in your code)
12
+ def edge_loss(image_tensor):
13
+ grayscale = image_tensor.mean(dim=0, keepdim=True)
14
+ grayscale = grayscale.unsqueeze(0)
15
+ sobel_x = torch.tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]], device=image_tensor.device).float().unsqueeze(0).unsqueeze(0)
16
+ sobel_y = sobel_x.transpose(2, 3)
17
+ gx = F.conv2d(grayscale, sobel_x, padding=1)
18
+ gy = F.conv2d(grayscale, sobel_y, padding=1)
19
+ return -torch.mean(torch.sqrt(gx ** 2 + gy ** 2))
20
+
21
+ def texture_loss(image_tensor):
22
+ return F.mse_loss(image_tensor, torch.rand_like(image_tensor, device=image_tensor.device))
23
+
24
+ def entropy_loss(image_tensor):
25
+ hist = torch.histc(image_tensor, bins=256, min=0, max=255)
26
+ hist = hist / hist.sum()
27
+ return -torch.sum(hist * torch.log(hist + 1e-7))
28
+
29
+ def symmetry_loss(image_tensor):
30
+ width = image_tensor.shape[-1]
31
+ left_half = image_tensor[:, :, :width // 2]
32
+ right_half = torch.flip(image_tensor[:, :, width // 2:], dims=[-1])
33
+ return F.mse_loss(left_half, right_half)
34
+
35
+ def contrast_loss(image_tensor):
36
+ min_val = image_tensor.min()
37
+ max_val = image_tensor.max()
38
+ return -torch.mean((image_tensor - min_val) / (max_val - min_val + 1e-7))
39
+
40
+ # Setup Stable Diffusion Pipeline
41
+ device = "cuda" if torch.cuda.is_available() else "cpu"
42
+ pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to(device)
43
+
44
+ # Image transform to tensor
45
+ transform = transforms.ToTensor()
46
+
47
+ # Loss functions dictionary
48
+ losses = {
49
+ "edge": edge_loss,
50
+ "texture": texture_loss,
51
+ "entropy": entropy_loss,
52
+ "symmetry": symmetry_loss,
53
+ "contrast": contrast_loss
54
+ }
55
+
56
+ # Define function to generate images for a given seed
57
+ def generate_images(seed):
58
+ generator = torch.Generator(device).manual_seed(seed)
59
+ output_image = pipe("A futuristic city skyline at sunset", generator=generator).images[0]
60
+
61
+ # Convert to tensor
62
+ image_tensor = transform(output_image).to(device)
63
+
64
+ loss_images = []
65
+ loss_values = []
66
+
67
+ # Compute losses and generate modified images
68
+ for loss_name, loss_fn in losses.items():
69
+ loss_value = loss_fn(image_tensor)
70
+
71
+ # Resize to thumbnail size
72
+ thumbnail_image = output_image.copy()
73
+ thumbnail_image.thumbnail((128, 128))
74
+
75
+ # Save loss image with thumbnail
76
+ loss_images.append(thumbnail_image)
77
+ loss_values.append(f"{loss_name}: {loss_value.item():.4f}")
78
+
79
+ return loss_images, loss_values
80
+
81
+ # Gradio Interface
82
+ def gradio_interface(seed):
83
+ loss_images, loss_values = generate_images(int(seed))
84
+ return loss_images, loss_values
85
+
86
+ # Set up Gradio UI
87
+ interface = gr.Interface(
88
+ fn=gradio_interface,
89
+ inputs=gr.inputs.Textbox(label="Enter Seed"),
90
+ outputs=[gr.outputs.Gallery(label="Loss Images"), gr.outputs.Textbox(label="Loss Values")]
91
+ )
92
+
93
+ # Launch the interface
94
+ interface.launch()