File size: 6,011 Bytes
2a6b1af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5ecd5f
a1a7f32
 
2a6b1af
b5ecd5f
 
2a6b1af
 
 
bc6c687
2a6b1af
d491fdb
 
 
 
 
 
 
 
 
 
 
2a6b1af
 
 
d491fdb
2a6b1af
b5ecd5f
2a6b1af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d491fdb
2a6b1af
 
 
 
3dbb2cf
2a6b1af
 
 
 
 
 
 
 
 
 
 
 
bc6c687
2a6b1af
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
import random
from typing import Mapping

import gradio as gr
import numpy
import torch
from huggingface_hub import hf_hub_download
from PIL import Image

from cldm.model import create_model, load_state_dict
from cldm.ddim_hacked import DDIMSampler
from mediapipe_face_common import generate_annotation

# Download the SD 1.5 model from HF
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_path = hf_hub_download(repo_id="CrucibleAI/ControlNetMediaPipeFace", filename="models/controlnet_sd21_laion_face_v2_full.ckpt", repo_type="model", revision="568dc2c9980572262d48cff1ef2a7e4a03fadeb6")
config_path = hf_hub_download(repo_id="CrucibleAI/ControlNetMediaPipeFace", filename="models/cldm_v21.yaml", repo_type="model", revision="568dc2c9980572262d48cff1ef2a7e4a03fadeb6")
model = create_model(config_path).cpu()
model.load_state_dict(load_state_dict(model_path, location=device))
model = model.to(device)
ddim_sampler = DDIMSampler(model)  # ControlNet _only_ works with DDIM.


def process(input_image: Image.Image, prompt, a_prompt, n_prompt, max_faces: int, min_confidence: float, num_samples, ddim_steps, guess_mode, strength, scale, seed: int, eta):
    with torch.no_grad():
        # Scale to 512x512.
        img_size = input_image.size
        scale_factor = 512/min(img_size)
        input_image = input_image.resize((1+int(img_size[0]*scale_factor), 1+int(img_size[1]*scale_factor)))
        img_size = input_image.size
        left_padding = (img_size[0] - 512)//2
        top_padding = (img_size[1] - 512)//2
        input_image = input_image.crop((left_padding, top_padding, left_padding+512, top_padding+512))
        
        # Generate annotation
        input_image = numpy.asarray(input_image)
        empty = generate_annotation(input_image, max_faces, min_confidence)
        visualization = Image.fromarray(empty)  # Save to help debug.

        # Prep for network:
        empty = numpy.moveaxis(empty, 2, 0)  # h, w, c -> c, h, w
        control = torch.from_numpy(empty.copy()).float().to(device) / 255.0
        control = torch.stack([control for _ in range(num_samples)], dim=0)
        # control = einops.rearrange(control, 'b h w c -> b c h w').clone()

        # Sanity check the dimensions.
        B, C, H, W = control.shape
        assert C == 3
        assert B == num_samples

        if seed != -1:
            random.seed(seed)
            os.environ['PYTHONHASHSEED'] = str(seed)
            numpy.random.seed(seed)
            torch.manual_seed(seed)
            torch.cuda.manual_seed(seed)
            torch.backends.cudnn.deterministic = True

        # model.low_vram_shift(is_diffusing=False)

        cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
        un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
        shape = (4, H // 8, W // 8)

        # model.low_vram_shift(is_diffusing=True)

        model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13)  # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
        samples, intermediates = ddim_sampler.sample(
            ddim_steps,
            num_samples,
            shape,
            cond,
            verbose=False,
            eta=eta,
            unconditional_guidance_scale=scale,
            unconditional_conditioning=un_cond
        )

        # model.low_vram_shift(is_diffusing=False)

        x_samples = model.decode_first_stage(samples)
        # x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(numpy.uint8)
        x_samples = numpy.moveaxis((x_samples * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(numpy.uint8), 1, -1)  # b, c, h, w -> b, h, w, c
        results = [visualization] + [x_samples[i] for i in range(num_samples)]

    return results


block = gr.Blocks().queue()
with block:
    with gr.Row():
        gr.Markdown("## Control Stable Diffusion with a Facial Pose")
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(source='upload', type="pil")
            prompt = gr.Textbox(label="Prompt")
            run_button = gr.Button(label="Run")
            with gr.Accordion("Advanced options", open=False):
                num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
                max_faces = gr.Slider(label="Max Faces", minimum=1, maximum=10, value=5, step=1)
                min_confidence = gr.Slider(label="Min Confidence", minimum=0.01, maximum=1.0, value=0.5, step=0.01)
                strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
                guess_mode = gr.Checkbox(label='Guess Mode', value=False)
                ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
                scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
                seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
                eta = gr.Number(label="eta (DDIM)", value=0.0)
                a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
                n_prompt = gr.Textbox(label="Negative Prompt",
                                      value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
        with gr.Column():
            result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
    ips = [input_image, prompt, a_prompt, n_prompt, max_faces, min_confidence, num_samples, ddim_steps, guess_mode, strength, scale, seed, eta]
    run_button.click(fn=process, inputs=ips, outputs=[result_gallery])


block.launch(server_name='0.0.0.0')