File size: 3,504 Bytes
e29cf96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# gpt4o_pix2struct_ocr.py

import os
import json
import base64
from PIL import Image
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
import numpy as np

import openai

model = "gpt-4o"

# Load Pix2Struct model + processor (vision-language OCR)
processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base")
pix2struct_model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base")


def load_prompt(prompt_file="prompts/prompt.txt"):
    with open(prompt_file, "r", encoding="utf-8") as f:
        return f.read().strip()


def try_extract_json(text):
    try:
        return json.loads(text)
    except json.JSONDecodeError:
        start = text.find('{')
        if start == -1:
            return None
        brace_count = 0
        json_candidate = ''
        for i in range(start, len(text)):
            if text[i] == '{':
                brace_count += 1
            elif text[i] == '}':
                brace_count -= 1
            json_candidate += text[i]
            if brace_count == 0 and json_candidate.strip():
                break
        try:
            return json.loads(json_candidate)
        except json.JSONDecodeError:
            return None


def encode_image_base64(image: Image.Image):
    from io import BytesIO
    buffer = BytesIO()
    image.save(buffer, format="JPEG")
    return base64.b64encode(buffer.getvalue()).decode("utf-8")


def extract_all_text_pix2struct(image: Image.Image):
    inputs = processor(images=image, return_tensors="pt")
    predictions = pix2struct_model.generate(**inputs, max_new_tokens=512)
    output_text = processor.decode(predictions[0], skip_special_tokens=True)
    return output_text.strip()


# Optional: assign best-matching label from full extracted text using proximity (simplified version)
def assign_event_gateway_names_from_ocr(image: Image.Image, json_data, ocr_text):
    if not ocr_text:
        return json_data

    # You could use NLP matching or regex in complex cases
    words = ocr_text.split()

    def guess_name_fallback(obj):
        if not obj.get("name") or obj["name"].strip() == "":
            obj["name"] = "(label unknown)"  # fallback if matching logic isn't yet implemented

    for evt in json_data.get("events", []):
        guess_name_fallback(evt)

    for gw in json_data.get("gateways", []):
        guess_name_fallback(gw)

    return json_data


def run_model(image: Image.Image, api_key: str = None):
    prompt_text = load_prompt()
    encoded_image = encode_image_base64(image)

    if not api_key:
        return {"json": None, "raw": "⚠️ API key is missing. Please provide your OpenAI API key."}

    client = openai.OpenAI(api_key=api_key)
    response = client.chat.completions.create(
        model=model,
        messages=[
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": prompt_text},
                    {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}}
                ]
            }
        ],
        max_tokens=5000
    )

    output_text = response.choices[0].message.content.strip()
    parsed_json = try_extract_json(output_text)

    # Vision-language OCR assist step (Pix2Struct)
    full_ocr_text = extract_all_text_pix2struct(image)
    parsed_json = assign_event_gateway_names_from_ocr(image, parsed_json, full_ocr_text)

    return {"json": parsed_json, "raw": output_text}