File size: 3,775 Bytes
6c0c37c
 
 
 
 
 
 
 
 
 
 
 
d6daf6b
 
6c0c37c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os
import json
from PIL import Image
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor

# Initialize Qwen2.5-VL model
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2.5-VL-7B-Instruct",
    torch_dtype=torch.bfloat16,
    device_map="cuda"
    #attn_implementation="flash_attention_2"
)

min_pixels = 256 * 28 * 28
max_pixels = 1080 * 28 * 28
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)

# Initialize Pix2Struct OCR model
ocr_processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base")
ocr_model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base")

# Load prompt
def load_prompt():
    with open("prompts/prompt.txt", "r") as f:
        return f.read()

# Try extracting JSON from text
def try_extract_json(text):
    try:
        return json.loads(text)
    except json.JSONDecodeError:
        start = text.find('{')
        if start == -1:
            return text
        brace_count = 0
        json_candidate = ''
        for i in range(start, len(text)):
            if text[i] == '{':
                brace_count += 1
            elif text[i] == '}':
                brace_count -= 1
            json_candidate += text[i]
            if brace_count == 0:
                break
        try:
            return json.loads(json_candidate)
        except json.JSONDecodeError:
            return text

# Extract OCR text using Pix2Struct
def extract_all_text_pix2struct(image: Image.Image):
    inputs = ocr_processor(images=image, return_tensors="pt")
    predictions = ocr_model.generate(**inputs, max_new_tokens=512)
    output_text = ocr_processor.decode(predictions[0], skip_special_tokens=True)
    return output_text.strip()

# Assign event/gateway names from OCR text
def assign_event_gateway_names_from_ocr(json_data: dict, ocr_text: str):
    if not ocr_text or not json_data:
        return json_data

    lines = [line.strip() for line in ocr_text.split('\n') if line.strip()]

    def assign_best_guess(obj):
        if not obj.get("name") or obj["name"].strip() == "":
            obj["name"] = "(label unknown)"

    for evt in json_data.get("events", []):
        assign_best_guess(evt)

    for gw in json_data.get("gateways", []):
        assign_best_guess(gw)

    return json_data

# Run model
def run_model(image: Image.Image):
    prompt = load_prompt()
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": prompt}
            ]
        }
    ]

    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    image_inputs, video_inputs = process_vision_info(messages)

    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt"
    ).to("cuda")

    generated_ids = model.generate(**inputs, max_new_tokens=5000)
    generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]

    output_text = processor.batch_decode(
        generated_ids_trimmed,
        skip_special_tokens=True,
        clean_up_tokenization_spaces=False
    )[0]

    parsed_json = try_extract_json(output_text)

    # Apply OCR post-processing
    ocr_text = extract_all_text_pix2struct(image)
    parsed_json = assign_event_gateway_names_from_ocr(parsed_json, ocr_text)

    return {
        "json": parsed_json,
        "raw": output_text
    }