File size: 3,660 Bytes
8093104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import os
import json
import re
from PIL import Image
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor

# Set Hugging Face Token
hf_token = os.getenv("HF_TOKEN")

# Initialize Aya Vision Model
model_id = "CohereForAI/aya-vision-8b"
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForImageTextToText.from_pretrained(
    model_id, device_map="auto", torch_dtype=torch.float16
)

# Initialize Pix2Struct OCR Model
ocr_processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base")
ocr_model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base")

# Load prompt
def load_prompt():
    with open("prompts/prompt.txt", "r", encoding="utf-8") as f:
        return f.read()

# Try extracting JSON from model output
def try_extract_json(text):
    if not text or not text.strip():
        return None
    try:
        return json.loads(text)
    except json.JSONDecodeError:
        # Try extracting JSON substring by brace balancing
        start = text.find('{')
        if start == -1:
            return None

        brace_count = 0
        json_candidate = ''
        for i in range(start, len(text)):
            char = text[i]
            if char == '{':
                brace_count += 1
            elif char == '}':
                brace_count -= 1
            json_candidate += char
            if brace_count == 0:
                break

        try:
            return json.loads(json_candidate)
        except json.JSONDecodeError:
            return None

# Extract OCR text using Pix2Struct
def extract_all_text_pix2struct(image: Image.Image):
    inputs = ocr_processor(images=image, return_tensors="pt")
    predictions = ocr_model.generate(**inputs, max_new_tokens=512)
    output_text = ocr_processor.decode(predictions[0], skip_special_tokens=True)
    return output_text.strip()

# Assign event/gateway names from OCR text
def assign_event_gateway_names_from_ocr(json_data: dict, ocr_text: str):
    if not ocr_text or not json_data:
        return json_data

    lines = [line.strip() for line in ocr_text.split('\n') if line.strip()]

    def assign_best_guess(obj):
        if not obj.get("name") or obj["name"].strip() == "":
            obj["name"] = "(label unknown)"

    for evt in json_data.get("events", []):
        assign_best_guess(evt)

    for gw in json_data.get("gateways", []):
        assign_best_guess(gw)

    return json_data

# Run Aya model on image
def run_model(image: Image.Image):
    prompt = load_prompt()

    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": prompt}
            ]
        }
    ]

    inputs = processor.apply_chat_template(
        messages,
        padding=True,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt"
    ).to(model.device)

    gen_tokens = model.generate(
        **inputs,
        max_new_tokens=5000,
        do_sample=True,
        temperature=0.3,
    )

    output_text = processor.tokenizer.decode(
        gen_tokens[0][inputs.input_ids.shape[1]:],
        skip_special_tokens=True
    )

    parsed_json = try_extract_json(output_text)

    # Apply OCR post-processing
    ocr_text = extract_all_text_pix2struct(image)
    parsed_json = assign_event_gateway_names_from_ocr(parsed_json, ocr_text)

    # Return both parsed and raw
    return {
        "json": parsed_json,
        "raw": output_text
    }