prithivMLmods commited on
Commit
f17f462
·
verified ·
1 Parent(s): 966f36e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +634 -374
app.py CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
  import spaces
2
  import json
3
  import math
@@ -6,458 +10,714 @@ import traceback
6
  from io import BytesIO
7
  from typing import Any, Dict, List, Optional, Tuple
8
  import re
 
 
 
9
 
 
10
  import fitz # PyMuPDF
11
  import gradio as gr
12
  import requests
13
  import torch
14
  from huggingface_hub import snapshot_download
15
  from PIL import Image, ImageDraw, ImageFont
 
16
  from qwen_vl_utils import process_vision_info
17
- from transformers import AutoModelForCausalLM, AutoProcessor, Qwen2_5_VLForConditionalGeneration
18
 
19
- # Constants
 
 
 
 
 
 
 
20
  MIN_PIXELS = 3136
21
  MAX_PIXELS = 11289600
22
  IMAGE_FACTOR = 28
23
-
24
- # Prompts
25
- prompt = """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
26
-
27
  1. Bbox format: [x1, y1, x2, y2]
28
-
29
  2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title'].
30
-
31
  3. Text Extraction & Formatting Rules:
32
  - Picture: For the 'Picture' category, the text field should be omitted.
33
  - Formula: Format its text as LaTeX.
34
  - Table: Format its text as HTML.
35
  - All Others (Text, Title, etc.): Format their text as Markdown.
36
-
37
  4. Constraints:
38
  - The output text must be the original text from the image, with no translation.
39
  - All layout elements must be sorted according to human reading order.
40
-
41
  5. Final Output: The entire output must be a single JSON object.
42
  """
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- # Utility Functions
45
- def round_by_factor(number: int, factor: int) -> int:
46
- """Returns the closest integer to 'number' that is divisible by 'factor'."""
47
- return round(number / factor) * factor
48
-
49
- def smart_resize(
50
- height: int,
51
- width: int,
52
- factor: int = 28,
53
- min_pixels: int = 3136,
54
- max_pixels: int = 11289600,
55
- ):
56
- """Rescales the image so that dimensions are divisible by 'factor', within pixel range, maintaining aspect ratio."""
57
- if max(height, width) / min(height, width) > 200:
58
- raise ValueError(f"Aspect ratio must be < 200, got {max(height, width) / min(height, width)}")
59
- h_bar = max(factor, round_by_factor(height, factor))
60
- w_bar = max(factor, round_by_factor(width, factor))
61
-
62
- if h_bar * w_bar > max_pixels:
63
- beta = math.sqrt((height * width) / max_pixels)
64
- h_bar = round_by_factor(height / beta, factor)
65
- w_bar = round_by_factor(width / beta, factor)
66
- elif h_bar * w_bar < min_pixels:
67
- beta = math.sqrt(min_pixels / (height * width))
68
- h_bar = round_by_factor(height * beta, factor)
69
- w_bar = round_by_factor(width * beta, factor)
70
- return h_bar, w_bar
71
-
72
- def fetch_image(image_input, min_pixels: int = None, max_pixels: int = None):
73
- """Fetch and process an image."""
74
- if isinstance(image_input, str):
75
- if image_input.startswith(("http://", "https://")):
76
- response = requests.get(image_input)
77
- image = Image.open(BytesIO(response.content)).convert('RGB')
78
- else:
79
- image = Image.open(image_input).convert('RGB')
80
- elif isinstance(image_input, Image.Image):
81
- image = image_input.convert('RGB')
82
- else:
83
- raise ValueError(f"Invalid image input type: {type(image_input)}")
84
-
85
- if min_pixels is not None or max_pixels is not None:
86
- min_pixels = min_pixels or MIN_PIXELS
87
- max_pixels = max_pixels or MAX_PIXELS
88
- height, width = smart_resize(
89
- image.height,
90
- image.width,
91
- factor=IMAGE_FACTOR,
92
- min_pixels=min_pixels,
93
- max_pixels=max_pixels
94
- )
95
- image = image.resize((width, height), Image.LANCZOS)
96
-
97
- return image
98
 
99
- def load_images_from_pdf(pdf_path: str) -> List[Image.Image]:
100
- """Load images from PDF file."""
101
- images = []
102
- try:
103
- pdf_document = fitz.open(pdf_path)
104
- for page_num in range(len(pdf_document)):
105
- page = pdf_document.load_page(page_num)
106
- mat = fitz.Matrix(2.0, 2.0) # Increase resolution
107
- pix = page.get_pixmap(matrix=mat)
108
- img_data = pix.tobytes("ppm")
109
- image = Image.open(BytesIO(img_data)).convert('RGB')
110
- images.append(image)
111
- pdf_document.close()
112
- except Exception as e:
113
- print(f"Error loading PDF: {e}")
114
- return []
115
- return images
116
-
117
- def draw_layout_on_image(image: Image.Image, layout_data: List[Dict]) -> Image.Image:
118
- """Draw layout bounding boxes on image."""
119
- img_copy = image.copy()
120
- draw = ImageDraw.Draw(img_copy)
121
-
122
- colors = {
123
- 'Caption': '#FF6B6B', 'Footnote': '#4ECDC4', 'Formula': '#45B7D1',
124
- 'List-item': '#96CEB4', 'Page-footer': '#FFEAA7', 'Page-header': '#DDA0DD',
125
- 'Picture': '#FFD93D', 'Section-header': '#6C5CE7', 'Table': '#FD79A8',
126
- 'Text': '#74B9FF', 'Title': '#E17055'
127
- }
128
-
129
- try:
130
- font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 12) or ImageFont.load_default()
131
- for item in layout_data:
132
- if 'bbox' in item and 'category' in item:
133
- bbox = item['bbox']
134
- category = item['category']
135
- color = colors.get(category, '#000000')
136
-
137
- draw.rectangle(bbox, outline=color, width=2)
138
-
139
- label = category
140
- label_bbox = draw.textbbox((0, 0), label, font=font)
141
- label_width, label_height = label_bbox[2] - label_bbox[0], label_bbox[3] - label_bbox[1]
142
-
143
- label_x, label_y = bbox[0], max(0, bbox[1] - label_height - 2)
144
- draw.rectangle([label_x, label_y, label_x + label_width + 4, label_y + label_height + 2], fill=color)
145
- draw.text((label_x + 2, label_y + 1), label, fill='white', font=font)
146
- except Exception as e:
147
- print(f"Error drawing layout: {e}")
148
-
149
- return img_copy
150
 
151
- def layoutjson2md(image: Image.Image, layout_data: List[Dict], text_key: str = 'text') -> str:
152
- """Convert layout JSON to markdown format."""
153
- import base64
154
-
155
- markdown_lines = []
156
  try:
157
- sorted_items = sorted(layout_data, key=lambda x: (x.get('bbox', [0, 0, 0, 0])[1], x.get('bbox', [0, 0, 0, 0])[0]))
158
- for item in sorted_items:
159
- category = item.get('category', '')
160
- text = item.get(text_key, '')
161
- bbox = item.get('bbox', [])
162
-
163
- if category == 'Picture' and bbox and len(bbox) == 4:
164
- try:
165
- x1, y1, x2, y2 = [max(0, int(x1)), max(0, int(y1)), min(image.width, int(x2)), min(image.height, int(y2))]
166
- if x2 > x1 and y2 > y1:
167
- cropped_img = image.crop((x1, y1, x2, y2))
168
- buffer = BytesIO()
169
- cropped_img.save(buffer, format='PNG')
170
- img_data = base64.b64encode(buffer.getvalue()).decode()
171
- markdown_lines.append(f"![Image](data:image/png;base64,{img_data})\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  else:
173
- markdown_lines.append("![Image](Image region detected)\n")
174
- except Exception as e:
175
- print(f"Error processing image region: {e}")
176
- markdown_lines.append("![Image](Image detected)\n")
177
- elif not text:
178
- continue
179
- elif category == 'Title':
180
- markdown_lines.append(f"# {text}\n")
181
- elif category == 'Section-header':
182
- markdown_lines.append(f"## {text}\n")
183
- elif category == 'Text':
184
- markdown_lines.append(f"{text}\n")
185
- elif category == 'List-item':
186
- markdown_lines.append(f"- {text}\n")
187
- elif category == 'Table':
188
- markdown_lines.append(f"{text}\n" if text.strip().startswith('<') else f"**Table:** {text}\n")
189
- elif category == 'Formula':
190
- markdown_lines.append(f"$$\n{text}\n$$\n" if text.strip().startswith('$') or '\\' in text else f"**Formula:** {text}\n")
191
- elif category == 'Caption':
192
- markdown_lines.append(f"*{text}*\n")
193
- elif category == 'Footnote':
194
- markdown_lines.append(f"^{text}^\n")
195
- elif category in ['Page-header', 'Page-footer']:
196
- continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  else:
198
- markdown_lines.append(f"{text}\n")
199
- markdown_lines.append("")
200
- except Exception as e:
201
- print(f"Error converting to markdown: {e}")
202
- return str(layout_data)
203
-
204
- return "\n".join(markdown_lines)
205
-
206
- # Load Models
207
- device = "cuda" if torch.cuda.is_available() else "cpu"
208
-
209
- # Load dot.ocr
210
- model_id = "rednote-hilab/dots.ocr"
211
- model_path = "./models/dots-ocr-local"
212
- snapshot_download(repo_id=model_id, local_dir=model_path, local_dir_use_symlinks=False)
213
- model = AutoModelForCausalLM.from_pretrained(
214
- model_path,
215
- attn_implementation="flash_attention_2",
216
- torch_dtype=torch.bfloat16,
217
- device_map="auto",
218
- trust_remote_code=True
219
- )
220
- processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
221
-
222
- # Load Camel-Doc-OCR-062825
223
- MODEL_ID_M = "prithivMLmods/Camel-Doc-OCR-062825"
224
- processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
225
- model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
226
- MODEL_ID_M,
227
- trust_remote_code=True,
228
- torch_dtype=torch.float16
229
- ).to(device).eval()
230
-
231
- # Load Megalodon-OCR-Sync-0713
232
- MODEL_ID_T = "prithivMLmods/Megalodon-OCR-Sync-0713"
233
- processor_t = AutoProcessor.from_pretrained(MODEL_ID_T, trust_remote_code=True)
234
- model_t = Qwen2_5_VLForConditionalGeneration.from_pretrained(
235
- MODEL_ID_T,
236
- trust_remote_code=True,
237
- torch_dtype=torch.float16
238
- ).to(device).eval()
239
-
240
- # Model Dictionary
241
- model_dict = {
242
- "dot.ocr": {"model": model, "processor": processor, "process_layout": True},
243
- "Camel-Doc-OCR-062825": {"model": model_m, "processor": processor_m, "process_layout": False},
244
- "Megalodon-OCR-Sync-0713": {"model": model_t, "processor": processor_t, "process_layout": False},
245
- }
246
-
247
- # Global State
248
- pdf_cache = {"images": [], "current_page": 0, "total_pages": 0, "file_type": None, "is_parsed": False, "results": []}
249
-
250
- @spaces.GPU()
251
- def inference(model, processor, image: Image.Image, prompt: str, max_new_tokens: int = 24000) -> str:
252
- """Run inference on an image with the given prompt using the specified model and processor."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  try:
254
- messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt}]}]
255
- text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
256
- image_inputs, video_inputs = process_vision_info(messages)
257
- inputs = processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt").to(device)
258
-
259
- with torch.no_grad():
260
- generated_ids = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False, temperature=0.1)
261
- generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
262
- output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)
263
- return output_text[0] if output_text else ""
264
  except Exception as e:
265
- print(f"Error during inference: {e}")
266
- traceback.print_exc()
267
- return f"Error during inference: {str(e)}"
268
-
269
- def process_image(
270
- image: Image.Image,
271
- model,
272
- processor,
273
- process_layout: bool,
274
- min_pixels: Optional[int] = None,
275
- max_pixels: Optional[int] = None
276
- ) -> Dict[str, Any]:
277
- """Process a single image with the specified model and processor."""
278
  try:
279
- if min_pixels is not None or max_pixels is not None:
280
- image = fetch_image(image, min_pixels=min_pixels, max_pixels=max_pixels)
281
 
282
- raw_output = inference(model, processor, image, prompt)
283
- result = {'original_image': image, 'raw_output': raw_output, 'processed_image': image, 'layout_result': None, 'markdown_content': raw_output}
 
284
 
285
- if process_layout:
286
- try:
287
- layout_data = json.loads(raw_output)
288
- result['layout_result'] = layout_data
289
- result['processed_image'] = draw_layout_on_image(image, layout_data)
290
- result['markdown_content'] = layoutjson2md(image, layout_data, text_key='text')
291
- except json.JSONDecodeError:
292
- print("Failed to parse JSON output, using raw output")
293
- except Exception as e:
294
- print(f"Error processing layout: {e}")
295
 
296
- return result
 
 
 
 
 
 
 
297
  except Exception as e:
298
- print(f"Error processing image: {e}")
299
- traceback.print_exc()
300
- return {'original_image': image, 'raw_output': str(e), 'processed_image': image, 'layout_result': None, 'markdown_content': str(e)}
301
-
302
- def load_file_for_preview(file_path: str) -> Tuple[Optional[Image.Image], str]:
303
- """Load file for preview (supports PDF and images)."""
304
- global pdf_cache
305
- if not file_path or not os.path.exists(file_path):
306
- return None, "No file selected"
307
-
308
- file_ext = os.path.splitext(file_path)[1].lower()
309
  try:
310
- if file_ext == '.pdf':
311
- images = load_images_from_pdf(file_path)
312
- if not images:
313
- return None, "Failed to load PDF"
314
- pdf_cache.update({"images": images, "current_page": 0, "total_pages": len(images), "file_type": "pdf", "is_parsed": False, "results": []})
315
- return images[0], f"Page 1 / {len(images)}"
316
- elif file_ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']:
317
- image = Image.open(file_path).convert('RGB')
318
- pdf_cache.update({"images": [image], "current_page": 0, "total_pages": 1, "file_type": "image", "is_parsed": False, "results": []})
319
- return image, "Page 1 / 1"
320
- else:
321
- return None, f"Unsupported file format: {file_ext}"
322
  except Exception as e:
323
- print(f"Error loading file: {e}")
324
- return None, f"Error loading file: {str(e)}"
325
-
326
- def turn_page(direction: str) -> Tuple[Optional[Image.Image], str, Any, Optional[Image.Image], Optional[Dict]]:
327
- """Navigate through PDF pages and update outputs."""
328
- global pdf_cache
329
- if not pdf_cache["images"]:
330
- return None, '<div class="page-info">No file loaded</div>', "No results yet", None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
 
332
- pdf_cache["current_page"] = max(0, pdf_cache["current_page"] - 1) if direction == "prev" else min(pdf_cache["total_pages"] - 1, pdf_cache["current_page"] + 1)
333
- index = pdf_cache["current_page"]
334
- current_image_preview = pdf_cache["images"][index]
335
- page_info_html = f'<div class="page-info">Page {index + 1} / {pdf_cache["total_pages"]}</div>'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
 
337
- markdown_content, processed_img, layout_json = "Page not processed yet", None, None
338
- if pdf_cache["is_parsed"] and index < len(pdf_cache["results"]) and pdf_cache["results"][index]:
339
- result = pdf_cache["results"][index]
340
- markdown_content = result.get('markdown_content') or result.get('raw_output', 'No content available')
341
- processed_img = result.get('processed_image')
342
- layout_json = result.get('layout_result')
 
 
 
 
 
 
 
 
 
 
 
343
 
344
- return current_image_preview, page_info_html, markdown_content, processed_img, layout_json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
 
346
  def create_gradio_interface():
347
- """Create the Gradio interface."""
 
348
  css = """
349
  .main-container { max-width: 1400px; margin: 0 auto; }
350
  .header-text { text-align: center; color: #2c3e50; margin-bottom: 20px; }
351
- .process-button {
352
- border: none !important;
353
- color: white !important;
354
- font-weight: bold !important;
355
- background-color: blue !important;}
356
- .process-button:hover {
357
- background-color: darkblue !important;
358
- transform: translateY(-2px) !important;
359
- box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; }
360
  .info-box { border: 1px solid #dee2e6; border-radius: 8px; padding: 15px; margin: 10px 0; }
361
  .page-info { text-align: center; padding: 8px 16px; border-radius: 20px; font-weight: bold; margin: 10px 0; }
362
- .model-status { padding: 10px; border-radius: 8px; margin: 10px 0; text-align: center; font-weight: bold; }
363
- .status-ready { background: #d1edff; color: #0c5460; border: 1px solid #b8daff; }
364
  """
365
-
366
- with gr.Blocks(theme="bethecloud/storj_theme", css=css, title="DotOCR Comparator") as demo:
367
  gr.HTML("""
368
  <div class="title" style="text-align: center">
369
- <h1>Dot<span style="color: red;">●</span>OCR Comparator</h1>
370
  <p style="font-size: 1.1em; color: #6b7280; margin-bottom: 0.6em;">
371
  Advanced vision-language model for image/PDF to markdown document processing
372
  </p>
373
  </div>
374
  """)
375
 
376
- with gr.Row():
377
  with gr.Column(scale=1):
378
- model_choice = gr.Radio(
379
- choices=["dot.ocr", "Camel-Doc-OCR-062825", "Megalodon-OCR-Sync-0713"],
380
- label="Select Model",
381
- value="dot.ocr"
382
- )
383
- file_input = gr.File(label="Upload Image or PDF", file_types=[".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".pdf"], type="filepath")
384
  with gr.Row():
385
  examples = gr.Examples(
386
  examples=["examples/sample_image1.png", "examples/sample_image2.png", "examples/sample_pdf.pdf"],
387
  inputs=file_input,
388
  label="Example Documents"
389
  )
390
- image_preview = gr.Image(label="Preview", type="pil", interactive=False, height=300)
 
 
 
391
  with gr.Row():
392
- prev_page_btn = gr.Button("◀ Previous", size="md")
393
  page_info = gr.HTML('<div class="page-info">No file loaded</div>')
394
- next_page_btn = gr.Button("Next ▶", size="md")
395
- with gr.Accordion("Advanced Settings", open=False):
396
- max_new_tokens = gr.Slider(minimum=1000, maximum=32000, value=24000, step=1000, label="Max New Tokens")
397
- min_pixels = gr.Number(value=MIN_PIXELS, label="Min Pixels")
398
- max_pixels = gr.Number(value=MAX_PIXELS, label="Max Pixels")
399
- process_btn = gr.Button("🚀 Process Document", variant="primary", elem_classes=["process-button"], size="lg")
400
- clear_btn = gr.Button("🗑️ Clear All", variant="secondary")
 
 
401
 
402
  with gr.Column(scale=2):
403
  with gr.Tabs():
404
- with gr.Tab("🖼️ Processed Image"):
405
- processed_image = gr.Image(label="Image with Layout Detection", type="pil", interactive=False, height=500)
406
  with gr.Tab("📝 Extracted Content"):
407
- markdown_output = gr.Markdown(value="Click 'Process Document' to see extracted content...", height=500)
 
 
408
  with gr.Tab("📋 Layout JSON"):
409
- json_output = gr.JSON(label="Layout Analysis Results", value=None)
410
-
411
- def process_document(file_path, model_choice, max_tokens, min_pix, max_pix):
412
- """Process the uploaded document with the selected model."""
413
- global pdf_cache
414
- if not file_path:
415
- return None, "Please upload a file first.", None
416
- if model_choice not in model_dict:
417
- return None, "Invalid model selected", None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
418
 
419
- selected_model = model_dict[model_choice]["model"]
420
- selected_processor = model_dict[model_choice]["processor"]
421
- process_layout = model_dict[model_choice]["process_layout"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
422
 
423
- image, page_info = load_file_for_preview(file_path)
424
- if image is None:
425
- return None, page_info, None
426
 
427
- if pdf_cache["file_type"] == "pdf":
428
- all_results, all_markdown = [], []
429
- for i, img in enumerate(pdf_cache["images"]):
430
- result = process_image(img, selected_model, selected_processor, process_layout, int(min_pix) if min_pix else None, int(max_pix) if max_pix else None)
431
- all_results.append(result)
432
- if result.get('markdown_content'):
433
- all_markdown.append(f"## Page {i+1}\n\n{result['markdown_content']}")
434
- pdf_cache["results"] = all_results
435
- pdf_cache["is_parsed"] = True
436
- first_result = all_results[0]
437
- return first_result['processed_image'], "\n\n---\n\n".join(all_markdown), first_result['layout_result']
438
- else:
439
- result = process_image(image, selected_model, selected_processor, process_layout, int(min_pix) if min_pix else None, int(max_pix) if max_pix else None)
440
- pdf_cache["results"] = [result]
441
- pdf_cache["is_parsed"] = True
442
- return result['processed_image'], result['markdown_content'] or "No content extracted", result['layout_result']
443
 
444
- def handle_file_upload(file_path):
445
- image, page_info = load_file_for_preview(file_path)
446
- return image, page_info
 
 
 
 
 
 
 
 
 
 
 
 
 
 
447
 
448
  def clear_all():
449
- global pdf_cache
450
- pdf_cache = {"images": [], "current_page": 0, "total_pages": 0, "file_type": None, "is_parsed": False, "results": []}
451
- return None, None, '<div class="page-info">No file loaded</div>', None, "Click 'Process Document' to see extracted content...", None
452
-
453
- file_input.change(handle_file_upload, inputs=[file_input], outputs=[image_preview, page_info])
454
- prev_page_btn.click(lambda: turn_page("prev"), outputs=[image_preview, page_info, markdown_output, processed_image, json_output])
455
- next_page_btn.click(lambda: turn_page("next"), outputs=[image_preview, page_info, markdown_output, processed_image, json_output])
456
- process_btn.click(process_document, inputs=[file_input, model_choice, max_new_tokens, min_pixels, max_pixels], outputs=[processed_image, markdown_output, json_output])
457
- clear_btn.click(clear_all, outputs=[file_input, image_preview, page_info, processed_image, markdown_output, json_output])
458
-
 
 
 
 
 
459
  return demo
460
 
461
  if __name__ == "__main__":
462
- demo = create_gradio_interface()
463
- demo.queue(max_size=50).launch(share=False, debug=True, show_error=True)
 
 
 
 
 
 
1
+ # app.py
2
+ # All code combined into a single file for convenience.
3
+
4
+ # --- Imports ---
5
  import spaces
6
  import json
7
  import math
 
10
  from io import BytesIO
11
  from typing import Any, Dict, List, Optional, Tuple
12
  import re
13
+ import base64
14
+ import copy
15
+ from dataclasses import dataclass
16
 
17
+ # Vision and ML Libraries
18
  import fitz # PyMuPDF
19
  import gradio as gr
20
  import requests
21
  import torch
22
  from huggingface_hub import snapshot_download
23
  from PIL import Image, ImageDraw, ImageFont
24
+ from transformers import AutoModelForCausalLM, AutoProcessor, VisionEncoderDecoderModel
25
  from qwen_vl_utils import process_vision_info
 
26
 
27
+ # Image Processing Libraries
28
+ import cv2
29
+ import numpy as np
30
+ import albumentations as alb
31
+ from albumentations.pytorch import ToTensorV2
32
+ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
33
+
34
+ # --- Constants & Global State ---
35
  MIN_PIXELS = 3136
36
  MAX_PIXELS = 11289600
37
  IMAGE_FACTOR = 28
38
+ DOT_OCR_PROMPT = """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
 
 
 
39
  1. Bbox format: [x1, y1, x2, y2]
 
40
  2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title'].
 
41
  3. Text Extraction & Formatting Rules:
42
  - Picture: For the 'Picture' category, the text field should be omitted.
43
  - Formula: Format its text as LaTeX.
44
  - Table: Format its text as HTML.
45
  - All Others (Text, Title, etc.): Format their text as Markdown.
 
46
  4. Constraints:
47
  - The output text must be the original text from the image, with no translation.
48
  - All layout elements must be sorted according to human reading order.
 
49
  5. Final Output: The entire output must be a single JSON object.
50
  """
51
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
52
+ PDF_CACHE = {
53
+ "images": [],
54
+ "current_page": 0,
55
+ "total_pages": 0,
56
+ "file_type": None,
57
+ "is_parsed": False,
58
+ "results": [],
59
+ "model_used": None,
60
+ }
61
+ MODELS = {}
62
 
63
+ # =================================================================================
64
+ # --- UTILITY FUNCTIONS (from markdown_utils.py and utils.py) ---
65
+ # =================================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
+ # --- Markdown Conversion Utilities ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
+ def extract_table_from_html(html_string):
70
+ """Extract and clean table tags from HTML string"""
 
 
 
71
  try:
72
+ table_pattern = re.compile(r'<table.*?>.*?</table>', re.DOTALL)
73
+ tables = table_pattern.findall(html_string)
74
+ tables = [re.sub(r'<table[^>]*>', '<table>', table) for table in tables]
75
+ return '\n'.join(tables)
76
+ except Exception as e:
77
+ print(f"extract_table_from_html error: {str(e)}")
78
+ return f"<table><tr><td>Error extracting table: {str(e)}</td></tr></table>"
79
+
80
+
81
+ class MarkdownConverter:
82
+ """Convert structured recognition results to Markdown format"""
83
+ def __init__(self):
84
+ self.heading_levels = {'title': '#', 'sec': '##', 'sub_sec': '###'}
85
+ self.special_labels = {'tab', 'fig', 'title', 'sec', 'sub_sec', 'list', 'formula', 'reference', 'alg'}
86
+
87
+ def try_remove_newline(self, text: str) -> str:
88
+ try:
89
+ text = text.strip().replace('-\n', '')
90
+ def is_chinese(char): return '\u4e00' <= char <= '\u9fff'
91
+ lines, processed_lines = text.split('\n'), []
92
+ for i in range(len(lines)-1):
93
+ current_line, next_line = lines[i].strip(), lines[i+1].strip()
94
+ if current_line:
95
+ if next_line:
96
+ if is_chinese(current_line[-1]) and is_chinese(next_line[0]):
97
+ processed_lines.append(current_line)
98
+ else:
99
+ processed_lines.append(current_line + ' ')
100
  else:
101
+ processed_lines.append(current_line + '\n')
102
+ else:
103
+ processed_lines.append('\n')
104
+ if lines and lines[-1].strip():
105
+ processed_lines.append(lines[-1].strip())
106
+ return ''.join(processed_lines)
107
+ except Exception as e:
108
+ print(f"try_remove_newline error: {str(e)}")
109
+ return text
110
+
111
+ def _handle_text(self, text: str) -> str:
112
+ try:
113
+ if not text: return ""
114
+ if text.strip().startswith("\\begin{array}") and text.strip().endswith("\\end{array}"):
115
+ text = "$$" + text + "$$"
116
+ elif ("_{" in text or "^{" in text or "\\" in text or "_ {" in text or "^ {" in text) and ("$" not in text) and ("\\begin" not in text):
117
+ text = "$" + text + "$"
118
+ text = self._process_formulas_in_text(text)
119
+ text = self.try_remove_newline(text)
120
+ return text
121
+ except Exception as e:
122
+ print(f"_handle_text error: {str(e)}")
123
+ return text
124
+
125
+ def _process_formulas_in_text(self, text: str) -> str:
126
+ try:
127
+ delimiters = [('$$', '$$'), ('\\[', '\\]'), ('$', '$'), ('\\(', '\\)')]
128
+ result = text
129
+ for start_delim, end_delim in delimiters:
130
+ current_pos, processed_parts = 0, []
131
+ while current_pos < len(result):
132
+ start_pos = result.find(start_delim, current_pos)
133
+ if start_pos == -1:
134
+ processed_parts.append(result[current_pos:])
135
+ break
136
+ processed_parts.append(result[current_pos:start_pos])
137
+ end_pos = result.find(end_delim, start_pos + len(start_delim))
138
+ if end_pos == -1:
139
+ processed_parts.append(result[start_pos:])
140
+ break
141
+ formula_content = result[start_pos + len(start_delim):end_pos]
142
+ processed_formula = formula_content.replace('\n', ' \\\\ ')
143
+ processed_parts.append(f"{start_delim}{processed_formula}{end_delim}")
144
+ current_pos = end_pos + len(end_delim)
145
+ result = ''.join(processed_parts)
146
+ return result
147
+ except Exception as e:
148
+ print(f"_process_formulas_in_text error: {str(e)}")
149
+ return text
150
+
151
+ def _remove_newline_in_heading(self, text: str) -> str:
152
+ try:
153
+ def is_chinese(char): return '\u4e00' <= char <= '\u9fff'
154
+ return text.replace('\n', '') if any(is_chinese(char) for char in text) else text.replace('\n', ' ')
155
+ except Exception as e:
156
+ print(f"_remove_newline_in_heading error: {str(e)}")
157
+ return text
158
+
159
+ def _handle_heading(self, text: str, label: str) -> str:
160
+ try:
161
+ level = self.heading_levels.get(label, '#')
162
+ text = self._remove_newline_in_heading(text.strip())
163
+ text = self._handle_text(text)
164
+ return f"{level} {text}\n\n"
165
+ except Exception as e:
166
+ print(f"_handle_heading error: {str(e)}")
167
+ return f"# Error processing heading: {text}\n\n"
168
+
169
+ def _handle_list_item(self, text: str) -> str:
170
+ try:
171
+ return f"- {text.strip()}\n"
172
+ except Exception as e:
173
+ print(f"_handle_list_item error: {str(e)}")
174
+ return f"- Error processing list item: {text}\n"
175
+
176
+ def _handle_figure(self, text: str, section_count: int) -> str:
177
+ try:
178
+ if not text.strip():
179
+ return f"![Figure {section_count}](data:image/png;base64,)\n\n"
180
+ if text.startswith("data:image/"):
181
+ return f"![Figure {section_count}]({text})\n\n"
182
  else:
183
+ return f"![Figure {section_count}](data:image/png;base64,{text})\n\n"
184
+ except Exception as e:
185
+ print(f"_handle_figure error: {str(e)}")
186
+ return f"*[Error processing figure: {str(e)}]*\n\n"
187
+
188
+ def _handle_table(self, text: str) -> str:
189
+ try:
190
+ if '<table' in text.lower() or '<tr' in text.lower():
191
+ return extract_table_from_html(text) + "\n\n"
192
+ else:
193
+ table_lines = text.split('\n')
194
+ if not table_lines: return "\n\n"
195
+ col_count = len(table_lines[0].split()) if table_lines[0] else 1
196
+ header = '| ' + ' | '.join(table_lines[0].split()) + ' |'
197
+ separator = '| ' + ' | '.join(['---'] * col_count) + ' |'
198
+ rows = [f"| {' | '.join(line.split())} |" for line in table_lines[1:]]
199
+ return '\n'.join([header, separator] + rows) + '\n\n'
200
+ except Exception as e:
201
+ print(f"_handle_table error: {str(e)}")
202
+ return f"*[Error processing table: {str(e)}]*\n\n"
203
+
204
+ def _handle_algorithm(self, text: str) -> str:
205
+ try:
206
+ text = re.sub(r'\\begin\{algorithm\}(.*?)\\end\{algorithm\}', r'\1', text, flags=re.DOTALL)
207
+ text = text.replace('\\begin{algorithmic}', '').replace('\\end{algorithmic}', '')
208
+ caption_match = re.search(r'\\caption\{(.*?)\}', text)
209
+ caption = f"**{caption_match.group(1)}**\n\n" if caption_match else ""
210
+ algorithm_text = re.sub(r'\\caption\{.*?\}', '', text).strip()
211
+ return f"{caption}```\n{algorithm_text}\n```\n\n"
212
+ except Exception as e:
213
+ print(f"_handle_algorithm error: {str(e)}")
214
+ return f"*[Error processing algorithm: {str(e)}]*\n\n{text}\n\n"
215
+
216
+ def _handle_formula(self, text: str) -> str:
217
+ try:
218
+ processed_text = self._process_formulas_in_text(text)
219
+ if '$$' not in processed_text and '\\[' not in processed_text:
220
+ processed_text = f'$${processed_text}$$'
221
+ return f"{processed_text}\n\n"
222
+ except Exception as e:
223
+ print(f"_handle_formula error: {str(e)}")
224
+ return f"*[Error processing formula: {str(e)}]*\n\n"
225
+
226
+ def convert(self, recognition_results: List[Dict[str, Any]]) -> str:
227
+ markdown_content = []
228
+ for i, result in enumerate(recognition_results):
229
+ try:
230
+ label, text = result.get('label', ''), result.get('text', '').strip()
231
+ if label == 'fig':
232
+ markdown_content.append(self._handle_figure(text, i))
233
+ continue
234
+ if not text: continue
235
+
236
+ if label in {'title', 'sec', 'sub_sec'}:
237
+ markdown_content.append(self._handle_heading(text, label))
238
+ elif label == 'list':
239
+ markdown_content.append(self._handle_list_item(text))
240
+ elif label == 'tab':
241
+ markdown_content.append(self._handle_table(text))
242
+ elif label == 'alg':
243
+ markdown_content.append(self._handle_algorithm(text))
244
+ elif label == 'formula':
245
+ markdown_content.append(self._handle_formula(text))
246
+ elif label not in self.special_labels:
247
+ markdown_content.append(f"{self._handle_text(text)}\n\n")
248
+ except Exception as e:
249
+ print(f"Error processing item {i}: {str(e)}")
250
+ markdown_content.append(f"*[Error processing content]*\n\n")
251
+ return self._post_process(''.join(markdown_content))
252
+
253
+ def _post_process(self, md: str) -> str:
254
+ try:
255
+ md = re.sub(r'\\author\{(.*?)\}', lambda m: self._handle_text(m.group(1)), md, flags=re.DOTALL)
256
+ md = re.sub(r'\$(\\author\{.*?\})\$', lambda m: self._handle_text(re.search(r'\\author\{(.*?)\}', m.group(1), re.DOTALL).group(1)), md, flags=re.DOTALL)
257
+ md = re.sub(r'\\begin\{abstract\}(.*?)\\end\{abstract\}', r'**Abstract** \1', md, flags=re.DOTALL)
258
+ md = re.sub(r'\\begin\{abstract\}', r'**Abstract**', md)
259
+ md = re.sub(r'\\eqno\{\((.*?)\)\}', r'\\tag{\1}', md)
260
+ md = md.replace("\[ \\\\", "$$ \\\\").replace("\\\\ \]", "\\\\ $$")
261
+ md = re.sub(r'_ {', r'_{', md)
262
+ md = re.sub(r'^ {', r'^{', md)
263
+ md = re.sub(r'\n{3,}', r'\n\n', md)
264
+ return md
265
+ except Exception as e:
266
+ print(f"_post_process error: {str(e)}")
267
+ return md
268
+
269
+ # --- General Processing Utilities ---
270
+ @dataclass
271
+ class ImageDimensions:
272
+ original_w: int
273
+ original_h: int
274
+ padded_w: int
275
+ padded_h: int
276
+
277
+ def adjust_box_edges(image, boxes: List[List[float]], max_pixels=15, threshold=0.2):
278
+ if isinstance(image, str):
279
+ image = cv2.imread(image)
280
+ img_h, img_w = image.shape[:2]
281
+ new_boxes = []
282
+ for box in boxes:
283
+ best_box = copy.deepcopy(box)
284
+
285
+ def check_edge(img, current_box, i, is_vertical):
286
+ edge = current_box[i]
287
+ gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
288
+ _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
289
+ if is_vertical:
290
+ line = binary[current_box[1] : current_box[3] + 1, edge]
291
+ else:
292
+ line = binary[edge, current_box[0] : current_box[2] + 1]
293
+ transitions = np.abs(np.diff(line))
294
+ return np.sum(transitions) / len(transitions)
295
+
296
+ edges = [(0, -1, True), (2, 1, True), (1, -1, False), (3, 1, False)]
297
+ current_box = copy.deepcopy(box)
298
+ current_box = [min(max(c, 0), d - 1) for c, d in zip(current_box, [img_w, img_h, img_w, img_h])]
299
+
300
+ for i, direction, is_vertical in edges:
301
+ best_score = check_edge(image, current_box, i, is_vertical)
302
+ if best_score <= threshold: continue
303
+ for _ in range(max_pixels):
304
+ current_box[i] += direction
305
+ dim = img_w if i in [0, 2] else img_h
306
+ current_box[i] = min(max(current_box[i], 0), dim - 1)
307
+ score = check_edge(image, current_box, i, is_vertical)
308
+ if score < best_score:
309
+ best_score, best_box = score, copy.deepcopy(current_box)
310
+ if score <= threshold: break
311
+ new_boxes.append(best_box)
312
+ return new_boxes
313
+
314
+ def parse_layout_string(bbox_str):
315
+ pattern = r"\[(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+)\]\s*(\w+)"
316
+ matches = re.finditer(pattern, bbox_str)
317
+ return [([float(m.group(i)) for i in range(1, 5)], m.group(5).strip()) for m in matches]
318
+
319
+ def map_to_original_coordinates(x1, y1, x2, y2, dims: ImageDimensions) -> Tuple[int, int, int, int]:
320
  try:
321
+ top, left = (dims.padded_h - dims.original_h) // 2, (dims.padded_w - dims.original_w) // 2
322
+ orig_x1, orig_y1 = max(0, x1 - left), max(0, y1 - top)
323
+ orig_x2, orig_y2 = min(dims.original_w, x2 - left), min(dims.original_h, y2 - top)
324
+ if orig_x2 <= orig_x1: orig_x2 = min(orig_x1 + 1, dims.original_w)
325
+ if orig_y2 <= orig_y1: orig_y2 = min(orig_y1 + 1, dims.original_h)
326
+ return int(orig_x1), int(orig_y1), int(orig_x2), int(orig_y2)
 
 
 
 
327
  except Exception as e:
328
+ print(f"map_to_original_coordinates error: {str(e)}")
329
+ return 0, 0, min(100, dims.original_w), min(100, dims.original_h)
330
+
331
+ def process_coordinates(coords, padded_image, dims: ImageDimensions, previous_box=None):
 
 
 
 
 
 
 
 
 
332
  try:
333
+ x1, y1 = int(coords[0] * dims.padded_w), int(coords[1] * dims.padded_h)
334
+ x2, y2 = int(coords[2] * dims.padded_w), int(coords[3] * dims.padded_h)
335
 
336
+ x1, y1, x2, y2 = max(0, x1), max(0, y1), min(dims.padded_w, x2), min(dims.padded_h, y2)
337
+ if x2 <= x1: x2 = min(x1 + 1, dims.padded_w)
338
+ if y2 <= y1: y2 = min(y1 + 1, dims.padded_h)
339
 
340
+ x1, y1, x2, y2 = adjust_box_edges(padded_image, [[x1, y1, x2, y2]])[0]
 
 
 
 
 
 
 
 
 
341
 
342
+ if previous_box:
343
+ prev_x1, prev_y1, prev_x2, prev_y2 = previous_box
344
+ if (x1 < prev_x2 and x2 > prev_x1) and (y1 < prev_y2 and y2 > prev_y1):
345
+ y1 = min(prev_y2, dims.padded_h - 1)
346
+ if y2 <= y1: y2 = min(y1 + 1, dims.padded_h)
347
+
348
+ orig_coords = map_to_original_coordinates(x1, y1, x2, y2, dims)
349
+ return x1, y1, x2, y2, *orig_coords, [x1, y1, x2, y2]
350
  except Exception as e:
351
+ print(f"process_coordinates error: {str(e)}")
352
+ orig_coords = 0, 0, min(100, dims.original_w), min(100, dims.original_h)
353
+ return 0, 0, 100, 100, *orig_coords, [0, 0, 100, 100]
354
+
355
+ def prepare_image(image) -> Tuple[np.ndarray, ImageDimensions]:
 
 
 
 
 
 
356
  try:
357
+ image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
358
+ original_h, original_w = image_cv.shape[:2]
359
+ max_size = max(original_h, original_w)
360
+ top, bottom = (max_size - original_h) // 2, max_size - original_h - ((max_size - original_h) // 2)
361
+ left, right = (max_size - original_w) // 2, max_size - original_w - ((max_size - original_w) // 2)
362
+ padded_image = cv2.copyMakeBorder(image_cv, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(0, 0, 0))
363
+ padded_h, padded_w = padded_image.shape[:2]
364
+ dims = ImageDimensions(original_w, original_h, padded_w, padded_h)
365
+ return padded_image, dims
 
 
 
366
  except Exception as e:
367
+ print(f"prepare_image error: {str(e)}")
368
+ dims = ImageDimensions(image.width, image.height, image.width, image.height)
369
+ return np.zeros((image.height, image.width, 3), dtype=np.uint8), dims
370
+
371
+
372
+ # =================================================================================
373
+ # --- MODEL WRAPPER CLASSES ---
374
+ # =================================================================================
375
+
376
+ class DotOcrModel:
377
+ def __init__(self, device: str):
378
+ self.model, self.processor, self.device = None, None, device
379
+ self.model_id, self.model_path = "rednote-hilab/dots.ocr", "./models/dots-ocr-local"
380
+
381
+ @spaces.GPU()
382
+ def load_model(self):
383
+ if self.model is None:
384
+ print("Loading dot.ocr model...")
385
+ snapshot_download(repo_id=self.model_id, local_dir=self.model_path, local_dir_use_symlinks=False)
386
+ self.model = AutoModelForCausalLM.from_pretrained(
387
+ self.model_path, attn_implementation="flash_attention_2",
388
+ torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True
389
+ )
390
+ self.processor = AutoProcessor.from_pretrained(self.model_path, trust_remote_code=True)
391
+ print("dot.ocr model loaded.")
392
+
393
+ @staticmethod
394
+ def smart_resize(height, width, factor, min_pixels, max_pixels):
395
+ if max(height, width) / min(height, width) > 200: raise ValueError("Aspect ratio too high")
396
+ h_bar, w_bar = max(factor, round(height / factor) * factor), max(factor, round(width / factor) * factor)
397
+ if h_bar * w_bar > max_pixels:
398
+ beta = math.sqrt((height * width) / max_pixels)
399
+ h_bar, w_bar = round(height / beta / factor) * factor, round(width / beta / factor) * factor
400
+ elif h_bar * w_bar < min_pixels:
401
+ beta = math.sqrt(min_pixels / (height * width))
402
+ h_bar, w_bar = round(height * beta / factor) * factor, round(width / beta / factor) * factor
403
+ return h_bar, w_bar
404
+
405
+ def fetch_image(self, image_input, min_pixels, max_pixels):
406
+ image = image_input.convert('RGB')
407
+ height, width = self.smart_resize(image.height, image.width, IMAGE_FACTOR, min_pixels, max_pixels)
408
+ return image.resize((width, height), Image.LANCZOS)
409
 
410
+ @spaces.GPU()
411
+ def inference(self, image: Image.Image, prompt: str, max_new_tokens: int = 24000) -> str:
412
+ self.load_model()
413
+ messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt}]}]
414
+ text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
415
+ image_inputs, _ = process_vision_info(messages)
416
+ inputs = self.processor(text=[text], images=image_inputs, padding=True, return_tensors="pt").to(self.device)
417
+ with torch.no_grad():
418
+ generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False, temperature=0.1)
419
+ generated_ids_trimmed = [out[len(ins):] for ins, out in zip(inputs.input_ids, generated_ids)]
420
+ return self.processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
421
+
422
+ def process_image(self, image: Image.Image, min_pixels: int, max_pixels: int):
423
+ resized_image = self.fetch_image(image, min_pixels, max_pixels)
424
+ raw_output = self.inference(resized_image, DOT_OCR_PROMPT)
425
+ result = {'original_image': image, 'raw_output': raw_output, 'layout_result': None}
426
+ try:
427
+ layout_data = json.loads(raw_output)
428
+ result['layout_result'] = layout_data
429
+ result['processed_image'] = self.draw_layout_on_image(image, layout_data)
430
+ result['markdown_content'] = self.layoutjson2md(image, layout_data)
431
+ except (json.JSONDecodeError, KeyError) as e:
432
+ print(f"Failed to parse or process dot.ocr layout: {e}")
433
+ result['processed_image'] = image
434
+ result['markdown_content'] = f"### Error processing output\nRaw model output:\n```json\n{raw_output}\n```"
435
+ return result
436
 
437
+ def draw_layout_on_image(self, image: Image.Image, layout_data: List[Dict]) -> Image.Image:
438
+ img_copy, draw = image.copy(), ImageDraw.Draw(img_copy)
439
+ colors = {'Caption': '#FF6B6B', 'Footnote': '#4ECDC4', 'Formula': '#45B7D1', 'List-item': '#96CEB4',
440
+ 'Page-footer': '#FFEAA7', 'Page-header': '#DDA0DD', 'Picture': '#FFD93D', 'Section-header': '#6C5CE7',
441
+ 'Table': '#FD79A8', 'Text': '#74B9FF', 'Title': '#E17055'}
442
+ try: font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 15)
443
+ except: font = ImageFont.load_default()
444
+ for item in layout_data:
445
+ if 'bbox' in item and 'category' in item:
446
+ bbox, category, color = item['bbox'], item['category'], colors.get(category, '#000000')
447
+ draw.rectangle(bbox, outline=color, width=3)
448
+ label_bbox = draw.textbbox((0, 0), category, font=font)
449
+ label_width, label_height = label_bbox[2] - label_bbox[0], label_bbox[3] - label_bbox[1]
450
+ label_x, label_y = bbox[0], max(0, bbox[1] - label_height - 5)
451
+ draw.rectangle([label_x, label_y, label_x + label_width + 4, label_y + label_height + 4], fill=color)
452
+ draw.text((label_x + 2, label_y + 2), category, fill='white', font=font)
453
+ return img_copy
454
 
455
+ def layoutjson2md(self, image: Image.Image, layout_data: List[Dict]) -> str:
456
+ md_lines, sorted_items = [], sorted(layout_data, key=lambda x: (x.get('bbox', [0]*4)[1], x.get('bbox', [0]*4)[0]))
457
+ for item in sorted_items:
458
+ cat, txt, bbox = item.get('category'), item.get('text'), item.get('bbox')
459
+ if cat == 'Picture' and bbox:
460
+ try:
461
+ x1, y1, x2, y2 = max(0, int(bbox[0])), max(0, int(bbox[1])), min(image.width, int(bbox[2])), min(image.height, int(bbox[3]))
462
+ if x2 > x1 and y2 > y1:
463
+ cropped = image.crop((x1, y1, x2, y2))
464
+ buffer = BytesIO()
465
+ cropped.save(buffer, format='PNG')
466
+ img_data = base64.b64encode(buffer.getvalue()).decode()
467
+ md_lines.append(f"![Image](data:image/png;base64,{img_data})\n")
468
+ except Exception: md_lines.append("![Image](Image region detected)\n")
469
+ elif not txt: continue
470
+ elif cat == 'Title': md_lines.append(f"# {txt}\n")
471
+ elif cat == 'Section-header': md_lines.append(f"## {txt}\n")
472
+ elif cat == 'List-item': md_lines.append(f"- {txt}\n")
473
+ elif cat == 'Formula': md_lines.append(f"$$\n{txt}\n$$\n")
474
+ elif cat == 'Caption': md_lines.append(f"*{txt}*\n")
475
+ elif cat == 'Footnote': md_lines.append(f"^{txt}^\n")
476
+ elif cat in ['Text', 'Table']: md_lines.append(f"{txt}\n")
477
+ return "\n".join(md_lines)
478
+
479
+ class DolphinModel:
480
+ def __init__(self, device: str):
481
+ self.model, self.processor, self.tokenizer, self.device = None, None, None, device
482
+ self.model_id = "ByteDance/Dolphin"
483
+
484
+ @spaces.GPU()
485
+ def load_model(self):
486
+ if self.model is None:
487
+ print("Loading Dolphin model...")
488
+ self.processor = AutoProcessor.from_pretrained(self.model_id)
489
+ self.model = VisionEncoderDecoderModel.from_pretrained(self.model_id).eval().to(self.device).half()
490
+ self.tokenizer = self.processor.tokenizer
491
+ print("Dolphin model loaded.")
492
+
493
+ @spaces.GPU()
494
+ def model_chat(self, prompt, image):
495
+ self.load_model()
496
+ images = image if isinstance(image, list) else [image]
497
+ prompts = prompt if isinstance(prompt, list) else [prompt] * len(images)
498
+ batch_inputs = self.processor(images, return_tensors="pt", padding=True)
499
+ batch_pixel_values = batch_inputs.pixel_values.half().to(self.device)
500
+ prompts = [f"<s>{p} <Answer/>" for p in prompts]
501
+ batch_prompt_inputs = self.tokenizer(prompts, add_special_tokens=False, return_tensors="pt")
502
+ batch_prompt_ids = batch_prompt_inputs.input_ids.to(self.device)
503
+ batch_attention_mask = batch_prompt_inputs.attention_mask.to(self.device)
504
+ outputs = self.model.generate(
505
+ pixel_values=batch_pixel_values, decoder_input_ids=batch_prompt_ids,
506
+ decoder_attention_mask=batch_attention_mask, max_length=4096,
507
+ pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id,
508
+ use_cache=True, bad_words_ids=[[self.tokenizer.unk_token_id]],
509
+ return_dict_in_generate=True, do_sample=False, num_beams=1, repetition_penalty=1.1
510
+ )
511
+ sequences = self.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)
512
+ results = [seq.replace(p, "").replace("<pad>", "").replace("</s>", "").strip() for p, seq in zip(prompts, sequences)]
513
+ return results if isinstance(image, list) else results[0]
514
+
515
+ def process_elements(self, layout_str: str, image: Image.Image, max_batch_size: int = 16):
516
+ padded_image, dims = prepare_image(image)
517
+ layout_results = parse_layout_string(layout_str)
518
+ elements, reading_order = [], 0
519
+ for bbox, label in layout_results:
520
+ try:
521
+ coords = process_coordinates(bbox, padded_image, dims)
522
+ x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2 = coords[:8]
523
+ cropped = padded_image[y1:y2, x1:x2]
524
+ if cropped.size > 0 and cropped.shape[0] > 3 and cropped.shape[1] > 3:
525
+ pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
526
+ elements.append({"crop": pil_crop, "label": label, "bbox": [orig_x1, orig_y1, orig_x2, orig_y2], "reading_order": reading_order})
527
+ reading_order += 1
528
+ except Exception as e:
529
+ print(f"Error processing Dolphin element bbox {bbox}: {e}")
530
+
531
+ text_elems = self.process_element_batch([e for e in elements if e['label'] != 'tab' and e['label'] != 'fig'], "Read text in the image.", max_batch_size)
532
+ table_elems = self.process_element_batch([e for e in elements if e['label'] == 'tab'], "Parse the table in the image.", max_batch_size)
533
+ fig_elems = [{"label": e['label'], "bbox": e['bbox'], "text": "", "reading_order": e['reading_order']} for e in elements if e['label'] == 'fig']
534
+
535
+ all_results = sorted(text_elems + table_elems + fig_elems, key=lambda x: x['reading_order'])
536
+ return all_results
537
+
538
+ def process_element_batch(self, elements, prompt, max_batch_size=16):
539
+ results = []
540
+ for i in range(0, len(elements), max_batch_size):
541
+ batch = elements[i:i+max_batch_size]
542
+ crops = [elem["crop"] for elem in batch]
543
+ prompts = [prompt] * len(crops)
544
+ batch_results = self.model_chat(prompts, crops)
545
+ for j, res_text in enumerate(batch_results):
546
+ elem = batch[j]
547
+ results.append({"label": elem["label"], "bbox": elem["bbox"], "text": res_text.strip(), "reading_order": elem["reading_order"]})
548
+ return results
549
+
550
+ def process_image(self, image: Image.Image):
551
+ layout_output = self.model_chat("Parse the reading order of this document.", image)
552
+ recognition_results = self.process_elements(layout_output, image)
553
+ markdown_content = MarkdownConverter().convert(recognition_results)
554
+ return {
555
+ 'original_image': image, 'processed_image': image, 'markdown_content': markdown_content,
556
+ 'layout_result': recognition_results, 'raw_output': layout_output
557
+ }
558
+
559
+
560
+ # =================================================================================
561
+ # --- GRADIO UI AND EVENT HANDLERS ---
562
+ # =================================================================================
563
 
564
  def create_gradio_interface():
565
+ """Create the main Gradio interface and define all event handlers"""
566
+
567
  css = """
568
  .main-container { max-width: 1400px; margin: 0 auto; }
569
  .header-text { text-align: center; color: #2c3e50; margin-bottom: 20px; }
570
+ .process-button { border: none !important; color: white !important; font-weight: bold !important; background-color: blue !important;}
571
+ .process-button:hover { background-color: darkblue !important; transform: translateY(-2px) !important; box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; }
 
 
 
 
 
 
 
572
  .info-box { border: 1px solid #dee2e6; border-radius: 8px; padding: 15px; margin: 10px 0; }
573
  .page-info { text-align: center; padding: 8px 16px; border-radius: 20px; font-weight: bold; margin: 10px 0; }
 
 
574
  """
575
+
576
+ with gr.Blocks(theme="bethecloud/storj_theme", css=css, title="Dot.OCR Comparator") as demo:
577
  gr.HTML("""
578
  <div class="title" style="text-align: center">
579
+ <h1>Dot<span style="color: red;">●</span><strong></strong>OCR Comparator</h1>
580
  <p style="font-size: 1.1em; color: #6b7280; margin-bottom: 0.6em;">
581
  Advanced vision-language model for image/PDF to markdown document processing
582
  </p>
583
  </div>
584
  """)
585
 
586
+ with gr.Row(elem_classes=["main-container"]):
587
  with gr.Column(scale=1):
588
+ file_input = gr.File(label="Upload Image or PDF", file_types=[".jpg", ".jpeg", ".png", ".pdf"], type="filepath")
589
+
 
 
 
 
590
  with gr.Row():
591
  examples = gr.Examples(
592
  examples=["examples/sample_image1.png", "examples/sample_image2.png", "examples/sample_pdf.pdf"],
593
  inputs=file_input,
594
  label="Example Documents"
595
  )
596
+
597
+ model_choice = gr.Radio(choices=["dot.ocr", "Dolphin"], label="Select Model", value="dot.ocr")
598
+ image_preview = gr.Image(label="Preview", type="pil", interactive=False, height=400)
599
+
600
  with gr.Row():
601
+ prev_page_btn = gr.Button("◀ Previous")
602
  page_info = gr.HTML('<div class="page-info">No file loaded</div>')
603
+ next_page_btn = gr.Button("Next ▶")
604
+
605
+ with gr.Accordion("Advanced Settings (dot.ocr only)", open=False):
606
+ min_pixels = gr.Number(value=MIN_PIXELS, label="Min Pixels", step=1)
607
+ max_pixels = gr.Number(value=MAX_PIXELS, label="Max Pixels", step=1)
608
+
609
+ with gr.Row():
610
+ process_btn = gr.Button("🚀 Process Document", variant="primary", elem_classes=["process-button"], scale=2)
611
+ clear_btn = gr.Button("🗑️ Clear", variant="secondary")
612
 
613
  with gr.Column(scale=2):
614
  with gr.Tabs():
 
 
615
  with gr.Tab("📝 Extracted Content"):
616
+ markdown_output = gr.Markdown(value="Click 'Process Document' to see extracted content...", elem_id="markdown_output")
617
+ with gr.Tab("🖼️ Processed Image"):
618
+ processed_image_output = gr.Image(label="Image with Layout Detection", type="pil", interactive=False)
619
  with gr.Tab("📋 Layout JSON"):
620
+ json_output = gr.JSON(label="Layout Analysis Results")
621
+
622
+ def load_file_for_preview(file_path: str) -> Tuple[List[Image.Image], str]:
623
+ images = []
624
+ if not file_path or not os.path.exists(file_path): return [], "No file selected"
625
+ try:
626
+ ext = os.path.splitext(file_path)[1].lower()
627
+ if ext == '.pdf':
628
+ doc = fitz.open(file_path)
629
+ for page in doc:
630
+ pix = page.get_pixmap(matrix=fitz.Matrix(2.0, 2.0))
631
+ images.append(Image.open(BytesIO(pix.tobytes("ppm"))).convert('RGB'))
632
+ doc.close()
633
+ elif ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']:
634
+ images.append(Image.open(file_path).convert('RGB'))
635
+ return images, f"Page 1 / {len(images)}"
636
+ except Exception as e:
637
+ print(f"Error loading file for preview: {e}")
638
+ return [], f"Error loading file: {e}"
639
+
640
+ def handle_file_upload(file_path):
641
+ global PDF_CACHE
642
+ images, page_info_str = load_file_for_preview(file_path)
643
+ if not images:
644
+ return None, page_info_str
645
+ PDF_CACHE = {
646
+ "images": images, "current_page": 0, "total_pages": len(images),
647
+ "is_parsed": False, "results": [], "model_used": None
648
+ }
649
+ return images[0], f'<div class="page-info">{page_info_str}</div>'
650
+
651
+ def process_document(file_path, model_name, min_pix, max_pix):
652
+ global PDF_CACHE
653
+ if not file_path or not PDF_CACHE["images"]:
654
+ return "Please upload a file first.", None, None
655
 
656
+ if model_name not in MODELS:
657
+ if model_name == 'dot.ocr': MODELS[model_name] = DotOcrModel(DEVICE)
658
+ elif model_name == 'Dolphin': MODELS[model_name] = DolphinModel(DEVICE)
659
+ model = MODELS[model_name]
660
+
661
+ all_results, all_markdown = [], []
662
+ for i, img in enumerate(PDF_CACHE["images"]):
663
+ gr.Info(f"Processing page {i+1}/{len(PDF_CACHE['images'])} with {model_name}...")
664
+ if model_name == 'dot.ocr':
665
+ result = model.process_image(img, int(min_pix), int(max_pix))
666
+ else: # Dolphin
667
+ result = model.process_image(img)
668
+ all_results.append(result)
669
+ if result.get('markdown_content'):
670
+ all_markdown.append(f"### Page {i+1}\n\n{result['markdown_content']}")
671
+
672
+ PDF_CACHE.update({"results": all_results, "is_parsed": True, "model_used": model_name})
673
+ if not all_results: return "Processing failed.", None, None
674
 
675
+ first_result = all_results[0]
676
+ combined_md = "\n\n---\n\n".join(all_markdown)
 
677
 
678
+ return combined_md, first_result.get('processed_image'), first_result.get('layout_result')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
679
 
680
+ def turn_page(direction):
681
+ global PDF_CACHE
682
+ if not PDF_CACHE["images"] or not PDF_CACHE["is_parsed"]:
683
+ return None, '<div class="page-info">No file parsed</div>', "No results yet", None, None
684
+
685
+ if direction == "prev": PDF_CACHE["current_page"] = max(0, PDF_CACHE["current_page"] - 1)
686
+ else: PDF_CACHE["current_page"] = min(PDF_CACHE["total_pages"] - 1, PDF_CACHE["current_page"] + 1)
687
+
688
+ idx = PDF_CACHE["current_page"]
689
+ page_info_html = f'<div class="page-info">Page {idx + 1} / {PDF_CACHE["total_pages"]}</div>'
690
+ preview_img = PDF_CACHE["images"][idx]
691
+ result = PDF_CACHE["results"][idx]
692
+
693
+ all_md = [f"### Page {i+1}\n\n{res.get('markdown_content', '')}" for i, res in enumerate(PDF_CACHE["results"])]
694
+ md_content = "\n\n---\n\n".join(all_md) if PDF_CACHE["total_pages"] > 1 else result.get('markdown_content', 'No content')
695
+
696
+ return preview_img, page_info_html, md_content, result.get('processed_image'), result.get('layout_result')
697
 
698
  def clear_all():
699
+ global PDF_CACHE
700
+ PDF_CACHE = {"images": [], "current_page": 0, "total_pages": 0, "is_parsed": False, "results": [], "model_used": None}
701
+ return None, None, '<div class="page-info">No file loaded</div>', "Click 'Process Document' to see extracted content...", None, None
702
+
703
+ # --- Wire UI components ---
704
+ file_input.change(handle_file_upload, inputs=file_input, outputs=[image_preview, page_info])
705
+ process_btn.click(
706
+ process_document,
707
+ inputs=[file_input, model_choice, min_pixels, max_pixels],
708
+ outputs=[markdown_output, processed_image_output, json_output]
709
+ )
710
+ prev_page_btn.click(lambda: turn_page("prev"), outputs=[image_preview, page_info, markdown_output, processed_image_output, json_output])
711
+ next_page_btn.click(lambda: turn_page("next"), outputs=[image_preview, page_info, markdown_output, processed_image_output, json_output])
712
+ clear_btn.click(clear_all, outputs=[file_input, image_preview, page_info, markdown_output, processed_image_output, json_output])
713
+
714
  return demo
715
 
716
  if __name__ == "__main__":
717
+ # Create example directory if it doesn't exist
718
+ if not os.path.exists("examples"):
719
+ os.makedirs("examples")
720
+ print("Created 'examples' directory. Please add sample images/PDFs there.")
721
+
722
+ app = create_gradio_interface()
723
+ app.queue().launch(debug=True, show_error=True)