OWLSAM / app.py
merve's picture
merve HF staff
add description
6c303c0 verified
raw
history blame
1.9 kB
from transformers import pipeline, SamModel, SamProcessor
import torch
import numpy as np
import spaces
checkpoint = "google/owlvit-base-patch16"
detector = pipeline(model=checkpoint, task="zero-shot-object-detection")
sam_model = SamModel.from_pretrained("facebook/sam-vit-base").to("cuda")
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
@spaces.GPU
def query(image, texts, threshold):
texts = texts.split(",")
print(texts)
print(image.size)
predictions = detector(
image,
candidate_labels=texts,
)
print(predictions)
result_labels = []
for pred in predictions:
box = pred["box"]
score = pred["score"]
label = pred["label"]
box = [round(pred["box"]["xmin"], 2), round(pred["box"]["ymin"], 2),
round(pred["box"]["xmax"], 2), round(pred["box"]["ymax"], 2)]
inputs = sam_processor(
image,
input_boxes=[[[box]]],
return_tensors="pt"
).to("cuda")
with torch.no_grad():
outputs = sam_model(**inputs)
mask = sam_processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)[0][0][0].numpy()
mask = mask[np.newaxis, ...]
result_labels.append((mask, label))
return image, result_labels
import gradio as gr
description = "This Space combines OWLv2, the state-of-the-art zero-shot object detection model with SAM, the state-of-the-art mask generation model. SAM normally doesn't accept text input. Combining SAM with OWLv2 makes SAM text promptable."
demo = gr.Interface(
query,
inputs=[gr.Image(type="pil"), "text", gr.Slider(0, 1, value=0.2)],
outputs="annotatedimage",
title="OWL 🀝 SAM",
description=description,
examples=[
["./cats.png", "cat", 0.1],
],
cache_examples=True
)
demo.launch(debug=True)