File size: 7,071 Bytes
5f38454
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import os
import torch
import gradio as gr
from tqdm import tqdm
from PIL import Image
import torch.nn.functional as F
from torchvision import transforms as tfms
from transformers import CLIPTextModel, CLIPTokenizer, logging
from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel, DiffusionPipeline

torch_device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
if "mps" == torch_device: os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = "1"

# Load the pipeline
model_path = "CompVis/stable-diffusion-v1-4"
sd_pipeline = DiffusionPipeline.from_pretrained(
    model_path,
    low_cpu_mem_usage=True,
    torch_dtype=torch.float32
).to(torch_device)

# Load textual inversions
sd_pipeline.load_textual_inversion("sd-concepts-library/illustration-style")
sd_pipeline.load_textual_inversion("sd-concepts-library/line-art")
sd_pipeline.load_textual_inversion("sd-concepts-library/hitokomoru-style-nao")
sd_pipeline.load_textual_inversion("sd-concepts-library/style-of-marc-allante")
sd_pipeline.load_textual_inversion("sd-concepts-library/midjourney-style")
sd_pipeline.load_textual_inversion("sd-concepts-library/hanfu-anime-style")
sd_pipeline.load_textual_inversion("sd-concepts-library/birb-style")

# Update style token dictionary
style_token_dict = {
    "Illustration Style": '<illustration-style>',
    "Line Art":'<line-art>',
    "Hitokomoru Style":'<hitokomoru-style-nao>',
    "Marc Allante": '<Marc_Allante>',
    "Midjourney":'<midjourney-style>',
    "Hanfu Anime": '<hanfu-anime-style>',
    "Birb Style": '<birb-style>'
}

def apply_guidance(latents, guidance_method, loss_scale):
    if guidance_method == 'Grayscale':
        rgb = latents_to_pil(latents)[0]
        gray = rgb.convert('L')
        gray_latents = pil_to_latent(gray.convert('RGB'))
        return latents + (gray_latents - latents) * loss_scale
    elif guidance_method == 'Bright':
        bright_latents = F.relu(latents)  # Simple brightness increase
        return latents + (bright_latents - latents) * loss_scale
    elif guidance_method == 'Contrast':
        mean = latents.mean()
        contrast_latents = (latents - mean) * 2 + mean
        return latents + (contrast_latents - latents) * loss_scale
    elif guidance_method == 'Symmetry':
        flipped_latents = torch.flip(latents, [3])  # Flip horizontally
        return latents + (flipped_latents - latents) * loss_scale
    elif guidance_method == 'Saturation':
        rgb = latents_to_pil(latents)[0]
        saturated = tfms.functional.adjust_saturation(tfms.ToTensor()(rgb), 2)
        saturated_latents = pil_to_latent(tfms.ToPILImage()(saturated))
        return latents + (saturated_latents - latents) * loss_scale
    else:
        return latents

def generate_with_guidance(prompt, num_inference_steps, guidance_scale, seed, guidance_method, loss_scale):
    generator = torch.Generator(device=torch_device).manual_seed(seed)
    
    # Get the text embeddings
    text_input = sd_pipeline.tokenizer(prompt, padding="max_length", max_length=sd_pipeline.tokenizer.model_max_length, truncation=True, return_tensors="pt")
    with torch.no_grad():
        text_embeddings = sd_pipeline.text_encoder(text_input.input_ids.to(torch_device))[0]
    
    # Set the timesteps
    sd_pipeline.scheduler.set_timesteps(num_inference_steps)
    
    # Prepare latents
    latents = torch.randn(
        (1, sd_pipeline.unet.in_channels, 64, 64),
        generator=generator,
        device=torch_device
    )
    latents = latents * sd_pipeline.scheduler.init_noise_sigma
    
    # Denoising loop
    for t in tqdm(sd_pipeline.scheduler.timesteps):
        # Expand the latents for classifier-free guidance
        latent_model_input = torch.cat([latents] * 2)
        latent_model_input = sd_pipeline.scheduler.scale_model_input(latent_model_input, timestep=t)
        
        # Predict the noise residual
        with torch.no_grad():
            noise_pred = sd_pipeline.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
        
        # Perform guidance
        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
        
        # Apply custom guidance
        latents = apply_guidance(latents, guidance_method, loss_scale / 10000)  # Normalize loss_scale
        
        # Compute the previous noisy sample x_t -> x_t-1
        latents = sd_pipeline.scheduler.step(noise_pred, t, latents).prev_sample
    
    # Scale and decode the image latents with vae
    latents = 1 / 0.18215 * latents
    with torch.no_grad():
        image = sd_pipeline.vae.decode(latents).sample
    
    # Convert to PIL Image
    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
    image = (image * 255).round().astype("uint8")[0]
    image = Image.fromarray(image)
    
    return image

def inference(text, style, inference_step, guidance_scale, seed, guidance_method, loss_scale):
    prompt = text + " " + style_token_dict[style]

    # Generate image with pipeline
    image_pipeline = sd_pipeline(
        prompt,
        num_inference_steps=inference_step,
        guidance_scale=guidance_scale,
        generator=torch.Generator(device=torch_device).manual_seed(seed)
    ).images[0]

    # Generate image with guidance
    image_guide = generate_with_guidance(prompt, inference_step, guidance_scale, seed, guidance_method, loss_scale)

    return image_pipeline, image_guide

title = "Generative with Textual Inversion and Guidance"
description = "A Gradio interface to infer Stable Diffusion and generate images with different art styles and guidance methods"
examples = [
    ["A majestic castle on a floating island", 'Illustration Style', 20, 7.5, 42, 'Grayscale', 200],
    ["A cyberpunk cityscape at night", 'Midjourney', 25, 8.0, 123, 'Contrast', 300]
]

demo = gr.Interface(inference, 
                    inputs = [gr.Textbox(label="Prompt", type="text"),
                              gr.Dropdown(label="Style", choices=list(style_token_dict.keys()), value="Illustration Style"), 
                              gr.Slider(1, 50, 10, step = 1, label="Inference steps"),
                              gr.Slider(1, 10, 7.5, step = 0.1, label="Guidance scale"),
                              gr.Slider(0, 10000, 42, step = 1, label="Seed"),
                              gr.Dropdown(label="Guidance method", choices=['Grayscale', 'Bright', 'Contrast', 
                                                                  'Symmetry', 'Saturation'], value="Grayscale"),
                              gr.Slider(100, 10000, 200, step = 100, label="Loss scale")],
                    outputs= [gr.Image(width=512, height=512, label="Generated art"),
                              gr.Image(width=512, height=512, label="Generated art with guidance")],
                    title=title,
                    description=description,
                    examples=examples)

demo.launch()