import gradio as gr | |
import torch | |
import yolov5 | |
from transformers import pipeline | |
pipeline = pipeline(task="image-classification", model="PranomVignesh/Police-vs-Public") | |
# from transformers import AutoFeatureExtractor, AutoModelForImageClassification | |
# extractor = AutoFeatureExtractor.from_pretrained("PranomVignesh/Police-vs-Public") | |
# model = AutoModelForImageClassification.from_pretrained("PranomVignesh/Police-vs-Public") | |
# Images | |
# torch.hub.download_url_to_file('https://github.com/ultralytics/yolov5/raw/master/data/images/zidane.jpg', 'zidane.jpg') | |
# torch.hub.download_url_to_file('https://raw.githubusercontent.com/WongKinYiu/yolov7/main/inference/images/image3.jpg', 'image3.jpg') | |
def yolov5_inference( | |
image | |
): | |
""" | |
YOLOv5 inference function | |
Args: | |
image: Input image | |
model_path: Path to the model | |
image_size: Image size | |
conf_threshold: Confidence threshold | |
iou_threshold: IOU threshold | |
Returns: | |
Rendered image | |
""" | |
model = yolov5.load('./best.pt', device="cpu") | |
results = model([image], size=224) | |
# outputs = model(**inputs) | |
# logits = outputs.logits | |
# probabilities = torch.softmax(logits, dim=1).tolist()[0] | |
# classes = ['Police/Authorized Personnel', 'Public/Unauthorized Person'] | |
# output = {name: float(prob) for name, prob in zip(classes, probabilities)} | |
probabilities = pipeline(image) | |
output = {p["label"]: p["score"] for p in probabilities} | |
return results.render()[0],output | |
inputs = gr.Image(type="pil") | |
outputs = [ | |
gr.Image(type="pil"), | |
gr.Label() | |
] | |
title = "Detection" | |
description = "YOLOv5 is a family of object detection models pretrained on COCO dataset. This model is a pip implementation of the original YOLOv5 model." | |
# examples = [['zidane.jpg', 'yolov5s.pt', 640, 0.25, 0.45], ['image3.jpg', 'yolov5s.pt', 640, 0.25, 0.45]] | |
demo_app = gr.Interface( | |
fn=yolov5_inference, | |
inputs=inputs, | |
outputs=outputs, | |
title=title, | |
# examples=examples, | |
# cache_examples=True, | |
# live=True, | |
# theme='huggingface', | |
) | |
demo_app.launch(debug=True, enable_queue=True) | |