Spaces:
Running
Running
Upload app.py
Browse filesnew app.py,Add compress oversized images to 1024 size.
app.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import gradio as gr
|
5 |
+
from transformers import Owlv2Processor, Owlv2ForObjectDetection
|
6 |
+
import spaces
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import matplotlib.patches as patches
|
9 |
+
from PIL import Image
|
10 |
+
import numpy as np
|
11 |
+
# 设置设备
|
12 |
+
if torch.cuda.is_available():
|
13 |
+
device = torch.device("cuda")
|
14 |
+
else:
|
15 |
+
device = torch.device("cpu")
|
16 |
+
#引入模型和推理器
|
17 |
+
model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble").to(device)
|
18 |
+
processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble")
|
19 |
+
#载入图像
|
20 |
+
@spaces.GPU
|
21 |
+
#输入图像,搜索文本,检测分数
|
22 |
+
def query_image(img, text_queries, score_threshold,compress):
|
23 |
+
img = load_image_as_np_array(img, compress)
|
24 |
+
text_queries = text_queries
|
25 |
+
#分割搜索文本
|
26 |
+
text_queries = text_queries.split(",")
|
27 |
+
#转换为正方行torch矩阵
|
28 |
+
#(长宽边最大的那个设置为size)
|
29 |
+
size = max(img.shape[:2])
|
30 |
+
target_sizes = torch.Tensor([[size, size]])
|
31 |
+
#创建输入(搜索文本和图像转换为torch张量发送到GPU)
|
32 |
+
inputs = processor(text=text_queries, images=img, return_tensors="pt").to(device)
|
33 |
+
#禁用梯度计算,运行推理
|
34 |
+
with torch.no_grad():
|
35 |
+
outputs = model(**inputs)
|
36 |
+
#输出分数和边界框信息
|
37 |
+
outputs.logits = outputs.logits.cpu()
|
38 |
+
outputs.pred_boxes = outputs.pred_boxes.cpu()
|
39 |
+
#导出输出结果
|
40 |
+
results = processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes)
|
41 |
+
#分类存储输出结果的边界框,分数,标签
|
42 |
+
boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]
|
43 |
+
|
44 |
+
#创建空列表
|
45 |
+
result_labels = []
|
46 |
+
# all_result = []
|
47 |
+
#遍历分类存储的输出结果
|
48 |
+
for box, score, label in zip(boxes, scores, labels):
|
49 |
+
#转换为整数
|
50 |
+
box = [int(i) for i in box.tolist()]
|
51 |
+
#过滤阈值以下的目标
|
52 |
+
if score < score_threshold:
|
53 |
+
continue
|
54 |
+
result_labels.append((box, text_queries[label.item()]))
|
55 |
+
text = len(result_labels)
|
56 |
+
return img, result_labels,img, result_labels
|
57 |
+
p
|
58 |
+
|
59 |
+
#图像输入:图像压缩
|
60 |
+
def load_image_as_np_array(img, compress=False):
|
61 |
+
# 输入图像文件
|
62 |
+
# with Image.open(image_path) as img:
|
63 |
+
# 转换为RGB
|
64 |
+
# img = img.convert("RGB")
|
65 |
+
#数组-图像
|
66 |
+
img = Image.fromarray(img)
|
67 |
+
#图像压缩
|
68 |
+
if compress:
|
69 |
+
# 获取图像尺寸
|
70 |
+
width, height = img.size
|
71 |
+
|
72 |
+
# 检查图像分辨率是否大于2048
|
73 |
+
max_dimension = max(width, height)
|
74 |
+
if max_dimension > 1024:
|
75 |
+
# Calculate the new size, maintaining the aspect ratio
|
76 |
+
scale_factor = 1024 / max_dimension
|
77 |
+
new_width = int(width * scale_factor)
|
78 |
+
new_height = int(height * scale_factor)
|
79 |
+
|
80 |
+
# 缩放图像
|
81 |
+
img = img.resize((new_width, new_height))
|
82 |
+
|
83 |
+
# 图像-数组
|
84 |
+
img = np.array(img)
|
85 |
+
return img
|
86 |
+
|
87 |
+
demo1 = gr.Interface(
|
88 |
+
query_image,
|
89 |
+
inputs=[gr.Image(),
|
90 |
+
gr.Text(value="insect",label="提示词(多个用,分开)"),
|
91 |
+
gr.Slider(0, 1, value=0.2,label="确信度阈值"),
|
92 |
+
gr.Checkbox(value=True,label="图像压缩")],
|
93 |
+
outputs=[gr.Annotatedimage()],
|
94 |
+
title="Zero-Shot Object Detection with OWLv2",
|
95 |
+
)
|
96 |
+
demo1.launch()
|