Spaces:
Runtime error
Runtime error
ttengwang
commited on
Commit
·
f1a2810
1
Parent(s):
89e01b9
update
Browse files- app.py +27 -6
- app_wo_langchain.py +0 -588
- caption_anything/captioner/base_captioner.py +56 -63
- caption_anything/captioner/blip.py +15 -16
- caption_anything/captioner/blip2.py +3 -4
- caption_anything/captioner/git.py +12 -9
- caption_anything/model.py +6 -0
- caption_anything/segmenter/base_segmenter.py +4 -16
- caption_anything/utils/chatbot.py +10 -21
- caption_anything/utils/utils.py +144 -98
app.py
CHANGED
@@ -7,6 +7,7 @@ from gradio import processing_utils
|
|
7 |
|
8 |
from packaging import version
|
9 |
from PIL import Image, ImageDraw
|
|
|
10 |
|
11 |
from caption_anything.model import CaptionAnything
|
12 |
from caption_anything.utils.image_editing_utils import create_bubble_frame
|
@@ -22,7 +23,6 @@ from segment_anything import sam_model_registry
|
|
22 |
args = parse_augment()
|
23 |
args.segmenter = "huge"
|
24 |
args.segmenter_checkpoint = "sam_vit_h_4b8939.pth"
|
25 |
-
|
26 |
if args.segmenter_checkpoint is None:
|
27 |
_, segmenter_checkpoint = prepare_segmenter(args.segmenter)
|
28 |
else:
|
@@ -131,7 +131,7 @@ def chat_input_callback(*args):
|
|
131 |
return state, state
|
132 |
|
133 |
def upload_callback(image_input, state, visual_chatgpt=None):
|
134 |
-
|
135 |
if isinstance(image_input, dict): # if upload from sketcher_input, input contains image and mask
|
136 |
image_input, mask = image_input['image'], image_input['mask']
|
137 |
|
@@ -162,7 +162,8 @@ def upload_callback(image_input, state, visual_chatgpt=None):
|
|
162 |
img_caption, _ = model.captioner.inference_seg(image_input)
|
163 |
Human_prompt = f'\nHuman: provide a new figure with path {new_image_path}. The description is: {img_caption}. This information helps you to understand this image, but you should use tools to finish following tasks, rather than directly imagine from my description. If you understand, say \"Received\". \n'
|
164 |
AI_prompt = "Received."
|
165 |
-
visual_chatgpt.
|
|
|
166 |
state = [(None, 'Received new image, resize it to width {} and height {}: '.format(image_input.size[0], image_input.size[1]))]
|
167 |
|
168 |
return state, state, image_input, click_state, image_input, image_input, image_embedding, \
|
@@ -309,12 +310,16 @@ def inference_traject(sketcher_image, enable_wiki, language, sentiment, factuali
|
|
309 |
|
310 |
yield state, state, refined_image_input, wiki
|
311 |
|
312 |
-
def clear_chat_memory(visual_chatgpt):
|
313 |
if visual_chatgpt is not None:
|
314 |
visual_chatgpt.memory.clear()
|
315 |
-
visual_chatgpt.current_image = None
|
316 |
visual_chatgpt.point_prompt = ""
|
317 |
-
|
|
|
|
|
|
|
|
|
|
|
318 |
def get_style():
|
319 |
current_version = version.parse(gr.__version__)
|
320 |
if current_version <= version.parse('3.24.1'):
|
@@ -465,6 +470,21 @@ def create_ui():
|
|
465 |
modules_not_need_gpt,
|
466 |
modules_not_need_gpt2, modules_not_need_gpt3, text_refiner, visual_chatgpt])
|
467 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
468 |
clear_button_click.click(
|
469 |
lambda x: ([[], [], []], x, ""),
|
470 |
[origin_image],
|
@@ -472,6 +492,7 @@ def create_ui():
|
|
472 |
queue=False,
|
473 |
show_progress=False
|
474 |
)
|
|
|
475 |
clear_button_image.click(
|
476 |
lambda: (None, [], [], [[], [], []], "", "", ""),
|
477 |
[],
|
|
|
7 |
|
8 |
from packaging import version
|
9 |
from PIL import Image, ImageDraw
|
10 |
+
import functools
|
11 |
|
12 |
from caption_anything.model import CaptionAnything
|
13 |
from caption_anything.utils.image_editing_utils import create_bubble_frame
|
|
|
23 |
args = parse_augment()
|
24 |
args.segmenter = "huge"
|
25 |
args.segmenter_checkpoint = "sam_vit_h_4b8939.pth"
|
|
|
26 |
if args.segmenter_checkpoint is None:
|
27 |
_, segmenter_checkpoint = prepare_segmenter(args.segmenter)
|
28 |
else:
|
|
|
131 |
return state, state
|
132 |
|
133 |
def upload_callback(image_input, state, visual_chatgpt=None):
|
134 |
+
|
135 |
if isinstance(image_input, dict): # if upload from sketcher_input, input contains image and mask
|
136 |
image_input, mask = image_input['image'], image_input['mask']
|
137 |
|
|
|
162 |
img_caption, _ = model.captioner.inference_seg(image_input)
|
163 |
Human_prompt = f'\nHuman: provide a new figure with path {new_image_path}. The description is: {img_caption}. This information helps you to understand this image, but you should use tools to finish following tasks, rather than directly imagine from my description. If you understand, say \"Received\". \n'
|
164 |
AI_prompt = "Received."
|
165 |
+
visual_chatgpt.global_prompt = Human_prompt + 'AI: ' + AI_prompt
|
166 |
+
visual_chatgpt.agent.memory.buffer = visual_chatgpt.agent.memory.buffer + visual_chatgpt.global_prompt
|
167 |
state = [(None, 'Received new image, resize it to width {} and height {}: '.format(image_input.size[0], image_input.size[1]))]
|
168 |
|
169 |
return state, state, image_input, click_state, image_input, image_input, image_embedding, \
|
|
|
310 |
|
311 |
yield state, state, refined_image_input, wiki
|
312 |
|
313 |
+
def clear_chat_memory(visual_chatgpt, keep_global=False):
|
314 |
if visual_chatgpt is not None:
|
315 |
visual_chatgpt.memory.clear()
|
|
|
316 |
visual_chatgpt.point_prompt = ""
|
317 |
+
if keep_global:
|
318 |
+
visual_chatgpt.agent.memory.buffer = visual_chatgpt.global_prompt
|
319 |
+
else:
|
320 |
+
visual_chatgpt.current_image = None
|
321 |
+
visual_chatgpt.global_prompt = ""
|
322 |
+
|
323 |
def get_style():
|
324 |
current_version = version.parse(gr.__version__)
|
325 |
if current_version <= version.parse('3.24.1'):
|
|
|
470 |
modules_not_need_gpt,
|
471 |
modules_not_need_gpt2, modules_not_need_gpt3, text_refiner, visual_chatgpt])
|
472 |
|
473 |
+
enable_chatGPT_button.click(
|
474 |
+
lambda: (None, [], [], [[], [], []], "", "", ""),
|
475 |
+
[],
|
476 |
+
[image_input, chatbot, state, click_state, wiki_output, origin_image],
|
477 |
+
queue=False,
|
478 |
+
show_progress=False
|
479 |
+
)
|
480 |
+
openai_api_key.submit(
|
481 |
+
lambda: (None, [], [], [[], [], []], "", "", ""),
|
482 |
+
[],
|
483 |
+
[image_input, chatbot, state, click_state, wiki_output, origin_image],
|
484 |
+
queue=False,
|
485 |
+
show_progress=False
|
486 |
+
)
|
487 |
+
|
488 |
clear_button_click.click(
|
489 |
lambda x: ([[], [], []], x, ""),
|
490 |
[origin_image],
|
|
|
492 |
queue=False,
|
493 |
show_progress=False
|
494 |
)
|
495 |
+
clear_button_click.click(functools.partial(clear_chat_memory, keep_global=True), inputs=[visual_chatgpt])
|
496 |
clear_button_image.click(
|
497 |
lambda: (None, [], [], [[], [], []], "", "", ""),
|
498 |
[],
|
app_wo_langchain.py
DELETED
@@ -1,588 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import json
|
3 |
-
from typing import List
|
4 |
-
|
5 |
-
import PIL
|
6 |
-
import gradio as gr
|
7 |
-
import numpy as np
|
8 |
-
from gradio import processing_utils
|
9 |
-
|
10 |
-
from packaging import version
|
11 |
-
from PIL import Image, ImageDraw
|
12 |
-
|
13 |
-
from caption_anything.model import CaptionAnything
|
14 |
-
from caption_anything.utils.image_editing_utils import create_bubble_frame
|
15 |
-
from caption_anything.utils.utils import mask_painter, seg_model_map, prepare_segmenter
|
16 |
-
from caption_anything.utils.parser import parse_augment
|
17 |
-
from caption_anything.captioner import build_captioner
|
18 |
-
from caption_anything.text_refiner import build_text_refiner
|
19 |
-
from caption_anything.segmenter import build_segmenter
|
20 |
-
from caption_anything.utils.chatbot import ConversationBot, build_chatbot_tools, get_new_image_name
|
21 |
-
from segment_anything import sam_model_registry
|
22 |
-
|
23 |
-
|
24 |
-
args = parse_augment()
|
25 |
-
|
26 |
-
args = parse_augment()
|
27 |
-
if args.segmenter_checkpoint is None:
|
28 |
-
_, segmenter_checkpoint = prepare_segmenter(args.segmenter)
|
29 |
-
else:
|
30 |
-
segmenter_checkpoint = args.segmenter_checkpoint
|
31 |
-
|
32 |
-
shared_captioner = build_captioner(args.captioner, args.device, args)
|
33 |
-
shared_sam_model = sam_model_registry[seg_model_map[args.segmenter]](checkpoint=segmenter_checkpoint).to(args.device)
|
34 |
-
|
35 |
-
|
36 |
-
class ImageSketcher(gr.Image):
|
37 |
-
"""
|
38 |
-
Fix the bug of gradio.Image that cannot upload with tool == 'sketch'.
|
39 |
-
"""
|
40 |
-
|
41 |
-
is_template = True # Magic to make this work with gradio.Block, don't remove unless you know what you're doing.
|
42 |
-
|
43 |
-
def __init__(self, **kwargs):
|
44 |
-
super().__init__(tool="sketch", **kwargs)
|
45 |
-
|
46 |
-
def preprocess(self, x):
|
47 |
-
if self.tool == 'sketch' and self.source in ["upload", "webcam"]:
|
48 |
-
assert isinstance(x, dict)
|
49 |
-
if x['mask'] is None:
|
50 |
-
decode_image = processing_utils.decode_base64_to_image(x['image'])
|
51 |
-
width, height = decode_image.size
|
52 |
-
mask = np.zeros((height, width, 4), dtype=np.uint8)
|
53 |
-
mask[..., -1] = 255
|
54 |
-
mask = self.postprocess(mask)
|
55 |
-
|
56 |
-
x['mask'] = mask
|
57 |
-
|
58 |
-
return super().preprocess(x)
|
59 |
-
|
60 |
-
|
61 |
-
def build_caption_anything_with_models(args, api_key="", captioner=None, sam_model=None, text_refiner=None,
|
62 |
-
session_id=None):
|
63 |
-
segmenter = build_segmenter(args.segmenter, args.device, args, model=sam_model)
|
64 |
-
captioner = captioner
|
65 |
-
if session_id is not None:
|
66 |
-
print('Init caption anything for session {}'.format(session_id))
|
67 |
-
return CaptionAnything(args, api_key, captioner=captioner, segmenter=segmenter, text_refiner=text_refiner)
|
68 |
-
|
69 |
-
|
70 |
-
def init_openai_api_key(api_key=""):
|
71 |
-
text_refiner = None
|
72 |
-
if api_key and len(api_key) > 30:
|
73 |
-
try:
|
74 |
-
text_refiner = build_text_refiner(args.text_refiner, args.device, args, api_key)
|
75 |
-
text_refiner.llm('hi') # test
|
76 |
-
except:
|
77 |
-
text_refiner = None
|
78 |
-
openai_available = text_refiner is not None
|
79 |
-
return gr.update(visible=openai_available), gr.update(visible=openai_available), gr.update(
|
80 |
-
visible=openai_available), gr.update(visible=True), gr.update(visible=True), gr.update(
|
81 |
-
visible=True), text_refiner
|
82 |
-
|
83 |
-
|
84 |
-
def get_click_prompt(chat_input, click_state, click_mode):
|
85 |
-
inputs = json.loads(chat_input)
|
86 |
-
if click_mode == 'Continuous':
|
87 |
-
points = click_state[0]
|
88 |
-
labels = click_state[1]
|
89 |
-
for input in inputs:
|
90 |
-
points.append(input[:2])
|
91 |
-
labels.append(input[2])
|
92 |
-
elif click_mode == 'Single':
|
93 |
-
points = []
|
94 |
-
labels = []
|
95 |
-
for input in inputs:
|
96 |
-
points.append(input[:2])
|
97 |
-
labels.append(input[2])
|
98 |
-
click_state[0] = points
|
99 |
-
click_state[1] = labels
|
100 |
-
else:
|
101 |
-
raise NotImplementedError
|
102 |
-
|
103 |
-
prompt = {
|
104 |
-
"prompt_type": ["click"],
|
105 |
-
"input_point": click_state[0],
|
106 |
-
"input_label": click_state[1],
|
107 |
-
"multimask_output": "True",
|
108 |
-
}
|
109 |
-
return prompt
|
110 |
-
|
111 |
-
|
112 |
-
def update_click_state(click_state, caption, click_mode):
|
113 |
-
if click_mode == 'Continuous':
|
114 |
-
click_state[2].append(caption)
|
115 |
-
elif click_mode == 'Single':
|
116 |
-
click_state[2] = [caption]
|
117 |
-
else:
|
118 |
-
raise NotImplementedError
|
119 |
-
|
120 |
-
|
121 |
-
def chat_with_points(chat_input, click_state, chat_state, state, text_refiner, img_caption):
|
122 |
-
if text_refiner is None:
|
123 |
-
response = "Text refiner is not initilzed, please input openai api key."
|
124 |
-
state = state + [(chat_input, response)]
|
125 |
-
return state, state, chat_state
|
126 |
-
|
127 |
-
points, labels, captions = click_state
|
128 |
-
# point_chat_prompt = "I want you act as a chat bot in terms of image. I will give you some points (w, h) in the image and tell you what happed on the point in natural language. Note that (0, 0) refers to the top-left corner of the image, w refers to the width and h refers the height. You should chat with me based on the fact in the image instead of imagination. Now I tell you the points with their visual description:\n{points_with_caps}\nNow begin chatting!"
|
129 |
-
suffix = '\nHuman: {chat_input}\nAI: '
|
130 |
-
qa_template = '\nHuman: {q}\nAI: {a}'
|
131 |
-
# # "The image is of width {width} and height {height}."
|
132 |
-
point_chat_prompt = "I am an AI trained to chat with you about an image. I am greate at what is going on in any image based on the image information your provide. The overall image description is \"{img_caption}\". You will also provide me objects in the image in details, i.e., their location and visual descriptions. Here are the locations and descriptions of events that happen in the image: {points_with_caps} \nYou are required to use language instead of number to describe these positions. Now, let's chat!"
|
133 |
-
prev_visual_context = ""
|
134 |
-
pos_points = []
|
135 |
-
pos_captions = []
|
136 |
-
|
137 |
-
for i in range(len(points)):
|
138 |
-
if labels[i] == 1:
|
139 |
-
pos_points.append(f"(X:{points[i][0]}, Y:{points[i][1]})")
|
140 |
-
pos_captions.append(captions[i])
|
141 |
-
prev_visual_context = prev_visual_context + '\n' + 'There is an event described as \"{}\" locating at {}'.format(
|
142 |
-
pos_captions[-1], ', '.join(pos_points))
|
143 |
-
|
144 |
-
context_length_thres = 500
|
145 |
-
prev_history = ""
|
146 |
-
for i in range(len(chat_state)):
|
147 |
-
q, a = chat_state[i]
|
148 |
-
if len(prev_history) < context_length_thres:
|
149 |
-
prev_history = prev_history + qa_template.format(**{"q": q, "a": a})
|
150 |
-
else:
|
151 |
-
break
|
152 |
-
chat_prompt = point_chat_prompt.format(
|
153 |
-
**{"img_caption": img_caption, "points_with_caps": prev_visual_context}) + prev_history + suffix.format(
|
154 |
-
**{"chat_input": chat_input})
|
155 |
-
print('\nchat_prompt: ', chat_prompt)
|
156 |
-
response = text_refiner.llm(chat_prompt)
|
157 |
-
state = state + [(chat_input, response)]
|
158 |
-
chat_state = chat_state + [(chat_input, response)]
|
159 |
-
return state, state, chat_state
|
160 |
-
|
161 |
-
|
162 |
-
def upload_callback(image_input, state):
|
163 |
-
if isinstance(image_input, dict): # if upload from sketcher_input, input contains image and mask
|
164 |
-
image_input, mask = image_input['image'], image_input['mask']
|
165 |
-
|
166 |
-
chat_state = []
|
167 |
-
click_state = [[], [], []]
|
168 |
-
res = 1024
|
169 |
-
width, height = image_input.size
|
170 |
-
ratio = min(1.0 * res / max(width, height), 1.0)
|
171 |
-
if ratio < 1.0:
|
172 |
-
image_input = image_input.resize((int(width * ratio), int(height * ratio)))
|
173 |
-
print('Scaling input image to {}'.format(image_input.size))
|
174 |
-
state = [] + [(None, 'Image size: ' + str(image_input.size))]
|
175 |
-
model = build_caption_anything_with_models(
|
176 |
-
args,
|
177 |
-
api_key="",
|
178 |
-
captioner=shared_captioner,
|
179 |
-
sam_model=shared_sam_model,
|
180 |
-
session_id=iface.app_id
|
181 |
-
)
|
182 |
-
model.segmenter.set_image(image_input)
|
183 |
-
image_embedding = model.image_embedding
|
184 |
-
original_size = model.original_size
|
185 |
-
input_size = model.input_size
|
186 |
-
img_caption, _ = model.captioner.inference_seg(image_input)
|
187 |
-
|
188 |
-
return state, state, chat_state, image_input, click_state, image_input, image_input, image_embedding, \
|
189 |
-
original_size, input_size, img_caption
|
190 |
-
|
191 |
-
|
192 |
-
def inference_click(image_input, point_prompt, click_mode, enable_wiki, language, sentiment, factuality,
|
193 |
-
length, image_embedding, state, click_state, original_size, input_size, text_refiner,
|
194 |
-
evt: gr.SelectData):
|
195 |
-
click_index = evt.index
|
196 |
-
|
197 |
-
if point_prompt == 'Positive':
|
198 |
-
coordinate = "[[{}, {}, 1]]".format(str(click_index[0]), str(click_index[1]))
|
199 |
-
else:
|
200 |
-
coordinate = "[[{}, {}, 0]]".format(str(click_index[0]), str(click_index[1]))
|
201 |
-
|
202 |
-
prompt = get_click_prompt(coordinate, click_state, click_mode)
|
203 |
-
input_points = prompt['input_point']
|
204 |
-
input_labels = prompt['input_label']
|
205 |
-
|
206 |
-
controls = {'length': length,
|
207 |
-
'sentiment': sentiment,
|
208 |
-
'factuality': factuality,
|
209 |
-
'language': language}
|
210 |
-
|
211 |
-
model = build_caption_anything_with_models(
|
212 |
-
args,
|
213 |
-
api_key="",
|
214 |
-
captioner=shared_captioner,
|
215 |
-
sam_model=shared_sam_model,
|
216 |
-
text_refiner=text_refiner,
|
217 |
-
session_id=iface.app_id
|
218 |
-
)
|
219 |
-
|
220 |
-
model.setup(image_embedding, original_size, input_size, is_image_set=True)
|
221 |
-
|
222 |
-
enable_wiki = True if enable_wiki in ['True', 'TRUE', 'true', True, 'Yes', 'YES', 'yes'] else False
|
223 |
-
out = model.inference(image_input, prompt, controls, disable_gpt=True, enable_wiki=enable_wiki)
|
224 |
-
|
225 |
-
state = state + [("Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]), None)]
|
226 |
-
state = state + [(None, "raw_caption: {}".format(out['generated_captions']['raw_caption']))]
|
227 |
-
wiki = out['generated_captions'].get('wiki', "")
|
228 |
-
update_click_state(click_state, out['generated_captions']['raw_caption'], click_mode)
|
229 |
-
text = out['generated_captions']['raw_caption']
|
230 |
-
input_mask = np.array(out['mask'].convert('P'))
|
231 |
-
image_input = mask_painter(np.array(image_input), input_mask)
|
232 |
-
origin_image_input = image_input
|
233 |
-
image_input = create_bubble_frame(image_input, text, (click_index[0], click_index[1]), input_mask,
|
234 |
-
input_points=input_points, input_labels=input_labels)
|
235 |
-
yield state, state, click_state, image_input, wiki
|
236 |
-
if not args.disable_gpt and model.text_refiner:
|
237 |
-
refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'],
|
238 |
-
enable_wiki=enable_wiki)
|
239 |
-
# new_cap = 'Original: ' + text + '. Refined: ' + refined_caption['caption']
|
240 |
-
new_cap = refined_caption['caption']
|
241 |
-
wiki = refined_caption['wiki']
|
242 |
-
state = state + [(None, f"caption: {new_cap}")]
|
243 |
-
refined_image_input = create_bubble_frame(origin_image_input, new_cap, (click_index[0], click_index[1]),
|
244 |
-
input_mask,
|
245 |
-
input_points=input_points, input_labels=input_labels)
|
246 |
-
yield state, state, click_state, refined_image_input, wiki
|
247 |
-
|
248 |
-
|
249 |
-
def get_sketch_prompt(mask: PIL.Image.Image, multi_mask=True):
|
250 |
-
"""
|
251 |
-
Get the prompt for the sketcher.
|
252 |
-
TODO: This is a temporary solution. We should cluster the sketch and get the bounding box of each cluster.
|
253 |
-
"""
|
254 |
-
|
255 |
-
mask = np.array(np.asarray(mask)[..., 0])
|
256 |
-
mask[mask > 0] = 1 # Refine the mask, let all nonzero values be 1
|
257 |
-
|
258 |
-
if not multi_mask:
|
259 |
-
y, x = np.where(mask == 1)
|
260 |
-
x1, y1 = np.min(x), np.min(y)
|
261 |
-
x2, y2 = np.max(x), np.max(y)
|
262 |
-
|
263 |
-
prompt = {
|
264 |
-
'prompt_type': ['box'],
|
265 |
-
'input_boxes': [
|
266 |
-
[x1, y1, x2, y2]
|
267 |
-
]
|
268 |
-
}
|
269 |
-
|
270 |
-
return prompt
|
271 |
-
|
272 |
-
traversed = np.zeros_like(mask)
|
273 |
-
groups = np.zeros_like(mask)
|
274 |
-
max_group_id = 1
|
275 |
-
|
276 |
-
# Iterate over all pixels
|
277 |
-
for x in range(mask.shape[0]):
|
278 |
-
for y in range(mask.shape[1]):
|
279 |
-
if traversed[x, y] == 1:
|
280 |
-
continue
|
281 |
-
|
282 |
-
if mask[x, y] == 0:
|
283 |
-
traversed[x, y] = 1
|
284 |
-
else:
|
285 |
-
# If pixel is part of mask
|
286 |
-
groups[x, y] = max_group_id
|
287 |
-
stack = [(x, y)]
|
288 |
-
while stack:
|
289 |
-
i, j = stack.pop()
|
290 |
-
if traversed[i, j] == 1:
|
291 |
-
continue
|
292 |
-
traversed[i, j] = 1
|
293 |
-
if mask[i, j] == 1:
|
294 |
-
groups[i, j] = max_group_id
|
295 |
-
for di, dj in [(1, 0), (-1, 0), (0, 1), (0, -1), (1, 1), (1, -1), (-1, 1), (-1, -1)]:
|
296 |
-
ni, nj = i + di, j + dj
|
297 |
-
traversed[i, j] = 1
|
298 |
-
if 0 <= nj < mask.shape[1] and mask.shape[0] > ni >= 0 == traversed[ni, nj]:
|
299 |
-
stack.append((i + di, j + dj))
|
300 |
-
max_group_id += 1
|
301 |
-
|
302 |
-
# get the bounding box of each group
|
303 |
-
boxes = []
|
304 |
-
for group in range(1, max_group_id):
|
305 |
-
y, x = np.where(groups == group)
|
306 |
-
x1, y1 = np.min(x), np.min(y)
|
307 |
-
x2, y2 = np.max(x), np.max(y)
|
308 |
-
boxes.append([x1, y1, x2, y2])
|
309 |
-
|
310 |
-
prompt = {
|
311 |
-
'prompt_type': ['box'],
|
312 |
-
'input_boxes': boxes
|
313 |
-
}
|
314 |
-
|
315 |
-
return prompt
|
316 |
-
|
317 |
-
|
318 |
-
def inference_traject(sketcher_image, enable_wiki, language, sentiment, factuality, length, image_embedding, state,
|
319 |
-
original_size, input_size, text_refiner):
|
320 |
-
image_input, mask = sketcher_image['image'], sketcher_image['mask']
|
321 |
-
|
322 |
-
prompt = get_sketch_prompt(mask, multi_mask=False)
|
323 |
-
boxes = prompt['input_boxes']
|
324 |
-
|
325 |
-
controls = {'length': length,
|
326 |
-
'sentiment': sentiment,
|
327 |
-
'factuality': factuality,
|
328 |
-
'language': language}
|
329 |
-
|
330 |
-
model = build_caption_anything_with_models(
|
331 |
-
args,
|
332 |
-
api_key="",
|
333 |
-
captioner=shared_captioner,
|
334 |
-
sam_model=shared_sam_model,
|
335 |
-
text_refiner=text_refiner,
|
336 |
-
session_id=iface.app_id
|
337 |
-
)
|
338 |
-
|
339 |
-
model.setup(image_embedding, original_size, input_size, is_image_set=True)
|
340 |
-
|
341 |
-
enable_wiki = True if enable_wiki in ['True', 'TRUE', 'true', True, 'Yes', 'YES', 'yes'] else False
|
342 |
-
out = model.inference(image_input, prompt, controls, disable_gpt=True, enable_wiki=enable_wiki)
|
343 |
-
|
344 |
-
# Update components and states
|
345 |
-
state.append((f'Box: {boxes}', None))
|
346 |
-
state.append((None, f'raw_caption: {out["generated_captions"]["raw_caption"]}'))
|
347 |
-
wiki = out['generated_captions'].get('wiki', "")
|
348 |
-
text = out['generated_captions']['raw_caption']
|
349 |
-
input_mask = np.array(out['mask'].convert('P'))
|
350 |
-
image_input = mask_painter(np.array(image_input), input_mask)
|
351 |
-
|
352 |
-
origin_image_input = image_input
|
353 |
-
|
354 |
-
fake_click_index = (int((boxes[0][0] + boxes[0][2]) / 2), int((boxes[0][1] + boxes[0][3]) / 2))
|
355 |
-
image_input = create_bubble_frame(image_input, text, fake_click_index, input_mask)
|
356 |
-
|
357 |
-
yield state, state, image_input, wiki
|
358 |
-
|
359 |
-
if not args.disable_gpt and model.text_refiner:
|
360 |
-
refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'],
|
361 |
-
enable_wiki=enable_wiki)
|
362 |
-
|
363 |
-
new_cap = refined_caption['caption']
|
364 |
-
wiki = refined_caption['wiki']
|
365 |
-
state = state + [(None, f"caption: {new_cap}")]
|
366 |
-
refined_image_input = create_bubble_frame(origin_image_input, new_cap, fake_click_index, input_mask)
|
367 |
-
|
368 |
-
yield state, state, refined_image_input, wiki
|
369 |
-
|
370 |
-
|
371 |
-
def get_style():
|
372 |
-
current_version = version.parse(gr.__version__)
|
373 |
-
if current_version <= version.parse('3.24.1'):
|
374 |
-
style = '''
|
375 |
-
#image_sketcher{min-height:500px}
|
376 |
-
#image_sketcher [data-testid="image"], #image_sketcher [data-testid="image"] > div{min-height: 500px}
|
377 |
-
#image_upload{min-height:500px}
|
378 |
-
#image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 500px}
|
379 |
-
'''
|
380 |
-
elif current_version <= version.parse('3.27'):
|
381 |
-
style = '''
|
382 |
-
#image_sketcher{min-height:500px}
|
383 |
-
#image_upload{min-height:500px}
|
384 |
-
'''
|
385 |
-
else:
|
386 |
-
style = None
|
387 |
-
|
388 |
-
return style
|
389 |
-
|
390 |
-
|
391 |
-
def create_ui():
|
392 |
-
title = """<p><h1 align="center">Caption-Anything</h1></p>
|
393 |
-
"""
|
394 |
-
description = """<p>Gradio demo for Caption Anything, image to dense captioning generation with various language styles. To use it, simply upload your image, or click one of the examples to load them. Code: <a href="https://github.com/ttengwang/Caption-Anything">https://github.com/ttengwang/Caption-Anything</a> <a href="https://huggingface.co/spaces/TencentARC/Caption-Anything?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>"""
|
395 |
-
|
396 |
-
examples = [
|
397 |
-
["test_images/img35.webp"],
|
398 |
-
["test_images/img2.jpg"],
|
399 |
-
["test_images/img5.jpg"],
|
400 |
-
["test_images/img12.jpg"],
|
401 |
-
["test_images/img14.jpg"],
|
402 |
-
["test_images/qingming3.jpeg"],
|
403 |
-
["test_images/img1.jpg"],
|
404 |
-
]
|
405 |
-
|
406 |
-
with gr.Blocks(
|
407 |
-
css=get_style()
|
408 |
-
) as iface:
|
409 |
-
state = gr.State([])
|
410 |
-
click_state = gr.State([[], [], []])
|
411 |
-
chat_state = gr.State([])
|
412 |
-
origin_image = gr.State(None)
|
413 |
-
image_embedding = gr.State(None)
|
414 |
-
text_refiner = gr.State(None)
|
415 |
-
original_size = gr.State(None)
|
416 |
-
input_size = gr.State(None)
|
417 |
-
img_caption = gr.State(None)
|
418 |
-
|
419 |
-
gr.Markdown(title)
|
420 |
-
gr.Markdown(description)
|
421 |
-
|
422 |
-
with gr.Row():
|
423 |
-
with gr.Column(scale=1.0):
|
424 |
-
with gr.Column(visible=False) as modules_not_need_gpt:
|
425 |
-
with gr.Tab("Click"):
|
426 |
-
image_input = gr.Image(type="pil", interactive=True, elem_id="image_upload")
|
427 |
-
example_image = gr.Image(type="pil", interactive=False, visible=False)
|
428 |
-
with gr.Row(scale=1.0):
|
429 |
-
with gr.Row(scale=0.4):
|
430 |
-
point_prompt = gr.Radio(
|
431 |
-
choices=["Positive", "Negative"],
|
432 |
-
value="Positive",
|
433 |
-
label="Point Prompt",
|
434 |
-
interactive=True)
|
435 |
-
click_mode = gr.Radio(
|
436 |
-
choices=["Continuous", "Single"],
|
437 |
-
value="Continuous",
|
438 |
-
label="Clicking Mode",
|
439 |
-
interactive=True)
|
440 |
-
with gr.Row(scale=0.4):
|
441 |
-
clear_button_click = gr.Button(value="Clear Clicks", interactive=True)
|
442 |
-
clear_button_image = gr.Button(value="Clear Image", interactive=True)
|
443 |
-
with gr.Tab("Trajectory (Beta)"):
|
444 |
-
sketcher_input = ImageSketcher(type="pil", interactive=True, brush_radius=20,
|
445 |
-
elem_id="image_sketcher")
|
446 |
-
with gr.Row():
|
447 |
-
submit_button_sketcher = gr.Button(value="Submit", interactive=True)
|
448 |
-
|
449 |
-
with gr.Column(visible=False) as modules_need_gpt:
|
450 |
-
with gr.Row(scale=1.0):
|
451 |
-
language = gr.Dropdown(
|
452 |
-
['English', 'Chinese', 'French', "Spanish", "Arabic", "Portuguese", "Cantonese"],
|
453 |
-
value="English", label="Language", interactive=True)
|
454 |
-
sentiment = gr.Radio(
|
455 |
-
choices=["Positive", "Natural", "Negative"],
|
456 |
-
value="Natural",
|
457 |
-
label="Sentiment",
|
458 |
-
interactive=True,
|
459 |
-
)
|
460 |
-
with gr.Row(scale=1.0):
|
461 |
-
factuality = gr.Radio(
|
462 |
-
choices=["Factual", "Imagination"],
|
463 |
-
value="Factual",
|
464 |
-
label="Factuality",
|
465 |
-
interactive=True,
|
466 |
-
)
|
467 |
-
length = gr.Slider(
|
468 |
-
minimum=10,
|
469 |
-
maximum=80,
|
470 |
-
value=10,
|
471 |
-
step=1,
|
472 |
-
interactive=True,
|
473 |
-
label="Generated Caption Length",
|
474 |
-
)
|
475 |
-
enable_wiki = gr.Radio(
|
476 |
-
choices=["Yes", "No"],
|
477 |
-
value="No",
|
478 |
-
label="Enable Wiki",
|
479 |
-
interactive=True)
|
480 |
-
with gr.Column(visible=True) as modules_not_need_gpt3:
|
481 |
-
gr.Examples(
|
482 |
-
examples=examples,
|
483 |
-
inputs=[example_image],
|
484 |
-
)
|
485 |
-
with gr.Column(scale=0.5):
|
486 |
-
openai_api_key = gr.Textbox(
|
487 |
-
placeholder="Input openAI API key",
|
488 |
-
show_label=False,
|
489 |
-
label="OpenAI API Key",
|
490 |
-
lines=1,
|
491 |
-
type="password")
|
492 |
-
with gr.Row(scale=0.5):
|
493 |
-
enable_chatGPT_button = gr.Button(value="Run with ChatGPT", interactive=True, variant='primary')
|
494 |
-
disable_chatGPT_button = gr.Button(value="Run without ChatGPT (Faster)", interactive=True,
|
495 |
-
variant='primary')
|
496 |
-
with gr.Column(visible=False) as modules_need_gpt2:
|
497 |
-
wiki_output = gr.Textbox(lines=5, label="Wiki", max_lines=5)
|
498 |
-
with gr.Column(visible=False) as modules_not_need_gpt2:
|
499 |
-
chatbot = gr.Chatbot(label="Chat about Selected Object", ).style(height=550, scale=0.5)
|
500 |
-
with gr.Column(visible=False) as modules_need_gpt3:
|
501 |
-
chat_input = gr.Textbox(show_label=False, placeholder="Enter text and press Enter").style(
|
502 |
-
container=False)
|
503 |
-
with gr.Row():
|
504 |
-
clear_button_text = gr.Button(value="Clear Text", interactive=True)
|
505 |
-
submit_button_text = gr.Button(value="Submit", interactive=True, variant="primary")
|
506 |
-
|
507 |
-
openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key],
|
508 |
-
outputs=[modules_need_gpt, modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt,
|
509 |
-
modules_not_need_gpt2, modules_not_need_gpt3, text_refiner])
|
510 |
-
enable_chatGPT_button.click(init_openai_api_key, inputs=[openai_api_key],
|
511 |
-
outputs=[modules_need_gpt, modules_need_gpt2, modules_need_gpt3,
|
512 |
-
modules_not_need_gpt,
|
513 |
-
modules_not_need_gpt2, modules_not_need_gpt3, text_refiner])
|
514 |
-
disable_chatGPT_button.click(init_openai_api_key,
|
515 |
-
outputs=[modules_need_gpt, modules_need_gpt2, modules_need_gpt3,
|
516 |
-
modules_not_need_gpt,
|
517 |
-
modules_not_need_gpt2, modules_not_need_gpt3, text_refiner])
|
518 |
-
|
519 |
-
clear_button_click.click(
|
520 |
-
lambda x: ([[], [], []], x, ""),
|
521 |
-
[origin_image],
|
522 |
-
[click_state, image_input, wiki_output],
|
523 |
-
queue=False,
|
524 |
-
show_progress=False
|
525 |
-
)
|
526 |
-
clear_button_image.click(
|
527 |
-
lambda: (None, [], [], [], [[], [], []], "", "", ""),
|
528 |
-
[],
|
529 |
-
[image_input, chatbot, state, chat_state, click_state, wiki_output, origin_image, img_caption],
|
530 |
-
queue=False,
|
531 |
-
show_progress=False
|
532 |
-
)
|
533 |
-
clear_button_text.click(
|
534 |
-
lambda: ([], [], [[], [], [], []], []),
|
535 |
-
[],
|
536 |
-
[chatbot, state, click_state, chat_state],
|
537 |
-
queue=False,
|
538 |
-
show_progress=False
|
539 |
-
)
|
540 |
-
image_input.clear(
|
541 |
-
lambda: (None, [], [], [], [[], [], []], "", "", ""),
|
542 |
-
[],
|
543 |
-
[image_input, chatbot, state, chat_state, click_state, wiki_output, origin_image, img_caption],
|
544 |
-
queue=False,
|
545 |
-
show_progress=False
|
546 |
-
)
|
547 |
-
|
548 |
-
image_input.upload(upload_callback, [image_input, state],
|
549 |
-
[chatbot, state, chat_state, origin_image, click_state, image_input, sketcher_input,
|
550 |
-
image_embedding, original_size, input_size, img_caption])
|
551 |
-
sketcher_input.upload(upload_callback, [sketcher_input, state],
|
552 |
-
[chatbot, state, chat_state, origin_image, click_state, image_input, sketcher_input,
|
553 |
-
image_embedding, original_size, input_size, img_caption])
|
554 |
-
chat_input.submit(chat_with_points, [chat_input, click_state, chat_state, state, text_refiner, img_caption],
|
555 |
-
[chatbot, state, chat_state])
|
556 |
-
chat_input.submit(lambda: "", None, chat_input)
|
557 |
-
example_image.change(upload_callback, [example_image, state],
|
558 |
-
[chatbot, state, chat_state, origin_image, click_state, image_input, sketcher_input,
|
559 |
-
image_embedding, original_size, input_size, img_caption])
|
560 |
-
|
561 |
-
# select coordinate
|
562 |
-
image_input.select(
|
563 |
-
inference_click,
|
564 |
-
inputs=[
|
565 |
-
origin_image, point_prompt, click_mode, enable_wiki, language, sentiment, factuality, length,
|
566 |
-
image_embedding, state, click_state, original_size, input_size, text_refiner
|
567 |
-
],
|
568 |
-
outputs=[chatbot, state, click_state, image_input, wiki_output],
|
569 |
-
show_progress=False, queue=True
|
570 |
-
)
|
571 |
-
|
572 |
-
submit_button_sketcher.click(
|
573 |
-
inference_traject,
|
574 |
-
inputs=[
|
575 |
-
sketcher_input, enable_wiki, language, sentiment, factuality, length, image_embedding, state,
|
576 |
-
original_size, input_size, text_refiner
|
577 |
-
],
|
578 |
-
outputs=[chatbot, state, sketcher_input, wiki_output],
|
579 |
-
show_progress=False, queue=True
|
580 |
-
)
|
581 |
-
|
582 |
-
return iface
|
583 |
-
|
584 |
-
|
585 |
-
if __name__ == '__main__':
|
586 |
-
iface = create_ui()
|
587 |
-
iface.queue(concurrency_count=5, api_open=False, max_size=10)
|
588 |
-
iface.launch(server_name="0.0.0.0", enable_queue=True, server_port=args.port, share=args.gradio_share)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
caption_anything/captioner/base_captioner.py
CHANGED
@@ -9,8 +9,10 @@ from typing import Union
|
|
9 |
import time
|
10 |
import clip
|
11 |
|
|
|
|
|
|
|
12 |
def boundary(inputs):
|
13 |
-
|
14 |
col = inputs.shape[1]
|
15 |
inputs = inputs.reshape(-1)
|
16 |
lens = len(inputs)
|
@@ -20,11 +22,11 @@ def boundary(inputs):
|
|
20 |
|
21 |
top = start // col
|
22 |
bottom = end // col
|
23 |
-
|
24 |
return top, bottom
|
25 |
|
|
|
26 |
def new_seg_to_box(seg_mask: Union[np.ndarray, Image.Image, str]):
|
27 |
-
|
28 |
if type(seg_mask) == str:
|
29 |
seg_mask = Image.open(seg_mask)
|
30 |
elif type(seg_mask) == np.ndarray:
|
@@ -35,12 +37,13 @@ def new_seg_to_box(seg_mask: Union[np.ndarray, Image.Image, str]):
|
|
35 |
left, right = boundary(seg_mask.T)
|
36 |
return [left / size, top / size, right / size, bottom / size]
|
37 |
|
|
|
38 |
def seg_to_box(seg_mask: Union[np.ndarray, Image.Image, str]):
|
39 |
if type(seg_mask) == str:
|
40 |
seg_mask = cv2.imread(seg_mask, cv2.IMREAD_GRAYSCALE)
|
41 |
_, seg_mask = cv2.threshold(seg_mask, 127, 255, 0)
|
42 |
elif type(seg_mask) == np.ndarray:
|
43 |
-
assert seg_mask.ndim == 2
|
44 |
seg_mask = seg_mask.astype('uint8')
|
45 |
if seg_mask.dtype == 'bool':
|
46 |
seg_mask = seg_mask * 255
|
@@ -49,25 +52,28 @@ def seg_to_box(seg_mask: Union[np.ndarray, Image.Image, str]):
|
|
49 |
rect = cv2.minAreaRect(contours)
|
50 |
box = cv2.boxPoints(rect)
|
51 |
if rect[-1] >= 45:
|
52 |
-
newstart = box.argmin(axis=0)[1]
|
53 |
else:
|
54 |
-
newstart = box.argmax(axis=0)[0]
|
55 |
box = np.concatenate([box[newstart:], box[:newstart]], axis=0)
|
56 |
box = np.int0(box)
|
57 |
return box
|
58 |
|
|
|
59 |
def get_w_h(rect_points):
|
60 |
w = np.linalg.norm(rect_points[0] - rect_points[1], ord=2).astype('int')
|
61 |
h = np.linalg.norm(rect_points[0] - rect_points[3], ord=2).astype('int')
|
62 |
return w, h
|
63 |
-
|
|
|
64 |
def cut_box(img, rect_points):
|
65 |
w, h = get_w_h(rect_points)
|
66 |
-
dst_pts = np.array([[h, 0], [h, w], [0, w], [0, 0],], dtype="float32")
|
67 |
transform = cv2.getPerspectiveTransform(rect_points.astype("float32"), dst_pts)
|
68 |
cropped_img = cv2.warpPerspective(img, transform, (h, w))
|
69 |
return cropped_img
|
70 |
-
|
|
|
71 |
class BaseCaptioner:
|
72 |
def __init__(self, device, enable_filter=False):
|
73 |
print(f"Initializing ImageCaptioning to {device}")
|
@@ -82,18 +88,15 @@ class BaseCaptioner:
|
|
82 |
|
83 |
@torch.no_grad()
|
84 |
def filter_caption(self, image: Union[np.ndarray, Image.Image, str], caption: str):
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
text_features = self.filter.encode_text(text) # (1, 512)
|
95 |
-
image_features /= image_features.norm(dim = -1, keepdim = True)
|
96 |
-
text_features /= text_features.norm(dim = -1, keepdim = True)
|
97 |
similarity = torch.matmul(image_features, text_features.transpose(1, 0)).item()
|
98 |
if similarity < self.threshold:
|
99 |
print('There seems to be nothing where you clicked.')
|
@@ -103,24 +106,21 @@ class BaseCaptioner:
|
|
103 |
print(f'Clip score of the caption is {similarity}')
|
104 |
return out
|
105 |
|
106 |
-
|
107 |
-
def inference(self, image: Union[np.ndarray, Image.Image, str], filter: bool=False):
|
108 |
raise NotImplementedError()
|
109 |
-
|
110 |
-
def inference_with_reduced_tokens(self, image: Union[np.ndarray, Image.Image, str], seg_mask, filter: bool=False):
|
111 |
raise NotImplementedError()
|
112 |
-
|
113 |
def inference_box(self, image: Union[np.ndarray, Image.Image, str], box: Union[list, np.ndarray], filter=False):
|
114 |
-
|
115 |
-
image = Image.open(image)
|
116 |
-
elif type(image) == np.ndarray:
|
117 |
-
image = Image.fromarray(image)
|
118 |
|
119 |
-
if np.array(box).size == 4:
|
|
|
120 |
size = max(image.width, image.height)
|
121 |
x1, y1, x2, y2 = box
|
122 |
-
image_crop = np.array(image.crop((x1 * size, y1 * size, x2 * size, y2 * size)))
|
123 |
-
elif np.array(box).size == 8:
|
124 |
image_crop = cut_box(np.array(image), box)
|
125 |
|
126 |
crop_save_path = f'result/crop_{time.time()}.png'
|
@@ -128,24 +128,20 @@ class BaseCaptioner:
|
|
128 |
print(f'croped image saved in {crop_save_path}')
|
129 |
caption = self.inference(image_crop, filter)
|
130 |
return caption, crop_save_path
|
131 |
-
|
132 |
|
133 |
-
def inference_seg(self, image: Union[np.ndarray, str], seg_mask: Union[np.ndarray, Image.Image, str]
|
|
|
134 |
if seg_mask is None:
|
135 |
seg_mask = np.ones(image.size).astype(bool)
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
if type(seg_mask) == str:
|
140 |
-
seg_mask = Image.open(seg_mask)
|
141 |
-
elif type(seg_mask) == np.ndarray:
|
142 |
-
seg_mask = Image.fromarray(seg_mask)
|
143 |
|
144 |
seg_mask = seg_mask.resize(image.size)
|
145 |
seg_mask = np.array(seg_mask) > 0
|
146 |
-
|
147 |
-
if crop_mode=="wo_bg":
|
148 |
-
image = np.array(image) * seg_mask[
|
149 |
image = np.uint8(image)
|
150 |
else:
|
151 |
image = np.array(image)
|
@@ -155,20 +151,17 @@ class BaseCaptioner:
|
|
155 |
else:
|
156 |
min_area_box = new_seg_to_box(seg_mask)
|
157 |
return self.inference_box(image, min_area_box, filter)
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
seg_mask = Image.open(seg_mask)
|
165 |
-
elif type(seg_mask) == np.ndarray:
|
166 |
-
seg_mask = Image.fromarray(seg_mask)
|
167 |
seg_mask = seg_mask.resize(image.size)
|
168 |
seg_mask = np.array(seg_mask) > 0
|
169 |
|
170 |
-
if crop_mode=="wo_bg":
|
171 |
-
image = np.array(image) * seg_mask[
|
172 |
else:
|
173 |
image = np.array(image)
|
174 |
|
@@ -176,24 +169,24 @@ class BaseCaptioner:
|
|
176 |
box = seg_to_box(seg_mask)
|
177 |
else:
|
178 |
box = new_seg_to_box(seg_mask)
|
179 |
-
|
180 |
-
if np.array(box).size == 4:
|
|
|
181 |
size = max(image.shape[0], image.shape[1])
|
182 |
x1, y1, x2, y2 = box
|
183 |
-
image_crop = np.array(image.crop((x1 * size, y1 * size, x2 * size, y2 * size)))
|
184 |
-
elif np.array(box).size == 8:
|
185 |
image_crop = cut_box(np.array(image), box)
|
186 |
crop_save_path = f'result/crop_{time.time()}.png'
|
187 |
Image.fromarray(image_crop).save(crop_save_path)
|
188 |
print(f'croped image saved in {crop_save_path}')
|
189 |
return crop_save_path
|
190 |
|
191 |
-
|
192 |
if __name__ == '__main__':
|
193 |
model = BaseCaptioner(device='cuda:0')
|
194 |
image_path = 'test_images/img2.jpg'
|
195 |
-
seg_mask = np.zeros((15,15))
|
196 |
seg_mask[5:10, 5:10] = 1
|
197 |
seg_mask = 'image/SAM/img10.jpg.raw_mask.png'
|
198 |
print(model.inference_seg(image_path, seg_mask))
|
199 |
-
|
|
|
9 |
import time
|
10 |
import clip
|
11 |
|
12 |
+
from caption_anything.utils.utils import load_image
|
13 |
+
|
14 |
+
|
15 |
def boundary(inputs):
|
|
|
16 |
col = inputs.shape[1]
|
17 |
inputs = inputs.reshape(-1)
|
18 |
lens = len(inputs)
|
|
|
22 |
|
23 |
top = start // col
|
24 |
bottom = end // col
|
25 |
+
|
26 |
return top, bottom
|
27 |
|
28 |
+
|
29 |
def new_seg_to_box(seg_mask: Union[np.ndarray, Image.Image, str]):
|
|
|
30 |
if type(seg_mask) == str:
|
31 |
seg_mask = Image.open(seg_mask)
|
32 |
elif type(seg_mask) == np.ndarray:
|
|
|
37 |
left, right = boundary(seg_mask.T)
|
38 |
return [left / size, top / size, right / size, bottom / size]
|
39 |
|
40 |
+
|
41 |
def seg_to_box(seg_mask: Union[np.ndarray, Image.Image, str]):
|
42 |
if type(seg_mask) == str:
|
43 |
seg_mask = cv2.imread(seg_mask, cv2.IMREAD_GRAYSCALE)
|
44 |
_, seg_mask = cv2.threshold(seg_mask, 127, 255, 0)
|
45 |
elif type(seg_mask) == np.ndarray:
|
46 |
+
assert seg_mask.ndim == 2 # only support single-channel segmentation mask
|
47 |
seg_mask = seg_mask.astype('uint8')
|
48 |
if seg_mask.dtype == 'bool':
|
49 |
seg_mask = seg_mask * 255
|
|
|
52 |
rect = cv2.minAreaRect(contours)
|
53 |
box = cv2.boxPoints(rect)
|
54 |
if rect[-1] >= 45:
|
55 |
+
newstart = box.argmin(axis=0)[1] # leftmost
|
56 |
else:
|
57 |
+
newstart = box.argmax(axis=0)[0] # topmost
|
58 |
box = np.concatenate([box[newstart:], box[:newstart]], axis=0)
|
59 |
box = np.int0(box)
|
60 |
return box
|
61 |
|
62 |
+
|
63 |
def get_w_h(rect_points):
|
64 |
w = np.linalg.norm(rect_points[0] - rect_points[1], ord=2).astype('int')
|
65 |
h = np.linalg.norm(rect_points[0] - rect_points[3], ord=2).astype('int')
|
66 |
return w, h
|
67 |
+
|
68 |
+
|
69 |
def cut_box(img, rect_points):
|
70 |
w, h = get_w_h(rect_points)
|
71 |
+
dst_pts = np.array([[h, 0], [h, w], [0, w], [0, 0], ], dtype="float32")
|
72 |
transform = cv2.getPerspectiveTransform(rect_points.astype("float32"), dst_pts)
|
73 |
cropped_img = cv2.warpPerspective(img, transform, (h, w))
|
74 |
return cropped_img
|
75 |
+
|
76 |
+
|
77 |
class BaseCaptioner:
|
78 |
def __init__(self, device, enable_filter=False):
|
79 |
print(f"Initializing ImageCaptioning to {device}")
|
|
|
88 |
|
89 |
@torch.no_grad()
|
90 |
def filter_caption(self, image: Union[np.ndarray, Image.Image, str], caption: str):
|
91 |
+
|
92 |
+
image = load_image(image, return_type='pil')
|
93 |
+
|
94 |
+
image = self.preprocess(image).unsqueeze(0).to(self.device) # (1, 3, 224, 224)
|
95 |
+
text = clip.tokenize(caption).to(self.device) # (1, 77)
|
96 |
+
image_features = self.filter.encode_image(image) # (1, 512)
|
97 |
+
text_features = self.filter.encode_text(text) # (1, 512)
|
98 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
99 |
+
text_features /= text_features.norm(dim=-1, keepdim=True)
|
|
|
|
|
|
|
100 |
similarity = torch.matmul(image_features, text_features.transpose(1, 0)).item()
|
101 |
if similarity < self.threshold:
|
102 |
print('There seems to be nothing where you clicked.')
|
|
|
106 |
print(f'Clip score of the caption is {similarity}')
|
107 |
return out
|
108 |
|
109 |
+
def inference(self, image: Union[np.ndarray, Image.Image, str], filter: bool = False):
|
|
|
110 |
raise NotImplementedError()
|
111 |
+
|
112 |
+
def inference_with_reduced_tokens(self, image: Union[np.ndarray, Image.Image, str], seg_mask, filter: bool = False):
|
113 |
raise NotImplementedError()
|
114 |
+
|
115 |
def inference_box(self, image: Union[np.ndarray, Image.Image, str], box: Union[list, np.ndarray], filter=False):
|
116 |
+
image = load_image(image, return_type="pil")
|
|
|
|
|
|
|
117 |
|
118 |
+
if np.array(box).size == 4:
|
119 |
+
# [x0, y0, x1, y1], where (x0, y0), (x1, y1) represent top-left and bottom-right corners
|
120 |
size = max(image.width, image.height)
|
121 |
x1, y1, x2, y2 = box
|
122 |
+
image_crop = np.array(image.crop((x1 * size, y1 * size, x2 * size, y2 * size)))
|
123 |
+
elif np.array(box).size == 8: # four corners of an irregular rectangle
|
124 |
image_crop = cut_box(np.array(image), box)
|
125 |
|
126 |
crop_save_path = f'result/crop_{time.time()}.png'
|
|
|
128 |
print(f'croped image saved in {crop_save_path}')
|
129 |
caption = self.inference(image_crop, filter)
|
130 |
return caption, crop_save_path
|
|
|
131 |
|
132 |
+
def inference_seg(self, image: Union[np.ndarray, str], seg_mask: Union[np.ndarray, Image.Image, str] = None,
|
133 |
+
crop_mode="w_bg", filter=False, disable_regular_box=False):
|
134 |
if seg_mask is None:
|
135 |
seg_mask = np.ones(image.size).astype(bool)
|
136 |
+
|
137 |
+
image = load_image(image, return_type="pil")
|
138 |
+
seg_mask = load_image(seg_mask, return_type="pil")
|
|
|
|
|
|
|
|
|
139 |
|
140 |
seg_mask = seg_mask.resize(image.size)
|
141 |
seg_mask = np.array(seg_mask) > 0
|
142 |
+
|
143 |
+
if crop_mode == "wo_bg":
|
144 |
+
image = np.array(image) * seg_mask[:, :, np.newaxis] + (1 - seg_mask[:, :, np.newaxis]) * 255
|
145 |
image = np.uint8(image)
|
146 |
else:
|
147 |
image = np.array(image)
|
|
|
151 |
else:
|
152 |
min_area_box = new_seg_to_box(seg_mask)
|
153 |
return self.inference_box(image, min_area_box, filter)
|
154 |
+
|
155 |
+
def generate_seg_cropped_image(self, image: Union[np.ndarray, str], seg_mask: Union[np.ndarray, Image.Image, str],
|
156 |
+
crop_mode="w_bg", disable_regular_box=False):
|
157 |
+
image = load_image(image, return_type="pil")
|
158 |
+
seg_mask = load_image(seg_mask, return_type="pil")
|
159 |
+
|
|
|
|
|
|
|
160 |
seg_mask = seg_mask.resize(image.size)
|
161 |
seg_mask = np.array(seg_mask) > 0
|
162 |
|
163 |
+
if crop_mode == "wo_bg":
|
164 |
+
image = np.array(image) * seg_mask[:, :, np.newaxis] + (1 - seg_mask[:, :, np.newaxis]) * 255
|
165 |
else:
|
166 |
image = np.array(image)
|
167 |
|
|
|
169 |
box = seg_to_box(seg_mask)
|
170 |
else:
|
171 |
box = new_seg_to_box(seg_mask)
|
172 |
+
|
173 |
+
if np.array(box).size == 4:
|
174 |
+
# [x0, y0, x1, y1], where (x0, y0), (x1, y1) represent top-left and bottom-right corners
|
175 |
size = max(image.shape[0], image.shape[1])
|
176 |
x1, y1, x2, y2 = box
|
177 |
+
image_crop = np.array(image.crop((x1 * size, y1 * size, x2 * size, y2 * size)))
|
178 |
+
elif np.array(box).size == 8: # four corners of an irregular rectangle
|
179 |
image_crop = cut_box(np.array(image), box)
|
180 |
crop_save_path = f'result/crop_{time.time()}.png'
|
181 |
Image.fromarray(image_crop).save(crop_save_path)
|
182 |
print(f'croped image saved in {crop_save_path}')
|
183 |
return crop_save_path
|
184 |
|
185 |
+
|
186 |
if __name__ == '__main__':
|
187 |
model = BaseCaptioner(device='cuda:0')
|
188 |
image_path = 'test_images/img2.jpg'
|
189 |
+
seg_mask = np.zeros((15, 15))
|
190 |
seg_mask[5:10, 5:10] = 1
|
191 |
seg_mask = 'image/SAM/img10.jpg.raw_mask.png'
|
192 |
print(model.inference_seg(image_path, seg_mask))
|
|
caption_anything/captioner/blip.py
CHANGED
@@ -1,14 +1,13 @@
|
|
1 |
import torch
|
2 |
-
from PIL import Image
|
3 |
from transformers import BlipProcessor
|
|
|
|
|
4 |
from .modeling_blip import BlipForConditionalGeneration
|
5 |
-
import json
|
6 |
-
import pdb
|
7 |
-
import cv2
|
8 |
import numpy as np
|
9 |
from typing import Union
|
10 |
from .base_captioner import BaseCaptioner
|
11 |
-
import torchvision.transforms.functional as F
|
12 |
|
13 |
|
14 |
class BLIPCaptioner(BaseCaptioner):
|
@@ -17,12 +16,12 @@ class BLIPCaptioner(BaseCaptioner):
|
|
17 |
self.device = device
|
18 |
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
19 |
self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
|
20 |
-
self.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large",
|
21 |
-
|
|
|
22 |
@torch.no_grad()
|
23 |
def inference(self, image: Union[np.ndarray, Image.Image, str], filter=False):
|
24 |
-
|
25 |
-
image = Image.open(image)
|
26 |
inputs = self.processor(image, return_tensors="pt").to(self.device, self.torch_dtype)
|
27 |
out = self.model.generate(**inputs, max_new_tokens=50)
|
28 |
captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
|
@@ -30,12 +29,13 @@ class BLIPCaptioner(BaseCaptioner):
|
|
30 |
captions = self.filter_caption(image, captions)
|
31 |
print(f"\nProcessed ImageCaptioning by BLIPCaptioner, Output Text: {captions}")
|
32 |
return captions
|
33 |
-
|
34 |
@torch.no_grad()
|
35 |
-
def inference_with_reduced_tokens(self, image: Union[np.ndarray, Image.Image, str], seg_mask, crop_mode="w_bg",
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
39 |
inputs = self.processor(image, return_tensors="pt")
|
40 |
pixel_values = inputs.pixel_values.to(self.device, self.torch_dtype)
|
41 |
_, _, H, W = pixel_values.shape
|
@@ -56,11 +56,10 @@ if __name__ == '__main__':
|
|
56 |
model = BLIPCaptioner(device='cuda:0')
|
57 |
# image_path = 'test_images/img2.jpg'
|
58 |
image_path = 'image/SAM/img10.jpg'
|
59 |
-
seg_mask = np.zeros((15,15))
|
60 |
seg_mask[5:10, 5:10] = 1
|
61 |
seg_mask = 'test_images/img10.jpg.raw_mask.png'
|
62 |
image_path = 'test_images/img2.jpg'
|
63 |
seg_mask = 'test_images/img2.jpg.raw_mask.png'
|
64 |
print(f'process image {image_path}')
|
65 |
print(model.inference_with_reduced_tokens(image_path, seg_mask))
|
66 |
-
|
|
|
1 |
import torch
|
2 |
+
from PIL import Image
|
3 |
from transformers import BlipProcessor
|
4 |
+
|
5 |
+
from caption_anything.utils.utils import load_image
|
6 |
from .modeling_blip import BlipForConditionalGeneration
|
|
|
|
|
|
|
7 |
import numpy as np
|
8 |
from typing import Union
|
9 |
from .base_captioner import BaseCaptioner
|
10 |
+
import torchvision.transforms.functional as F
|
11 |
|
12 |
|
13 |
class BLIPCaptioner(BaseCaptioner):
|
|
|
16 |
self.device = device
|
17 |
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
18 |
self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
|
19 |
+
self.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large",
|
20 |
+
torch_dtype=self.torch_dtype).to(self.device)
|
21 |
+
|
22 |
@torch.no_grad()
|
23 |
def inference(self, image: Union[np.ndarray, Image.Image, str], filter=False):
|
24 |
+
image = load_image(image, return_type="pil")
|
|
|
25 |
inputs = self.processor(image, return_tensors="pt").to(self.device, self.torch_dtype)
|
26 |
out = self.model.generate(**inputs, max_new_tokens=50)
|
27 |
captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
|
|
|
29 |
captions = self.filter_caption(image, captions)
|
30 |
print(f"\nProcessed ImageCaptioning by BLIPCaptioner, Output Text: {captions}")
|
31 |
return captions
|
32 |
+
|
33 |
@torch.no_grad()
|
34 |
+
def inference_with_reduced_tokens(self, image: Union[np.ndarray, Image.Image, str], seg_mask, crop_mode="w_bg",
|
35 |
+
filter=False, disable_regular_box=False):
|
36 |
+
crop_save_path = self.generate_seg_cropped_image(image=image, seg_mask=seg_mask, crop_mode=crop_mode,
|
37 |
+
disable_regular_box=disable_regular_box)
|
38 |
+
image = load_image(image, return_type="pil")
|
39 |
inputs = self.processor(image, return_tensors="pt")
|
40 |
pixel_values = inputs.pixel_values.to(self.device, self.torch_dtype)
|
41 |
_, _, H, W = pixel_values.shape
|
|
|
56 |
model = BLIPCaptioner(device='cuda:0')
|
57 |
# image_path = 'test_images/img2.jpg'
|
58 |
image_path = 'image/SAM/img10.jpg'
|
59 |
+
seg_mask = np.zeros((15, 15))
|
60 |
seg_mask[5:10, 5:10] = 1
|
61 |
seg_mask = 'test_images/img10.jpg.raw_mask.png'
|
62 |
image_path = 'test_images/img2.jpg'
|
63 |
seg_mask = 'test_images/img2.jpg.raw_mask.png'
|
64 |
print(f'process image {image_path}')
|
65 |
print(model.inference_with_reduced_tokens(image_path, seg_mask))
|
|
caption_anything/captioner/blip2.py
CHANGED
@@ -4,7 +4,7 @@ import numpy as np
|
|
4 |
from typing import Union
|
5 |
from transformers import AutoProcessor, Blip2ForConditionalGeneration
|
6 |
|
7 |
-
from caption_anything.utils.utils import is_platform_win
|
8 |
from .base_captioner import BaseCaptioner
|
9 |
|
10 |
class BLIP2Captioner(BaseCaptioner):
|
@@ -21,11 +21,10 @@ class BLIP2Captioner(BaseCaptioner):
|
|
21 |
|
22 |
@torch.no_grad()
|
23 |
def inference(self, image: Union[np.ndarray, Image.Image, str], filter=False):
|
24 |
-
|
25 |
-
image = Image.open(image)
|
26 |
|
27 |
if not self.dialogue:
|
28 |
-
text_prompt = '
|
29 |
inputs = self.processor(image, text = text_prompt, return_tensors="pt").to(self.device, self.torch_dtype)
|
30 |
out = self.model.generate(**inputs, max_new_tokens=50)
|
31 |
captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
|
|
|
4 |
from typing import Union
|
5 |
from transformers import AutoProcessor, Blip2ForConditionalGeneration
|
6 |
|
7 |
+
from caption_anything.utils.utils import is_platform_win, load_image
|
8 |
from .base_captioner import BaseCaptioner
|
9 |
|
10 |
class BLIP2Captioner(BaseCaptioner):
|
|
|
21 |
|
22 |
@torch.no_grad()
|
23 |
def inference(self, image: Union[np.ndarray, Image.Image, str], filter=False):
|
24 |
+
image = load_image(image, return_type="pil")
|
|
|
25 |
|
26 |
if not self.dialogue:
|
27 |
+
text_prompt = 'The image shows'
|
28 |
inputs = self.processor(image, text = text_prompt, return_tensors="pt").to(self.device, self.torch_dtype)
|
29 |
out = self.model.generate(**inputs, max_new_tokens=50)
|
30 |
captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
|
caption_anything/captioner/git.py
CHANGED
@@ -1,4 +1,6 @@
|
|
1 |
from transformers import GitProcessor, AutoProcessor
|
|
|
|
|
2 |
from .modeling_git import GitForCausalLM
|
3 |
from PIL import Image
|
4 |
import torch
|
@@ -15,11 +17,10 @@ class GITCaptioner(BaseCaptioner):
|
|
15 |
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
16 |
self.processor = AutoProcessor.from_pretrained("microsoft/git-large")
|
17 |
self.model = GitForCausalLM.from_pretrained("microsoft/git-large", torch_dtype=self.torch_dtype).to(self.device)
|
18 |
-
|
19 |
@torch.no_grad()
|
20 |
def inference(self, image: Union[np.ndarray, Image.Image, str], filter=False):
|
21 |
-
|
22 |
-
image = Image.open(image)
|
23 |
pixel_values = self.processor(images=image, return_tensors="pt").pixel_values.to(self.device, self.torch_dtype)
|
24 |
generated_ids = self.model.generate(pixel_values=pixel_values, max_new_tokens=50)
|
25 |
generated_caption = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
|
@@ -29,10 +30,11 @@ class GITCaptioner(BaseCaptioner):
|
|
29 |
return generated_caption
|
30 |
|
31 |
@torch.no_grad()
|
32 |
-
def inference_with_reduced_tokens(self, image: Union[np.ndarray, Image.Image, str], seg_mask, crop_mode="w_bg",
|
33 |
-
|
34 |
-
|
35 |
-
|
|
|
36 |
inputs = self.processor(images=image, return_tensors="pt")
|
37 |
pixel_values = inputs.pixel_values.to(self.device, self.torch_dtype)
|
38 |
_, _, H, W = pixel_values.shape
|
@@ -48,10 +50,11 @@ class GITCaptioner(BaseCaptioner):
|
|
48 |
print(f"\nProcessed ImageCaptioning by BLIPCaptioner, Output Text: {captions}")
|
49 |
return captions, crop_save_path
|
50 |
|
|
|
51 |
if __name__ == '__main__':
|
52 |
model = GITCaptioner(device='cuda:2', enable_filter=False)
|
53 |
image_path = 'test_images/img2.jpg'
|
54 |
-
seg_mask = np.zeros((224,224))
|
55 |
seg_mask[50:200, 50:200] = 1
|
56 |
print(f'process image {image_path}')
|
57 |
-
print(model.inference_with_reduced_tokens(image_path, seg_mask))
|
|
|
1 |
from transformers import GitProcessor, AutoProcessor
|
2 |
+
|
3 |
+
from caption_anything.utils.utils import load_image
|
4 |
from .modeling_git import GitForCausalLM
|
5 |
from PIL import Image
|
6 |
import torch
|
|
|
17 |
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
18 |
self.processor = AutoProcessor.from_pretrained("microsoft/git-large")
|
19 |
self.model = GitForCausalLM.from_pretrained("microsoft/git-large", torch_dtype=self.torch_dtype).to(self.device)
|
20 |
+
|
21 |
@torch.no_grad()
|
22 |
def inference(self, image: Union[np.ndarray, Image.Image, str], filter=False):
|
23 |
+
image = load_image(image, return_type="pil")
|
|
|
24 |
pixel_values = self.processor(images=image, return_tensors="pt").pixel_values.to(self.device, self.torch_dtype)
|
25 |
generated_ids = self.model.generate(pixel_values=pixel_values, max_new_tokens=50)
|
26 |
generated_caption = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
|
|
|
30 |
return generated_caption
|
31 |
|
32 |
@torch.no_grad()
|
33 |
+
def inference_with_reduced_tokens(self, image: Union[np.ndarray, Image.Image, str], seg_mask, crop_mode="w_bg",
|
34 |
+
filter=False, disable_regular_box=False):
|
35 |
+
crop_save_path = self.generate_seg_cropped_image(image=image, seg_mask=seg_mask, crop_mode=crop_mode,
|
36 |
+
disable_regular_box=disable_regular_box)
|
37 |
+
image = load_image(image, return_type="pil")
|
38 |
inputs = self.processor(images=image, return_tensors="pt")
|
39 |
pixel_values = inputs.pixel_values.to(self.device, self.torch_dtype)
|
40 |
_, _, H, W = pixel_values.shape
|
|
|
50 |
print(f"\nProcessed ImageCaptioning by BLIPCaptioner, Output Text: {captions}")
|
51 |
return captions, crop_save_path
|
52 |
|
53 |
+
|
54 |
if __name__ == '__main__':
|
55 |
model = GITCaptioner(device='cuda:2', enable_filter=False)
|
56 |
image_path = 'test_images/img2.jpg'
|
57 |
+
seg_mask = np.zeros((224, 224))
|
58 |
seg_mask[50:200, 50:200] = 1
|
59 |
print(f'process image {image_path}')
|
60 |
+
print(model.inference_with_reduced_tokens(image_path, seg_mask))
|
caption_anything/model.py
CHANGED
@@ -62,9 +62,12 @@ class CaptionAnything:
|
|
62 |
print('OpenAI GPT is not available')
|
63 |
|
64 |
def inference(self, image, prompt, controls, disable_gpt=False, enable_wiki=False):
|
|
|
|
|
65 |
# segment with prompt
|
66 |
print("CA prompt: ", prompt, "CA controls", controls)
|
67 |
seg_mask = self.segmenter.inference(image, prompt)[0, ...]
|
|
|
68 |
if self.args.enable_morphologyex:
|
69 |
seg_mask = 255 * seg_mask.astype(np.uint8)
|
70 |
seg_mask = np.stack([seg_mask, seg_mask, seg_mask], axis=-1)
|
@@ -80,6 +83,7 @@ class CaptionAnything:
|
|
80 |
seg_mask_img.save(mask_save_path)
|
81 |
print('seg_mask path: ', mask_save_path)
|
82 |
print("seg_mask.shape: ", seg_mask.shape)
|
|
|
83 |
# captioning with mask
|
84 |
if self.args.enable_reduce_tokens:
|
85 |
caption, crop_save_path = self.captioner. \
|
@@ -92,6 +96,7 @@ class CaptionAnything:
|
|
92 |
inference_seg(image, seg_mask, crop_mode=self.args.seg_crop_mode,
|
93 |
filter=self.args.clip_filter,
|
94 |
disable_regular_box=self.args.disable_regular_box)
|
|
|
95 |
# refining with TextRefiner
|
96 |
context_captions = []
|
97 |
if self.args.context_captions:
|
@@ -111,6 +116,7 @@ class CaptionAnything:
|
|
111 |
|
112 |
if __name__ == "__main__":
|
113 |
from caption_anything.utils.parser import parse_augment
|
|
|
114 |
args = parse_augment()
|
115 |
# image_path = 'test_images/img3.jpg'
|
116 |
image_path = 'test_images/img1.jpg'
|
|
|
62 |
print('OpenAI GPT is not available')
|
63 |
|
64 |
def inference(self, image, prompt, controls, disable_gpt=False, enable_wiki=False):
|
65 |
+
# TODO: Add support to multiple seg masks.
|
66 |
+
|
67 |
# segment with prompt
|
68 |
print("CA prompt: ", prompt, "CA controls", controls)
|
69 |
seg_mask = self.segmenter.inference(image, prompt)[0, ...]
|
70 |
+
|
71 |
if self.args.enable_morphologyex:
|
72 |
seg_mask = 255 * seg_mask.astype(np.uint8)
|
73 |
seg_mask = np.stack([seg_mask, seg_mask, seg_mask], axis=-1)
|
|
|
83 |
seg_mask_img.save(mask_save_path)
|
84 |
print('seg_mask path: ', mask_save_path)
|
85 |
print("seg_mask.shape: ", seg_mask.shape)
|
86 |
+
|
87 |
# captioning with mask
|
88 |
if self.args.enable_reduce_tokens:
|
89 |
caption, crop_save_path = self.captioner. \
|
|
|
96 |
inference_seg(image, seg_mask, crop_mode=self.args.seg_crop_mode,
|
97 |
filter=self.args.clip_filter,
|
98 |
disable_regular_box=self.args.disable_regular_box)
|
99 |
+
|
100 |
# refining with TextRefiner
|
101 |
context_captions = []
|
102 |
if self.args.context_captions:
|
|
|
116 |
|
117 |
if __name__ == "__main__":
|
118 |
from caption_anything.utils.parser import parse_augment
|
119 |
+
|
120 |
args = parse_augment()
|
121 |
# image_path = 'test_images/img3.jpg'
|
122 |
image_path = 'test_images/img1.jpg'
|
caption_anything/segmenter/base_segmenter.py
CHANGED
@@ -5,7 +5,7 @@ from PIL import Image, ImageDraw, ImageOps
|
|
5 |
import numpy as np
|
6 |
from typing import Union
|
7 |
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
|
8 |
-
from caption_anything.utils.utils import prepare_segmenter, seg_model_map
|
9 |
import matplotlib.pyplot as plt
|
10 |
import PIL
|
11 |
|
@@ -30,21 +30,9 @@ class BaseSegmenter:
|
|
30 |
self.image_embedding = None
|
31 |
self.image = None
|
32 |
|
33 |
-
def read_image(self, image: Union[np.ndarray, Image.Image, str]):
|
34 |
-
if type(image) == str: # input path
|
35 |
-
image = Image.open(image)
|
36 |
-
image = np.array(image)
|
37 |
-
elif type(image) == Image.Image:
|
38 |
-
image = np.array(image)
|
39 |
-
elif type(image) == np.ndarray:
|
40 |
-
image = image
|
41 |
-
else:
|
42 |
-
raise TypeError
|
43 |
-
return image
|
44 |
-
|
45 |
@torch.no_grad()
|
46 |
def set_image(self, image: Union[np.ndarray, Image.Image, str]):
|
47 |
-
image =
|
48 |
self.image = image
|
49 |
if self.reuse_feature:
|
50 |
self.predictor.set_image(image)
|
@@ -57,7 +45,7 @@ class BaseSegmenter:
|
|
57 |
SAM inference of image according to control.
|
58 |
Args:
|
59 |
image: str or PIL.Image or np.ndarray
|
60 |
-
control:
|
61 |
prompt_type:
|
62 |
1. {control['prompt_type'] = ['everything']} to segment everything in the image.
|
63 |
2. {control['prompt_type'] = ['click', 'box']} to segment according to click and box.
|
@@ -77,7 +65,7 @@ class BaseSegmenter:
|
|
77 |
masks: np.ndarray of shape [num_masks, height, width]
|
78 |
|
79 |
"""
|
80 |
-
image =
|
81 |
if 'everything' in control['prompt_type']:
|
82 |
masks = self.mask_generator.generate(image)
|
83 |
new_masks = np.concatenate([mask["segmentation"][np.newaxis, :] for mask in masks])
|
|
|
5 |
import numpy as np
|
6 |
from typing import Union
|
7 |
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
|
8 |
+
from caption_anything.utils.utils import prepare_segmenter, seg_model_map, load_image
|
9 |
import matplotlib.pyplot as plt
|
10 |
import PIL
|
11 |
|
|
|
30 |
self.image_embedding = None
|
31 |
self.image = None
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
@torch.no_grad()
|
34 |
def set_image(self, image: Union[np.ndarray, Image.Image, str]):
|
35 |
+
image = load_image(image, return_type='numpy')
|
36 |
self.image = image
|
37 |
if self.reuse_feature:
|
38 |
self.predictor.set_image(image)
|
|
|
45 |
SAM inference of image according to control.
|
46 |
Args:
|
47 |
image: str or PIL.Image or np.ndarray
|
48 |
+
control: dict to control SAM.
|
49 |
prompt_type:
|
50 |
1. {control['prompt_type'] = ['everything']} to segment everything in the image.
|
51 |
2. {control['prompt_type'] = ['click', 'box']} to segment according to click and box.
|
|
|
65 |
masks: np.ndarray of shape [num_masks, height, width]
|
66 |
|
67 |
"""
|
68 |
+
image = load_image(image, return_type='numpy')
|
69 |
if 'everything' in control['prompt_type']:
|
70 |
masks = self.mask_generator.generate(image)
|
71 |
new_masks = np.concatenate([mask["segmentation"][np.newaxis, :] for mask in masks])
|
caption_anything/utils/chatbot.py
CHANGED
@@ -19,22 +19,11 @@ from PIL import Image, ImageDraw, ImageOps
|
|
19 |
from transformers import pipeline, BlipProcessor, BlipForConditionalGeneration, BlipForQuestionAnswering
|
20 |
|
21 |
VISUAL_CHATGPT_PREFIX = """
|
22 |
-
Caption Anything Chatbox (short as CATchat) is designed to be able to assist with a wide range of text and visual related tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics.
|
23 |
|
24 |
-
As a language model,
|
25 |
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
# VISUAL_CHATGPT_PREFIX = """Visual ChatGPT is designed to be able to assist with a wide range of text and visual related tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. Visual ChatGPT is able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand.
|
30 |
-
|
31 |
-
# Visual ChatGPT is able to process and understand large amounts of text and images. As a language model, Visual ChatGPT can not directly read images, but it has a list of tools to finish different visual tasks. Each image will have a file name formed as "chat_image/xxx.png", and Visual ChatGPT can invoke different tools to indirectly understand pictures. When talking about images, Visual ChatGPT is very strict to the file name and will never fabricate nonexistent files. Visual ChatGPT is able to use tools in a sequence, and is loyal to the tool observation outputs rather than faking the image content and image file name.
|
32 |
-
|
33 |
-
# Visual ChatGPT is aware of the coordinate of an object in the image, which is represented as a point (X, Y) on the object. Note that (0, 0) represents the bottom-left corner of the image.
|
34 |
-
|
35 |
-
# Human may provide new figures to Visual ChatGPT with a description. The description helps Visual ChatGPT to understand this image, but Visual ChatGPT should use tools to finish following tasks, rather than directly imagine from the description.
|
36 |
-
|
37 |
-
# Overall, Visual ChatGPT is a powerful visual dialogue assistant tool that can help with a wide range of tasks and provide valuable insights and information on a wide range of topics.
|
38 |
|
39 |
|
40 |
# TOOLS:
|
@@ -63,8 +52,7 @@ Previous conversation history:
|
|
63 |
{chat_history}
|
64 |
|
65 |
New input: {input}
|
66 |
-
|
67 |
-
The thoughts and observations are only visible for CATchat, CATchat should remember to repeat important information in the final response for Human.
|
68 |
|
69 |
Thought: Do I need to use a tool? {agent_scratchpad} (You are strictly to use the aforementioned "Thought/Action/Action Input/Observation" format as the answer.)"""
|
70 |
|
@@ -111,9 +99,9 @@ class VisualQuestionAnswering:
|
|
111 |
# "Salesforce/blip-vqa-capfilt-large", torch_dtype=self.torch_dtype).to(self.device)
|
112 |
|
113 |
@prompts(name="Answer Question About The Image",
|
114 |
-
description="useful when you need an answer for a question based on an image. "
|
115 |
-
"like: what is the
|
116 |
-
"The input to this tool should be a comma separated string of two, representing the
|
117 |
def inference(self, inputs):
|
118 |
image_path, question = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
|
119 |
raw_image = Image.open(image_path).convert('RGB')
|
@@ -151,12 +139,13 @@ def build_chatbot_tools(load_dict):
|
|
151 |
class ConversationBot:
|
152 |
def __init__(self, tools, api_key=""):
|
153 |
# load_dict = {'VisualQuestionAnswering':'cuda:0', 'ImageCaptioning':'cuda:1',...}
|
154 |
-
llm = OpenAI(model_name="gpt-3.5-turbo", temperature=0, openai_api_key=api_key)
|
155 |
self.llm = llm
|
156 |
self.memory = ConversationBufferMemory(memory_key="chat_history", output_key='output')
|
157 |
self.tools = tools
|
158 |
self.current_image = None
|
159 |
self.point_prompt = ""
|
|
|
160 |
self.agent = initialize_agent(
|
161 |
self.tools,
|
162 |
self.llm,
|
@@ -212,7 +201,7 @@ if __name__ == '__main__':
|
|
212 |
bot = ConversationBot(tools)
|
213 |
with gr.Blocks(css="#chatbot .overflow-y-auto{height:500px}") as demo:
|
214 |
with gr.Row():
|
215 |
-
chatbot = gr.Chatbot(elem_id="chatbot", label="
|
216 |
auxwindow = gr.Chatbot(elem_id="chatbot", label="Aux Window").style(height=1000,scale=0.5)
|
217 |
state = gr.State([])
|
218 |
aux_state = gr.State([])
|
|
|
19 |
from transformers import pipeline, BlipProcessor, BlipForConditionalGeneration, BlipForQuestionAnswering
|
20 |
|
21 |
VISUAL_CHATGPT_PREFIX = """
|
22 |
+
I want you act as Caption Anything Chatbox (short as CATchat), which is designed to be able to assist with a wide range of text and visual related tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. You are able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand.
|
23 |
|
24 |
+
As a language model, you can not directly read images, but can invoke VQA tool to indirectly understand pictures, by repeatly asking questions about the objects and scene of the image. You should carefully asking informative questions to maximize your information about this image content. Each image will have a file name formed as "chat_image/xxx.png", you are very strict to the file name and will never fabricate nonexistent files.
|
25 |
|
26 |
+
You have access to the following tools:"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
|
29 |
# TOOLS:
|
|
|
52 |
{chat_history}
|
53 |
|
54 |
New input: {input}
|
55 |
+
As a language model, you must repeatly to use VQA tools to observe images. You response should be consistent with the outputs of the VQA tool instead of imagination. Do not repeat asking the same question.
|
|
|
56 |
|
57 |
Thought: Do I need to use a tool? {agent_scratchpad} (You are strictly to use the aforementioned "Thought/Action/Action Input/Observation" format as the answer.)"""
|
58 |
|
|
|
99 |
# "Salesforce/blip-vqa-capfilt-large", torch_dtype=self.torch_dtype).to(self.device)
|
100 |
|
101 |
@prompts(name="Answer Question About The Image",
|
102 |
+
description="VQA tool is useful when you need an answer for a question based on an image. "
|
103 |
+
"like: what is the color of an object, how many cats in this figure, where is the child sitting, what does the cat doing, why is he laughing."
|
104 |
+
"The input to this tool should be a comma separated string of two, representing the image path and the question.")
|
105 |
def inference(self, inputs):
|
106 |
image_path, question = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
|
107 |
raw_image = Image.open(image_path).convert('RGB')
|
|
|
139 |
class ConversationBot:
|
140 |
def __init__(self, tools, api_key=""):
|
141 |
# load_dict = {'VisualQuestionAnswering':'cuda:0', 'ImageCaptioning':'cuda:1',...}
|
142 |
+
llm = OpenAI(model_name="gpt-3.5-turbo", temperature=0.7, openai_api_key=api_key)
|
143 |
self.llm = llm
|
144 |
self.memory = ConversationBufferMemory(memory_key="chat_history", output_key='output')
|
145 |
self.tools = tools
|
146 |
self.current_image = None
|
147 |
self.point_prompt = ""
|
148 |
+
self.global_prompt = ""
|
149 |
self.agent = initialize_agent(
|
150 |
self.tools,
|
151 |
self.llm,
|
|
|
201 |
bot = ConversationBot(tools)
|
202 |
with gr.Blocks(css="#chatbot .overflow-y-auto{height:500px}") as demo:
|
203 |
with gr.Row():
|
204 |
+
chatbot = gr.Chatbot(elem_id="chatbot", label="CATchat").style(height=1000,scale=0.5)
|
205 |
auxwindow = gr.Chatbot(elem_id="chatbot", label="Aux Window").style(height=1000,scale=0.5)
|
206 |
state = gr.State([])
|
207 |
aux_state = gr.State([])
|
caption_anything/utils/utils.py
CHANGED
@@ -1,13 +1,41 @@
|
|
1 |
import os
|
|
|
|
|
|
|
2 |
import cv2
|
|
|
3 |
import requests
|
4 |
import numpy as np
|
|
|
|
|
|
|
5 |
from PIL import Image
|
6 |
-
import time
|
7 |
-
import sys
|
8 |
-
import urllib
|
9 |
from tqdm import tqdm
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
def is_platform_win():
|
13 |
return sys.platform == "win32"
|
@@ -114,7 +142,7 @@ def vis_add_mask(image, mask, color, alpha, kernel_size):
|
|
114 |
mask = mask.astype('float').copy()
|
115 |
mask = (cv2.GaussianBlur(mask, (kernel_size, kernel_size), kernel_size) / 255.) * (alpha)
|
116 |
for i in range(3):
|
117 |
-
image[:, :, i] = image[:, :, i] * (1-alpha+mask) + color[i] * (alpha-mask)
|
118 |
return image
|
119 |
|
120 |
|
@@ -122,11 +150,12 @@ def vis_add_mask_wo_blur(image, mask, color, alpha):
|
|
122 |
color = np.array(color)
|
123 |
mask = mask.astype('float').copy()
|
124 |
for i in range(3):
|
125 |
-
image[:, :, i] = image[:, :, i] * (1-alpha+mask) + color[i] * (alpha-mask)
|
126 |
return image
|
127 |
|
128 |
|
129 |
-
def vis_add_mask_wo_gaussian(image, background_mask, contour_mask, background_color, contour_color, background_alpha,
|
|
|
130 |
background_color = np.array(background_color)
|
131 |
contour_color = np.array(contour_color)
|
132 |
|
@@ -134,16 +163,17 @@ def vis_add_mask_wo_gaussian(image, background_mask, contour_mask, background_co
|
|
134 |
# contour_mask = 1 - contour_mask
|
135 |
|
136 |
for i in range(3):
|
137 |
-
image[:, :, i] = image[:, :, i] * (1-background_alpha+background_mask*background_alpha) \
|
138 |
-
+ background_color[i] * (background_alpha-background_mask*background_alpha)
|
139 |
|
140 |
-
image[:, :, i] = image[:, :, i] * (1-contour_alpha+contour_mask*contour_alpha) \
|
141 |
-
+ contour_color[i] * (contour_alpha-contour_mask*contour_alpha)
|
142 |
|
143 |
return image.astype('uint8')
|
144 |
|
145 |
|
146 |
-
def mask_painter(input_image, input_mask, background_alpha=0.7, background_blur_radius=7, contour_width=3,
|
|
|
147 |
"""
|
148 |
add color mask to the background/foreground area
|
149 |
input_image: numpy array (w, h, C)
|
@@ -163,23 +193,27 @@ def mask_painter(input_image, input_mask, background_alpha=0.7, background_blur_
|
|
163 |
assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD'
|
164 |
|
165 |
# 0: background, 1: foreground
|
166 |
-
input_mask[input_mask>0] = 255
|
167 |
if paint_foreground:
|
168 |
-
painted_image = vis_add_mask(input_image, 255 - input_mask, color_list[background_color], background_alpha,
|
|
|
169 |
else:
|
170 |
-
|
171 |
-
painted_image = vis_add_mask(input_image, input_mask, color_list[background_color], background_alpha,
|
|
|
172 |
# mask contour
|
173 |
contour_mask = input_mask.copy()
|
174 |
-
contour_mask = cv2.Canny(contour_mask, 100, 200)
|
175 |
# widden contour
|
176 |
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (contour_width, contour_width))
|
177 |
contour_mask = cv2.dilate(contour_mask, kernel)
|
178 |
-
painted_image = vis_add_mask(painted_image, 255-contour_mask, color_list[contour_color], contour_alpha,
|
|
|
179 |
return painted_image
|
180 |
|
181 |
|
182 |
-
def mask_painter_foreground_all(input_image, input_masks, background_alpha=0.7, background_blur_radius=7,
|
|
|
183 |
"""
|
184 |
paint color mask on the all foreground area
|
185 |
input_image: numpy array with shape (w, h, C)
|
@@ -194,22 +228,24 @@ def mask_painter_foreground_all(input_image, input_masks, background_alpha=0.7,
|
|
194 |
Output:
|
195 |
painted_image: numpy array
|
196 |
"""
|
197 |
-
|
198 |
for i, input_mask in enumerate(input_masks):
|
199 |
-
input_image = mask_painter(input_image, input_mask,
|
|
|
200 |
return input_image
|
201 |
|
|
|
202 |
def mask_generator_00(mask, background_radius, contour_radius):
|
203 |
# no background width when '00'
|
204 |
# distance map
|
205 |
dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
|
206 |
-
dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
|
207 |
dist_map = dist_transform_fore - dist_transform_back
|
208 |
# ...:::!!!:::...
|
209 |
contour_radius += 2
|
210 |
contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
|
211 |
contour_mask = contour_mask / np.max(contour_mask)
|
212 |
-
contour_mask[contour_mask>0.5] = 1.
|
213 |
|
214 |
return mask, contour_mask
|
215 |
|
@@ -218,7 +254,7 @@ def mask_generator_01(mask, background_radius, contour_radius):
|
|
218 |
# no background width when '00'
|
219 |
# distance map
|
220 |
dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
|
221 |
-
dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
|
222 |
dist_map = dist_transform_fore - dist_transform_back
|
223 |
# ...:::!!!:::...
|
224 |
contour_radius += 2
|
@@ -230,7 +266,7 @@ def mask_generator_01(mask, background_radius, contour_radius):
|
|
230 |
def mask_generator_10(mask, background_radius, contour_radius):
|
231 |
# distance map
|
232 |
dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
|
233 |
-
dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
|
234 |
dist_map = dist_transform_fore - dist_transform_back
|
235 |
# .....:::::!!!!!
|
236 |
background_mask = np.clip(dist_map, -background_radius, background_radius)
|
@@ -240,14 +276,14 @@ def mask_generator_10(mask, background_radius, contour_radius):
|
|
240 |
contour_radius += 2
|
241 |
contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
|
242 |
contour_mask = contour_mask / np.max(contour_mask)
|
243 |
-
contour_mask[contour_mask>0.5] = 1.
|
244 |
return background_mask, contour_mask
|
245 |
|
246 |
|
247 |
def mask_generator_11(mask, background_radius, contour_radius):
|
248 |
# distance map
|
249 |
dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
|
250 |
-
dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
|
251 |
dist_map = dist_transform_fore - dist_transform_back
|
252 |
# .....:::::!!!!!
|
253 |
background_mask = np.clip(dist_map, -background_radius, background_radius)
|
@@ -260,7 +296,8 @@ def mask_generator_11(mask, background_radius, contour_radius):
|
|
260 |
return background_mask, contour_mask
|
261 |
|
262 |
|
263 |
-
def mask_painter_wo_gaussian(input_image, input_mask, background_alpha=0.5, background_blur_radius=7, contour_width=3,
|
|
|
264 |
"""
|
265 |
Input:
|
266 |
input_image: numpy array
|
@@ -283,8 +320,8 @@ def mask_painter_wo_gaussian(input_image, input_mask, background_alpha=0.5, back
|
|
283 |
width, height = input_image.shape[0], input_image.shape[1]
|
284 |
res = 1024
|
285 |
ratio = min(1.0 * res / max(width, height), 1.0)
|
286 |
-
input_image = cv2.resize(input_image, (int(height*ratio), int(width*ratio)))
|
287 |
-
input_mask = cv2.resize(input_mask, (int(height*ratio), int(width*ratio)))
|
288 |
|
289 |
# 0: background, 1: foreground
|
290 |
msk = np.clip(input_mask, 0, 1)
|
@@ -292,23 +329,78 @@ def mask_painter_wo_gaussian(input_image, input_mask, background_alpha=0.5, back
|
|
292 |
# generate masks for background and contour pixels
|
293 |
background_radius = (background_blur_radius - 1) // 2
|
294 |
contour_radius = (contour_width - 1) // 2
|
295 |
-
generator_dict = {'00':mask_generator_00, '01':mask_generator_01, '10':mask_generator_10,
|
|
|
296 |
background_mask, contour_mask = generator_dict[mode](msk, background_radius, contour_radius)
|
297 |
|
298 |
# paint
|
299 |
painted_image = vis_add_mask_wo_gaussian \
|
300 |
-
(input_image, background_mask, contour_mask, color_list[0], color_list[contour_color], background_alpha,
|
|
|
301 |
|
302 |
return painted_image
|
303 |
|
304 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
305 |
if __name__ == '__main__':
|
306 |
|
307 |
-
background_alpha = 0.7
|
308 |
-
background_blur_radius = 31
|
309 |
-
contour_width = 11
|
310 |
-
contour_color = 3
|
311 |
-
contour_alpha = 1
|
312 |
|
313 |
# load input image and mask
|
314 |
input_image = np.array(Image.open('./test_images/painter_input_image.jpg').convert('RGB'))
|
@@ -323,23 +415,28 @@ if __name__ == '__main__':
|
|
323 |
|
324 |
for i in range(50):
|
325 |
t2 = time.time()
|
326 |
-
painted_image_00 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius,
|
|
|
327 |
e2 = time.time()
|
328 |
|
329 |
t3 = time.time()
|
330 |
-
painted_image_10 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius,
|
|
|
331 |
e3 = time.time()
|
332 |
|
333 |
t1 = time.time()
|
334 |
-
painted_image = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width,
|
|
|
335 |
e1 = time.time()
|
336 |
|
337 |
t4 = time.time()
|
338 |
-
painted_image_01 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius,
|
|
|
339 |
e4 = time.time()
|
340 |
|
341 |
t5 = time.time()
|
342 |
-
painted_image_11 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius,
|
|
|
343 |
e5 = time.time()
|
344 |
|
345 |
overall_time_1 += (e1 - t1)
|
@@ -348,11 +445,11 @@ if __name__ == '__main__':
|
|
348 |
overall_time_4 += (e4 - t4)
|
349 |
overall_time_5 += (e5 - t5)
|
350 |
|
351 |
-
print(f'average time w gaussian: {overall_time_1/50}')
|
352 |
-
print(f'average time w/o gaussian00: {overall_time_2/50}')
|
353 |
-
print(f'average time w/o gaussian10: {overall_time_3/50}')
|
354 |
-
print(f'average time w/o gaussian01: {overall_time_4/50}')
|
355 |
-
print(f'average time w/o gaussian11: {overall_time_5/50}')
|
356 |
|
357 |
# save
|
358 |
painted_image_00 = Image.fromarray(painted_image_00)
|
@@ -366,54 +463,3 @@ if __name__ == '__main__':
|
|
366 |
|
367 |
painted_image_11 = Image.fromarray(painted_image_11)
|
368 |
painted_image_11.save('./test_images/painter_output_image_11.png')
|
369 |
-
|
370 |
-
|
371 |
-
seg_model_map = {
|
372 |
-
'base': 'vit_b',
|
373 |
-
'large': 'vit_l',
|
374 |
-
'huge': 'vit_h'
|
375 |
-
}
|
376 |
-
ckpt_url_map = {
|
377 |
-
'vit_b': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth',
|
378 |
-
'vit_l': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth',
|
379 |
-
'vit_h': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth'
|
380 |
-
}
|
381 |
-
expected_sha256_map = {
|
382 |
-
'vit_b': 'ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912',
|
383 |
-
'vit_l': '3adcc4315b642a4d2101128f611684e8734c41232a17c648ed1693702a49a622',
|
384 |
-
'vit_h': 'a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e'
|
385 |
-
}
|
386 |
-
def prepare_segmenter(segmenter = "huge", download_root: str = None):
|
387 |
-
"""
|
388 |
-
Prepare segmenter model and download checkpoint if necessary.
|
389 |
-
|
390 |
-
Returns: segmenter model name from 'vit_b', 'vit_l', 'vit_h'.
|
391 |
-
|
392 |
-
"""
|
393 |
-
|
394 |
-
os.makedirs('result', exist_ok=True)
|
395 |
-
seg_model_name = seg_model_map[segmenter]
|
396 |
-
checkpoint_url = ckpt_url_map[seg_model_name]
|
397 |
-
folder = download_root or os.path.expanduser("~/.cache/SAM")
|
398 |
-
filename = os.path.basename(checkpoint_url)
|
399 |
-
segmenter_checkpoint = download_checkpoint(checkpoint_url, folder, filename, expected_sha256_map[seg_model_name])
|
400 |
-
|
401 |
-
return seg_model_name, segmenter_checkpoint
|
402 |
-
|
403 |
-
|
404 |
-
def download_checkpoint(url, folder, filename, expected_sha256):
|
405 |
-
os.makedirs(folder, exist_ok=True)
|
406 |
-
download_target = os.path.join(folder, filename)
|
407 |
-
if os.path.isfile(download_target):
|
408 |
-
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
|
409 |
-
return download_target
|
410 |
-
|
411 |
-
print(f'Download SAM checkpoint {url}, saving to {download_target} ...')
|
412 |
-
with requests.get(url, stream=True) as response, open(download_target, "wb") as output:
|
413 |
-
progress = tqdm(total=int(response.headers.get('content-length', 0)), unit='B', unit_scale=True)
|
414 |
-
for data in response.iter_content(chunk_size=1024):
|
415 |
-
size = output.write(data)
|
416 |
-
progress.update(size)
|
417 |
-
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
|
418 |
-
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
|
419 |
-
return download_target
|
|
|
1 |
import os
|
2 |
+
import time
|
3 |
+
import sys
|
4 |
+
|
5 |
import cv2
|
6 |
+
import hashlib
|
7 |
import requests
|
8 |
import numpy as np
|
9 |
+
|
10 |
+
from typing import Union
|
11 |
+
|
12 |
from PIL import Image
|
|
|
|
|
|
|
13 |
from tqdm import tqdm
|
14 |
+
|
15 |
+
|
16 |
+
def load_image(image: Union[np.ndarray, Image.Image, str], return_type='numpy'):
|
17 |
+
"""
|
18 |
+
Load image from path or PIL.Image or numpy.ndarray to required format.
|
19 |
+
"""
|
20 |
+
|
21 |
+
# Check if image is already in return_type
|
22 |
+
if isinstance(image, Image.Image) and return_type == 'pil' or \
|
23 |
+
isinstance(image, np.ndarray) and return_type == 'numpy':
|
24 |
+
return image
|
25 |
+
|
26 |
+
# PIL.Image as intermediate format
|
27 |
+
if isinstance(image, str):
|
28 |
+
image = Image.open(image)
|
29 |
+
elif isinstance(image, np.ndarray):
|
30 |
+
image = Image.fromarray(image)
|
31 |
+
|
32 |
+
if return_type == 'pil':
|
33 |
+
return image
|
34 |
+
elif return_type == 'numpy':
|
35 |
+
return np.asarray(image)
|
36 |
+
else:
|
37 |
+
raise NotImplementedError()
|
38 |
+
|
39 |
|
40 |
def is_platform_win():
|
41 |
return sys.platform == "win32"
|
|
|
142 |
mask = mask.astype('float').copy()
|
143 |
mask = (cv2.GaussianBlur(mask, (kernel_size, kernel_size), kernel_size) / 255.) * (alpha)
|
144 |
for i in range(3):
|
145 |
+
image[:, :, i] = image[:, :, i] * (1 - alpha + mask) + color[i] * (alpha - mask)
|
146 |
return image
|
147 |
|
148 |
|
|
|
150 |
color = np.array(color)
|
151 |
mask = mask.astype('float').copy()
|
152 |
for i in range(3):
|
153 |
+
image[:, :, i] = image[:, :, i] * (1 - alpha + mask) + color[i] * (alpha - mask)
|
154 |
return image
|
155 |
|
156 |
|
157 |
+
def vis_add_mask_wo_gaussian(image, background_mask, contour_mask, background_color, contour_color, background_alpha,
|
158 |
+
contour_alpha):
|
159 |
background_color = np.array(background_color)
|
160 |
contour_color = np.array(contour_color)
|
161 |
|
|
|
163 |
# contour_mask = 1 - contour_mask
|
164 |
|
165 |
for i in range(3):
|
166 |
+
image[:, :, i] = image[:, :, i] * (1 - background_alpha + background_mask * background_alpha) \
|
167 |
+
+ background_color[i] * (background_alpha - background_mask * background_alpha)
|
168 |
|
169 |
+
image[:, :, i] = image[:, :, i] * (1 - contour_alpha + contour_mask * contour_alpha) \
|
170 |
+
+ contour_color[i] * (contour_alpha - contour_mask * contour_alpha)
|
171 |
|
172 |
return image.astype('uint8')
|
173 |
|
174 |
|
175 |
+
def mask_painter(input_image, input_mask, background_alpha=0.7, background_blur_radius=7, contour_width=3,
|
176 |
+
contour_color=3, contour_alpha=1, background_color=0, paint_foreground=False):
|
177 |
"""
|
178 |
add color mask to the background/foreground area
|
179 |
input_image: numpy array (w, h, C)
|
|
|
193 |
assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD'
|
194 |
|
195 |
# 0: background, 1: foreground
|
196 |
+
input_mask[input_mask > 0] = 255
|
197 |
if paint_foreground:
|
198 |
+
painted_image = vis_add_mask(input_image, 255 - input_mask, color_list[background_color], background_alpha,
|
199 |
+
background_blur_radius) # black for background
|
200 |
else:
|
201 |
+
# mask background
|
202 |
+
painted_image = vis_add_mask(input_image, input_mask, color_list[background_color], background_alpha,
|
203 |
+
background_blur_radius) # black for background
|
204 |
# mask contour
|
205 |
contour_mask = input_mask.copy()
|
206 |
+
contour_mask = cv2.Canny(contour_mask, 100, 200) # contour extraction
|
207 |
# widden contour
|
208 |
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (contour_width, contour_width))
|
209 |
contour_mask = cv2.dilate(contour_mask, kernel)
|
210 |
+
painted_image = vis_add_mask(painted_image, 255 - contour_mask, color_list[contour_color], contour_alpha,
|
211 |
+
contour_width)
|
212 |
return painted_image
|
213 |
|
214 |
|
215 |
+
def mask_painter_foreground_all(input_image, input_masks, background_alpha=0.7, background_blur_radius=7,
|
216 |
+
contour_width=3, contour_color=3, contour_alpha=1):
|
217 |
"""
|
218 |
paint color mask on the all foreground area
|
219 |
input_image: numpy array with shape (w, h, C)
|
|
|
228 |
Output:
|
229 |
painted_image: numpy array
|
230 |
"""
|
231 |
+
|
232 |
for i, input_mask in enumerate(input_masks):
|
233 |
+
input_image = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width,
|
234 |
+
contour_color, contour_alpha, background_color=i + 2, paint_foreground=True)
|
235 |
return input_image
|
236 |
|
237 |
+
|
238 |
def mask_generator_00(mask, background_radius, contour_radius):
|
239 |
# no background width when '00'
|
240 |
# distance map
|
241 |
dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
|
242 |
+
dist_transform_back = cv2.distanceTransform(1 - mask, cv2.DIST_L2, 3)
|
243 |
dist_map = dist_transform_fore - dist_transform_back
|
244 |
# ...:::!!!:::...
|
245 |
contour_radius += 2
|
246 |
contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
|
247 |
contour_mask = contour_mask / np.max(contour_mask)
|
248 |
+
contour_mask[contour_mask > 0.5] = 1.
|
249 |
|
250 |
return mask, contour_mask
|
251 |
|
|
|
254 |
# no background width when '00'
|
255 |
# distance map
|
256 |
dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
|
257 |
+
dist_transform_back = cv2.distanceTransform(1 - mask, cv2.DIST_L2, 3)
|
258 |
dist_map = dist_transform_fore - dist_transform_back
|
259 |
# ...:::!!!:::...
|
260 |
contour_radius += 2
|
|
|
266 |
def mask_generator_10(mask, background_radius, contour_radius):
|
267 |
# distance map
|
268 |
dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
|
269 |
+
dist_transform_back = cv2.distanceTransform(1 - mask, cv2.DIST_L2, 3)
|
270 |
dist_map = dist_transform_fore - dist_transform_back
|
271 |
# .....:::::!!!!!
|
272 |
background_mask = np.clip(dist_map, -background_radius, background_radius)
|
|
|
276 |
contour_radius += 2
|
277 |
contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
|
278 |
contour_mask = contour_mask / np.max(contour_mask)
|
279 |
+
contour_mask[contour_mask > 0.5] = 1.
|
280 |
return background_mask, contour_mask
|
281 |
|
282 |
|
283 |
def mask_generator_11(mask, background_radius, contour_radius):
|
284 |
# distance map
|
285 |
dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
|
286 |
+
dist_transform_back = cv2.distanceTransform(1 - mask, cv2.DIST_L2, 3)
|
287 |
dist_map = dist_transform_fore - dist_transform_back
|
288 |
# .....:::::!!!!!
|
289 |
background_mask = np.clip(dist_map, -background_radius, background_radius)
|
|
|
296 |
return background_mask, contour_mask
|
297 |
|
298 |
|
299 |
+
def mask_painter_wo_gaussian(input_image, input_mask, background_alpha=0.5, background_blur_radius=7, contour_width=3,
|
300 |
+
contour_color=3, contour_alpha=1, mode='11'):
|
301 |
"""
|
302 |
Input:
|
303 |
input_image: numpy array
|
|
|
320 |
width, height = input_image.shape[0], input_image.shape[1]
|
321 |
res = 1024
|
322 |
ratio = min(1.0 * res / max(width, height), 1.0)
|
323 |
+
input_image = cv2.resize(input_image, (int(height * ratio), int(width * ratio)))
|
324 |
+
input_mask = cv2.resize(input_mask, (int(height * ratio), int(width * ratio)))
|
325 |
|
326 |
# 0: background, 1: foreground
|
327 |
msk = np.clip(input_mask, 0, 1)
|
|
|
329 |
# generate masks for background and contour pixels
|
330 |
background_radius = (background_blur_radius - 1) // 2
|
331 |
contour_radius = (contour_width - 1) // 2
|
332 |
+
generator_dict = {'00': mask_generator_00, '01': mask_generator_01, '10': mask_generator_10,
|
333 |
+
'11': mask_generator_11}
|
334 |
background_mask, contour_mask = generator_dict[mode](msk, background_radius, contour_radius)
|
335 |
|
336 |
# paint
|
337 |
painted_image = vis_add_mask_wo_gaussian \
|
338 |
+
(input_image, background_mask, contour_mask, color_list[0], color_list[contour_color], background_alpha,
|
339 |
+
contour_alpha) # black for background
|
340 |
|
341 |
return painted_image
|
342 |
|
343 |
|
344 |
+
seg_model_map = {
|
345 |
+
'base': 'vit_b',
|
346 |
+
'large': 'vit_l',
|
347 |
+
'huge': 'vit_h'
|
348 |
+
}
|
349 |
+
ckpt_url_map = {
|
350 |
+
'vit_b': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth',
|
351 |
+
'vit_l': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth',
|
352 |
+
'vit_h': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth'
|
353 |
+
}
|
354 |
+
expected_sha256_map = {
|
355 |
+
'vit_b': 'ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912',
|
356 |
+
'vit_l': '3adcc4315b642a4d2101128f611684e8734c41232a17c648ed1693702a49a622',
|
357 |
+
'vit_h': 'a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e'
|
358 |
+
}
|
359 |
+
|
360 |
+
|
361 |
+
def prepare_segmenter(segmenter="huge", download_root: str = None):
|
362 |
+
"""
|
363 |
+
Prepare segmenter model and download checkpoint if necessary.
|
364 |
+
|
365 |
+
Returns: segmenter model name from 'vit_b', 'vit_l', 'vit_h'.
|
366 |
+
|
367 |
+
"""
|
368 |
+
|
369 |
+
os.makedirs('result', exist_ok=True)
|
370 |
+
seg_model_name = seg_model_map[segmenter]
|
371 |
+
checkpoint_url = ckpt_url_map[seg_model_name]
|
372 |
+
folder = download_root or os.path.expanduser("~/.cache/SAM")
|
373 |
+
filename = os.path.basename(checkpoint_url)
|
374 |
+
segmenter_checkpoint = download_checkpoint(checkpoint_url, folder, filename, expected_sha256_map[seg_model_name])
|
375 |
+
|
376 |
+
return seg_model_name, segmenter_checkpoint
|
377 |
+
|
378 |
+
|
379 |
+
def download_checkpoint(url, folder, filename, expected_sha256):
|
380 |
+
os.makedirs(folder, exist_ok=True)
|
381 |
+
download_target = os.path.join(folder, filename)
|
382 |
+
if os.path.isfile(download_target):
|
383 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
|
384 |
+
return download_target
|
385 |
+
|
386 |
+
print(f'Download SAM checkpoint {url}, saving to {download_target} ...')
|
387 |
+
with requests.get(url, stream=True) as response, open(download_target, "wb") as output:
|
388 |
+
progress = tqdm(total=int(response.headers.get('content-length', 0)), unit='B', unit_scale=True)
|
389 |
+
for data in response.iter_content(chunk_size=1024):
|
390 |
+
size = output.write(data)
|
391 |
+
progress.update(size)
|
392 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
|
393 |
+
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
|
394 |
+
return download_target
|
395 |
+
|
396 |
+
|
397 |
if __name__ == '__main__':
|
398 |
|
399 |
+
background_alpha = 0.7 # transparency of background 1: all black, 0: do nothing
|
400 |
+
background_blur_radius = 31 # radius of background blur, must be odd number
|
401 |
+
contour_width = 11 # contour width, must be odd number
|
402 |
+
contour_color = 3 # id in color map, 0: black, 1: white, >1: others
|
403 |
+
contour_alpha = 1 # transparency of background, 0: no contour highlighted
|
404 |
|
405 |
# load input image and mask
|
406 |
input_image = np.array(Image.open('./test_images/painter_input_image.jpg').convert('RGB'))
|
|
|
415 |
|
416 |
for i in range(50):
|
417 |
t2 = time.time()
|
418 |
+
painted_image_00 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius,
|
419 |
+
contour_width, contour_color, contour_alpha, mode='00')
|
420 |
e2 = time.time()
|
421 |
|
422 |
t3 = time.time()
|
423 |
+
painted_image_10 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius,
|
424 |
+
contour_width, contour_color, contour_alpha, mode='10')
|
425 |
e3 = time.time()
|
426 |
|
427 |
t1 = time.time()
|
428 |
+
painted_image = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width,
|
429 |
+
contour_color, contour_alpha)
|
430 |
e1 = time.time()
|
431 |
|
432 |
t4 = time.time()
|
433 |
+
painted_image_01 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius,
|
434 |
+
contour_width, contour_color, contour_alpha, mode='01')
|
435 |
e4 = time.time()
|
436 |
|
437 |
t5 = time.time()
|
438 |
+
painted_image_11 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius,
|
439 |
+
contour_width, contour_color, contour_alpha, mode='11')
|
440 |
e5 = time.time()
|
441 |
|
442 |
overall_time_1 += (e1 - t1)
|
|
|
445 |
overall_time_4 += (e4 - t4)
|
446 |
overall_time_5 += (e5 - t5)
|
447 |
|
448 |
+
print(f'average time w gaussian: {overall_time_1 / 50}')
|
449 |
+
print(f'average time w/o gaussian00: {overall_time_2 / 50}')
|
450 |
+
print(f'average time w/o gaussian10: {overall_time_3 / 50}')
|
451 |
+
print(f'average time w/o gaussian01: {overall_time_4 / 50}')
|
452 |
+
print(f'average time w/o gaussian11: {overall_time_5 / 50}')
|
453 |
|
454 |
# save
|
455 |
painted_image_00 = Image.fromarray(painted_image_00)
|
|
|
463 |
|
464 |
painted_image_11 = Image.fromarray(painted_image_11)
|
465 |
painted_image_11.save('./test_images/painter_output_image_11.png')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|