Spaces:
Runtime error
Runtime error
Prasanna Sridhar
commited on
Commit
·
346623e
1
Parent(s):
c469934
remove unused imports
Browse files
app.py
CHANGED
|
@@ -1,16 +1,10 @@
|
|
| 1 |
import spaces
|
| 2 |
import gradio as gr
|
| 3 |
-
import copy
|
| 4 |
import random
|
| 5 |
import torch
|
| 6 |
-
import
|
| 7 |
-
from PIL import Image, ImageDraw, ImageFont
|
| 8 |
-
import torchvision.transforms.functional as F
|
| 9 |
import numpy as np
|
| 10 |
import argparse
|
| 11 |
-
import json
|
| 12 |
-
import plotly.express as px
|
| 13 |
-
import pandas as pd
|
| 14 |
from util.slconfig import SLConfig, DictAction
|
| 15 |
from util.misc import nested_tensor_from_tensor_list
|
| 16 |
import datasets.transforms as T
|
|
@@ -258,14 +252,14 @@ def get_ind_to_filter(text, word_ids, keywords):
|
|
| 258 |
def count(image, text, prompts, state, device):
|
| 259 |
|
| 260 |
keywords = "" # do not handle this for now
|
| 261 |
-
|
| 262 |
# Handle no prompt case.
|
| 263 |
if prompts is None:
|
| 264 |
prompts = {"image": image, "points": []}
|
| 265 |
input_image, _ = transform(image, {"exemplars": torch.tensor([])})
|
| 266 |
input_image = input_image.unsqueeze(0).to(device)
|
| 267 |
exemplars = get_box_inputs(prompts["points"])
|
| 268 |
-
|
| 269 |
input_image_exemplars, exemplars = transform(prompts["image"], {"exemplars": torch.tensor(exemplars)})
|
| 270 |
input_image_exemplars = input_image_exemplars.unsqueeze(0).to(device)
|
| 271 |
exemplars = [exemplars["exemplars"].to(device)]
|
|
@@ -278,7 +272,7 @@ def count(image, text, prompts, state, device):
|
|
| 278 |
[torch.tensor([0]).to(device) for _ in range(len(input_image))],
|
| 279 |
captions=[text + " ."] * len(input_image),
|
| 280 |
)
|
| 281 |
-
|
| 282 |
ind_to_filter = get_ind_to_filter(text, model_output["token"][0].word_ids, keywords)
|
| 283 |
logits = model_output["pred_logits"].sigmoid()[0][:, ind_to_filter]
|
| 284 |
boxes = model_output["pred_boxes"][0]
|
|
@@ -288,7 +282,7 @@ def count(image, text, prompts, state, device):
|
|
| 288 |
box_mask = logits.max(dim=-1).values > CONF_THRESH
|
| 289 |
logits = logits[box_mask, :].cpu().numpy()
|
| 290 |
boxes = boxes[box_mask, :].cpu().numpy()
|
| 291 |
-
|
| 292 |
# Plot results.
|
| 293 |
(w, h) = image.size
|
| 294 |
det_map = np.zeros((h, w))
|
|
@@ -327,7 +321,7 @@ def count(image, text, prompts, state, device):
|
|
| 327 |
if len(text.strip()) > 0:
|
| 328 |
out_label += " text"
|
| 329 |
if exemplars[0].size()[0] == 1:
|
| 330 |
-
out_label += " and " + str(exemplars[0].size()[0]) + " visual exemplar."
|
| 331 |
elif exemplars[0].size()[0] > 1:
|
| 332 |
out_label += " and " + str(exemplars[0].size()[0]) + " visual exemplars."
|
| 333 |
else:
|
|
@@ -339,7 +333,7 @@ def count(image, text, prompts, state, device):
|
|
| 339 |
out_label += " " + str(exemplars[0].size()[0]) + " visual exemplars."
|
| 340 |
else:
|
| 341 |
out_label = "Nothing specified to detect."
|
| 342 |
-
|
| 343 |
return (gr.Image(output_img, visible=True, label=out_label, show_label=True), gr.Number(label="Predicted Count", visible=True, value=boxes.shape[0]), new_submit_btn, gr.Tab(visible=True), step_3, state)
|
| 344 |
|
| 345 |
@spaces.GPU
|
|
@@ -351,11 +345,11 @@ def count_main(image, text, prompts, device):
|
|
| 351 |
input_image, _ = transform(image, {"exemplars": torch.tensor([])})
|
| 352 |
input_image = input_image.unsqueeze(0).to(device)
|
| 353 |
exemplars = get_box_inputs(prompts["points"])
|
| 354 |
-
|
| 355 |
input_image_exemplars, exemplars = transform(prompts["image"], {"exemplars": torch.tensor(exemplars)})
|
| 356 |
input_image_exemplars = input_image_exemplars.unsqueeze(0).to(device)
|
| 357 |
exemplars = [exemplars["exemplars"].to(device)]
|
| 358 |
-
|
| 359 |
with torch.no_grad():
|
| 360 |
model_output = model(
|
| 361 |
nested_tensor_from_tensor_list(input_image),
|
|
@@ -364,7 +358,7 @@ def count_main(image, text, prompts, device):
|
|
| 364 |
[torch.tensor([0]).to(device) for _ in range(len(input_image))],
|
| 365 |
captions=[text + " ."] * len(input_image),
|
| 366 |
)
|
| 367 |
-
|
| 368 |
ind_to_filter = get_ind_to_filter(text, model_output["token"][0].word_ids, keywords)
|
| 369 |
logits = model_output["pred_logits"].sigmoid()[0][:, ind_to_filter]
|
| 370 |
boxes = model_output["pred_boxes"][0]
|
|
@@ -374,7 +368,7 @@ def count_main(image, text, prompts, device):
|
|
| 374 |
box_mask = logits.max(dim=-1).values > CONF_THRESH
|
| 375 |
logits = logits[box_mask, :].cpu().numpy()
|
| 376 |
boxes = boxes[box_mask, :].cpu().numpy()
|
| 377 |
-
|
| 378 |
# Plot results.
|
| 379 |
(w, h) = image.size
|
| 380 |
det_map = np.zeros((h, w))
|
|
@@ -395,7 +389,7 @@ def count_main(image, text, prompts, device):
|
|
| 395 |
if len(text.strip()) > 0:
|
| 396 |
out_label += " text"
|
| 397 |
if exemplars[0].size()[0] == 1:
|
| 398 |
-
out_label += " and " + str(exemplars[0].size()[0]) + " visual exemplar."
|
| 399 |
elif exemplars[0].size()[0] > 1:
|
| 400 |
out_label += " and " + str(exemplars[0].size()[0]) + " visual exemplars."
|
| 401 |
else:
|
|
@@ -407,7 +401,7 @@ def count_main(image, text, prompts, device):
|
|
| 407 |
out_label += " " + str(exemplars[0].size()[0]) + " visual exemplars."
|
| 408 |
else:
|
| 409 |
out_label = "Nothing specified to detect."
|
| 410 |
-
|
| 411 |
return (gr.Image(output_img, visible=True, label=out_label, show_label=True), gr.Number(label="Predicted Count", visible=True, value=boxes.shape[0]))
|
| 412 |
|
| 413 |
def remove_label(image):
|
|
@@ -452,20 +446,20 @@ with gr.Blocks(title="CountGD: Multi-Modal Open-World Counting", theme="soft", h
|
|
| 452 |
with gr.Tab("Step 1", visible=True) as step_1:
|
| 453 |
input_image = gr.Image(type='pil', label='Input Image', show_label='True', value="strawberry.jpg", interactive=False, width="30vw")
|
| 454 |
gr.Markdown('# Click "Count" to count the strawberries.')
|
| 455 |
-
|
| 456 |
with gr.Column():
|
| 457 |
with gr.Tab("Output Image"):
|
| 458 |
detected_instances = gr.Image(label="Detected Instances", show_label='True', interactive=False, visible=True, width="40vw")
|
| 459 |
-
|
| 460 |
with gr.Row():
|
| 461 |
input_text = gr.Textbox(label="What would you like to count?", value="strawberry", interactive=True)
|
| 462 |
pred_count = gr.Number(label="Predicted Count", visible=False)
|
| 463 |
submit_btn = gr.Button("Count", variant="primary", interactive=True)
|
| 464 |
-
|
| 465 |
submit_btn.click(fn=remove_label, inputs=[detected_instances], outputs=[detected_instances]).then(fn=count, inputs=[input_image, input_text, exemplar_image, state, device], outputs=[detected_instances, pred_count, submit_btn, step_2, step_3, state])
|
| 466 |
exemplar_image.change(check_submit_btn, inputs=[exemplar_image, state], outputs=[submit_btn])
|
| 467 |
with gr.Tab("App", visible=True) as main_app:
|
| 468 |
-
|
| 469 |
gr.Markdown(
|
| 470 |
"""
|
| 471 |
# <center>CountGD: Multi-Modal Open-World Counting
|
|
@@ -476,7 +470,7 @@ with gr.Blocks(title="CountGD: Multi-Modal Open-World Counting", theme="soft", h
|
|
| 476 |
Limitation: this app does not support fine-grained counting based on attributes or visual grounding inputs yet. Note: if the exemplar and text conflict each other, both will be counted.</center>
|
| 477 |
"""
|
| 478 |
)
|
| 479 |
-
|
| 480 |
with gr.Row():
|
| 481 |
with gr.Column():
|
| 482 |
input_image_main = gr.Image(type='pil', label='Input Image', show_label='True', value="strawberry.jpg", interactive=True)
|
|
@@ -490,6 +484,6 @@ with gr.Blocks(title="CountGD: Multi-Modal Open-World Counting", theme="soft", h
|
|
| 490 |
gr.Examples(label="Examples: click on a row to load the example. Add visual exemplars by drawing boxes on the loaded \"Visual Exemplar Image.\"", examples=examples, inputs=[input_image_main, input_text_main, exemplar_image_main])
|
| 491 |
submit_btn_main.click(fn=remove_label, inputs=[detected_instances_main], outputs=[detected_instances_main]).then(fn=count_main, inputs=[input_image_main, input_text_main, exemplar_image_main, device], outputs=[detected_instances_main, pred_count_main])
|
| 492 |
clear_btn_main.add([input_image_main, input_text_main, exemplar_image_main, detected_instances_main, pred_count_main])
|
| 493 |
-
|
| 494 |
|
| 495 |
demo.queue().launch(allowed_paths=['back-icon.jpg', 'paste-icon.jpg', 'upload-icon.jpg', 'button-legend.jpg'])
|
|
|
|
| 1 |
import spaces
|
| 2 |
import gradio as gr
|
|
|
|
| 3 |
import random
|
| 4 |
import torch
|
| 5 |
+
from PIL import Image
|
|
|
|
|
|
|
| 6 |
import numpy as np
|
| 7 |
import argparse
|
|
|
|
|
|
|
|
|
|
| 8 |
from util.slconfig import SLConfig, DictAction
|
| 9 |
from util.misc import nested_tensor_from_tensor_list
|
| 10 |
import datasets.transforms as T
|
|
|
|
| 252 |
def count(image, text, prompts, state, device):
|
| 253 |
|
| 254 |
keywords = "" # do not handle this for now
|
| 255 |
+
|
| 256 |
# Handle no prompt case.
|
| 257 |
if prompts is None:
|
| 258 |
prompts = {"image": image, "points": []}
|
| 259 |
input_image, _ = transform(image, {"exemplars": torch.tensor([])})
|
| 260 |
input_image = input_image.unsqueeze(0).to(device)
|
| 261 |
exemplars = get_box_inputs(prompts["points"])
|
| 262 |
+
|
| 263 |
input_image_exemplars, exemplars = transform(prompts["image"], {"exemplars": torch.tensor(exemplars)})
|
| 264 |
input_image_exemplars = input_image_exemplars.unsqueeze(0).to(device)
|
| 265 |
exemplars = [exemplars["exemplars"].to(device)]
|
|
|
|
| 272 |
[torch.tensor([0]).to(device) for _ in range(len(input_image))],
|
| 273 |
captions=[text + " ."] * len(input_image),
|
| 274 |
)
|
| 275 |
+
|
| 276 |
ind_to_filter = get_ind_to_filter(text, model_output["token"][0].word_ids, keywords)
|
| 277 |
logits = model_output["pred_logits"].sigmoid()[0][:, ind_to_filter]
|
| 278 |
boxes = model_output["pred_boxes"][0]
|
|
|
|
| 282 |
box_mask = logits.max(dim=-1).values > CONF_THRESH
|
| 283 |
logits = logits[box_mask, :].cpu().numpy()
|
| 284 |
boxes = boxes[box_mask, :].cpu().numpy()
|
| 285 |
+
|
| 286 |
# Plot results.
|
| 287 |
(w, h) = image.size
|
| 288 |
det_map = np.zeros((h, w))
|
|
|
|
| 321 |
if len(text.strip()) > 0:
|
| 322 |
out_label += " text"
|
| 323 |
if exemplars[0].size()[0] == 1:
|
| 324 |
+
out_label += " and " + str(exemplars[0].size()[0]) + " visual exemplar."
|
| 325 |
elif exemplars[0].size()[0] > 1:
|
| 326 |
out_label += " and " + str(exemplars[0].size()[0]) + " visual exemplars."
|
| 327 |
else:
|
|
|
|
| 333 |
out_label += " " + str(exemplars[0].size()[0]) + " visual exemplars."
|
| 334 |
else:
|
| 335 |
out_label = "Nothing specified to detect."
|
| 336 |
+
|
| 337 |
return (gr.Image(output_img, visible=True, label=out_label, show_label=True), gr.Number(label="Predicted Count", visible=True, value=boxes.shape[0]), new_submit_btn, gr.Tab(visible=True), step_3, state)
|
| 338 |
|
| 339 |
@spaces.GPU
|
|
|
|
| 345 |
input_image, _ = transform(image, {"exemplars": torch.tensor([])})
|
| 346 |
input_image = input_image.unsqueeze(0).to(device)
|
| 347 |
exemplars = get_box_inputs(prompts["points"])
|
| 348 |
+
|
| 349 |
input_image_exemplars, exemplars = transform(prompts["image"], {"exemplars": torch.tensor(exemplars)})
|
| 350 |
input_image_exemplars = input_image_exemplars.unsqueeze(0).to(device)
|
| 351 |
exemplars = [exemplars["exemplars"].to(device)]
|
| 352 |
+
|
| 353 |
with torch.no_grad():
|
| 354 |
model_output = model(
|
| 355 |
nested_tensor_from_tensor_list(input_image),
|
|
|
|
| 358 |
[torch.tensor([0]).to(device) for _ in range(len(input_image))],
|
| 359 |
captions=[text + " ."] * len(input_image),
|
| 360 |
)
|
| 361 |
+
|
| 362 |
ind_to_filter = get_ind_to_filter(text, model_output["token"][0].word_ids, keywords)
|
| 363 |
logits = model_output["pred_logits"].sigmoid()[0][:, ind_to_filter]
|
| 364 |
boxes = model_output["pred_boxes"][0]
|
|
|
|
| 368 |
box_mask = logits.max(dim=-1).values > CONF_THRESH
|
| 369 |
logits = logits[box_mask, :].cpu().numpy()
|
| 370 |
boxes = boxes[box_mask, :].cpu().numpy()
|
| 371 |
+
|
| 372 |
# Plot results.
|
| 373 |
(w, h) = image.size
|
| 374 |
det_map = np.zeros((h, w))
|
|
|
|
| 389 |
if len(text.strip()) > 0:
|
| 390 |
out_label += " text"
|
| 391 |
if exemplars[0].size()[0] == 1:
|
| 392 |
+
out_label += " and " + str(exemplars[0].size()[0]) + " visual exemplar."
|
| 393 |
elif exemplars[0].size()[0] > 1:
|
| 394 |
out_label += " and " + str(exemplars[0].size()[0]) + " visual exemplars."
|
| 395 |
else:
|
|
|
|
| 401 |
out_label += " " + str(exemplars[0].size()[0]) + " visual exemplars."
|
| 402 |
else:
|
| 403 |
out_label = "Nothing specified to detect."
|
| 404 |
+
|
| 405 |
return (gr.Image(output_img, visible=True, label=out_label, show_label=True), gr.Number(label="Predicted Count", visible=True, value=boxes.shape[0]))
|
| 406 |
|
| 407 |
def remove_label(image):
|
|
|
|
| 446 |
with gr.Tab("Step 1", visible=True) as step_1:
|
| 447 |
input_image = gr.Image(type='pil', label='Input Image', show_label='True', value="strawberry.jpg", interactive=False, width="30vw")
|
| 448 |
gr.Markdown('# Click "Count" to count the strawberries.')
|
| 449 |
+
|
| 450 |
with gr.Column():
|
| 451 |
with gr.Tab("Output Image"):
|
| 452 |
detected_instances = gr.Image(label="Detected Instances", show_label='True', interactive=False, visible=True, width="40vw")
|
| 453 |
+
|
| 454 |
with gr.Row():
|
| 455 |
input_text = gr.Textbox(label="What would you like to count?", value="strawberry", interactive=True)
|
| 456 |
pred_count = gr.Number(label="Predicted Count", visible=False)
|
| 457 |
submit_btn = gr.Button("Count", variant="primary", interactive=True)
|
| 458 |
+
|
| 459 |
submit_btn.click(fn=remove_label, inputs=[detected_instances], outputs=[detected_instances]).then(fn=count, inputs=[input_image, input_text, exemplar_image, state, device], outputs=[detected_instances, pred_count, submit_btn, step_2, step_3, state])
|
| 460 |
exemplar_image.change(check_submit_btn, inputs=[exemplar_image, state], outputs=[submit_btn])
|
| 461 |
with gr.Tab("App", visible=True) as main_app:
|
| 462 |
+
|
| 463 |
gr.Markdown(
|
| 464 |
"""
|
| 465 |
# <center>CountGD: Multi-Modal Open-World Counting
|
|
|
|
| 470 |
Limitation: this app does not support fine-grained counting based on attributes or visual grounding inputs yet. Note: if the exemplar and text conflict each other, both will be counted.</center>
|
| 471 |
"""
|
| 472 |
)
|
| 473 |
+
|
| 474 |
with gr.Row():
|
| 475 |
with gr.Column():
|
| 476 |
input_image_main = gr.Image(type='pil', label='Input Image', show_label='True', value="strawberry.jpg", interactive=True)
|
|
|
|
| 484 |
gr.Examples(label="Examples: click on a row to load the example. Add visual exemplars by drawing boxes on the loaded \"Visual Exemplar Image.\"", examples=examples, inputs=[input_image_main, input_text_main, exemplar_image_main])
|
| 485 |
submit_btn_main.click(fn=remove_label, inputs=[detected_instances_main], outputs=[detected_instances_main]).then(fn=count_main, inputs=[input_image_main, input_text_main, exemplar_image_main, device], outputs=[detected_instances_main, pred_count_main])
|
| 486 |
clear_btn_main.add([input_image_main, input_text_main, exemplar_image_main, detected_instances_main, pred_count_main])
|
| 487 |
+
|
| 488 |
|
| 489 |
demo.queue().launch(allowed_paths=['back-icon.jpg', 'paste-icon.jpg', 'upload-icon.jpg', 'button-legend.jpg'])
|