File size: 7,497 Bytes
98f5fb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9999ca7
98f5fb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
# ------------------------------------------------------------------------------
# Copyright (c) 2023, Andres Chait. All rights reserved.
# ------------------------------------------------------------------------------

from __future__ import annotations

import math
import cv2
import random
from fnmatch import fnmatch
import numpy as np

import gradio as gr
import torch
from PIL import Image, ImageOps
from diffusers import StableDiffusionInstructPix2PixPipeline

title = "Gradio-TTI"

description = """
<p style='text-align: center'> Andres Chait, Tamir Babil, Yaron Schnitman and Avi Rotem<br>
<a href='https://huggingface.co/spaces/andreschait/Gradio-TTI' target='_blank'>Project Page</a> | <a href='https://arxiv.org/abs/2310.00390'>Paper</a> | <a href='https://github.com/andreschait/Gradio-TTI' target='_blank'>Code</a></p>
Demo for Gradio-TTI: Instruction-Tuned Text-to-Image Diffusion Models. \n
Please upload a new image and provide an instruction outlining the specific vision task you wish Gradio-TTI to perform (e.g., “Segment the dog”, “Detect the dog”, “Estimate the depth map of this image”, etc.). \n
"""  # noqa


example_instructions = [
                        "Please help me detect Buzz.",
                        "Please help me detect Woody's face.",
                        "Create a monocular depth map.",
]

model_id = "andreschait/Kapara-K9"

def main():
    # pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16, safety_checker=None).to("cpu")
    pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16, safety_checker=None).to("cuda")
    example_image = Image.open("imgs/example2.jpg").convert("RGB")
    

    def load_example(
        seed: int, 
        randomize_seed: bool,
        text_cfg_scale: float,
        image_cfg_scale: float,
    ):
        example_instruction = random.choice(example_instructions)
        return [example_image, example_instruction] + generate(
            example_image,
            example_instruction,
            seed,
            0,
            text_cfg_scale,
            image_cfg_scale,
        )

    def generate(
        input_image: Image.Image,
        instruction: str,
        seed: int,
        randomize_seed:bool,
        text_cfg_scale: float,
        image_cfg_scale: float,
    ):
        seed = random.randint(0, 100000) if randomize_seed else seed
        text_cfg_scale = text_cfg_scale
        image_cfg_scale = image_cfg_scale
        width, height = input_image.size
        factor = 512 / max(width, height)
        factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height)
        width = int((width * factor) // 64) * 64
        height = int((height * factor) // 64) * 64
        input_image = ImageOps.fit(input_image, (width, height), method=Image.Resampling.LANCZOS)
        
        if instruction == "":
            return [input_image]

        generator = torch.manual_seed(seed)
        edited_image = pipe(
            instruction, image=input_image,
            guidance_scale=text_cfg_scale, image_guidance_scale=image_cfg_scale,
            num_inference_steps=25, generator=generator,
        ).images[0]
        instruction_ = instruction.lower()
        
        if fnmatch(instruction_, "*segment*") or fnmatch(instruction_, "*split*") or fnmatch(instruction_, "*divide*"):
            input_image  = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR) #numpy.ndarray
            edited_image = cv2.cvtColor(np.array(edited_image), cv2.COLOR_RGB2GRAY)
            ret, thresh  = cv2.threshold(edited_image, 127, 255, cv2.THRESH_BINARY)
            img2         = input_image.copy()
            seed_seg     = np.random.randint(0,10000)
            np.random.seed(seed_seg)
            colors       = np.random.randint(0,255,(3))
            colors2      = np.random.randint(0,255,(3))
            contours,_   = cv2.findContours(thresh,cv2.RETR_LIST,cv2.CHAIN_APPROX_NONE)
            edited_image = cv2.drawContours(input_image,contours,-1,(int(colors[0]),int(colors[1]),int(colors[2])),3)
            for j in range(len(contours)):
                edited_image_2 = cv2.fillPoly(img2, [contours[j]], (int(colors2[0]),int(colors2[1]),int(colors2[2])))
            img_merge = cv2.addWeighted(edited_image, 0.5,edited_image_2, 0.5, 0)
            edited_image  = Image.fromarray(cv2.cvtColor(img_merge, cv2.COLOR_BGR2RGB))
        
        if fnmatch(instruction_, "*depth*"):
            edited_image = cv2.cvtColor(np.array(edited_image), cv2.COLOR_RGB2GRAY)
            n_min    = np.min(edited_image)
            n_max    = np.max(edited_image)
    
            edited_image = (edited_image-n_min)/(n_max-n_min+1e-8)
            edited_image = (255*edited_image).astype(np.uint8)
            edited_image = cv2.applyColorMap(edited_image, cv2.COLORMAP_JET)
            edited_image = Image.fromarray(cv2.cvtColor(edited_image, cv2.COLOR_BGR2RGB))

        # text_cfg_scale   = 7.5
        # image_cfg_scale  = 1.5
        return [seed, text_cfg_scale, image_cfg_scale, edited_image]


    with gr.Blocks() as demo:
#         gr.HTML("""<h1 style="font-weight: 900; margin-bottom: 7px;">
#    InstructCV: Towards Universal Text-to-Image Vision Generalists
# </h1>""")
        gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>" + title + "</h1>")
        gr.Markdown(description)
        with gr.Row():
            with gr.Column(scale=1.5, min_width=100):
                generate_button = gr.Button("Generate result")
            with gr.Column(scale=1.5, min_width=100):
                load_button = gr.Button("Load example")
            with gr.Column(scale=3):
                instruction = gr.Textbox(lines=1, label="Instruction", interactive=True)

        with gr.Row():
            input_image = gr.Image(label="Input Image", type="pil", interactive=True)
            edited_image = gr.Image(label=f"Output Image", type="pil", interactive=False)
            input_image.style(height=512, width=512)
            edited_image.style(height=512, width=512)
        
        with gr.Row(): 
            randomize_seed = gr.Radio(
                ["Fix Seed", "Randomize Seed"],
                value="Randomize Seed",
                type="index",
                show_label=False,
                interactive=True,
            )
            
            seed = gr.Number(value=90, precision=0, label="Seed", interactive=True)
            text_cfg_scale = gr.Number(value=7.5, label=f"Text weight", interactive=True)
            image_cfg_scale = gr.Number(value=1.5, label=f"Image weight", interactive=True)


        # gr.Markdown(Intro_text)
        
        load_button.click(
            fn=load_example,
            inputs=[
                seed,
                randomize_seed,
                text_cfg_scale,
                image_cfg_scale,
            ],
            outputs=[input_image, instruction, seed, text_cfg_scale, image_cfg_scale, edited_image],
        )
        generate_button.click(
            fn=generate,
            inputs=[
                input_image,
                instruction,
                seed,
                randomize_seed,
                text_cfg_scale,
                image_cfg_scale,
            ],
            outputs=[seed, text_cfg_scale, image_cfg_scale, edited_image],
        )

    demo.queue(concurrency_count=1)
    demo.launch(share=False)


if __name__ == "__main__":
    main()