File size: 5,574 Bytes
6e3f87e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eafa330
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import spaces
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL, EulerAncestralDiscreteScheduler
from diffusers.utils import load_image
from PIL import Image
import torch
import numpy as np
import cv2
import gradio as gr
from torchvision import transforms 
import fire
import os

controlnet = ControlNetModel.from_pretrained(
    "geyongtao/HumanWild",
    torch_dtype=torch.float16
).to('cuda')

vae = AutoencoderKL.from_pretrained(
    "madebyollin/sdxl-vae-fp16-fix", 
    torch_dtype=torch.float16).to("cuda")

pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    controlnet=controlnet,
    vae=vae,
    torch_dtype=torch.float16,
    use_safetensors=True,
    low_cpu_mem_usage=True,
    offload_state_dict=True,
).to('cuda')
pipe.controlnet.to(memory_format=torch.channels_last)

# pipe.enable_xformers_memory_efficient_attention()
pipe.force_zeros_for_empty_prompt = False


def resize_image(image):
    image = image.convert('RGB')
    current_size = image.size
    if current_size[0] > current_size[1]:
        center_cropped_image = transforms.functional.center_crop(image, (current_size[1], current_size[1]))
    else:
        center_cropped_image = transforms.functional.center_crop(image, (current_size[0], current_size[0]))
    resized_image = transforms.functional.resize(center_cropped_image, (1024, 1024))
    return resized_image

def get_normal_map(image):
    image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
    with torch.no_grad(), torch.autocast("cuda"):
        depth_map = depth_estimator(image).predicted_depth
    image = transforms.functional.center_crop(image, min(image.shape[-2:]))
    depth_map = torch.nn.functional.interpolate(
        depth_map.unsqueeze(1),
        size=(1024, 1024),
        mode="bicubic",
        align_corners=False,
    )
    depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
    depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
    depth_map = (depth_map - depth_min) / (depth_max - depth_min)
    image = torch.cat([depth_map] * 3, dim=1)
    image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
    image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
    return image


@spaces.GPU
def generate_(prompt, negative_prompt, normal_image, num_steps, controlnet_conditioning_scale, seed):
    generator = torch.Generator("cuda").manual_seed(seed)    
    images = pipe(
        prompt, 
        negative_prompt=negative_prompt, 
        image=normal_image, 
        num_inference_steps=num_steps, 
        controlnet_conditioning_scale=float(controlnet_conditioning_scale),
        num_images_per_prompt=2,
        generator=generator,
    ).images
    return images

@spaces.GPU
def process(normal_image, prompt, negative_prompt, num_steps, controlnet_conditioning_scale, seed):
    # resize input_image to 1024x1024
    normal_image = resize_image(normal_image)
    # depth_image = get_depth_map(input_image)
    images = generate_(prompt, negative_prompt, normal_image, num_steps, controlnet_conditioning_scale, seed)

    return [images[0], images[1]]


def run_demo():
    
    _TITLE = '''3D Human Reconstruction in the Wild with Synthetic Data Using Generative Models'''

    block = gr.Blocks().queue()
    
    with block:
        gr.Markdown("# 3D Human Reconstruction in the Wild with Synthetic Data Using Generative Models ")
        gr.HTML('''
          <p style="margin-bottom: 10px; font-size: 94%">
            This is a demo for Surface Normal ControlNet 
        ''')
        with gr.Row():
            with gr.Column():
                input_image = gr.Image(sources=None, type="pil") # None for upload, ctrl+v and webcam

                example_folder = os.path.join(os.path.dirname(__file__), "./assets")
                example_fns = [os.path.join(example_folder, example) for example in os.listdir(example_folder)]
                gr.Examples(
                    examples=example_fns,
                    inputs=[input_image],
                    cache_examples=False,
                    label='Examples (click one of the images below to start)',
                    examples_per_page=30
                )
            
                prompt = gr.Textbox(label="Prompt", value="a person, in the wild")
                negative_prompt = gr.Textbox(visible=False, label="Negative prompt", value="Logo,Watermark,Text,Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra limbs,Gross proportions,Missing arms,Mutated hands,Long neck,Duplicate,Mutilated,Mutilated hands,Poorly drawn face,Deformed,Bad anatomy,Cloned face,Malformed limbs,Missing legs,Too many fingers")
                num_steps = gr.Slider(label="Number of steps", minimum=25, maximum=50, value=30, step=1)
                controlnet_conditioning_scale = gr.Slider(label="ControlNet conditioning scale", minimum=0.1, maximum=1.0, value=0.95, step=0.05)
                seed = gr.Slider(label="Seed", minimum=0, maximum=2147483647, step=1, randomize=True,)
                run_button = gr.Button(value="Run")

            with gr.Column():
                result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", columns=[2], height='auto')
        ips = [input_image, prompt, negative_prompt, num_steps, controlnet_conditioning_scale, seed]
        
        run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
    
    block.launch(debug = True)

if __name__ == '__main__':
    fire.Fire(run_demo)