Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import gradio as gr | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
import torch.nn.functional as F | |
from torchvision import transforms | |
from diffusers import StableDiffusionPipeline | |
# Define Loss Functions (same as in your code) | |
def edge_loss(image_tensor): | |
grayscale = image_tensor.mean(dim=0, keepdim=True) | |
grayscale = grayscale.unsqueeze(0) | |
sobel_x = torch.tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]], device=image_tensor.device).float().unsqueeze(0).unsqueeze(0) | |
sobel_y = sobel_x.transpose(2, 3) | |
gx = F.conv2d(grayscale, sobel_x, padding=1) | |
gy = F.conv2d(grayscale, sobel_y, padding=1) | |
return -torch.mean(torch.sqrt(gx ** 2 + gy ** 2)) | |
def texture_loss(image_tensor): | |
return F.mse_loss(image_tensor, torch.rand_like(image_tensor, device=image_tensor.device)) | |
def entropy_loss(image_tensor): | |
hist = torch.histc(image_tensor, bins=256, min=0, max=255) | |
hist = hist / hist.sum() | |
return -torch.sum(hist * torch.log(hist + 1e-7)) | |
def symmetry_loss(image_tensor): | |
width = image_tensor.shape[-1] | |
left_half = image_tensor[:, :, :width // 2] | |
right_half = torch.flip(image_tensor[:, :, width // 2:], dims=[-1]) | |
return F.mse_loss(left_half, right_half) | |
def contrast_loss(image_tensor): | |
min_val = image_tensor.min() | |
max_val = image_tensor.max() | |
return -torch.mean((image_tensor - min_val) / (max_val - min_val + 1e-7)) | |
# Setup Stable Diffusion Pipeline | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to(device) | |
# Image transform to tensor | |
transform = transforms.ToTensor() | |
# Loss functions dictionary | |
losses = { | |
"edge": edge_loss, | |
"texture": texture_loss, | |
"entropy": entropy_loss, | |
"symmetry": symmetry_loss, | |
"contrast": contrast_loss | |
} | |
# Define function to generate images for a given seed | |
def generate_images(seed): | |
generator = torch.Generator(device).manual_seed(seed) | |
output_image = pipe("A futuristic city skyline at sunset", generator=generator).images[0] | |
# Convert to tensor | |
image_tensor = transform(output_image).to(device) | |
loss_images = [] | |
loss_values = [] | |
# Compute losses and generate modified images | |
for loss_name, loss_fn in losses.items(): | |
loss_value = loss_fn(image_tensor) | |
# Resize to thumbnail size | |
thumbnail_image = output_image.copy() | |
thumbnail_image.thumbnail((128, 128)) | |
# Save loss image with thumbnail | |
loss_images.append(thumbnail_image) | |
loss_values.append(f"{loss_name}: {loss_value.item():.4f}") | |
return loss_images, loss_values | |
# Gradio Interface | |
def gradio_interface(seed): | |
loss_images, loss_values = generate_images(int(seed)) | |
return loss_images, loss_values | |
# Set up Gradio UI | |
interface = gr.Interface( | |
fn=gradio_interface, | |
inputs=gr.inputs.Textbox(label="Enter Seed"), | |
outputs=[gr.outputs.Gallery(label="Loss Images"), gr.outputs.Textbox(label="Loss Values")] | |
) | |
# Launch the interface | |
interface.launch() |