File size: 5,410 Bytes
5f38454
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fdc757a
 
 
 
5f38454
fdc757a
 
5f38454
fdc757a
 
5f38454
fdc757a
 
 
5f38454
fdc757a
 
5f38454
fdc757a
 
5f38454
fdc757a
 
 
 
 
 
5f38454
 
fdc757a
5f38454
fdc757a
 
 
 
 
 
 
 
 
5f38454
fdc757a
5f38454
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ec8bd76
5f38454
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import torch
import gradio as gr
from tqdm import tqdm
from PIL import Image
import torch.nn.functional as F
from torchvision import transforms as tfms
from transformers import CLIPTextModel, CLIPTokenizer, logging
from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel, DiffusionPipeline

torch_device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
if "mps" == torch_device: os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = "1"

# Load the pipeline
model_path = "CompVis/stable-diffusion-v1-4"
sd_pipeline = DiffusionPipeline.from_pretrained(
    model_path,
    low_cpu_mem_usage=True,
    torch_dtype=torch.float32
).to(torch_device)

# Load textual inversions
sd_pipeline.load_textual_inversion("sd-concepts-library/illustration-style")
sd_pipeline.load_textual_inversion("sd-concepts-library/line-art")
sd_pipeline.load_textual_inversion("sd-concepts-library/hitokomoru-style-nao")
sd_pipeline.load_textual_inversion("sd-concepts-library/style-of-marc-allante")
sd_pipeline.load_textual_inversion("sd-concepts-library/midjourney-style")
sd_pipeline.load_textual_inversion("sd-concepts-library/hanfu-anime-style")
sd_pipeline.load_textual_inversion("sd-concepts-library/birb-style")

# Update style token dictionary
style_token_dict = {
    "Illustration Style": '<illustration-style>',
    "Line Art":'<line-art>',
    "Hitokomoru Style":'<hitokomoru-style-nao>',
    "Marc Allante": '<Marc_Allante>',
    "Midjourney":'<midjourney-style>',
    "Hanfu Anime": '<hanfu-anime-style>',
    "Birb Style": '<birb-style>'
}

def apply_guidance(image, guidance_method, loss_scale):
    # Convert PIL Image to tensor
    img_tensor = tfms.ToTensor()(image).unsqueeze(0).to(torch_device)
    
    if guidance_method == 'Grayscale':
        gray = tfms.Grayscale(3)(img_tensor)
        guided = img_tensor + (gray - img_tensor) * (loss_scale / 10000)
    elif guidance_method == 'Bright':
        bright = F.relu(img_tensor)  # Simple brightness increase
        guided = img_tensor + (bright - img_tensor) * (loss_scale / 10000)
    elif guidance_method == 'Contrast':
        mean = img_tensor.mean()
        contrast = (img_tensor - mean) * 2 + mean
        guided = img_tensor + (contrast - img_tensor) * (loss_scale / 10000)
    elif guidance_method == 'Symmetry':
        flipped = torch.flip(img_tensor, [3])  # Flip horizontally
        guided = img_tensor + (flipped - img_tensor) * (loss_scale / 10000)
    elif guidance_method == 'Saturation':
        saturated = tfms.functional.adjust_saturation(img_tensor, 2)
        guided = img_tensor + (saturated - img_tensor) * (loss_scale / 10000)
    else:
        return image

    # Convert back to PIL Image
    guided = guided.squeeze(0).clamp(0, 1)
    guided = (guided * 255).byte().cpu().permute(1, 2, 0).numpy()
    return Image.fromarray(guided)

def generate_with_guidance(prompt, num_inference_steps, guidance_scale, seed, guidance_method, loss_scale):
    # Generate image with pipeline
    generator = torch.Generator(device=torch_device).manual_seed(seed)
    image = sd_pipeline(
        prompt,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        generator=generator
    ).images[0]

    # Apply guidance
    guided_image = apply_guidance(image, guidance_method, loss_scale)
    
    return guided_image

def inference(text, style, inference_step, guidance_scale, seed, guidance_method, loss_scale):
    prompt = text + " " + style_token_dict[style]

    # Generate image with pipeline
    image_pipeline = sd_pipeline(
        prompt,
        num_inference_steps=inference_step,
        guidance_scale=guidance_scale,
        generator=torch.Generator(device=torch_device).manual_seed(seed)
    ).images[0]

    # Generate image with guidance
    image_guide = generate_with_guidance(prompt, inference_step, guidance_scale, seed, guidance_method, loss_scale)

    return image_pipeline, image_guide

title = "Generative with Textual Inversion and Guidance"
description = "A Gradio interface to infer Stable Diffusion and generate images with different art styles and guidance methods"
examples = [
    ["A majestic castle on a floating island", 'Illustration Style', 20, 7.5, 42, 'Grayscale', 200]
]

demo = gr.Interface(inference, 
                    inputs = [gr.Textbox(label="Prompt", type="text"),
                              gr.Dropdown(label="Style", choices=list(style_token_dict.keys()), value="Illustration Style"), 
                              gr.Slider(1, 50, 10, step = 1, label="Inference steps"),
                              gr.Slider(1, 10, 7.5, step = 0.1, label="Guidance scale"),
                              gr.Slider(0, 10000, 42, step = 1, label="Seed"),
                              gr.Dropdown(label="Guidance method", choices=['Grayscale', 'Bright', 'Contrast', 
                                                                  'Symmetry', 'Saturation'], value="Grayscale"),
                              gr.Slider(100, 10000, 200, step = 100, label="Loss scale")],
                    outputs= [gr.Image(width=512, height=512, label="Generated art"),
                              gr.Image(width=512, height=512, label="Generated art with guidance")],
                    title=title,
                    description=description,
                    examples=examples)

demo.launch()