File size: 1,733 Bytes
3b64d74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35c26d1
6b475f6
 
1de9d3e
6b475f6
35c26d1
 
3b64d74
 
 
 
 
91b8028
5e702a8
 
3b64d74
 
 
daa0506
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# Credits to IDEA Research for the model:
# https://huggingface.co/IDEA-Research/grounding-dino-tiny

from base64 import b64decode
from io import BytesIO

import gradio as gr
import spaces
from PIL import Image
import torch
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection 

model_id = "IDEA-Research/grounding-dino-tiny"
device = "cuda" if torch.cuda.is_available() else "cpu"

processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)

def predict(base64: str, queries: str, box_threshold: float, text_threshold: float):
    decoded_img = b64decode(base64)
    image_stream = BytesIO(decoded_img)
    image = Image.open(image_stream)
    
    inputs = processor(images=image, text=queries, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model(**inputs)
    
    results = processor.post_process_grounded_object_detection(
        outputs,
        inputs.input_ids,
        box_threshold=box_threshold,
        text_threshold=text_threshold,
        target_sizes=[image.size[::-1]]
    )
    fmt_results = {
        "scores": [float(s) for s in results[0]["scores"]],
        "labels": results[0]["labels"],
        "boxes": [[float(x) for x in box] for box in results[0]["boxes"]]
    }
    print(fmt_results)
    return fmt_results

demo = gr.Interface(
    fn=predict,
    inputs=[
        gr.Text(label="Image (B64)"),
        gr.Text(label="Queries, in lowercase, separated by full stop", placeholder="a bird. a blue bird."),
        gr.Number(label="box_threshold", value=0.4),
        gr.Number(label="text_threshold", value=0.3)
    ],
    outputs=gr.JSON(label="Predictions"),
)
demo.launch()