Vishaltiwari2019's picture
Update app.py
6ac27af verified
raw
history blame
2.19 kB
import gradio as gr
from transformers import DetrImageProcessor, DetrForObjectDetection
import torch
from PIL import Image, ImageDraw
import random
def detect_objects(image):
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
target_sizes = torch.tensor([image.size[::-1]])
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
draw = ImageDraw.Draw(image)
labels = []
for i, (score, label, box) in enumerate(zip(results["scores"], results["labels"], results["boxes"])):
box = [round(i, 2) for i in box.tolist()]
color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
draw.rectangle(box, outline=color, width=3)
label_text = f"{model.config.id2label[label.item()]}: {round(score.item(), 2)}"
draw.text((box[0], box[1]), label_text, fill=color,)
labels.append(label_text)
return image, labels
def upload_image(file):
image = Image.open(file.name)
image_with_boxes, labels = detect_objects(image)
return image_with_boxes, labels
def show_labels(labels):
return "\n".join(labels)
# Interface to display the image with bounding boxes
iface_objects = gr.Interface(
fn=upload_image,
inputs="file",
outputs=["image", "text"],
title="Object Detection",
description="Upload an image and detect objects using DETR model.",
allow_flagging=False
)
# Interface to display the detected labels
iface_labels = gr.Interface(
fn=show_labels,
inputs="text",
outputs="text",
title="Detected Labels",
description="Displays the labels detected in the uploaded image.",
allow_flagging=False
)
# Combine interfaces with a tapped interface
interface = gr.Interface(
[iface_objects, iface_labels],
inputs="text",
outputs="text",
title="Object Detection with Labels",
description="Upload an image and view detected objects and labels.",
allow_flagging=False
)
interface.launch()