File size: 3,363 Bytes
f26688e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b165ce
f26688e
 
 
 
 
 
 
 
3b165ce
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# -*- coding: utf-8 -*-

import torch
import gradio as gr
from transformers import Owlv2Processor, Owlv2ForObjectDetection
import spaces
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image
import numpy as np
# 设置设备
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
#引入模型和推理器
model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble").to(device)
processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble")
#载入图像
@spaces.GPU
#输入图像,搜索文本,检测分数
def query_image(img, text_queries, score_threshold,compress):
    img = load_image_as_np_array(img, compress)
    text_queries = text_queries
    #分割搜索文本
    text_queries = text_queries.split(",")
    #转换为正方行torch矩阵
    #(长宽边最大的那个设置为size)
    size = max(img.shape[:2])
    target_sizes = torch.Tensor([[size, size]])
    #创建输入(搜索文本和图像转换为torch张量发送到GPU)
    inputs = processor(text=text_queries, images=img, return_tensors="pt").to(device)
    #禁用梯度计算,运行推理
    with torch.no_grad():
        outputs = model(**inputs)
    #输出分数和边界框信息
    outputs.logits = outputs.logits.cpu()
    outputs.pred_boxes = outputs.pred_boxes.cpu()
    #导出输出结果
    results = processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes)
    #分类存储输出结果的边界框,分数,标签
    boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]

    #创建空列表
    result_labels = []
#    all_result = []
    #遍历分类存储的输出结果
    for box, score, label in zip(boxes, scores, labels):
        #转换为整数
        box = [int(i) for i in box.tolist()]
        #过滤阈值以下的目标
        if score < score_threshold:
            continue
        result_labels.append((box, text_queries[label.item()]))
    text = len(result_labels)
    return img, result_labels,img, result_labels

#图像输入:图像压缩
def load_image_as_np_array(img, compress=False):
    # 输入图像文件
    # with Image.open(image_path) as img:
        # 转换为RGB
        # img = img.convert("RGB")
    #数组-图像
    img = Image.fromarray(img)
        #图像压缩
    if compress:
            # 获取图像尺寸
        width, height = img.size

            # 检查图像分辨率是否大于2048
        max_dimension = max(width, height)
        if max_dimension > 1024:
                # Calculate the new size, maintaining the aspect ratio
            scale_factor = 1024 / max_dimension
            new_width = int(width * scale_factor)
            new_height = int(height * scale_factor)

                # 缩放图像
            img = img.resize((new_width, new_height))

        # 图像-数组
    img = np.array(img)
    return img

demo = gr.Interface(
    query_image,
    inputs=[gr.Image(),
            gr.Text(value="insect",label="提示词(多个用,分开)"),
            gr.Slider(0, 1, value=0.2,label="确信度阈值"),
            gr.Checkbox(value=True,label="图像压缩")],
    outputs=[gr.Annotatedimage()],
    title="Zero-Shot Object Detection with OWLv2",
)
demo.launch()