Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
import torch | |
import gradio as gr | |
from transformers import Owlv2Processor, Owlv2ForObjectDetection | |
import spaces | |
import matplotlib.pyplot as plt | |
import matplotlib.patches as patches | |
from PIL import Image | |
import numpy as np | |
# 设置设备 | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
else: | |
device = torch.device("cpu") | |
#引入模型和推理器 | |
model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble").to(device) | |
processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble") | |
#载入图像 | |
#输入图像,搜索文本,检测分数 | |
def query_image(img, text_queries, score_threshold,compress): | |
img = load_image_as_np_array(img, compress) | |
text_queries = text_queries | |
#分割搜索文本 | |
text_queries = text_queries.split(",") | |
#转换为正方行torch矩阵 | |
#(长宽边最大的那个设置为size) | |
size = max(img.shape[:2]) | |
target_sizes = torch.Tensor([[size, size]]) | |
#创建输入(搜索文本和图像转换为torch张量发送到GPU) | |
inputs = processor(text=text_queries, images=img, return_tensors="pt").to(device) | |
#禁用梯度计算,运行推理 | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
#输出分数和边界框信息 | |
outputs.logits = outputs.logits.cpu() | |
outputs.pred_boxes = outputs.pred_boxes.cpu() | |
#导出输出结果 | |
results = processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes) | |
#分类存储输出结果的边界框,分数,标签 | |
boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"] | |
#创建空列表 | |
result_labels = [] | |
# all_result = [] | |
#遍历分类存储的输出结果 | |
for box, score, label in zip(boxes, scores, labels): | |
#转换为整数 | |
box = [int(i) for i in box.tolist()] | |
#过滤阈值以下的目标 | |
if score < score_threshold: | |
continue | |
result_labels.append((box, text_queries[label.item()])) | |
text = len(result_labels) | |
return img, result_labels,img, result_labels | |
#图像输入:图像压缩 | |
def load_image_as_np_array(img, compress=False): | |
# 输入图像文件 | |
# with Image.open(image_path) as img: | |
# 转换为RGB | |
# img = img.convert("RGB") | |
#数组-图像 | |
img = Image.fromarray(img) | |
#图像压缩 | |
if compress: | |
# 获取图像尺寸 | |
width, height = img.size | |
# 检查图像分辨率是否大于2048 | |
max_dimension = max(width, height) | |
if max_dimension > 1024: | |
# Calculate the new size, maintaining the aspect ratio | |
scale_factor = 1024 / max_dimension | |
new_width = int(width * scale_factor) | |
new_height = int(height * scale_factor) | |
# 缩放图像 | |
img = img.resize((new_width, new_height)) | |
# 图像-数组 | |
img = np.array(img) | |
return img | |
demo1 = gr.Interface( | |
query_image, | |
inputs=[gr.Image(), | |
gr.Text(value="insect",label="提示词(多个用,分开)"), | |
gr.Slider(0, 1, value=0.2,label="确信度阈值"), | |
gr.Checkbox(value=True,label="图像压缩")], | |
outputs=[gr.Annotatedimage()], | |
title="Zero-Shot Object Detection with OWLv2", | |
) | |
demo1.launch() | |