Spaces:
Runtime error
Runtime error
import gradio as gr | |
import PIL | |
import torch | |
import numpy as np | |
from PIL import Image | |
from tqdm import tqdm | |
import torch.nn.functional as F | |
import torchvision.transforms as T | |
from diffusers import LMSDiscreteScheduler, DiffusionPipeline | |
# configurations | |
torch_device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" | |
height, width = 512, 512 | |
guidance_scale = 8 | |
loss_scale = 200 | |
num_inference_steps = 10 | |
model_path = "CompVis/stable-diffusion-v1-4" | |
sd_pipeline = DiffusionPipeline.from_pretrained( | |
model_path, | |
low_cpu_mem_usage = True, | |
torch_dtype=torch.float32 | |
).to(torch_device) | |
sd_pipeline.load_textual_inversion("sd-concepts-library/illustration-style") | |
sd_pipeline.load_textual_inversion("sd-concepts-library/line-art") | |
sd_pipeline.load_textual_inversion("sd-concepts-library/hitokomoru-style-nao") | |
sd_pipeline.load_textual_inversion("sd-concepts-library/style-of-marc-allante") | |
sd_pipeline.load_textual_inversion("sd-concepts-library/midjourney-style") | |
sd_pipeline.load_textual_inversion("sd-concepts-library/hanfu-anime-style") | |
sd_pipeline.load_textual_inversion("sd-concepts-library/birb-style") | |
styles_mapping = { | |
"Illustration Style": '<illustration-style>', "Line Art":'<line-art>', | |
"Hitokomoru Style":'<hitokomoru-style-nao>', "Marc Allante": '<Marc_Allante>', | |
"Midjourney":'<midjourney-style>', "Hanfu Anime": '<hanfu-anime-style>', | |
"Birb Style": '<birb-style>' | |
} | |
# Define seeds for all the styles | |
seed_list = [11, 56, 110, 65, 5, 29, 47] | |
# Optimized loss computation functions | |
def edge_detection(image): | |
channels = image.shape[1] | |
kernels = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1], | |
[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], device=image.device).float() | |
kernels = kernels.view(2, 1, 3, 3).repeat(channels, 1, 1, 1) | |
padded_image = F.pad(image, (1, 1, 1, 1), mode='replicate') | |
edge = F.conv2d(padded_image, kernels, groups=channels) | |
return torch.sqrt(edge[:, :channels]**2 + edge[:, channels:]**2) | |
def compute_loss(original_image, loss_type: str): | |
if loss_type == 'blue': | |
return torch.abs(original_image[:,2] - 0.9).mean() | |
elif loss_type == 'edge': | |
ed_value = edge_detection(original_image) | |
return F.mse_loss(ed_value, (ed_value > 0.5).float()) | |
elif loss_type == 'contrast': | |
transformed_image = TF.adjust_contrast(original_image, contrast_factor=2.0) | |
return torch.abs(transformed_image - original_image).mean() | |
elif loss_type == 'brightness': | |
transformed_image = TF.adjust_brightness(original_image, brightness_factor=2.0) | |
return torch.abs(transformed_image - original_image).mean() | |
elif loss_type == 'sharpness': | |
transformed_image = TF.adjust_sharpness(original_image, sharpness_factor=2.0) | |
return torch.abs(transformed_image - original_image).mean() | |
elif loss_type == 'saturation': | |
transformed_image = TF.adjust_saturation(original_image, saturation_factor=10.0) | |
return torch.abs(transformed_image - original_image).mean() | |
else: | |
return torch.tensor(0.0, device=original_image.device) | |
# Optimized generate_image function | |
def generate_image(seed, prompt, loss_type, loss_flag=False): | |
generator = torch.manual_seed(seed) | |
batch_size = 1 | |
text_embeddings = sd_pipeline._encode_prompt(prompt, sd_pipeline.device, 1, True) | |
latents = torch.randn( | |
(batch_size, sd_pipeline.unet.config.in_channels, height // 8, width // 8), | |
generator=generator, | |
).to(sd_pipeline.device) | |
latents = latents * sd_pipeline.scheduler.init_noise_sigma | |
sd_pipeline.scheduler.set_timesteps(num_inference_steps) | |
for i, t in enumerate(tqdm(sd_pipeline.scheduler.timesteps)): | |
latent_model_input = torch.cat([latents] * 2) | |
latent_model_input = sd_pipeline.scheduler.scale_model_input(latent_model_input, t) | |
with torch.no_grad(): | |
noise_pred = sd_pipeline.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
if loss_flag and i % 5 == 0: | |
latents = latents.detach().requires_grad_() | |
latents_x0 = sd_pipeline.scheduler.step(noise_pred, t, latents).prev_sample | |
with torch.no_grad(): | |
denoised_images = sd_pipeline.vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 | |
loss = compute_loss(denoised_images, loss_type) * loss_scale | |
print(f"Step {i}, Loss: {loss.item():.4f}") | |
cond_grad = torch.autograd.grad(loss, latents)[0] | |
latents = latents.detach() - cond_grad * sd_pipeline.scheduler.sigmas[i] ** 2 | |
latents = sd_pipeline.scheduler.step(noise_pred, t, latents).prev_sample | |
return latents | |
def generate_image(prompt, style, guidance_type): | |
styled_prompt = f"{prompt} in the style of {styles_mapping[style]}" | |
seed = torch.randint(0, 1000000, (1,)).item() | |
latents = generate_image(seed, styled_prompt, guidance_type, loss_flag=True) | |
with torch.no_grad(): | |
image = sd_pipeline.decode_latents(latents) | |
image = sd_pipeline.numpy_to_pil(image)[0] | |
return image | |
def get_examples(): | |
examples = [ | |
["A bird sitting on a tree", "Midjourney", "edge"], | |
["Cats fighting on the road", "Marc Allante", "brightness"], | |
["A mouse with the head of a puppy", "Hitokomoru Style", "contrast"], | |
["A woman with a smiling face in front of an Italian Pizza", "Hanfu Anime", "brightness"], | |
["A campfire (oil on canvas)", "Birb Style", "blue"], | |
] | |
return examples | |
iface = gr.Interface( | |
fn=generate_image, | |
inputs=[ | |
gr.Textbox(label="Prompt"), | |
gr.Dropdown(list(styles_mapping.keys()), label="Style"), | |
gr.Dropdown(["blue", "edge", "contrast", "brightness", "sharpness", "saturation"], label="Guidance Type"), | |
], | |
outputs=gr.Image(label="Generated Image"), | |
title="Stable Diffusion with Custom Styles", | |
description="Generate images using a custom Stable Diffusion model with various styles and guidance types.", | |
examples=get_examples(), | |
) | |
iface.launch() |