ARCQUB commited on
Commit
462e5c0
·
verified ·
1 Parent(s): 74e0e2d

Update models/pixtral.py

Browse files
Files changed (1) hide show
  1. models/pixtral.py +123 -113
models/pixtral.py CHANGED
@@ -1,113 +1,123 @@
1
- import os
2
- import json
3
- import base64
4
- from PIL import Image
5
- from vllm import LLM
6
- from vllm.sampling_params import SamplingParams
7
- from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
8
-
9
- # Optional: Replace with your Hugging Face token or use environment variable
10
- hf_token = os.getenv("HF_TOKEN")
11
- Image.MAX_IMAGE_PIXELS = None
12
-
13
- # Initialize Pixtral model
14
- model_name = "mistralai/Pixtral-12B-2409"
15
- sampling_params = SamplingParams(max_tokens=5000)
16
- llm = LLM(model=model_name, tokenizer_mode="mistral", dtype="bfloat16", max_model_len=30000)
17
-
18
- # Initialize Pix2Struct OCR model
19
- ocr_processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base")
20
- ocr_model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base")
21
-
22
- # Load prompt from file
23
- def load_prompt():
24
- with open("prompts/prompt.txt", "r", encoding="utf-8") as f:
25
- return f.read()
26
-
27
- # Extract structured JSON from text
28
- def try_extract_json(text):
29
- if not text or not text.strip():
30
- return None
31
- try:
32
- return json.loads(text)
33
- except json.JSONDecodeError:
34
- start = text.find('{')
35
- if start == -1:
36
- return None
37
-
38
- brace_count = 0
39
- json_candidate = ''
40
- for i in range(start, len(text)):
41
- if text[i] == '{':
42
- brace_count += 1
43
- elif text[i] == '}':
44
- brace_count -= 1
45
- json_candidate += text[i]
46
- if brace_count == 0:
47
- break
48
- try:
49
- return json.loads(json_candidate)
50
- except json.JSONDecodeError:
51
- return None
52
-
53
- # Base64 encode image
54
- def encode_image_as_base64(pil_image):
55
- from io import BytesIO
56
- buffer = BytesIO()
57
- pil_image.save(buffer, format="JPEG")
58
- encoded = base64.b64encode(buffer.getvalue()).decode("utf-8")
59
- return encoded
60
-
61
- # Extract OCR text using Pix2Struct
62
- def extract_all_text_pix2struct(image: Image.Image):
63
- inputs = ocr_processor(images=image, return_tensors="pt")
64
- predictions = ocr_model.generate(**inputs, max_new_tokens=512)
65
- output_text = ocr_processor.decode(predictions[0], skip_special_tokens=True)
66
- return output_text.strip()
67
-
68
- # Assign event/gateway names from OCR text
69
- def assign_event_gateway_names_from_ocr(json_data: dict, ocr_text: str):
70
- if not ocr_text or not json_data:
71
- return json_data
72
-
73
- lines = [line.strip() for line in ocr_text.split('\n') if line.strip()]
74
-
75
- def assign_best_guess(obj):
76
- if not obj.get("name") or obj["name"].strip() == "":
77
- obj["name"] = "(label unknown)"
78
-
79
- for evt in json_data.get("events", []):
80
- assign_best_guess(evt)
81
-
82
- for gw in json_data.get("gateways", []):
83
- assign_best_guess(gw)
84
-
85
- return json_data
86
-
87
- # Run model
88
- def run_model(image: Image.Image):
89
- prompt = load_prompt()
90
- encoded_image = encode_image_as_base64(image)
91
-
92
- messages = [
93
- {
94
- "role": "user",
95
- "content": [
96
- {"type": "text", "text": prompt},
97
- {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}}
98
- ]
99
- }
100
- ]
101
-
102
- outputs = llm.chat(messages, sampling_params=sampling_params)
103
- raw_output = outputs[0].outputs[0].text
104
- parsed_json = try_extract_json(raw_output)
105
-
106
- # Apply OCR post-processing
107
- ocr_text = extract_all_text_pix2struct(image)
108
- parsed_json = assign_event_gateway_names_from_ocr(parsed_json, ocr_text)
109
-
110
- return {
111
- "json": parsed_json,
112
- "raw": raw_output
113
- }
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import base64
4
+ from PIL import Image
5
+ from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
6
+ from vllm import LLM
7
+ from vllm.sampling_params import SamplingParams
8
+
9
+ # Hugging Face token from environment (optional)
10
+ hf_token = os.getenv("HF_TOKEN")
11
+ Image.MAX_IMAGE_PIXELS = None
12
+
13
+ # Global placeholders (lazy-loaded later)
14
+ llm = None
15
+ ocr_model = None
16
+ ocr_processor = None
17
+ sampling_params = SamplingParams(max_tokens=5000)
18
+
19
+
20
+ def load_prompt():
21
+ with open("prompts/prompt.txt", "r", encoding="utf-8") as f:
22
+ return f.read()
23
+
24
+
25
+ def try_extract_json(text):
26
+ if not text or not text.strip():
27
+ return None
28
+ try:
29
+ return json.loads(text)
30
+ except json.JSONDecodeError:
31
+ start = text.find('{')
32
+ if start == -1:
33
+ return None
34
+ brace_count = 0
35
+ json_candidate = ''
36
+ for i in range(start, len(text)):
37
+ if text[i] == '{':
38
+ brace_count += 1
39
+ elif text[i] == '}':
40
+ brace_count -= 1
41
+ json_candidate += text[i]
42
+ if brace_count == 0:
43
+ break
44
+ try:
45
+ return json.loads(json_candidate)
46
+ except json.JSONDecodeError:
47
+ return None
48
+
49
+
50
+ def encode_image_as_base64(pil_image):
51
+ from io import BytesIO
52
+ buffer = BytesIO()
53
+ pil_image.save(buffer, format="JPEG")
54
+ return base64.b64encode(buffer.getvalue()).decode("utf-8")
55
+
56
+
57
+ def extract_all_text_pix2struct(image: Image.Image):
58
+ global ocr_processor, ocr_model
59
+
60
+ if ocr_processor is None or ocr_model is None:
61
+ ocr_processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base")
62
+ ocr_model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base")
63
+ device = "cuda" if torch.cuda.is_available() else "cpu"
64
+ ocr_model = ocr_model.to(device)
65
+
66
+ inputs = ocr_processor(images=image, return_tensors="pt").to(ocr_model.device)
67
+ predictions = ocr_model.generate(**inputs, max_new_tokens=512)
68
+ return ocr_processor.decode(predictions[0], skip_special_tokens=True).strip()
69
+
70
+
71
+ def assign_event_gateway_names_from_ocr(json_data: dict, ocr_text: str):
72
+ if not ocr_text or not json_data:
73
+ return json_data
74
+
75
+ def assign_best_guess(obj):
76
+ if not obj.get("name") or obj["name"].strip() == "":
77
+ obj["name"] = "(label unknown)"
78
+
79
+ for evt in json_data.get("events", []):
80
+ assign_best_guess(evt)
81
+
82
+ for gw in json_data.get("gateways", []):
83
+ assign_best_guess(gw)
84
+
85
+ return json_data
86
+
87
+
88
+ def run_model(image: Image.Image):
89
+ global llm
90
+
91
+ if llm is None:
92
+ llm = LLM(
93
+ model="mistralai/Pixtral-12B-2409",
94
+ tokenizer_mode="mistral",
95
+ dtype="bfloat16",
96
+ max_model_len=30000,
97
+ )
98
+
99
+ prompt = load_prompt()
100
+ encoded_image = encode_image_as_base64(image)
101
+
102
+ messages = [
103
+ {
104
+ "role": "user",
105
+ "content": [
106
+ {"type": "text", "text": prompt},
107
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}}
108
+ ]
109
+ }
110
+ ]
111
+
112
+ outputs = llm.chat(messages, sampling_params=sampling_params)
113
+ raw_output = outputs[0].outputs[0].text
114
+ parsed_json = try_extract_json(raw_output)
115
+
116
+ # Apply OCR enrichment
117
+ ocr_text = extract_all_text_pix2struct(image)
118
+ parsed_json = assign_event_gateway_names_from_ocr(parsed_json, ocr_text)
119
+
120
+ return {
121
+ "json": parsed_json,
122
+ "raw": raw_output
123
+ }