Andres Chait commited on
Commit
98f5fb8
Β·
1 Parent(s): 1a36a9f

Update ReadMe

Browse files
Files changed (5) hide show
  1. README.md +6 -5
  2. app.py +182 -0
  3. imgs/example2.jpg +0 -0
  4. requirements.txt +8 -0
  5. requirments.txt +8 -0
README.md CHANGED
@@ -1,12 +1,13 @@
1
  ---
2
- title: Gradio TTI
3
- emoji: 🐒
4
- colorFrom: pink
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 4.1.1
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Gradio-TTI
3
+ emoji: πŸŒ–
4
+ colorFrom: red
5
+ colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 3.34.0
8
  app_file: app.py
9
  pinned: false
10
+ license: apache-2.0
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Copyright (c) 2023, Andres Chait. All rights reserved.
3
+ # ------------------------------------------------------------------------------
4
+
5
+ from __future__ import annotations
6
+
7
+ import math
8
+ import cv2
9
+ import random
10
+ from fnmatch import fnmatch
11
+ import numpy as np
12
+
13
+ import gradio as gr
14
+ import torch
15
+ from PIL import Image, ImageOps
16
+ from diffusers import StableDiffusionInstructPix2PixPipeline
17
+
18
+ title = "Gradio-TTI"
19
+
20
+ description = """
21
+ <p style='text-align: center'> Andres Chait, Tamir Babil, Yaron Schnitman and Avi Rotem<br>
22
+ <a href='https://huggingface.co/spaces/andreschait/Gradio-TTI' target='_blank'>Project Page</a> | <a href='https://arxiv.org/abs/2310.00390'>Paper</a> | <a href='https://github.com/andreschait/Gradio-TTI' target='_blank'>Code</a></p>
23
+ Demo for Gradio-TTI: Instruction-Tuned Text-to-Image Diffusion Models. \n
24
+ Please upload a new image and provide an instruction outlining the specific vision task you wish Gradio-TTI to perform (e.g., β€œSegment the dog”, β€œDetect the dog”, β€œEstimate the depth map of this image”, etc.). \n
25
+ """ # noqa
26
+
27
+
28
+ example_instructions = [
29
+ "Please help me detect Buzz.",
30
+ "Please help me detect Woody's face.",
31
+ "Create a monocular depth map.",
32
+ ]
33
+
34
+ model_id = "andreschait/Gradio-TTI"
35
+
36
+ def main():
37
+ # pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16, safety_checker=None).to("cpu")
38
+ pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16, safety_checker=None).to("cuda")
39
+ example_image = Image.open("imgs/example2.jpg").convert("RGB")
40
+
41
+
42
+ def load_example(
43
+ seed: int,
44
+ randomize_seed: bool,
45
+ text_cfg_scale: float,
46
+ image_cfg_scale: float,
47
+ ):
48
+ example_instruction = random.choice(example_instructions)
49
+ return [example_image, example_instruction] + generate(
50
+ example_image,
51
+ example_instruction,
52
+ seed,
53
+ 0,
54
+ text_cfg_scale,
55
+ image_cfg_scale,
56
+ )
57
+
58
+ def generate(
59
+ input_image: Image.Image,
60
+ instruction: str,
61
+ seed: int,
62
+ randomize_seed:bool,
63
+ text_cfg_scale: float,
64
+ image_cfg_scale: float,
65
+ ):
66
+ seed = random.randint(0, 100000) if randomize_seed else seed
67
+ text_cfg_scale = text_cfg_scale
68
+ image_cfg_scale = image_cfg_scale
69
+ width, height = input_image.size
70
+ factor = 512 / max(width, height)
71
+ factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height)
72
+ width = int((width * factor) // 64) * 64
73
+ height = int((height * factor) // 64) * 64
74
+ input_image = ImageOps.fit(input_image, (width, height), method=Image.Resampling.LANCZOS)
75
+
76
+ if instruction == "":
77
+ return [input_image]
78
+
79
+ generator = torch.manual_seed(seed)
80
+ edited_image = pipe(
81
+ instruction, image=input_image,
82
+ guidance_scale=text_cfg_scale, image_guidance_scale=image_cfg_scale,
83
+ num_inference_steps=25, generator=generator,
84
+ ).images[0]
85
+ instruction_ = instruction.lower()
86
+
87
+ if fnmatch(instruction_, "*segment*") or fnmatch(instruction_, "*split*") or fnmatch(instruction_, "*divide*"):
88
+ input_image = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR) #numpy.ndarray
89
+ edited_image = cv2.cvtColor(np.array(edited_image), cv2.COLOR_RGB2GRAY)
90
+ ret, thresh = cv2.threshold(edited_image, 127, 255, cv2.THRESH_BINARY)
91
+ img2 = input_image.copy()
92
+ seed_seg = np.random.randint(0,10000)
93
+ np.random.seed(seed_seg)
94
+ colors = np.random.randint(0,255,(3))
95
+ colors2 = np.random.randint(0,255,(3))
96
+ contours,_ = cv2.findContours(thresh,cv2.RETR_LIST,cv2.CHAIN_APPROX_NONE)
97
+ edited_image = cv2.drawContours(input_image,contours,-1,(int(colors[0]),int(colors[1]),int(colors[2])),3)
98
+ for j in range(len(contours)):
99
+ edited_image_2 = cv2.fillPoly(img2, [contours[j]], (int(colors2[0]),int(colors2[1]),int(colors2[2])))
100
+ img_merge = cv2.addWeighted(edited_image, 0.5,edited_image_2, 0.5, 0)
101
+ edited_image = Image.fromarray(cv2.cvtColor(img_merge, cv2.COLOR_BGR2RGB))
102
+
103
+ if fnmatch(instruction_, "*depth*"):
104
+ edited_image = cv2.cvtColor(np.array(edited_image), cv2.COLOR_RGB2GRAY)
105
+ n_min = np.min(edited_image)
106
+ n_max = np.max(edited_image)
107
+
108
+ edited_image = (edited_image-n_min)/(n_max-n_min+1e-8)
109
+ edited_image = (255*edited_image).astype(np.uint8)
110
+ edited_image = cv2.applyColorMap(edited_image, cv2.COLORMAP_JET)
111
+ edited_image = Image.fromarray(cv2.cvtColor(edited_image, cv2.COLOR_BGR2RGB))
112
+
113
+ # text_cfg_scale = 7.5
114
+ # image_cfg_scale = 1.5
115
+ return [seed, text_cfg_scale, image_cfg_scale, edited_image]
116
+
117
+
118
+ with gr.Blocks() as demo:
119
+ # gr.HTML("""<h1 style="font-weight: 900; margin-bottom: 7px;">
120
+ # InstructCV: Towards Universal Text-to-Image Vision Generalists
121
+ # </h1>""")
122
+ gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>" + title + "</h1>")
123
+ gr.Markdown(description)
124
+ with gr.Row():
125
+ with gr.Column(scale=1.5, min_width=100):
126
+ generate_button = gr.Button("Generate result")
127
+ with gr.Column(scale=1.5, min_width=100):
128
+ load_button = gr.Button("Load example")
129
+ with gr.Column(scale=3):
130
+ instruction = gr.Textbox(lines=1, label="Instruction", interactive=True)
131
+
132
+ with gr.Row():
133
+ input_image = gr.Image(label="Input Image", type="pil", interactive=True)
134
+ edited_image = gr.Image(label=f"Output Image", type="pil", interactive=False)
135
+ input_image.style(height=512, width=512)
136
+ edited_image.style(height=512, width=512)
137
+
138
+ with gr.Row():
139
+ randomize_seed = gr.Radio(
140
+ ["Fix Seed", "Randomize Seed"],
141
+ value="Randomize Seed",
142
+ type="index",
143
+ show_label=False,
144
+ interactive=True,
145
+ )
146
+
147
+ seed = gr.Number(value=90, precision=0, label="Seed", interactive=True)
148
+ text_cfg_scale = gr.Number(value=7.5, label=f"Text weight", interactive=True)
149
+ image_cfg_scale = gr.Number(value=1.5, label=f"Image weight", interactive=True)
150
+
151
+
152
+ # gr.Markdown(Intro_text)
153
+
154
+ load_button.click(
155
+ fn=load_example,
156
+ inputs=[
157
+ seed,
158
+ randomize_seed,
159
+ text_cfg_scale,
160
+ image_cfg_scale,
161
+ ],
162
+ outputs=[input_image, instruction, seed, text_cfg_scale, image_cfg_scale, edited_image],
163
+ )
164
+ generate_button.click(
165
+ fn=generate,
166
+ inputs=[
167
+ input_image,
168
+ instruction,
169
+ seed,
170
+ randomize_seed,
171
+ text_cfg_scale,
172
+ image_cfg_scale,
173
+ ],
174
+ outputs=[seed, text_cfg_scale, image_cfg_scale, edited_image],
175
+ )
176
+
177
+ demo.queue(concurrency_count=1)
178
+ demo.launch(share=False)
179
+
180
+
181
+ if __name__ == "__main__":
182
+ main()
imgs/example2.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ -f --extra-index-url https://download.pytorch.org/whl/cu116
2
+ torch
3
+ torchvision
4
+ numpy
5
+ transformers
6
+ accelerate
7
+ opencv-python
8
+ git+https://github.com/huggingface/diffusers
requirments.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ -f --extra-index-url https://download.pytorch.org/whl/cu116
2
+ torch
3
+ torchvision
4
+ numpy
5
+ transformers
6
+ accelerate
7
+ opencv-python
8
+ git+https://github.com/huggingface/diffusers