Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import gradio as gr
|
4 |
+
import numpy as np
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
from PIL import Image
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torchvision import transforms
|
9 |
+
from diffusers import StableDiffusionPipeline
|
10 |
+
|
11 |
+
# Define Loss Functions (same as in your code)
|
12 |
+
def edge_loss(image_tensor):
|
13 |
+
grayscale = image_tensor.mean(dim=0, keepdim=True)
|
14 |
+
grayscale = grayscale.unsqueeze(0)
|
15 |
+
sobel_x = torch.tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]], device=image_tensor.device).float().unsqueeze(0).unsqueeze(0)
|
16 |
+
sobel_y = sobel_x.transpose(2, 3)
|
17 |
+
gx = F.conv2d(grayscale, sobel_x, padding=1)
|
18 |
+
gy = F.conv2d(grayscale, sobel_y, padding=1)
|
19 |
+
return -torch.mean(torch.sqrt(gx ** 2 + gy ** 2))
|
20 |
+
|
21 |
+
def texture_loss(image_tensor):
|
22 |
+
return F.mse_loss(image_tensor, torch.rand_like(image_tensor, device=image_tensor.device))
|
23 |
+
|
24 |
+
def entropy_loss(image_tensor):
|
25 |
+
hist = torch.histc(image_tensor, bins=256, min=0, max=255)
|
26 |
+
hist = hist / hist.sum()
|
27 |
+
return -torch.sum(hist * torch.log(hist + 1e-7))
|
28 |
+
|
29 |
+
def symmetry_loss(image_tensor):
|
30 |
+
width = image_tensor.shape[-1]
|
31 |
+
left_half = image_tensor[:, :, :width // 2]
|
32 |
+
right_half = torch.flip(image_tensor[:, :, width // 2:], dims=[-1])
|
33 |
+
return F.mse_loss(left_half, right_half)
|
34 |
+
|
35 |
+
def contrast_loss(image_tensor):
|
36 |
+
min_val = image_tensor.min()
|
37 |
+
max_val = image_tensor.max()
|
38 |
+
return -torch.mean((image_tensor - min_val) / (max_val - min_val + 1e-7))
|
39 |
+
|
40 |
+
# Setup Stable Diffusion Pipeline
|
41 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
42 |
+
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to(device)
|
43 |
+
|
44 |
+
# Image transform to tensor
|
45 |
+
transform = transforms.ToTensor()
|
46 |
+
|
47 |
+
# Loss functions dictionary
|
48 |
+
losses = {
|
49 |
+
"edge": edge_loss,
|
50 |
+
"texture": texture_loss,
|
51 |
+
"entropy": entropy_loss,
|
52 |
+
"symmetry": symmetry_loss,
|
53 |
+
"contrast": contrast_loss
|
54 |
+
}
|
55 |
+
|
56 |
+
# Define function to generate images for a given seed
|
57 |
+
def generate_images(seed):
|
58 |
+
generator = torch.Generator(device).manual_seed(seed)
|
59 |
+
output_image = pipe("A futuristic city skyline at sunset", generator=generator).images[0]
|
60 |
+
|
61 |
+
# Convert to tensor
|
62 |
+
image_tensor = transform(output_image).to(device)
|
63 |
+
|
64 |
+
loss_images = []
|
65 |
+
loss_values = []
|
66 |
+
|
67 |
+
# Compute losses and generate modified images
|
68 |
+
for loss_name, loss_fn in losses.items():
|
69 |
+
loss_value = loss_fn(image_tensor)
|
70 |
+
|
71 |
+
# Resize to thumbnail size
|
72 |
+
thumbnail_image = output_image.copy()
|
73 |
+
thumbnail_image.thumbnail((128, 128))
|
74 |
+
|
75 |
+
# Save loss image with thumbnail
|
76 |
+
loss_images.append(thumbnail_image)
|
77 |
+
loss_values.append(f"{loss_name}: {loss_value.item():.4f}")
|
78 |
+
|
79 |
+
return loss_images, loss_values
|
80 |
+
|
81 |
+
# Gradio Interface
|
82 |
+
def gradio_interface(seed):
|
83 |
+
loss_images, loss_values = generate_images(int(seed))
|
84 |
+
return loss_images, loss_values
|
85 |
+
|
86 |
+
# Set up Gradio UI
|
87 |
+
interface = gr.Interface(
|
88 |
+
fn=gradio_interface,
|
89 |
+
inputs=gr.inputs.Textbox(label="Enter Seed"),
|
90 |
+
outputs=[gr.outputs.Gallery(label="Loss Images"), gr.outputs.Textbox(label="Loss Values")]
|
91 |
+
)
|
92 |
+
|
93 |
+
# Launch the interface
|
94 |
+
interface.launch()
|