Spaces:
Sleeping
Sleeping
File size: 7,071 Bytes
5f38454 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import os
import torch
import gradio as gr
from tqdm import tqdm
from PIL import Image
import torch.nn.functional as F
from torchvision import transforms as tfms
from transformers import CLIPTextModel, CLIPTokenizer, logging
from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel, DiffusionPipeline
torch_device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
if "mps" == torch_device: os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = "1"
# Load the pipeline
model_path = "CompVis/stable-diffusion-v1-4"
sd_pipeline = DiffusionPipeline.from_pretrained(
model_path,
low_cpu_mem_usage=True,
torch_dtype=torch.float32
).to(torch_device)
# Load textual inversions
sd_pipeline.load_textual_inversion("sd-concepts-library/illustration-style")
sd_pipeline.load_textual_inversion("sd-concepts-library/line-art")
sd_pipeline.load_textual_inversion("sd-concepts-library/hitokomoru-style-nao")
sd_pipeline.load_textual_inversion("sd-concepts-library/style-of-marc-allante")
sd_pipeline.load_textual_inversion("sd-concepts-library/midjourney-style")
sd_pipeline.load_textual_inversion("sd-concepts-library/hanfu-anime-style")
sd_pipeline.load_textual_inversion("sd-concepts-library/birb-style")
# Update style token dictionary
style_token_dict = {
"Illustration Style": '<illustration-style>',
"Line Art":'<line-art>',
"Hitokomoru Style":'<hitokomoru-style-nao>',
"Marc Allante": '<Marc_Allante>',
"Midjourney":'<midjourney-style>',
"Hanfu Anime": '<hanfu-anime-style>',
"Birb Style": '<birb-style>'
}
def apply_guidance(latents, guidance_method, loss_scale):
if guidance_method == 'Grayscale':
rgb = latents_to_pil(latents)[0]
gray = rgb.convert('L')
gray_latents = pil_to_latent(gray.convert('RGB'))
return latents + (gray_latents - latents) * loss_scale
elif guidance_method == 'Bright':
bright_latents = F.relu(latents) # Simple brightness increase
return latents + (bright_latents - latents) * loss_scale
elif guidance_method == 'Contrast':
mean = latents.mean()
contrast_latents = (latents - mean) * 2 + mean
return latents + (contrast_latents - latents) * loss_scale
elif guidance_method == 'Symmetry':
flipped_latents = torch.flip(latents, [3]) # Flip horizontally
return latents + (flipped_latents - latents) * loss_scale
elif guidance_method == 'Saturation':
rgb = latents_to_pil(latents)[0]
saturated = tfms.functional.adjust_saturation(tfms.ToTensor()(rgb), 2)
saturated_latents = pil_to_latent(tfms.ToPILImage()(saturated))
return latents + (saturated_latents - latents) * loss_scale
else:
return latents
def generate_with_guidance(prompt, num_inference_steps, guidance_scale, seed, guidance_method, loss_scale):
generator = torch.Generator(device=torch_device).manual_seed(seed)
# Get the text embeddings
text_input = sd_pipeline.tokenizer(prompt, padding="max_length", max_length=sd_pipeline.tokenizer.model_max_length, truncation=True, return_tensors="pt")
with torch.no_grad():
text_embeddings = sd_pipeline.text_encoder(text_input.input_ids.to(torch_device))[0]
# Set the timesteps
sd_pipeline.scheduler.set_timesteps(num_inference_steps)
# Prepare latents
latents = torch.randn(
(1, sd_pipeline.unet.in_channels, 64, 64),
generator=generator,
device=torch_device
)
latents = latents * sd_pipeline.scheduler.init_noise_sigma
# Denoising loop
for t in tqdm(sd_pipeline.scheduler.timesteps):
# Expand the latents for classifier-free guidance
latent_model_input = torch.cat([latents] * 2)
latent_model_input = sd_pipeline.scheduler.scale_model_input(latent_model_input, timestep=t)
# Predict the noise residual
with torch.no_grad():
noise_pred = sd_pipeline.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# Perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# Apply custom guidance
latents = apply_guidance(latents, guidance_method, loss_scale / 10000) # Normalize loss_scale
# Compute the previous noisy sample x_t -> x_t-1
latents = sd_pipeline.scheduler.step(noise_pred, t, latents).prev_sample
# Scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
with torch.no_grad():
image = sd_pipeline.vae.decode(latents).sample
# Convert to PIL Image
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
image = (image * 255).round().astype("uint8")[0]
image = Image.fromarray(image)
return image
def inference(text, style, inference_step, guidance_scale, seed, guidance_method, loss_scale):
prompt = text + " " + style_token_dict[style]
# Generate image with pipeline
image_pipeline = sd_pipeline(
prompt,
num_inference_steps=inference_step,
guidance_scale=guidance_scale,
generator=torch.Generator(device=torch_device).manual_seed(seed)
).images[0]
# Generate image with guidance
image_guide = generate_with_guidance(prompt, inference_step, guidance_scale, seed, guidance_method, loss_scale)
return image_pipeline, image_guide
title = "Generative with Textual Inversion and Guidance"
description = "A Gradio interface to infer Stable Diffusion and generate images with different art styles and guidance methods"
examples = [
["A majestic castle on a floating island", 'Illustration Style', 20, 7.5, 42, 'Grayscale', 200],
["A cyberpunk cityscape at night", 'Midjourney', 25, 8.0, 123, 'Contrast', 300]
]
demo = gr.Interface(inference,
inputs = [gr.Textbox(label="Prompt", type="text"),
gr.Dropdown(label="Style", choices=list(style_token_dict.keys()), value="Illustration Style"),
gr.Slider(1, 50, 10, step = 1, label="Inference steps"),
gr.Slider(1, 10, 7.5, step = 0.1, label="Guidance scale"),
gr.Slider(0, 10000, 42, step = 1, label="Seed"),
gr.Dropdown(label="Guidance method", choices=['Grayscale', 'Bright', 'Contrast',
'Symmetry', 'Saturation'], value="Grayscale"),
gr.Slider(100, 10000, 200, step = 100, label="Loss scale")],
outputs= [gr.Image(width=512, height=512, label="Generated art"),
gr.Image(width=512, height=512, label="Generated art with guidance")],
title=title,
description=description,
examples=examples)
demo.launch() |