File size: 5,393 Bytes
eb71923
 
 
 
 
 
 
 
 
076d107
7f3c758
eb71923
 
 
 
 
 
076d107
b62ef71
076d107
 
eb71923
 
 
076d107
eb71923
076d107
eb71923
 
 
076d107
eb71923
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
076d107
eb71923
 
076d107
 
 
 
 
 
 
 
 
 
 
 
 
 
7f3c758
 
 
 
 
 
 
 
 
 
 
076d107
7f3c758
076d107
 
 
 
7f3c758
 
076d107
 
 
 
 
eb71923
076d107
eb71923
076d107
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import spaces
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL, EulerAncestralDiscreteScheduler
from diffusers.utils import load_image
from PIL import Image
import torch
import numpy as np
import cv2
import gradio as gr
from torchvision import transforms 
import fire
import os

controlnet = ControlNetModel.from_pretrained(
    "geyongtao/HumanWild",
    torch_dtype=torch.float16
).to('cuda')

vae = AutoencoderKL.from_pretrained(
    "madebyollin/sdxl-vae-fp16-fix", 
    torch_dtype=torch.float16).to("cuda")

pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    controlnet=controlnet,
    vae=vae,
    torch_dtype=torch.float16,
    use_safetensors=True,
    low_cpu_mem_usage=True,
    offload_state_dict=True,
).to('cuda')
pipe.controlnet.to(memory_format=torch.channels_last)

# pipe.enable_xformers_memory_efficient_attention()
pipe.force_zeros_for_empty_prompt = False


def resize_image(image):
    image = image.convert('RGB')
    current_size = image.size
    if current_size[0] > current_size[1]:
        center_cropped_image = transforms.functional.center_crop(image, (current_size[1], current_size[1]))
    else:
        center_cropped_image = transforms.functional.center_crop(image, (current_size[0], current_size[0]))
    resized_image = transforms.functional.resize(center_cropped_image, (1024, 1024))
    return resized_image

def get_normal_map(image):
    image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
    with torch.no_grad(), torch.autocast("cuda"):
        depth_map = depth_estimator(image).predicted_depth
    image = transforms.functional.center_crop(image, min(image.shape[-2:]))
    depth_map = torch.nn.functional.interpolate(
        depth_map.unsqueeze(1),
        size=(1024, 1024),
        mode="bicubic",
        align_corners=False,
    )
    depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
    depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
    depth_map = (depth_map - depth_min) / (depth_max - depth_min)
    image = torch.cat([depth_map] * 3, dim=1)
    image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
    image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
    return image


@spaces.GPU
def generate_(prompt, negative_prompt, canny_image, num_steps, controlnet_conditioning_scale, seed):
    generator = torch.Generator("cuda").manual_seed(seed)    
    images = pipe(
    prompt, negative_prompt=negative_prompt, image=canny_image, num_inference_steps=num_steps, controlnet_conditioning_scale=float(controlnet_conditioning_scale),
    generator=generator,
    ).images
    return images

@spaces.GPU
def process(normal_image, prompt, negative_prompt, num_steps, controlnet_conditioning_scale, seed):
    # resize input_image to 1024x1024
    normal_image = resize_image(normal_image)
    # depth_image = get_depth_map(input_image)
    images = generate_(prompt, negative_prompt, normal_image, num_steps, controlnet_conditioning_scale, seed)

    return [normal_image, images[0]]


def run_demo():
    block = gr.Blocks().queue()
    
    with block:
        gr.Markdown("## Surface Normal ControlNet ")
        gr.HTML('''
          <p style="margin-bottom: 10px; font-size: 94%">
            This is a demo for Surface Normal ControlNet that using
            <a href="https://huggingface.co/geyongtao/HumanWild" target="_blank"> HumanWild model</a> as backbone. 
          </p>
        ''')
        with gr.Row():
            with gr.Column():
                input_image = gr.Image(sources=None, type="pil") # None for upload, ctrl+v and webcam

                example_folder = os.path.join(os.path.dirname(__file__), "./assets")
                example_fns = [os.path.join(example_folder, example) for example in os.listdir(example_folder)]
                gr.Examples(
                    examples=example_fns,
                    inputs=[input_image],
                    cache_examples=False,
                    label='Examples (click one of the images below to start)',
                    examples_per_page=30
                )
            
                prompt = gr.Textbox(label="Prompt")
                negative_prompt = "Logo,Watermark,Text,Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra limbs,Gross proportions,Missing arms,Mutated hands,Long neck,Duplicate,Mutilated,Mutilated hands,Poorly drawn face,Deformed,Bad anatomy,Cloned face,Malformed limbs,Missing legs,Too many fingers"
                num_steps = gr.Slider(label="Number of steps", minimum=25, maximum=100, value=50, step=1)
                controlnet_conditioning_scale = gr.Slider(label="ControlNet conditioning scale", minimum=0.1, maximum=2.0, value=1.0, step=0.05)
                seed = gr.Slider(label="Seed", minimum=0, maximum=2147483647, step=1, randomize=True,)
                run_button = gr.Button(value="Run")


            with gr.Column():
                result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", columns=[2], height='auto')
        ips = [input_image, prompt, negative_prompt, num_steps, controlnet_conditioning_scale, seed]
        
        run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
    
    block.launch(debug = True)

if __name__ == '__main__':
    fire.Fire(run_demo)