sagar007's picture
Update app.py
7ad86ac verified
raw
history blame
5.41 kB
import os
import torch
import gradio as gr
from tqdm import tqdm
from PIL import Image
import torch.nn.functional as F
from torchvision import transforms as tfms
from transformers import CLIPTextModel, CLIPTokenizer, logging
from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel, DiffusionPipeline
torch_device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
if "mps" == torch_device: os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = "1"
# Load the pipeline
model_path = "CompVis/stable-diffusion-v1-4"
sd_pipeline = DiffusionPipeline.from_pretrained(
model_path,
low_cpu_mem_usage=True,
torch_dtype=torch.float32
).to(torch_device)
# Load textual inversions
sd_pipeline.load_textual_inversion("sd-concepts-library/illustration-style")
sd_pipeline.load_textual_inversion("sd-concepts-library/line-art")
sd_pipeline.load_textual_inversion("sd-concepts-library/hitokomoru-style-nao")
sd_pipeline.load_textual_inversion("sd-concepts-library/style-of-marc-allante")
sd_pipeline.load_textual_inversion("sd-concepts-library/midjourney-style")
sd_pipeline.load_textual_inversion("sd-concepts-library/hanfu-anime-style")
sd_pipeline.load_textual_inversion("sd-concepts-library/birb-style")
# Update style token dictionary
style_token_dict = {
"Illustration Style": '<illustration-style>',
"Line Art":'<line-art>',
"Hitokomoru Style":'<hitokomoru-style-nao>',
"Marc Allante": '<Marc_Allante>',
"Midjourney":'<midjourney-style>',
"Hanfu Anime": '<hanfu-anime-style>',
"Birb Style": '<birb-style>'
}
def apply_guidance(image, guidance_method, loss_scale):
# Convert PIL Image to tensor
img_tensor = tfms.ToTensor()(image).unsqueeze(0).to(torch_device)
if guidance_method == 'Grayscale':
gray = tfms.Grayscale(3)(img_tensor)
guided = img_tensor + (gray - img_tensor) * (loss_scale / 10000)
elif guidance_method == 'Bright':
bright = F.relu(img_tensor) # Simple brightness increase
guided = img_tensor + (bright - img_tensor) * (loss_scale / 10000)
elif guidance_method == 'Contrast':
mean = img_tensor.mean()
contrast = (img_tensor - mean) * 2 + mean
guided = img_tensor + (contrast - img_tensor) * (loss_scale / 10000)
elif guidance_method == 'Symmetry':
flipped = torch.flip(img_tensor, [3]) # Flip horizontally
guided = img_tensor + (flipped - img_tensor) * (loss_scale / 10000)
elif guidance_method == 'Saturation':
saturated = tfms.functional.adjust_saturation(img_tensor, 2)
guided = img_tensor + (saturated - img_tensor) * (loss_scale / 10000)
else:
return image
# Convert back to PIL Image
guided = guided.squeeze(0).clamp(0, 1)
guided = (guided * 255).byte().cpu().permute(1, 2, 0).numpy()
return Image.fromarray(guided)
def generate_with_guidance(prompt, num_inference_steps, guidance_scale, seed, guidance_method, loss_scale):
# Generate image with pipeline
generator = torch.Generator(device=torch_device).manual_seed(seed)
image = sd_pipeline(
prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator
).images[0]
# Apply guidance
guided_image = apply_guidance(image, guidance_method, loss_scale)
return guided_image
def inference(text, style, inference_step, guidance_scale, seed, guidance_method, loss_scale):
prompt = text + " " + style_token_dict[style]
# Generate image with pipeline
image_pipeline = sd_pipeline(
prompt,
num_inference_steps=inference_step,
guidance_scale=guidance_scale,
generator=torch.Generator(device=torch_device).manual_seed(seed)
).images[0]
# Generate image with guidance
image_guide = generate_with_guidance(prompt, inference_step, guidance_scale, seed, guidance_method, loss_scale)
return image_pipeline, image_guide
title = "Generative with Textual Inversion and Guidance"
description = "A Gradio interface to infer Stable Diffusion and generate images with different art styles and guidance methods"
examples = [
["A majestic castle on a floating island", 'Illustration Style', 10, 7.5, 42, 'Grayscale', 200]
]
demo = gr.Interface(inference,
inputs = [gr.Textbox(label="Prompt", type="text"),
gr.Dropdown(label="Style", choices=list(style_token_dict.keys()), value="Illustration Style"),
gr.Slider(1, 50, 10, step = 1, label="Inference steps"),
gr.Slider(1, 10, 7.5, step = 0.1, label="Guidance scale"),
gr.Slider(0, 10000, 42, step = 1, label="Seed"),
gr.Dropdown(label="Guidance method", choices=['Grayscale', 'Bright', 'Contrast',
'Symmetry', 'Saturation'], value="Grayscale"),
gr.Slider(100, 10000, 200, step = 100, label="Loss scale")],
outputs= [gr.Image(width=512, height=512, label="Generated art"),
gr.Image(width=512, height=512, label="Generated art with guidance")],
title=title,
description=description,
examples=examples)
demo.launch()