ARCQUB commited on
Commit
9e1e3ef
·
verified ·
1 Parent(s): 462e5c0

Update models/qwen.py

Browse files
Files changed (1) hide show
  1. models/qwen.py +48 -33
models/qwen.py CHANGED
@@ -3,31 +3,21 @@ import json
3
  from PIL import Image
4
  import torch
5
  from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
6
- from qwen_vl_utils import process_vision_info
7
  from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
 
8
 
9
- # Initialize Qwen2.5-VL model
10
- model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
11
- "Qwen/Qwen2.5-VL-7B-Instruct",
12
- torch_dtype=torch.bfloat16,
13
- device_map="cuda"
14
- #attn_implementation="flash_attention_2"
15
- )
16
-
17
- min_pixels = 256 * 28 * 28
18
- max_pixels = 1080 * 28 * 28
19
- processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
20
 
21
- # Initialize Pix2Struct OCR model
22
- ocr_processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base")
23
- ocr_model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base")
24
 
25
- # Load prompt
26
  def load_prompt():
27
- with open("prompts/prompt.txt", "r") as f:
28
  return f.read()
29
 
30
- # Try extracting JSON from text
31
  def try_extract_json(text):
32
  try:
33
  return json.loads(text)
@@ -50,20 +40,25 @@ def try_extract_json(text):
50
  except json.JSONDecodeError:
51
  return text
52
 
53
- # Extract OCR text using Pix2Struct
54
  def extract_all_text_pix2struct(image: Image.Image):
55
- inputs = ocr_processor(images=image, return_tensors="pt")
 
 
 
 
 
 
 
 
56
  predictions = ocr_model.generate(**inputs, max_new_tokens=512)
57
- output_text = ocr_processor.decode(predictions[0], skip_special_tokens=True)
58
- return output_text.strip()
59
 
60
- # Assign event/gateway names from OCR text
61
  def assign_event_gateway_names_from_ocr(json_data: dict, ocr_text: str):
62
  if not ocr_text or not json_data:
63
  return json_data
64
 
65
- lines = [line.strip() for line in ocr_text.split('\n') if line.strip()]
66
-
67
  def assign_best_guess(obj):
68
  if not obj.get("name") or obj["name"].strip() == "":
69
  obj["name"] = "(label unknown)"
@@ -76,9 +71,29 @@ def assign_event_gateway_names_from_ocr(json_data: dict, ocr_text: str):
76
 
77
  return json_data
78
 
79
- # Run model
80
  def run_model(image: Image.Image):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  prompt = load_prompt()
 
82
  messages = [
83
  {
84
  "role": "user",
@@ -89,21 +104,21 @@ def run_model(image: Image.Image):
89
  }
90
  ]
91
 
92
- text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
93
  image_inputs, video_inputs = process_vision_info(messages)
94
 
95
- inputs = processor(
96
  text=[text],
97
  images=image_inputs,
98
  videos=video_inputs,
99
  padding=True,
100
  return_tensors="pt"
101
- ).to("cuda")
102
 
103
- generated_ids = model.generate(**inputs, max_new_tokens=5000)
104
  generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
105
 
106
- output_text = processor.batch_decode(
107
  generated_ids_trimmed,
108
  skip_special_tokens=True,
109
  clean_up_tokenization_spaces=False
@@ -111,11 +126,11 @@ def run_model(image: Image.Image):
111
 
112
  parsed_json = try_extract_json(output_text)
113
 
114
- # Apply OCR post-processing
115
  ocr_text = extract_all_text_pix2struct(image)
116
  parsed_json = assign_event_gateway_names_from_ocr(parsed_json, ocr_text)
117
 
118
  return {
119
  "json": parsed_json,
120
  "raw": output_text
121
- }
 
3
  from PIL import Image
4
  import torch
5
  from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
 
6
  from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
7
+ from qwen_vl_utils import process_vision_info
8
 
9
+ # Globals (lazy-loaded at runtime)
10
+ qwen_model = None
11
+ qwen_processor = None
12
+ ocr_model = None
13
+ ocr_processor = None
 
 
 
 
 
 
14
 
 
 
 
15
 
 
16
  def load_prompt():
17
+ with open("prompts/prompt.txt", "r", encoding="utf-8") as f:
18
  return f.read()
19
 
20
+
21
  def try_extract_json(text):
22
  try:
23
  return json.loads(text)
 
40
  except json.JSONDecodeError:
41
  return text
42
 
43
+
44
  def extract_all_text_pix2struct(image: Image.Image):
45
+ global ocr_model, ocr_processor
46
+
47
+ if ocr_model is None or ocr_processor is None:
48
+ ocr_processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base")
49
+ ocr_model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base")
50
+ device = "cuda" if torch.cuda.is_available() else "cpu"
51
+ ocr_model = ocr_model.to(device)
52
+
53
+ inputs = ocr_processor(images=image, return_tensors="pt").to(ocr_model.device)
54
  predictions = ocr_model.generate(**inputs, max_new_tokens=512)
55
+ return ocr_processor.decode(predictions[0], skip_special_tokens=True).strip()
56
+
57
 
 
58
  def assign_event_gateway_names_from_ocr(json_data: dict, ocr_text: str):
59
  if not ocr_text or not json_data:
60
  return json_data
61
 
 
 
62
  def assign_best_guess(obj):
63
  if not obj.get("name") or obj["name"].strip() == "":
64
  obj["name"] = "(label unknown)"
 
71
 
72
  return json_data
73
 
74
+
75
  def run_model(image: Image.Image):
76
+ global qwen_model, qwen_processor
77
+
78
+ if qwen_model is None or qwen_processor is None:
79
+ qwen_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
80
+ "Qwen/Qwen2.5-VL-7B-Instruct",
81
+ torch_dtype=torch.bfloat16,
82
+ device_map="auto"
83
+ # You can enable flash attention here if needed:
84
+ # attn_implementation="flash_attention_2"
85
+ )
86
+
87
+ min_pixels = 256 * 28 * 28
88
+ max_pixels = 1080 * 28 * 28
89
+ qwen_processor = AutoProcessor.from_pretrained(
90
+ "Qwen/Qwen2.5-VL-7B-Instruct",
91
+ min_pixels=min_pixels,
92
+ max_pixels=max_pixels
93
+ )
94
+
95
  prompt = load_prompt()
96
+
97
  messages = [
98
  {
99
  "role": "user",
 
104
  }
105
  ]
106
 
107
+ text = qwen_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
108
  image_inputs, video_inputs = process_vision_info(messages)
109
 
110
+ inputs = qwen_processor(
111
  text=[text],
112
  images=image_inputs,
113
  videos=video_inputs,
114
  padding=True,
115
  return_tensors="pt"
116
+ ).to(qwen_model.device)
117
 
118
+ generated_ids = qwen_model.generate(**inputs, max_new_tokens=5000)
119
  generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
120
 
121
+ output_text = qwen_processor.batch_decode(
122
  generated_ids_trimmed,
123
  skip_special_tokens=True,
124
  clean_up_tokenization_spaces=False
 
126
 
127
  parsed_json = try_extract_json(output_text)
128
 
129
+ # OCR post-processing
130
  ocr_text = extract_all_text_pix2struct(image)
131
  parsed_json = assign_event_gateway_names_from_ocr(parsed_json, ocr_text)
132
 
133
  return {
134
  "json": parsed_json,
135
  "raw": output_text
136
+ }