File size: 1,861 Bytes
426e73b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from transformers import pipeline, SamModel, SamProcessor
import torch
import numpy as np

checkpoint = "google/owlvit-base-patch16"
detector = pipeline(model=checkpoint, task="zero-shot-object-detection")
sam_model = SamModel.from_pretrained("facebook/sam-vit-base").to("cuda")
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

def query(image, texts, threshold):
  texts = texts.split(",")
  print(texts)
  print(image.size)
  predictions = detector(
    image,
    candidate_labels=texts,
  )
  print(predictions)
  result_labels = []
  for pred in predictions:
    
    box = pred["box"]
    score = pred["score"]
    label = pred["label"]
    box = [round(pred["box"]["xmin"], 2), round(pred["box"]["ymin"], 2), 
        round(pred["box"]["xmax"], 2), round(pred["box"]["ymax"], 2)]

    inputs = sam_processor(
            image,
            input_boxes=[[[box]]],
            return_tensors="pt"
        ).to("cuda")

    with torch.no_grad():
        outputs = sam_model(**inputs)

    mask = sam_processor.image_processor.post_process_masks(
        outputs.pred_masks.cpu(),
        inputs["original_sizes"].cpu(),
        inputs["reshaped_input_sizes"].cpu()
    )[0][0][0].numpy()
    mask = mask[np.newaxis, ...]
    result_labels.append((mask, label))
  return image, result_labels

import gradio as gr

description = "This Space combines OWLv2, the state-of-the-art zero-shot object detection model with SAM, the state-of-the-art mask generation model. SAM normally doesn't accept text input. Combining SAM with OWLv2 makes SAM text promptable."
demo = gr.Interface(
    query,
    inputs=[gr.Image(type="pil"), "text", gr.Slider(0, 1, value=0.2)],
    outputs="annotatedimage",
    title="OWL 🀝 SAM",
    #description=description,
    examples=[
        ["/content/cats.png", "cat", 0.1],
    ],
)
demo.launch(debug=True)