Spaces:
Runtime error
Runtime error
File size: 3,081 Bytes
1920232 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import os
import torch
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch.nn.functional as F
from torchvision import transforms
from diffusers import StableDiffusionPipeline
# Define Loss Functions (same as in your code)
def edge_loss(image_tensor):
grayscale = image_tensor.mean(dim=0, keepdim=True)
grayscale = grayscale.unsqueeze(0)
sobel_x = torch.tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]], device=image_tensor.device).float().unsqueeze(0).unsqueeze(0)
sobel_y = sobel_x.transpose(2, 3)
gx = F.conv2d(grayscale, sobel_x, padding=1)
gy = F.conv2d(grayscale, sobel_y, padding=1)
return -torch.mean(torch.sqrt(gx ** 2 + gy ** 2))
def texture_loss(image_tensor):
return F.mse_loss(image_tensor, torch.rand_like(image_tensor, device=image_tensor.device))
def entropy_loss(image_tensor):
hist = torch.histc(image_tensor, bins=256, min=0, max=255)
hist = hist / hist.sum()
return -torch.sum(hist * torch.log(hist + 1e-7))
def symmetry_loss(image_tensor):
width = image_tensor.shape[-1]
left_half = image_tensor[:, :, :width // 2]
right_half = torch.flip(image_tensor[:, :, width // 2:], dims=[-1])
return F.mse_loss(left_half, right_half)
def contrast_loss(image_tensor):
min_val = image_tensor.min()
max_val = image_tensor.max()
return -torch.mean((image_tensor - min_val) / (max_val - min_val + 1e-7))
# Setup Stable Diffusion Pipeline
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to(device)
# Image transform to tensor
transform = transforms.ToTensor()
# Loss functions dictionary
losses = {
"edge": edge_loss,
"texture": texture_loss,
"entropy": entropy_loss,
"symmetry": symmetry_loss,
"contrast": contrast_loss
}
# Define function to generate images for a given seed
def generate_images(seed):
generator = torch.Generator(device).manual_seed(seed)
output_image = pipe("A futuristic city skyline at sunset", generator=generator).images[0]
# Convert to tensor
image_tensor = transform(output_image).to(device)
loss_images = []
loss_values = []
# Compute losses and generate modified images
for loss_name, loss_fn in losses.items():
loss_value = loss_fn(image_tensor)
# Resize to thumbnail size
thumbnail_image = output_image.copy()
thumbnail_image.thumbnail((128, 128))
# Save loss image with thumbnail
loss_images.append(thumbnail_image)
loss_values.append(f"{loss_name}: {loss_value.item():.4f}")
return loss_images, loss_values
# Gradio Interface
def gradio_interface(seed):
loss_images, loss_values = generate_images(int(seed))
return loss_images, loss_values
# Set up Gradio UI
interface = gr.Interface(
fn=gradio_interface,
inputs=gr.inputs.Textbox(label="Enter Seed"),
outputs=[gr.outputs.Gallery(label="Loss Images"), gr.outputs.Textbox(label="Loss Values")]
)
# Launch the interface
interface.launch() |