qq-hzlh commited on
Commit
03418d9
·
verified ·
1 Parent(s): 2fcecb0

Upload 3 files

Browse files
Files changed (3) hide show
  1. README.md +3 -3
  2. app.py +152 -0
  3. requirements.txt +6 -0
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
  title: VLM R1 OVD
3
  emoji: 👁
4
- colorFrom: blue
5
- colorTo: blue
6
  sdk: gradio
7
- sdk_version: 5.22.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
1
  ---
2
  title: VLM R1 OVD
3
  emoji: 👁
4
+ colorFrom: yellow
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 5.0.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
app.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import torch
3
+ import json_repair
4
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
5
+ from PIL import Image, ImageDraw
6
+
7
+ def draw_bbox(image, annotation):
8
+ x1, y1, x2, y2 = annotation["bbox_2d"]
9
+ label = annotation["label"]
10
+ draw = ImageDraw.Draw(image)
11
+
12
+ # 绘制边界框
13
+ draw.rectangle((x1, y1, x2, y2), outline="red", width=5)
14
+
15
+ # 绘制标签文本
16
+ font_size = 20
17
+ text_position = (x1, y1 - font_size - 5) if y1 > font_size + 5 else (x1, y2 + 5)
18
+ try:
19
+ draw.text(text_position, label, fill="red", font_size = font_size)
20
+ except Exception as e:
21
+ print(f"文本绘制错误: {e}")
22
+ # 如果默认绘制失败,使用简单的方式绘制文本
23
+ draw.text(text_position, label, fill="red")
24
+
25
+ return image
26
+
27
+ def draw_bboxes(image, annotations):
28
+ """绘制多个边界框和标签"""
29
+ result_image = image.copy()
30
+ for annotation in annotations:
31
+ result_image = draw_bbox(result_image, annotation)
32
+
33
+ return result_image
34
+
35
+ def extract_bbox_answer(content):
36
+ # Extract content between <answer> and </answer> if present
37
+ answer_matches = re.findall(r'<answer>(.*?)</answer>', content, re.DOTALL)
38
+ if answer_matches:
39
+ # Use the last match
40
+ text = answer_matches[-1]
41
+ else:
42
+ text = content
43
+
44
+ # 使用json_repair修复JSON
45
+ try:
46
+ data = json_repair.loads(text)
47
+ if isinstance(data, list) and len(data) > 0:
48
+ return data
49
+ else:
50
+ return []
51
+ except Exception as e:
52
+ print(f"JSON解析错误: {e}")
53
+ return []
54
+
55
+ import spaces
56
+
57
+ @spaces.GPU
58
+ def process_image_and_text(image, text):
59
+ """Process image and text input, return thinking process and bbox"""
60
+ question = f"Please carefully check the image and detect the following objects: [{text}]. "
61
+
62
+ question = question + "First thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Please carefully check the image and detect the following objects: [\"equestrian rider's helmet\"]. Output the bbox coordinates of detected objects in <answer></answer>. The bbox coordinates in Markdown format should be: \n```json\n[{\"bbox_2d\": [x1, y1, x2, y2], \"label\": \"object name\"}]\n```\n If no targets are detected in the image, simply respond with \"None\"."
63
+
64
+ messages = [
65
+ {
66
+ "role": "user",
67
+ "content": [
68
+ {"type": "image"},
69
+ {"type": "text", "text": question},
70
+ ],
71
+ }
72
+ ]
73
+
74
+ text = processor.apply_chat_template(
75
+ messages, tokenize=False, add_generation_prompt=True
76
+ )
77
+
78
+ inputs = processor(
79
+ text=[text],
80
+ images=image,
81
+ return_tensors="pt",
82
+ padding=True,
83
+ padding_side="left",
84
+ add_special_tokens=False,
85
+ )
86
+
87
+ inputs = inputs.to("cuda")
88
+
89
+ with torch.no_grad():
90
+ generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=1024, do_sample=False)
91
+ generated_ids_trimmed = [
92
+ out_ids[len(inputs.input_ids[0]):] for out_ids in generated_ids
93
+ ]
94
+
95
+ output_text = processor.batch_decode(
96
+ generated_ids_trimmed, skip_special_tokens=True
97
+ )[0]
98
+ print("output_text: ", output_text)
99
+
100
+ # Extract thinking process
101
+ think_match = re.search(r'<think>(.*?)</think>', output_text, re.DOTALL)
102
+ thinking_process = think_match.group(1).strip() if think_match else "No thinking process found"
103
+
104
+ answer_match = re.search(r'<answer>(.*?)</answer>', output_text, re.DOTALL)
105
+ answer_output = answer_match.group(1).strip() if answer_match else "No answer extracted"
106
+
107
+ # Get bbox and draw
108
+ bbox = extract_bbox_answer(output_text)
109
+
110
+ # Draw bbox on the image
111
+ result_image = image.copy()
112
+ result_image = draw_bboxes(result_image, bbox)
113
+
114
+ return thinking_process, answer_output,result_image
115
+
116
+ if __name__ == "__main__":
117
+ import gradio as gr
118
+
119
+ model_path = "omlab/VLM-R1-Qwen2.5VL-3B-Math-0305"
120
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
121
+ device = "cuda"
122
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16)
123
+ model.to(device)
124
+ processor = AutoProcessor.from_pretrained(model_path)
125
+
126
+ def gradio_interface(image, text):
127
+ thinking, output,result_image = process_image_and_text(image, text)
128
+ return thinking, output, result_image
129
+
130
+ demo = gr.Interface(
131
+ fn=gradio_interface,
132
+ inputs=[
133
+ gr.Image(type="pil", label="Input Image"),
134
+ gr.Textbox(label="Description Text")
135
+ ],
136
+ outputs=[
137
+ gr.Textbox(label="Thinking Process"),
138
+ gr.Textbox(label="Response"),
139
+ gr.Image(type="pil", label="Result with Bbox")
140
+ ],
141
+ title="Open-Vocabulary Object Detection Demo",
142
+ description="Upload an image and input description text, the system will return the thinking process and region annotation. \n\nOur GitHub: [VLM-R1](https://github.com/om-ai-lab/VLM-R1/tree/main)",
143
+ examples=[
144
+ ["examples/image1.jpg", "person"],
145
+ ["examples/image2.jpg", "drink, fruit"],
146
+ ["examples/image3.png", "keyboard, white cup, laptop"],
147
+ ],
148
+ cache_examples=False,
149
+ examples_per_page=10
150
+ )
151
+
152
+ demo.launch(server_name="0.0.0.0", server_port=7861, share=True)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ git+https://github.com/huggingface/transformers
3
+ Pillow>=10.0.0
4
+ httpx[socks]
5
+ accelerate>=0.26.0
6
+ json_repair>=0.1.0