ARCQUB commited on
Commit
6c0c37c
·
verified ·
1 Parent(s): 4ca761a

Upload folder using huggingface_hub

Browse files
Files changed (7) hide show
  1. app.py +187 -0
  2. models/aya_vision.py +128 -0
  3. models/gpt4o.py +111 -0
  4. models/pixtral.py +113 -0
  5. models/qwen.py +121 -0
  6. prompts/prompt.txt +145 -0
  7. requirements.txt +10 -0
app.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import importlib
3
+ from PIL import Image
4
+ import json
5
+
6
+ # === Model Mapping ===
7
+ MODEL_MAP = {
8
+ "Qwen": "models.qwen",
9
+ "Pixtral": "models.pixtral",
10
+ "Aya Vision": "models.aya_vision",
11
+ "GPT-4o": "models.gpt4o"
12
+ }
13
+
14
+ # === Load Model
15
+ def load_model_runner(model_name):
16
+ module = importlib.import_module(MODEL_MAP[model_name])
17
+ return module.run_model
18
+
19
+ # === Format Raw JSON Output
20
+ def format_result_json(output):
21
+ if isinstance(output, dict):
22
+ return json.dumps(output, indent=2)
23
+ else:
24
+ return str(output).strip()
25
+
26
+ # === Prettified Output View
27
+ def format_pretty_view(output):
28
+ if not isinstance(output, dict):
29
+ return "No structured JSON found.\n\n" + str(output)
30
+
31
+ lines = []
32
+ process = output.get("process", output)
33
+
34
+ if "name" in process:
35
+ lines.append(f"📦 Process Name: {process['name']}\n")
36
+
37
+ if "startEvent" in process:
38
+ start = process["startEvent"]
39
+ name = start.get("name", "")
40
+ type_ = start.get("type", "")
41
+ desc = start.get("description", "")
42
+ line = f"▶️ Start: {name}"
43
+ if type_:
44
+ line += f" ({type_})"
45
+ if desc:
46
+ line += f" - {desc}"
47
+ lines.append(line)
48
+
49
+ if "endEvent" in process:
50
+ end = process["endEvent"]
51
+ name = end.get("name", "")
52
+ type_ = end.get("type", "")
53
+ desc = end.get("description", "")
54
+ line = f"⏹ End: {name}"
55
+ if type_:
56
+ line += f" ({type_})"
57
+ if desc:
58
+ line += f" - {desc}"
59
+ lines.append(line)
60
+
61
+ if "tasks" in process:
62
+ lines.append("\n🔹 Tasks:")
63
+ for t in process["tasks"]:
64
+ name = t.get("name", "")
65
+ type_ = t.get("type", "")
66
+ desc = t.get("description", "")
67
+ line = f" - {name}"
68
+ if type_:
69
+ line += f" ({type_})"
70
+ if desc:
71
+ line += f" - {desc}"
72
+ lines.append(line)
73
+
74
+ if "events" in process:
75
+ lines.append("\n📨 Events:")
76
+ for e in process["events"]:
77
+ name = e.get("name", "")
78
+ type_ = e.get("type", "")
79
+ desc = e.get("description", "")
80
+ line = f" - {name}"
81
+ if type_:
82
+ line += f" ({type_})"
83
+ if desc:
84
+ line += f" - {desc}"
85
+ lines.append(line)
86
+
87
+ if "gateways" in process:
88
+ lines.append("\n🔀 Gateways:")
89
+ for g in process["gateways"]:
90
+ name = g.get("name", "")
91
+ type_ = g.get("type", "")
92
+ label = g.get("label", "") # some outputs may use 'label'
93
+ desc = g.get("description", "")
94
+ line = f" - {name}"
95
+ if type_:
96
+ line += f" ({type_})"
97
+ if label:
98
+ line += f" | Label: {label}"
99
+ if desc:
100
+ line += f" - {desc}"
101
+ lines.append(line)
102
+
103
+ if "sequenceFlows" in process:
104
+ lines.append("\n➡️ Sequence Flows:")
105
+ for f in process["sequenceFlows"]:
106
+ src = f.get("sourceTask") or f.get("sourceEvent") or "Unknown"
107
+ tgt = f.get("targetTask") or f.get("targetEvent") or "Unknown"
108
+ condition = f.get("condition", "")
109
+ line = f" - {src} → {tgt}"
110
+ if condition:
111
+ line += f" [Condition: {condition}]"
112
+ lines.append(line)
113
+
114
+ if "connections" in process:
115
+ lines.append("\n🔗 Connections:")
116
+ for c in process["connections"]:
117
+ src = c.get("sourceTask") or c.get("sourceEvent") or "Unknown"
118
+ tgt = c.get("targetTask") or c.get("targetEvent") or "Unknown"
119
+ condition = c.get("condition", "")
120
+ line = f" - {src} → {tgt}"
121
+ if condition:
122
+ line += f" [Condition: {condition}]"
123
+ lines.append(line)
124
+
125
+ if "relationships" in process:
126
+ lines.append("\n🔗 Relationships:")
127
+ for r in process["relationships"]:
128
+ source = r.get("source")
129
+ target = r.get("target")
130
+ src = source.get("ref", "Unknown") if isinstance(source, dict) else str(source)
131
+ tgt = target.get("ref", "Unknown") if isinstance(target, dict) else str(target)
132
+ desc = r.get("description", "")
133
+ line = f" - {src} → {tgt}"
134
+ if desc:
135
+ line += f" | {desc}"
136
+ lines.append(line)
137
+
138
+ return "\n".join(lines).strip()
139
+
140
+ # === Main Inference Handler
141
+ def process_single_image(model_name, image_file, api_key_file=None):
142
+ runner = load_model_runner(model_name)
143
+ image = Image.open(image_file.name).convert("RGB")
144
+
145
+ api_key = None
146
+ if model_name == "GPT-4o" and api_key_file is not None:
147
+ try:
148
+ api_key = open(api_key_file.name, "r").read().strip()
149
+ except Exception as e:
150
+ return image, "(API key file could not be read)", f"(Error: {e})"
151
+
152
+ if model_name == "GPT-4o":
153
+ result = runner(image, api_key=api_key)
154
+ else:
155
+ result = runner(image)
156
+
157
+ parsed_json = result.get("json")
158
+ raw_text = result.get("raw", "")
159
+
160
+ if parsed_json:
161
+ json_output = format_result_json(parsed_json)
162
+ pretty_output = format_pretty_view(parsed_json)
163
+ else:
164
+ json_output = "(No valid JSON extracted)"
165
+ pretty_output = "(No structured content extracted)\n\n⚠️ Raw Model Output:\n" + raw_text
166
+
167
+ return image, json_output, pretty_output
168
+
169
+ # === Gradio Interface (Simple)
170
+ iface = gr.Interface(
171
+ fn=process_single_image,
172
+ inputs=[
173
+ gr.Dropdown(choices=list(MODEL_MAP.keys()), label="Select Vision Model"),
174
+ gr.File(file_types=["image"], label="Upload a BPMN Image"),
175
+ gr.File(file_types=[".txt"], label="🔐 Upload OpenAI API Key File (only for GPT-4o)")
176
+ ],
177
+ outputs=[
178
+ gr.Image(label="Input Image"),
179
+ gr.Textbox(label="Raw JSON Output (Technical)", lines=20),
180
+ gr.Textbox(label="Prettified View (User-Friendly)", lines=25)
181
+ ],
182
+ title="🖼️ Vision Model Extractor - JSON + Pretty View",
183
+ description="Upload a BPMN image and select a vision model to extract structured output. API key file is required only for GPT-4o.",
184
+ allow_flagging="never"
185
+ )
186
+
187
+ iface.launch(share=True)
models/aya_vision.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import re
4
+ from PIL import Image
5
+ import torch
6
+ from transformers import AutoProcessor, AutoModelForImageTextToText
7
+ from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
8
+
9
+ # Set Hugging Face Token
10
+ hf_token = os.getenv("HF_TOKEN")
11
+
12
+ # Initialize Aya Vision Model
13
+ model_id = "CohereForAI/aya-vision-8b"
14
+ processor = AutoProcessor.from_pretrained(model_id)
15
+ model = AutoModelForImageTextToText.from_pretrained(
16
+ model_id, device_map="auto", torch_dtype=torch.float16
17
+ )
18
+
19
+ # Initialize Pix2Struct OCR Model
20
+ ocr_processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base")
21
+ ocr_model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base")
22
+
23
+ # Load prompt
24
+ def load_prompt():
25
+ with open("/content/vision_model_space/vision_model_space_new/prompts/prompt.txt", "r", encoding="utf-8") as f:
26
+ return f.read()
27
+
28
+ # Try extracting JSON from model output
29
+ def try_extract_json(text):
30
+ if not text or not text.strip():
31
+ return None
32
+ try:
33
+ return json.loads(text)
34
+ except json.JSONDecodeError:
35
+ # Try extracting JSON substring by brace balancing
36
+ start = text.find('{')
37
+ if start == -1:
38
+ return None
39
+
40
+ brace_count = 0
41
+ json_candidate = ''
42
+ for i in range(start, len(text)):
43
+ char = text[i]
44
+ if char == '{':
45
+ brace_count += 1
46
+ elif char == '}':
47
+ brace_count -= 1
48
+ json_candidate += char
49
+ if brace_count == 0:
50
+ break
51
+
52
+ try:
53
+ return json.loads(json_candidate)
54
+ except json.JSONDecodeError:
55
+ return None
56
+
57
+ # Extract OCR text using Pix2Struct
58
+ def extract_all_text_pix2struct(image: Image.Image):
59
+ inputs = ocr_processor(images=image, return_tensors="pt")
60
+ predictions = ocr_model.generate(**inputs, max_new_tokens=512)
61
+ output_text = ocr_processor.decode(predictions[0], skip_special_tokens=True)
62
+ return output_text.strip()
63
+
64
+ # Assign event/gateway names from OCR text
65
+ def assign_event_gateway_names_from_ocr(json_data: dict, ocr_text: str):
66
+ if not ocr_text or not json_data:
67
+ return json_data
68
+
69
+ lines = [line.strip() for line in ocr_text.split('\n') if line.strip()]
70
+
71
+ def assign_best_guess(obj):
72
+ if not obj.get("name") or obj["name"].strip() == "":
73
+ obj["name"] = "(label unknown)"
74
+
75
+ for evt in json_data.get("events", []):
76
+ assign_best_guess(evt)
77
+
78
+ for gw in json_data.get("gateways", []):
79
+ assign_best_guess(gw)
80
+
81
+ return json_data
82
+
83
+ # Run Aya model on image
84
+ def run_model(image: Image.Image):
85
+ prompt = load_prompt()
86
+
87
+ messages = [
88
+ {
89
+ "role": "user",
90
+ "content": [
91
+ {"type": "image", "image": image},
92
+ {"type": "text", "text": prompt}
93
+ ]
94
+ }
95
+ ]
96
+
97
+ inputs = processor.apply_chat_template(
98
+ messages,
99
+ padding=True,
100
+ add_generation_prompt=True,
101
+ tokenize=True,
102
+ return_dict=True,
103
+ return_tensors="pt"
104
+ ).to(model.device)
105
+
106
+ gen_tokens = model.generate(
107
+ **inputs,
108
+ max_new_tokens=5000,
109
+ do_sample=True,
110
+ temperature=0.3,
111
+ )
112
+
113
+ output_text = processor.tokenizer.decode(
114
+ gen_tokens[0][inputs.input_ids.shape[1]:],
115
+ skip_special_tokens=True
116
+ )
117
+
118
+ parsed_json = try_extract_json(output_text)
119
+
120
+ # Apply OCR post-processing
121
+ ocr_text = extract_all_text_pix2struct(image)
122
+ parsed_json = assign_event_gateway_names_from_ocr(parsed_json, ocr_text)
123
+
124
+ # Return both parsed and raw
125
+ return {
126
+ "json": parsed_json,
127
+ "raw": output_text
128
+ }
models/gpt4o.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # gpt4o_pix2struct_ocr.py
2
+
3
+ import os
4
+ import json
5
+ import base64
6
+ from PIL import Image
7
+ from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
8
+ import numpy as np
9
+
10
+ import openai
11
+
12
+ model = "gpt-4o"
13
+
14
+ # Load Pix2Struct model + processor (vision-language OCR)
15
+ processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base")
16
+ pix2struct_model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base")
17
+
18
+
19
+ def load_prompt(prompt_file="/content/vision_model_space/vision_model_space_new/prompts/prompt.txt"):
20
+ with open(prompt_file, "r", encoding="utf-8") as f:
21
+ return f.read().strip()
22
+
23
+
24
+ def try_extract_json(text):
25
+ try:
26
+ return json.loads(text)
27
+ except json.JSONDecodeError:
28
+ start = text.find('{')
29
+ if start == -1:
30
+ return None
31
+ brace_count = 0
32
+ json_candidate = ''
33
+ for i in range(start, len(text)):
34
+ if text[i] == '{':
35
+ brace_count += 1
36
+ elif text[i] == '}':
37
+ brace_count -= 1
38
+ json_candidate += text[i]
39
+ if brace_count == 0 and json_candidate.strip():
40
+ break
41
+ try:
42
+ return json.loads(json_candidate)
43
+ except json.JSONDecodeError:
44
+ return None
45
+
46
+
47
+ def encode_image_base64(image: Image.Image):
48
+ from io import BytesIO
49
+ buffer = BytesIO()
50
+ image.save(buffer, format="JPEG")
51
+ return base64.b64encode(buffer.getvalue()).decode("utf-8")
52
+
53
+
54
+ def extract_all_text_pix2struct(image: Image.Image):
55
+ inputs = processor(images=image, return_tensors="pt")
56
+ predictions = pix2struct_model.generate(**inputs, max_new_tokens=512)
57
+ output_text = processor.decode(predictions[0], skip_special_tokens=True)
58
+ return output_text.strip()
59
+
60
+
61
+ # Optional: assign best-matching label from full extracted text using proximity (simplified version)
62
+ def assign_event_gateway_names_from_ocr(image: Image.Image, json_data, ocr_text):
63
+ if not ocr_text:
64
+ return json_data
65
+
66
+ # You could use NLP matching or regex in complex cases
67
+ words = ocr_text.split()
68
+
69
+ def guess_name_fallback(obj):
70
+ if not obj.get("name") or obj["name"].strip() == "":
71
+ obj["name"] = "(label unknown)" # fallback if matching logic isn't yet implemented
72
+
73
+ for evt in json_data.get("events", []):
74
+ guess_name_fallback(evt)
75
+
76
+ for gw in json_data.get("gateways", []):
77
+ guess_name_fallback(gw)
78
+
79
+ return json_data
80
+
81
+
82
+ def run_model(image: Image.Image, api_key: str = None):
83
+ prompt_text = load_prompt()
84
+ encoded_image = encode_image_base64(image)
85
+
86
+ if not api_key:
87
+ return {"json": None, "raw": "⚠️ API key is missing. Please provide your OpenAI API key."}
88
+
89
+ client = openai.OpenAI(api_key=api_key)
90
+ response = client.chat.completions.create(
91
+ model=model,
92
+ messages=[
93
+ {
94
+ "role": "user",
95
+ "content": [
96
+ {"type": "text", "text": prompt_text},
97
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}}
98
+ ]
99
+ }
100
+ ],
101
+ max_tokens=5000
102
+ )
103
+
104
+ output_text = response.choices[0].message.content.strip()
105
+ parsed_json = try_extract_json(output_text)
106
+
107
+ # Vision-language OCR assist step (Pix2Struct)
108
+ full_ocr_text = extract_all_text_pix2struct(image)
109
+ parsed_json = assign_event_gateway_names_from_ocr(image, parsed_json, full_ocr_text)
110
+
111
+ return {"json": parsed_json, "raw": output_text}
models/pixtral.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import base64
4
+ from PIL import Image
5
+ from vllm import LLM
6
+ from vllm.sampling_params import SamplingParams
7
+ from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
8
+
9
+ # Optional: Replace with your Hugging Face token or use environment variable
10
+ hf_token = os.getenv("HF_TOKEN")
11
+ Image.MAX_IMAGE_PIXELS = None
12
+
13
+ # Initialize Pixtral model
14
+ model_name = "mistralai/Pixtral-12B-2409"
15
+ sampling_params = SamplingParams(max_tokens=5000)
16
+ llm = LLM(model=model_name, tokenizer_mode="mistral", dtype="bfloat16", max_model_len=30000)
17
+
18
+ # Initialize Pix2Struct OCR model
19
+ ocr_processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base")
20
+ ocr_model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base")
21
+
22
+ # Load prompt from file
23
+ def load_prompt():
24
+ with open("prompts/prompt.txt", "r", encoding="utf-8") as f:
25
+ return f.read()
26
+
27
+ # Extract structured JSON from text
28
+ def try_extract_json(text):
29
+ if not text or not text.strip():
30
+ return None
31
+ try:
32
+ return json.loads(text)
33
+ except json.JSONDecodeError:
34
+ start = text.find('{')
35
+ if start == -1:
36
+ return None
37
+
38
+ brace_count = 0
39
+ json_candidate = ''
40
+ for i in range(start, len(text)):
41
+ if text[i] == '{':
42
+ brace_count += 1
43
+ elif text[i] == '}':
44
+ brace_count -= 1
45
+ json_candidate += text[i]
46
+ if brace_count == 0:
47
+ break
48
+ try:
49
+ return json.loads(json_candidate)
50
+ except json.JSONDecodeError:
51
+ return None
52
+
53
+ # Base64 encode image
54
+ def encode_image_as_base64(pil_image):
55
+ from io import BytesIO
56
+ buffer = BytesIO()
57
+ pil_image.save(buffer, format="JPEG")
58
+ encoded = base64.b64encode(buffer.getvalue()).decode("utf-8")
59
+ return encoded
60
+
61
+ # Extract OCR text using Pix2Struct
62
+ def extract_all_text_pix2struct(image: Image.Image):
63
+ inputs = ocr_processor(images=image, return_tensors="pt")
64
+ predictions = ocr_model.generate(**inputs, max_new_tokens=512)
65
+ output_text = ocr_processor.decode(predictions[0], skip_special_tokens=True)
66
+ return output_text.strip()
67
+
68
+ # Assign event/gateway names from OCR text
69
+ def assign_event_gateway_names_from_ocr(json_data: dict, ocr_text: str):
70
+ if not ocr_text or not json_data:
71
+ return json_data
72
+
73
+ lines = [line.strip() for line in ocr_text.split('\n') if line.strip()]
74
+
75
+ def assign_best_guess(obj):
76
+ if not obj.get("name") or obj["name"].strip() == "":
77
+ obj["name"] = "(label unknown)"
78
+
79
+ for evt in json_data.get("events", []):
80
+ assign_best_guess(evt)
81
+
82
+ for gw in json_data.get("gateways", []):
83
+ assign_best_guess(gw)
84
+
85
+ return json_data
86
+
87
+ # Run model
88
+ def run_model(image: Image.Image):
89
+ prompt = load_prompt()
90
+ encoded_image = encode_image_as_base64(image)
91
+
92
+ messages = [
93
+ {
94
+ "role": "user",
95
+ "content": [
96
+ {"type": "text", "text": prompt},
97
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}}
98
+ ]
99
+ }
100
+ ]
101
+
102
+ outputs = llm.chat(messages, sampling_params=sampling_params)
103
+ raw_output = outputs[0].outputs[0].text
104
+ parsed_json = try_extract_json(raw_output)
105
+
106
+ # Apply OCR post-processing
107
+ ocr_text = extract_all_text_pix2struct(image)
108
+ parsed_json = assign_event_gateway_names_from_ocr(parsed_json, ocr_text)
109
+
110
+ return {
111
+ "json": parsed_json,
112
+ "raw": raw_output
113
+ }
models/qwen.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from PIL import Image
4
+ import torch
5
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
6
+ from qwen_vl_utils import process_vision_info
7
+ from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
8
+
9
+ # Initialize Qwen2.5-VL model
10
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
11
+ "Qwen/Qwen2.5-VL-7B-Instruct",
12
+ torch_dtype=torch.bfloat16,
13
+ device_map="cuda",
14
+ attn_implementation="flash_attention_2"
15
+ )
16
+
17
+ min_pixels = 256 * 28 * 28
18
+ max_pixels = 1080 * 28 * 28
19
+ processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
20
+
21
+ # Initialize Pix2Struct OCR model
22
+ ocr_processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base")
23
+ ocr_model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base")
24
+
25
+ # Load prompt
26
+ def load_prompt():
27
+ with open("prompts/prompt.txt", "r") as f:
28
+ return f.read()
29
+
30
+ # Try extracting JSON from text
31
+ def try_extract_json(text):
32
+ try:
33
+ return json.loads(text)
34
+ except json.JSONDecodeError:
35
+ start = text.find('{')
36
+ if start == -1:
37
+ return text
38
+ brace_count = 0
39
+ json_candidate = ''
40
+ for i in range(start, len(text)):
41
+ if text[i] == '{':
42
+ brace_count += 1
43
+ elif text[i] == '}':
44
+ brace_count -= 1
45
+ json_candidate += text[i]
46
+ if brace_count == 0:
47
+ break
48
+ try:
49
+ return json.loads(json_candidate)
50
+ except json.JSONDecodeError:
51
+ return text
52
+
53
+ # Extract OCR text using Pix2Struct
54
+ def extract_all_text_pix2struct(image: Image.Image):
55
+ inputs = ocr_processor(images=image, return_tensors="pt")
56
+ predictions = ocr_model.generate(**inputs, max_new_tokens=512)
57
+ output_text = ocr_processor.decode(predictions[0], skip_special_tokens=True)
58
+ return output_text.strip()
59
+
60
+ # Assign event/gateway names from OCR text
61
+ def assign_event_gateway_names_from_ocr(json_data: dict, ocr_text: str):
62
+ if not ocr_text or not json_data:
63
+ return json_data
64
+
65
+ lines = [line.strip() for line in ocr_text.split('\n') if line.strip()]
66
+
67
+ def assign_best_guess(obj):
68
+ if not obj.get("name") or obj["name"].strip() == "":
69
+ obj["name"] = "(label unknown)"
70
+
71
+ for evt in json_data.get("events", []):
72
+ assign_best_guess(evt)
73
+
74
+ for gw in json_data.get("gateways", []):
75
+ assign_best_guess(gw)
76
+
77
+ return json_data
78
+
79
+ # Run model
80
+ def run_model(image: Image.Image):
81
+ prompt = load_prompt()
82
+ messages = [
83
+ {
84
+ "role": "user",
85
+ "content": [
86
+ {"type": "image", "image": image},
87
+ {"type": "text", "text": prompt}
88
+ ]
89
+ }
90
+ ]
91
+
92
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
93
+ image_inputs, video_inputs = process_vision_info(messages)
94
+
95
+ inputs = processor(
96
+ text=[text],
97
+ images=image_inputs,
98
+ videos=video_inputs,
99
+ padding=True,
100
+ return_tensors="pt"
101
+ ).to("cuda")
102
+
103
+ generated_ids = model.generate(**inputs, max_new_tokens=5000)
104
+ generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
105
+
106
+ output_text = processor.batch_decode(
107
+ generated_ids_trimmed,
108
+ skip_special_tokens=True,
109
+ clean_up_tokenization_spaces=False
110
+ )[0]
111
+
112
+ parsed_json = try_extract_json(output_text)
113
+
114
+ # Apply OCR post-processing
115
+ ocr_text = extract_all_text_pix2struct(image)
116
+ parsed_json = assign_event_gateway_names_from_ocr(parsed_json, ocr_text)
117
+
118
+ return {
119
+ "json": parsed_json,
120
+ "raw": output_text
121
+ }
prompts/prompt.txt ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ You are an advanced BPMN diagram analysis engine specialized in extracting structured data from visual BPMN diagrams.
2
+
3
+ You will be given an image containing a BPMN diagram. Your job is to identify, categorize, and extract all visible BPMN components based strictly on their visual appearance and layout.
4
+
5
+ [IMPORTANT] Instructions:
6
+ - Output must be a single structured JSON object.
7
+ - Do NOT include any explanation or extra text — only the JSON output.
8
+ - Do NOT infer or assume missing details. Only include elements that are visually present and identifiable.
9
+ - If a label is unreadable or not clearly visible, exclude that element.
10
+ - Detect and extract all BPMN components strictly based on visual appearance.
11
+ - Include bounding boxes for all components.
12
+ - Identify all text labels near tasks, events, and gateways and attach them accordingly.
13
+ - Detect arrow types: solid arrows = Sequence Flows, dashed arrows = Message Flows.
14
+ - Do not skip or infer missing elements.
15
+ - If any label is attached to an event (e.g., "Money received", "Customer disagreed"), include it in the JSON output.
16
+
17
+
18
+ [VISUAL ELEMENTS TO DETECT AND EXTRACT]
19
+ - Pools: Rectangles enclosing process areas.
20
+ - Lanes: Subdivisions within pools.
21
+ - Tasks: Rounded rectangles representing activities.
22
+ - Events: Circles (Start = thin border, Intermediate = double border, End = thick border).
23
+ - Gateways: Diamond shapes (Exclusive, Parallel, Inclusive).
24
+ - Sequence Flows: Solid arrows showing progression.
25
+ - Message Flows: Dashed arrows indicating communication between participants.
26
+ - Data Objects: Document symbols representing data or content.
27
+ - Data Stores: Cylindrical storage symbols.
28
+ - Other BPMN Artifacts: Any annotations, groups, or visual elements explicitly present.
29
+
30
+ [IMPORTANT]
31
+ For every event and gateway, detect and include the text label located near or next to the symbol,
32
+ and assign it as the name field. Do not skip event/gateway labels even if the labels are outside the shape.
33
+ Treat these labels as part of the component. If a name is not detected, return an empty string for name, do not omit the field.
34
+
35
+ [EXAMPLE JSON OUTPUT FORMAT]
36
+ Please follow the structure below exactly. Your output must start with a single JSON object as shown:
37
+
38
+ {{
39
+ "pools": [
40
+ {{
41
+ "id": "pool_1",
42
+ "name": "Customer Process",
43
+ "bounding_box": {{ "x": 10, "y": 20, "width": 1200, "height": 700 }}
44
+ }}
45
+ ],
46
+ "lanes": [
47
+ {{
48
+ "id": "lane_1",
49
+ "name": "Customer",
50
+ "bounding_box": {{ "x": 50, "y": 30, "width": 1150, "height": 200 }},
51
+ "parent_pool": "pool_1"
52
+ }}
53
+ ],
54
+ "tasks": [
55
+ {{
56
+ "id": "task_1",
57
+ "name": "Submit Request",
58
+ "lane": "Customer",
59
+ "bounding_box": {{ "x": 100, "y": 50, "width": 150, "height": 80 }},
60
+ "incoming": ["start_event_1"],
61
+ "outgoing": ["task_2"]
62
+ }}
63
+ ],
64
+ "events": [
65
+ {{
66
+ "id": "start_event_1",
67
+ "type": "StartEvent",
68
+ "name": "Start Process",
69
+ "bounding_box": {{ "x": 50, "y": 75, "width": 40, "height": 40 }},
70
+ "outgoing": ["task_1"]
71
+ }},
72
+ {{
73
+ "id": "end_event_1",
74
+ "type": "EndEvent",
75
+ "name": "Process Complete",
76
+ "bounding_box": {{ "x": 1200, "y": 600, "width": 40, "height": 40 }},
77
+ "incoming": ["task_5"]
78
+ }},
79
+ {{
80
+ "id": "intermediate_event_1",
81
+ "type": "IntermediateEvent",
82
+ "name": "Customer Disagreed",
83
+ "bounding_box": { "x": 600, "y": 200, "width": 40, "height": 40 },
84
+ "incoming": ["task_3"],
85
+ "outgoing": ["task_6"]
86
+ }}
87
+
88
+ ],
89
+ "gateways": [
90
+ {{
91
+ "id": "gateway_1",
92
+ "type": "ExclusiveGateway",
93
+ "name": "Request Valid?",
94
+ "bounding_box": {{ "x": 600, "y": 150, "width": 50, "height": 50 }},
95
+ "incoming": ["task_3"],
96
+ "outgoing": ["task_4", "task_5"]
97
+ }},
98
+ {{
99
+ "id": "gateway_2",
100
+ "type": "ParallelGateway",
101
+ "name": "Split Tasks",
102
+ "bounding_box": { "x": 700, "y": 100, "width": 50, "height": 50 },
103
+ "incoming": ["task_4"],
104
+ "outgoing": ["task_5", "task_6"]
105
+ }}
106
+ ],
107
+ "datastores": [
108
+ {{
109
+ "id": "data_1",
110
+ "name": "Customer Database",
111
+ "bounding_box": {{ "x": 800, "y": 400, "width": 100, "height": 100 }},
112
+ "incoming": ["task_6"],
113
+ "outgoing": ["task_7"]
114
+ }}
115
+ ],
116
+ "flows": [
117
+ {{
118
+ "id": "flow_1",
119
+ "type": "SequenceFlow",
120
+ "name": "Submit Review",
121
+ "source": "start_event_1",
122
+ "target": "task_1",
123
+ "waypoints": [
124
+ {{ "x": 70, "y": 100 }},
125
+ {{ "x": 100, "y": 100 }}
126
+ ]
127
+ }},
128
+ {{
129
+ "id": "flow_2",
130
+ "type": "MessageFlow",
131
+ "name": "Disagreement Notification",
132
+ "source": "task_4",
133
+ "target": "external_system",
134
+ "waypoints": [
135
+ {{ "x": 400, "y": 200 }},
136
+ {{ "x": 500, "y": 300 }}
137
+ ]
138
+ }}
139
+ ]
140
+ }}
141
+
142
+ [FINAL INSTRUCTIONS]
143
+ - Return ONLY the final JSON structure without any introductory or explanatory text.
144
+ - Make sure the output is valid JSON and complete.
145
+ - Do NOT include placeholder elements or guess missing labels.
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ git+https://github.com/huggingface/transformers
3
+ pillow
4
+ gradio
5
+ qwen-vl-utils
6
+ flash_attn
7
+ vllm
8
+ mistral_common
9
+ paddleocr
10
+ paddlepaddle