ARCQUB commited on
Commit
8093104
·
verified ·
1 Parent(s): c395a91

Update models/aya_vision.py

Browse files
Files changed (1) hide show
  1. models/aya_vision.py +128 -128
models/aya_vision.py CHANGED
@@ -1,128 +1,128 @@
1
- import os
2
- import json
3
- import re
4
- from PIL import Image
5
- import torch
6
- from transformers import AutoProcessor, AutoModelForImageTextToText
7
- from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
8
-
9
- # Set Hugging Face Token
10
- hf_token = os.getenv("HF_TOKEN")
11
-
12
- # Initialize Aya Vision Model
13
- model_id = "CohereForAI/aya-vision-8b"
14
- processor = AutoProcessor.from_pretrained(model_id)
15
- model = AutoModelForImageTextToText.from_pretrained(
16
- model_id, device_map="auto", torch_dtype=torch.float16
17
- )
18
-
19
- # Initialize Pix2Struct OCR Model
20
- ocr_processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base")
21
- ocr_model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base")
22
-
23
- # Load prompt
24
- def load_prompt():
25
- with open("/content/vision_model_space/vision_model_space_new/prompts/prompt.txt", "r", encoding="utf-8") as f:
26
- return f.read()
27
-
28
- # Try extracting JSON from model output
29
- def try_extract_json(text):
30
- if not text or not text.strip():
31
- return None
32
- try:
33
- return json.loads(text)
34
- except json.JSONDecodeError:
35
- # Try extracting JSON substring by brace balancing
36
- start = text.find('{')
37
- if start == -1:
38
- return None
39
-
40
- brace_count = 0
41
- json_candidate = ''
42
- for i in range(start, len(text)):
43
- char = text[i]
44
- if char == '{':
45
- brace_count += 1
46
- elif char == '}':
47
- brace_count -= 1
48
- json_candidate += char
49
- if brace_count == 0:
50
- break
51
-
52
- try:
53
- return json.loads(json_candidate)
54
- except json.JSONDecodeError:
55
- return None
56
-
57
- # Extract OCR text using Pix2Struct
58
- def extract_all_text_pix2struct(image: Image.Image):
59
- inputs = ocr_processor(images=image, return_tensors="pt")
60
- predictions = ocr_model.generate(**inputs, max_new_tokens=512)
61
- output_text = ocr_processor.decode(predictions[0], skip_special_tokens=True)
62
- return output_text.strip()
63
-
64
- # Assign event/gateway names from OCR text
65
- def assign_event_gateway_names_from_ocr(json_data: dict, ocr_text: str):
66
- if not ocr_text or not json_data:
67
- return json_data
68
-
69
- lines = [line.strip() for line in ocr_text.split('\n') if line.strip()]
70
-
71
- def assign_best_guess(obj):
72
- if not obj.get("name") or obj["name"].strip() == "":
73
- obj["name"] = "(label unknown)"
74
-
75
- for evt in json_data.get("events", []):
76
- assign_best_guess(evt)
77
-
78
- for gw in json_data.get("gateways", []):
79
- assign_best_guess(gw)
80
-
81
- return json_data
82
-
83
- # Run Aya model on image
84
- def run_model(image: Image.Image):
85
- prompt = load_prompt()
86
-
87
- messages = [
88
- {
89
- "role": "user",
90
- "content": [
91
- {"type": "image", "image": image},
92
- {"type": "text", "text": prompt}
93
- ]
94
- }
95
- ]
96
-
97
- inputs = processor.apply_chat_template(
98
- messages,
99
- padding=True,
100
- add_generation_prompt=True,
101
- tokenize=True,
102
- return_dict=True,
103
- return_tensors="pt"
104
- ).to(model.device)
105
-
106
- gen_tokens = model.generate(
107
- **inputs,
108
- max_new_tokens=5000,
109
- do_sample=True,
110
- temperature=0.3,
111
- )
112
-
113
- output_text = processor.tokenizer.decode(
114
- gen_tokens[0][inputs.input_ids.shape[1]:],
115
- skip_special_tokens=True
116
- )
117
-
118
- parsed_json = try_extract_json(output_text)
119
-
120
- # Apply OCR post-processing
121
- ocr_text = extract_all_text_pix2struct(image)
122
- parsed_json = assign_event_gateway_names_from_ocr(parsed_json, ocr_text)
123
-
124
- # Return both parsed and raw
125
- return {
126
- "json": parsed_json,
127
- "raw": output_text
128
- }
 
1
+ import os
2
+ import json
3
+ import re
4
+ from PIL import Image
5
+ import torch
6
+ from transformers import AutoProcessor, AutoModelForImageTextToText
7
+ from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
8
+
9
+ # Set Hugging Face Token
10
+ hf_token = os.getenv("HF_TOKEN")
11
+
12
+ # Initialize Aya Vision Model
13
+ model_id = "CohereForAI/aya-vision-8b"
14
+ processor = AutoProcessor.from_pretrained(model_id)
15
+ model = AutoModelForImageTextToText.from_pretrained(
16
+ model_id, device_map="auto", torch_dtype=torch.float16
17
+ )
18
+
19
+ # Initialize Pix2Struct OCR Model
20
+ ocr_processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base")
21
+ ocr_model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base")
22
+
23
+ # Load prompt
24
+ def load_prompt():
25
+ with open("prompts/prompt.txt", "r", encoding="utf-8") as f:
26
+ return f.read()
27
+
28
+ # Try extracting JSON from model output
29
+ def try_extract_json(text):
30
+ if not text or not text.strip():
31
+ return None
32
+ try:
33
+ return json.loads(text)
34
+ except json.JSONDecodeError:
35
+ # Try extracting JSON substring by brace balancing
36
+ start = text.find('{')
37
+ if start == -1:
38
+ return None
39
+
40
+ brace_count = 0
41
+ json_candidate = ''
42
+ for i in range(start, len(text)):
43
+ char = text[i]
44
+ if char == '{':
45
+ brace_count += 1
46
+ elif char == '}':
47
+ brace_count -= 1
48
+ json_candidate += char
49
+ if brace_count == 0:
50
+ break
51
+
52
+ try:
53
+ return json.loads(json_candidate)
54
+ except json.JSONDecodeError:
55
+ return None
56
+
57
+ # Extract OCR text using Pix2Struct
58
+ def extract_all_text_pix2struct(image: Image.Image):
59
+ inputs = ocr_processor(images=image, return_tensors="pt")
60
+ predictions = ocr_model.generate(**inputs, max_new_tokens=512)
61
+ output_text = ocr_processor.decode(predictions[0], skip_special_tokens=True)
62
+ return output_text.strip()
63
+
64
+ # Assign event/gateway names from OCR text
65
+ def assign_event_gateway_names_from_ocr(json_data: dict, ocr_text: str):
66
+ if not ocr_text or not json_data:
67
+ return json_data
68
+
69
+ lines = [line.strip() for line in ocr_text.split('\n') if line.strip()]
70
+
71
+ def assign_best_guess(obj):
72
+ if not obj.get("name") or obj["name"].strip() == "":
73
+ obj["name"] = "(label unknown)"
74
+
75
+ for evt in json_data.get("events", []):
76
+ assign_best_guess(evt)
77
+
78
+ for gw in json_data.get("gateways", []):
79
+ assign_best_guess(gw)
80
+
81
+ return json_data
82
+
83
+ # Run Aya model on image
84
+ def run_model(image: Image.Image):
85
+ prompt = load_prompt()
86
+
87
+ messages = [
88
+ {
89
+ "role": "user",
90
+ "content": [
91
+ {"type": "image", "image": image},
92
+ {"type": "text", "text": prompt}
93
+ ]
94
+ }
95
+ ]
96
+
97
+ inputs = processor.apply_chat_template(
98
+ messages,
99
+ padding=True,
100
+ add_generation_prompt=True,
101
+ tokenize=True,
102
+ return_dict=True,
103
+ return_tensors="pt"
104
+ ).to(model.device)
105
+
106
+ gen_tokens = model.generate(
107
+ **inputs,
108
+ max_new_tokens=5000,
109
+ do_sample=True,
110
+ temperature=0.3,
111
+ )
112
+
113
+ output_text = processor.tokenizer.decode(
114
+ gen_tokens[0][inputs.input_ids.shape[1]:],
115
+ skip_special_tokens=True
116
+ )
117
+
118
+ parsed_json = try_extract_json(output_text)
119
+
120
+ # Apply OCR post-processing
121
+ ocr_text = extract_all_text_pix2struct(image)
122
+ parsed_json = assign_event_gateway_names_from_ocr(parsed_json, ocr_text)
123
+
124
+ # Return both parsed and raw
125
+ return {
126
+ "json": parsed_json,
127
+ "raw": output_text
128
+ }