Spaces:
Running
Running
Upload folder using huggingface_hub
Browse files- app.py +187 -0
- models/aya_vision.py +128 -0
- models/gpt4o.py +111 -0
- models/pixtral.py +113 -0
- models/qwen.py +121 -0
- prompts/prompt.txt +145 -0
- requirements.txt +10 -0
app.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import importlib
|
3 |
+
from PIL import Image
|
4 |
+
import json
|
5 |
+
|
6 |
+
# === Model Mapping ===
|
7 |
+
MODEL_MAP = {
|
8 |
+
"Qwen": "models.qwen",
|
9 |
+
"Pixtral": "models.pixtral",
|
10 |
+
"Aya Vision": "models.aya_vision",
|
11 |
+
"GPT-4o": "models.gpt4o"
|
12 |
+
}
|
13 |
+
|
14 |
+
# === Load Model
|
15 |
+
def load_model_runner(model_name):
|
16 |
+
module = importlib.import_module(MODEL_MAP[model_name])
|
17 |
+
return module.run_model
|
18 |
+
|
19 |
+
# === Format Raw JSON Output
|
20 |
+
def format_result_json(output):
|
21 |
+
if isinstance(output, dict):
|
22 |
+
return json.dumps(output, indent=2)
|
23 |
+
else:
|
24 |
+
return str(output).strip()
|
25 |
+
|
26 |
+
# === Prettified Output View
|
27 |
+
def format_pretty_view(output):
|
28 |
+
if not isinstance(output, dict):
|
29 |
+
return "No structured JSON found.\n\n" + str(output)
|
30 |
+
|
31 |
+
lines = []
|
32 |
+
process = output.get("process", output)
|
33 |
+
|
34 |
+
if "name" in process:
|
35 |
+
lines.append(f"📦 Process Name: {process['name']}\n")
|
36 |
+
|
37 |
+
if "startEvent" in process:
|
38 |
+
start = process["startEvent"]
|
39 |
+
name = start.get("name", "")
|
40 |
+
type_ = start.get("type", "")
|
41 |
+
desc = start.get("description", "")
|
42 |
+
line = f"▶️ Start: {name}"
|
43 |
+
if type_:
|
44 |
+
line += f" ({type_})"
|
45 |
+
if desc:
|
46 |
+
line += f" - {desc}"
|
47 |
+
lines.append(line)
|
48 |
+
|
49 |
+
if "endEvent" in process:
|
50 |
+
end = process["endEvent"]
|
51 |
+
name = end.get("name", "")
|
52 |
+
type_ = end.get("type", "")
|
53 |
+
desc = end.get("description", "")
|
54 |
+
line = f"⏹ End: {name}"
|
55 |
+
if type_:
|
56 |
+
line += f" ({type_})"
|
57 |
+
if desc:
|
58 |
+
line += f" - {desc}"
|
59 |
+
lines.append(line)
|
60 |
+
|
61 |
+
if "tasks" in process:
|
62 |
+
lines.append("\n🔹 Tasks:")
|
63 |
+
for t in process["tasks"]:
|
64 |
+
name = t.get("name", "")
|
65 |
+
type_ = t.get("type", "")
|
66 |
+
desc = t.get("description", "")
|
67 |
+
line = f" - {name}"
|
68 |
+
if type_:
|
69 |
+
line += f" ({type_})"
|
70 |
+
if desc:
|
71 |
+
line += f" - {desc}"
|
72 |
+
lines.append(line)
|
73 |
+
|
74 |
+
if "events" in process:
|
75 |
+
lines.append("\n📨 Events:")
|
76 |
+
for e in process["events"]:
|
77 |
+
name = e.get("name", "")
|
78 |
+
type_ = e.get("type", "")
|
79 |
+
desc = e.get("description", "")
|
80 |
+
line = f" - {name}"
|
81 |
+
if type_:
|
82 |
+
line += f" ({type_})"
|
83 |
+
if desc:
|
84 |
+
line += f" - {desc}"
|
85 |
+
lines.append(line)
|
86 |
+
|
87 |
+
if "gateways" in process:
|
88 |
+
lines.append("\n🔀 Gateways:")
|
89 |
+
for g in process["gateways"]:
|
90 |
+
name = g.get("name", "")
|
91 |
+
type_ = g.get("type", "")
|
92 |
+
label = g.get("label", "") # some outputs may use 'label'
|
93 |
+
desc = g.get("description", "")
|
94 |
+
line = f" - {name}"
|
95 |
+
if type_:
|
96 |
+
line += f" ({type_})"
|
97 |
+
if label:
|
98 |
+
line += f" | Label: {label}"
|
99 |
+
if desc:
|
100 |
+
line += f" - {desc}"
|
101 |
+
lines.append(line)
|
102 |
+
|
103 |
+
if "sequenceFlows" in process:
|
104 |
+
lines.append("\n➡️ Sequence Flows:")
|
105 |
+
for f in process["sequenceFlows"]:
|
106 |
+
src = f.get("sourceTask") or f.get("sourceEvent") or "Unknown"
|
107 |
+
tgt = f.get("targetTask") or f.get("targetEvent") or "Unknown"
|
108 |
+
condition = f.get("condition", "")
|
109 |
+
line = f" - {src} → {tgt}"
|
110 |
+
if condition:
|
111 |
+
line += f" [Condition: {condition}]"
|
112 |
+
lines.append(line)
|
113 |
+
|
114 |
+
if "connections" in process:
|
115 |
+
lines.append("\n🔗 Connections:")
|
116 |
+
for c in process["connections"]:
|
117 |
+
src = c.get("sourceTask") or c.get("sourceEvent") or "Unknown"
|
118 |
+
tgt = c.get("targetTask") or c.get("targetEvent") or "Unknown"
|
119 |
+
condition = c.get("condition", "")
|
120 |
+
line = f" - {src} → {tgt}"
|
121 |
+
if condition:
|
122 |
+
line += f" [Condition: {condition}]"
|
123 |
+
lines.append(line)
|
124 |
+
|
125 |
+
if "relationships" in process:
|
126 |
+
lines.append("\n🔗 Relationships:")
|
127 |
+
for r in process["relationships"]:
|
128 |
+
source = r.get("source")
|
129 |
+
target = r.get("target")
|
130 |
+
src = source.get("ref", "Unknown") if isinstance(source, dict) else str(source)
|
131 |
+
tgt = target.get("ref", "Unknown") if isinstance(target, dict) else str(target)
|
132 |
+
desc = r.get("description", "")
|
133 |
+
line = f" - {src} → {tgt}"
|
134 |
+
if desc:
|
135 |
+
line += f" | {desc}"
|
136 |
+
lines.append(line)
|
137 |
+
|
138 |
+
return "\n".join(lines).strip()
|
139 |
+
|
140 |
+
# === Main Inference Handler
|
141 |
+
def process_single_image(model_name, image_file, api_key_file=None):
|
142 |
+
runner = load_model_runner(model_name)
|
143 |
+
image = Image.open(image_file.name).convert("RGB")
|
144 |
+
|
145 |
+
api_key = None
|
146 |
+
if model_name == "GPT-4o" and api_key_file is not None:
|
147 |
+
try:
|
148 |
+
api_key = open(api_key_file.name, "r").read().strip()
|
149 |
+
except Exception as e:
|
150 |
+
return image, "(API key file could not be read)", f"(Error: {e})"
|
151 |
+
|
152 |
+
if model_name == "GPT-4o":
|
153 |
+
result = runner(image, api_key=api_key)
|
154 |
+
else:
|
155 |
+
result = runner(image)
|
156 |
+
|
157 |
+
parsed_json = result.get("json")
|
158 |
+
raw_text = result.get("raw", "")
|
159 |
+
|
160 |
+
if parsed_json:
|
161 |
+
json_output = format_result_json(parsed_json)
|
162 |
+
pretty_output = format_pretty_view(parsed_json)
|
163 |
+
else:
|
164 |
+
json_output = "(No valid JSON extracted)"
|
165 |
+
pretty_output = "(No structured content extracted)\n\n⚠️ Raw Model Output:\n" + raw_text
|
166 |
+
|
167 |
+
return image, json_output, pretty_output
|
168 |
+
|
169 |
+
# === Gradio Interface (Simple)
|
170 |
+
iface = gr.Interface(
|
171 |
+
fn=process_single_image,
|
172 |
+
inputs=[
|
173 |
+
gr.Dropdown(choices=list(MODEL_MAP.keys()), label="Select Vision Model"),
|
174 |
+
gr.File(file_types=["image"], label="Upload a BPMN Image"),
|
175 |
+
gr.File(file_types=[".txt"], label="🔐 Upload OpenAI API Key File (only for GPT-4o)")
|
176 |
+
],
|
177 |
+
outputs=[
|
178 |
+
gr.Image(label="Input Image"),
|
179 |
+
gr.Textbox(label="Raw JSON Output (Technical)", lines=20),
|
180 |
+
gr.Textbox(label="Prettified View (User-Friendly)", lines=25)
|
181 |
+
],
|
182 |
+
title="🖼️ Vision Model Extractor - JSON + Pretty View",
|
183 |
+
description="Upload a BPMN image and select a vision model to extract structured output. API key file is required only for GPT-4o.",
|
184 |
+
allow_flagging="never"
|
185 |
+
)
|
186 |
+
|
187 |
+
iface.launch(share=True)
|
models/aya_vision.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import re
|
4 |
+
from PIL import Image
|
5 |
+
import torch
|
6 |
+
from transformers import AutoProcessor, AutoModelForImageTextToText
|
7 |
+
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
|
8 |
+
|
9 |
+
# Set Hugging Face Token
|
10 |
+
hf_token = os.getenv("HF_TOKEN")
|
11 |
+
|
12 |
+
# Initialize Aya Vision Model
|
13 |
+
model_id = "CohereForAI/aya-vision-8b"
|
14 |
+
processor = AutoProcessor.from_pretrained(model_id)
|
15 |
+
model = AutoModelForImageTextToText.from_pretrained(
|
16 |
+
model_id, device_map="auto", torch_dtype=torch.float16
|
17 |
+
)
|
18 |
+
|
19 |
+
# Initialize Pix2Struct OCR Model
|
20 |
+
ocr_processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base")
|
21 |
+
ocr_model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base")
|
22 |
+
|
23 |
+
# Load prompt
|
24 |
+
def load_prompt():
|
25 |
+
with open("/content/vision_model_space/vision_model_space_new/prompts/prompt.txt", "r", encoding="utf-8") as f:
|
26 |
+
return f.read()
|
27 |
+
|
28 |
+
# Try extracting JSON from model output
|
29 |
+
def try_extract_json(text):
|
30 |
+
if not text or not text.strip():
|
31 |
+
return None
|
32 |
+
try:
|
33 |
+
return json.loads(text)
|
34 |
+
except json.JSONDecodeError:
|
35 |
+
# Try extracting JSON substring by brace balancing
|
36 |
+
start = text.find('{')
|
37 |
+
if start == -1:
|
38 |
+
return None
|
39 |
+
|
40 |
+
brace_count = 0
|
41 |
+
json_candidate = ''
|
42 |
+
for i in range(start, len(text)):
|
43 |
+
char = text[i]
|
44 |
+
if char == '{':
|
45 |
+
brace_count += 1
|
46 |
+
elif char == '}':
|
47 |
+
brace_count -= 1
|
48 |
+
json_candidate += char
|
49 |
+
if brace_count == 0:
|
50 |
+
break
|
51 |
+
|
52 |
+
try:
|
53 |
+
return json.loads(json_candidate)
|
54 |
+
except json.JSONDecodeError:
|
55 |
+
return None
|
56 |
+
|
57 |
+
# Extract OCR text using Pix2Struct
|
58 |
+
def extract_all_text_pix2struct(image: Image.Image):
|
59 |
+
inputs = ocr_processor(images=image, return_tensors="pt")
|
60 |
+
predictions = ocr_model.generate(**inputs, max_new_tokens=512)
|
61 |
+
output_text = ocr_processor.decode(predictions[0], skip_special_tokens=True)
|
62 |
+
return output_text.strip()
|
63 |
+
|
64 |
+
# Assign event/gateway names from OCR text
|
65 |
+
def assign_event_gateway_names_from_ocr(json_data: dict, ocr_text: str):
|
66 |
+
if not ocr_text or not json_data:
|
67 |
+
return json_data
|
68 |
+
|
69 |
+
lines = [line.strip() for line in ocr_text.split('\n') if line.strip()]
|
70 |
+
|
71 |
+
def assign_best_guess(obj):
|
72 |
+
if not obj.get("name") or obj["name"].strip() == "":
|
73 |
+
obj["name"] = "(label unknown)"
|
74 |
+
|
75 |
+
for evt in json_data.get("events", []):
|
76 |
+
assign_best_guess(evt)
|
77 |
+
|
78 |
+
for gw in json_data.get("gateways", []):
|
79 |
+
assign_best_guess(gw)
|
80 |
+
|
81 |
+
return json_data
|
82 |
+
|
83 |
+
# Run Aya model on image
|
84 |
+
def run_model(image: Image.Image):
|
85 |
+
prompt = load_prompt()
|
86 |
+
|
87 |
+
messages = [
|
88 |
+
{
|
89 |
+
"role": "user",
|
90 |
+
"content": [
|
91 |
+
{"type": "image", "image": image},
|
92 |
+
{"type": "text", "text": prompt}
|
93 |
+
]
|
94 |
+
}
|
95 |
+
]
|
96 |
+
|
97 |
+
inputs = processor.apply_chat_template(
|
98 |
+
messages,
|
99 |
+
padding=True,
|
100 |
+
add_generation_prompt=True,
|
101 |
+
tokenize=True,
|
102 |
+
return_dict=True,
|
103 |
+
return_tensors="pt"
|
104 |
+
).to(model.device)
|
105 |
+
|
106 |
+
gen_tokens = model.generate(
|
107 |
+
**inputs,
|
108 |
+
max_new_tokens=5000,
|
109 |
+
do_sample=True,
|
110 |
+
temperature=0.3,
|
111 |
+
)
|
112 |
+
|
113 |
+
output_text = processor.tokenizer.decode(
|
114 |
+
gen_tokens[0][inputs.input_ids.shape[1]:],
|
115 |
+
skip_special_tokens=True
|
116 |
+
)
|
117 |
+
|
118 |
+
parsed_json = try_extract_json(output_text)
|
119 |
+
|
120 |
+
# Apply OCR post-processing
|
121 |
+
ocr_text = extract_all_text_pix2struct(image)
|
122 |
+
parsed_json = assign_event_gateway_names_from_ocr(parsed_json, ocr_text)
|
123 |
+
|
124 |
+
# Return both parsed and raw
|
125 |
+
return {
|
126 |
+
"json": parsed_json,
|
127 |
+
"raw": output_text
|
128 |
+
}
|
models/gpt4o.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# gpt4o_pix2struct_ocr.py
|
2 |
+
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
import base64
|
6 |
+
from PIL import Image
|
7 |
+
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
import openai
|
11 |
+
|
12 |
+
model = "gpt-4o"
|
13 |
+
|
14 |
+
# Load Pix2Struct model + processor (vision-language OCR)
|
15 |
+
processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base")
|
16 |
+
pix2struct_model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base")
|
17 |
+
|
18 |
+
|
19 |
+
def load_prompt(prompt_file="/content/vision_model_space/vision_model_space_new/prompts/prompt.txt"):
|
20 |
+
with open(prompt_file, "r", encoding="utf-8") as f:
|
21 |
+
return f.read().strip()
|
22 |
+
|
23 |
+
|
24 |
+
def try_extract_json(text):
|
25 |
+
try:
|
26 |
+
return json.loads(text)
|
27 |
+
except json.JSONDecodeError:
|
28 |
+
start = text.find('{')
|
29 |
+
if start == -1:
|
30 |
+
return None
|
31 |
+
brace_count = 0
|
32 |
+
json_candidate = ''
|
33 |
+
for i in range(start, len(text)):
|
34 |
+
if text[i] == '{':
|
35 |
+
brace_count += 1
|
36 |
+
elif text[i] == '}':
|
37 |
+
brace_count -= 1
|
38 |
+
json_candidate += text[i]
|
39 |
+
if brace_count == 0 and json_candidate.strip():
|
40 |
+
break
|
41 |
+
try:
|
42 |
+
return json.loads(json_candidate)
|
43 |
+
except json.JSONDecodeError:
|
44 |
+
return None
|
45 |
+
|
46 |
+
|
47 |
+
def encode_image_base64(image: Image.Image):
|
48 |
+
from io import BytesIO
|
49 |
+
buffer = BytesIO()
|
50 |
+
image.save(buffer, format="JPEG")
|
51 |
+
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
52 |
+
|
53 |
+
|
54 |
+
def extract_all_text_pix2struct(image: Image.Image):
|
55 |
+
inputs = processor(images=image, return_tensors="pt")
|
56 |
+
predictions = pix2struct_model.generate(**inputs, max_new_tokens=512)
|
57 |
+
output_text = processor.decode(predictions[0], skip_special_tokens=True)
|
58 |
+
return output_text.strip()
|
59 |
+
|
60 |
+
|
61 |
+
# Optional: assign best-matching label from full extracted text using proximity (simplified version)
|
62 |
+
def assign_event_gateway_names_from_ocr(image: Image.Image, json_data, ocr_text):
|
63 |
+
if not ocr_text:
|
64 |
+
return json_data
|
65 |
+
|
66 |
+
# You could use NLP matching or regex in complex cases
|
67 |
+
words = ocr_text.split()
|
68 |
+
|
69 |
+
def guess_name_fallback(obj):
|
70 |
+
if not obj.get("name") or obj["name"].strip() == "":
|
71 |
+
obj["name"] = "(label unknown)" # fallback if matching logic isn't yet implemented
|
72 |
+
|
73 |
+
for evt in json_data.get("events", []):
|
74 |
+
guess_name_fallback(evt)
|
75 |
+
|
76 |
+
for gw in json_data.get("gateways", []):
|
77 |
+
guess_name_fallback(gw)
|
78 |
+
|
79 |
+
return json_data
|
80 |
+
|
81 |
+
|
82 |
+
def run_model(image: Image.Image, api_key: str = None):
|
83 |
+
prompt_text = load_prompt()
|
84 |
+
encoded_image = encode_image_base64(image)
|
85 |
+
|
86 |
+
if not api_key:
|
87 |
+
return {"json": None, "raw": "⚠️ API key is missing. Please provide your OpenAI API key."}
|
88 |
+
|
89 |
+
client = openai.OpenAI(api_key=api_key)
|
90 |
+
response = client.chat.completions.create(
|
91 |
+
model=model,
|
92 |
+
messages=[
|
93 |
+
{
|
94 |
+
"role": "user",
|
95 |
+
"content": [
|
96 |
+
{"type": "text", "text": prompt_text},
|
97 |
+
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}}
|
98 |
+
]
|
99 |
+
}
|
100 |
+
],
|
101 |
+
max_tokens=5000
|
102 |
+
)
|
103 |
+
|
104 |
+
output_text = response.choices[0].message.content.strip()
|
105 |
+
parsed_json = try_extract_json(output_text)
|
106 |
+
|
107 |
+
# Vision-language OCR assist step (Pix2Struct)
|
108 |
+
full_ocr_text = extract_all_text_pix2struct(image)
|
109 |
+
parsed_json = assign_event_gateway_names_from_ocr(image, parsed_json, full_ocr_text)
|
110 |
+
|
111 |
+
return {"json": parsed_json, "raw": output_text}
|
models/pixtral.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import base64
|
4 |
+
from PIL import Image
|
5 |
+
from vllm import LLM
|
6 |
+
from vllm.sampling_params import SamplingParams
|
7 |
+
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
|
8 |
+
|
9 |
+
# Optional: Replace with your Hugging Face token or use environment variable
|
10 |
+
hf_token = os.getenv("HF_TOKEN")
|
11 |
+
Image.MAX_IMAGE_PIXELS = None
|
12 |
+
|
13 |
+
# Initialize Pixtral model
|
14 |
+
model_name = "mistralai/Pixtral-12B-2409"
|
15 |
+
sampling_params = SamplingParams(max_tokens=5000)
|
16 |
+
llm = LLM(model=model_name, tokenizer_mode="mistral", dtype="bfloat16", max_model_len=30000)
|
17 |
+
|
18 |
+
# Initialize Pix2Struct OCR model
|
19 |
+
ocr_processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base")
|
20 |
+
ocr_model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base")
|
21 |
+
|
22 |
+
# Load prompt from file
|
23 |
+
def load_prompt():
|
24 |
+
with open("prompts/prompt.txt", "r", encoding="utf-8") as f:
|
25 |
+
return f.read()
|
26 |
+
|
27 |
+
# Extract structured JSON from text
|
28 |
+
def try_extract_json(text):
|
29 |
+
if not text or not text.strip():
|
30 |
+
return None
|
31 |
+
try:
|
32 |
+
return json.loads(text)
|
33 |
+
except json.JSONDecodeError:
|
34 |
+
start = text.find('{')
|
35 |
+
if start == -1:
|
36 |
+
return None
|
37 |
+
|
38 |
+
brace_count = 0
|
39 |
+
json_candidate = ''
|
40 |
+
for i in range(start, len(text)):
|
41 |
+
if text[i] == '{':
|
42 |
+
brace_count += 1
|
43 |
+
elif text[i] == '}':
|
44 |
+
brace_count -= 1
|
45 |
+
json_candidate += text[i]
|
46 |
+
if brace_count == 0:
|
47 |
+
break
|
48 |
+
try:
|
49 |
+
return json.loads(json_candidate)
|
50 |
+
except json.JSONDecodeError:
|
51 |
+
return None
|
52 |
+
|
53 |
+
# Base64 encode image
|
54 |
+
def encode_image_as_base64(pil_image):
|
55 |
+
from io import BytesIO
|
56 |
+
buffer = BytesIO()
|
57 |
+
pil_image.save(buffer, format="JPEG")
|
58 |
+
encoded = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
59 |
+
return encoded
|
60 |
+
|
61 |
+
# Extract OCR text using Pix2Struct
|
62 |
+
def extract_all_text_pix2struct(image: Image.Image):
|
63 |
+
inputs = ocr_processor(images=image, return_tensors="pt")
|
64 |
+
predictions = ocr_model.generate(**inputs, max_new_tokens=512)
|
65 |
+
output_text = ocr_processor.decode(predictions[0], skip_special_tokens=True)
|
66 |
+
return output_text.strip()
|
67 |
+
|
68 |
+
# Assign event/gateway names from OCR text
|
69 |
+
def assign_event_gateway_names_from_ocr(json_data: dict, ocr_text: str):
|
70 |
+
if not ocr_text or not json_data:
|
71 |
+
return json_data
|
72 |
+
|
73 |
+
lines = [line.strip() for line in ocr_text.split('\n') if line.strip()]
|
74 |
+
|
75 |
+
def assign_best_guess(obj):
|
76 |
+
if not obj.get("name") or obj["name"].strip() == "":
|
77 |
+
obj["name"] = "(label unknown)"
|
78 |
+
|
79 |
+
for evt in json_data.get("events", []):
|
80 |
+
assign_best_guess(evt)
|
81 |
+
|
82 |
+
for gw in json_data.get("gateways", []):
|
83 |
+
assign_best_guess(gw)
|
84 |
+
|
85 |
+
return json_data
|
86 |
+
|
87 |
+
# Run model
|
88 |
+
def run_model(image: Image.Image):
|
89 |
+
prompt = load_prompt()
|
90 |
+
encoded_image = encode_image_as_base64(image)
|
91 |
+
|
92 |
+
messages = [
|
93 |
+
{
|
94 |
+
"role": "user",
|
95 |
+
"content": [
|
96 |
+
{"type": "text", "text": prompt},
|
97 |
+
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}}
|
98 |
+
]
|
99 |
+
}
|
100 |
+
]
|
101 |
+
|
102 |
+
outputs = llm.chat(messages, sampling_params=sampling_params)
|
103 |
+
raw_output = outputs[0].outputs[0].text
|
104 |
+
parsed_json = try_extract_json(raw_output)
|
105 |
+
|
106 |
+
# Apply OCR post-processing
|
107 |
+
ocr_text = extract_all_text_pix2struct(image)
|
108 |
+
parsed_json = assign_event_gateway_names_from_ocr(parsed_json, ocr_text)
|
109 |
+
|
110 |
+
return {
|
111 |
+
"json": parsed_json,
|
112 |
+
"raw": raw_output
|
113 |
+
}
|
models/qwen.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
from PIL import Image
|
4 |
+
import torch
|
5 |
+
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
|
6 |
+
from qwen_vl_utils import process_vision_info
|
7 |
+
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
|
8 |
+
|
9 |
+
# Initialize Qwen2.5-VL model
|
10 |
+
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
11 |
+
"Qwen/Qwen2.5-VL-7B-Instruct",
|
12 |
+
torch_dtype=torch.bfloat16,
|
13 |
+
device_map="cuda",
|
14 |
+
attn_implementation="flash_attention_2"
|
15 |
+
)
|
16 |
+
|
17 |
+
min_pixels = 256 * 28 * 28
|
18 |
+
max_pixels = 1080 * 28 * 28
|
19 |
+
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
|
20 |
+
|
21 |
+
# Initialize Pix2Struct OCR model
|
22 |
+
ocr_processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base")
|
23 |
+
ocr_model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base")
|
24 |
+
|
25 |
+
# Load prompt
|
26 |
+
def load_prompt():
|
27 |
+
with open("prompts/prompt.txt", "r") as f:
|
28 |
+
return f.read()
|
29 |
+
|
30 |
+
# Try extracting JSON from text
|
31 |
+
def try_extract_json(text):
|
32 |
+
try:
|
33 |
+
return json.loads(text)
|
34 |
+
except json.JSONDecodeError:
|
35 |
+
start = text.find('{')
|
36 |
+
if start == -1:
|
37 |
+
return text
|
38 |
+
brace_count = 0
|
39 |
+
json_candidate = ''
|
40 |
+
for i in range(start, len(text)):
|
41 |
+
if text[i] == '{':
|
42 |
+
brace_count += 1
|
43 |
+
elif text[i] == '}':
|
44 |
+
brace_count -= 1
|
45 |
+
json_candidate += text[i]
|
46 |
+
if brace_count == 0:
|
47 |
+
break
|
48 |
+
try:
|
49 |
+
return json.loads(json_candidate)
|
50 |
+
except json.JSONDecodeError:
|
51 |
+
return text
|
52 |
+
|
53 |
+
# Extract OCR text using Pix2Struct
|
54 |
+
def extract_all_text_pix2struct(image: Image.Image):
|
55 |
+
inputs = ocr_processor(images=image, return_tensors="pt")
|
56 |
+
predictions = ocr_model.generate(**inputs, max_new_tokens=512)
|
57 |
+
output_text = ocr_processor.decode(predictions[0], skip_special_tokens=True)
|
58 |
+
return output_text.strip()
|
59 |
+
|
60 |
+
# Assign event/gateway names from OCR text
|
61 |
+
def assign_event_gateway_names_from_ocr(json_data: dict, ocr_text: str):
|
62 |
+
if not ocr_text or not json_data:
|
63 |
+
return json_data
|
64 |
+
|
65 |
+
lines = [line.strip() for line in ocr_text.split('\n') if line.strip()]
|
66 |
+
|
67 |
+
def assign_best_guess(obj):
|
68 |
+
if not obj.get("name") or obj["name"].strip() == "":
|
69 |
+
obj["name"] = "(label unknown)"
|
70 |
+
|
71 |
+
for evt in json_data.get("events", []):
|
72 |
+
assign_best_guess(evt)
|
73 |
+
|
74 |
+
for gw in json_data.get("gateways", []):
|
75 |
+
assign_best_guess(gw)
|
76 |
+
|
77 |
+
return json_data
|
78 |
+
|
79 |
+
# Run model
|
80 |
+
def run_model(image: Image.Image):
|
81 |
+
prompt = load_prompt()
|
82 |
+
messages = [
|
83 |
+
{
|
84 |
+
"role": "user",
|
85 |
+
"content": [
|
86 |
+
{"type": "image", "image": image},
|
87 |
+
{"type": "text", "text": prompt}
|
88 |
+
]
|
89 |
+
}
|
90 |
+
]
|
91 |
+
|
92 |
+
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
93 |
+
image_inputs, video_inputs = process_vision_info(messages)
|
94 |
+
|
95 |
+
inputs = processor(
|
96 |
+
text=[text],
|
97 |
+
images=image_inputs,
|
98 |
+
videos=video_inputs,
|
99 |
+
padding=True,
|
100 |
+
return_tensors="pt"
|
101 |
+
).to("cuda")
|
102 |
+
|
103 |
+
generated_ids = model.generate(**inputs, max_new_tokens=5000)
|
104 |
+
generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
|
105 |
+
|
106 |
+
output_text = processor.batch_decode(
|
107 |
+
generated_ids_trimmed,
|
108 |
+
skip_special_tokens=True,
|
109 |
+
clean_up_tokenization_spaces=False
|
110 |
+
)[0]
|
111 |
+
|
112 |
+
parsed_json = try_extract_json(output_text)
|
113 |
+
|
114 |
+
# Apply OCR post-processing
|
115 |
+
ocr_text = extract_all_text_pix2struct(image)
|
116 |
+
parsed_json = assign_event_gateway_names_from_ocr(parsed_json, ocr_text)
|
117 |
+
|
118 |
+
return {
|
119 |
+
"json": parsed_json,
|
120 |
+
"raw": output_text
|
121 |
+
}
|
prompts/prompt.txt
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
You are an advanced BPMN diagram analysis engine specialized in extracting structured data from visual BPMN diagrams.
|
2 |
+
|
3 |
+
You will be given an image containing a BPMN diagram. Your job is to identify, categorize, and extract all visible BPMN components based strictly on their visual appearance and layout.
|
4 |
+
|
5 |
+
[IMPORTANT] Instructions:
|
6 |
+
- Output must be a single structured JSON object.
|
7 |
+
- Do NOT include any explanation or extra text — only the JSON output.
|
8 |
+
- Do NOT infer or assume missing details. Only include elements that are visually present and identifiable.
|
9 |
+
- If a label is unreadable or not clearly visible, exclude that element.
|
10 |
+
- Detect and extract all BPMN components strictly based on visual appearance.
|
11 |
+
- Include bounding boxes for all components.
|
12 |
+
- Identify all text labels near tasks, events, and gateways and attach them accordingly.
|
13 |
+
- Detect arrow types: solid arrows = Sequence Flows, dashed arrows = Message Flows.
|
14 |
+
- Do not skip or infer missing elements.
|
15 |
+
- If any label is attached to an event (e.g., "Money received", "Customer disagreed"), include it in the JSON output.
|
16 |
+
|
17 |
+
|
18 |
+
[VISUAL ELEMENTS TO DETECT AND EXTRACT]
|
19 |
+
- Pools: Rectangles enclosing process areas.
|
20 |
+
- Lanes: Subdivisions within pools.
|
21 |
+
- Tasks: Rounded rectangles representing activities.
|
22 |
+
- Events: Circles (Start = thin border, Intermediate = double border, End = thick border).
|
23 |
+
- Gateways: Diamond shapes (Exclusive, Parallel, Inclusive).
|
24 |
+
- Sequence Flows: Solid arrows showing progression.
|
25 |
+
- Message Flows: Dashed arrows indicating communication between participants.
|
26 |
+
- Data Objects: Document symbols representing data or content.
|
27 |
+
- Data Stores: Cylindrical storage symbols.
|
28 |
+
- Other BPMN Artifacts: Any annotations, groups, or visual elements explicitly present.
|
29 |
+
|
30 |
+
[IMPORTANT]
|
31 |
+
For every event and gateway, detect and include the text label located near or next to the symbol,
|
32 |
+
and assign it as the name field. Do not skip event/gateway labels even if the labels are outside the shape.
|
33 |
+
Treat these labels as part of the component. If a name is not detected, return an empty string for name, do not omit the field.
|
34 |
+
|
35 |
+
[EXAMPLE JSON OUTPUT FORMAT]
|
36 |
+
Please follow the structure below exactly. Your output must start with a single JSON object as shown:
|
37 |
+
|
38 |
+
{{
|
39 |
+
"pools": [
|
40 |
+
{{
|
41 |
+
"id": "pool_1",
|
42 |
+
"name": "Customer Process",
|
43 |
+
"bounding_box": {{ "x": 10, "y": 20, "width": 1200, "height": 700 }}
|
44 |
+
}}
|
45 |
+
],
|
46 |
+
"lanes": [
|
47 |
+
{{
|
48 |
+
"id": "lane_1",
|
49 |
+
"name": "Customer",
|
50 |
+
"bounding_box": {{ "x": 50, "y": 30, "width": 1150, "height": 200 }},
|
51 |
+
"parent_pool": "pool_1"
|
52 |
+
}}
|
53 |
+
],
|
54 |
+
"tasks": [
|
55 |
+
{{
|
56 |
+
"id": "task_1",
|
57 |
+
"name": "Submit Request",
|
58 |
+
"lane": "Customer",
|
59 |
+
"bounding_box": {{ "x": 100, "y": 50, "width": 150, "height": 80 }},
|
60 |
+
"incoming": ["start_event_1"],
|
61 |
+
"outgoing": ["task_2"]
|
62 |
+
}}
|
63 |
+
],
|
64 |
+
"events": [
|
65 |
+
{{
|
66 |
+
"id": "start_event_1",
|
67 |
+
"type": "StartEvent",
|
68 |
+
"name": "Start Process",
|
69 |
+
"bounding_box": {{ "x": 50, "y": 75, "width": 40, "height": 40 }},
|
70 |
+
"outgoing": ["task_1"]
|
71 |
+
}},
|
72 |
+
{{
|
73 |
+
"id": "end_event_1",
|
74 |
+
"type": "EndEvent",
|
75 |
+
"name": "Process Complete",
|
76 |
+
"bounding_box": {{ "x": 1200, "y": 600, "width": 40, "height": 40 }},
|
77 |
+
"incoming": ["task_5"]
|
78 |
+
}},
|
79 |
+
{{
|
80 |
+
"id": "intermediate_event_1",
|
81 |
+
"type": "IntermediateEvent",
|
82 |
+
"name": "Customer Disagreed",
|
83 |
+
"bounding_box": { "x": 600, "y": 200, "width": 40, "height": 40 },
|
84 |
+
"incoming": ["task_3"],
|
85 |
+
"outgoing": ["task_6"]
|
86 |
+
}}
|
87 |
+
|
88 |
+
],
|
89 |
+
"gateways": [
|
90 |
+
{{
|
91 |
+
"id": "gateway_1",
|
92 |
+
"type": "ExclusiveGateway",
|
93 |
+
"name": "Request Valid?",
|
94 |
+
"bounding_box": {{ "x": 600, "y": 150, "width": 50, "height": 50 }},
|
95 |
+
"incoming": ["task_3"],
|
96 |
+
"outgoing": ["task_4", "task_5"]
|
97 |
+
}},
|
98 |
+
{{
|
99 |
+
"id": "gateway_2",
|
100 |
+
"type": "ParallelGateway",
|
101 |
+
"name": "Split Tasks",
|
102 |
+
"bounding_box": { "x": 700, "y": 100, "width": 50, "height": 50 },
|
103 |
+
"incoming": ["task_4"],
|
104 |
+
"outgoing": ["task_5", "task_6"]
|
105 |
+
}}
|
106 |
+
],
|
107 |
+
"datastores": [
|
108 |
+
{{
|
109 |
+
"id": "data_1",
|
110 |
+
"name": "Customer Database",
|
111 |
+
"bounding_box": {{ "x": 800, "y": 400, "width": 100, "height": 100 }},
|
112 |
+
"incoming": ["task_6"],
|
113 |
+
"outgoing": ["task_7"]
|
114 |
+
}}
|
115 |
+
],
|
116 |
+
"flows": [
|
117 |
+
{{
|
118 |
+
"id": "flow_1",
|
119 |
+
"type": "SequenceFlow",
|
120 |
+
"name": "Submit Review",
|
121 |
+
"source": "start_event_1",
|
122 |
+
"target": "task_1",
|
123 |
+
"waypoints": [
|
124 |
+
{{ "x": 70, "y": 100 }},
|
125 |
+
{{ "x": 100, "y": 100 }}
|
126 |
+
]
|
127 |
+
}},
|
128 |
+
{{
|
129 |
+
"id": "flow_2",
|
130 |
+
"type": "MessageFlow",
|
131 |
+
"name": "Disagreement Notification",
|
132 |
+
"source": "task_4",
|
133 |
+
"target": "external_system",
|
134 |
+
"waypoints": [
|
135 |
+
{{ "x": 400, "y": 200 }},
|
136 |
+
{{ "x": 500, "y": 300 }}
|
137 |
+
]
|
138 |
+
}}
|
139 |
+
]
|
140 |
+
}}
|
141 |
+
|
142 |
+
[FINAL INSTRUCTIONS]
|
143 |
+
- Return ONLY the final JSON structure without any introductory or explanatory text.
|
144 |
+
- Make sure the output is valid JSON and complete.
|
145 |
+
- Do NOT include placeholder elements or guess missing labels.
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
git+https://github.com/huggingface/transformers
|
3 |
+
pillow
|
4 |
+
gradio
|
5 |
+
qwen-vl-utils
|
6 |
+
flash_attn
|
7 |
+
vllm
|
8 |
+
mistral_common
|
9 |
+
paddleocr
|
10 |
+
paddlepaddle
|