prithivMLmods commited on
Commit
db537bc
·
verified ·
1 Parent(s): 7b7c3b2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +756 -647
app.py CHANGED
@@ -1,583 +1,724 @@
1
- # app.py
2
- # All code combined into a single file for convenience.
3
-
4
- # --- Imports ---
5
- import spaces
6
- import json
7
- import math
8
  import os
9
- import traceback
10
- from io import BytesIO
11
- from typing import Any, Dict, List, Optional, Tuple
12
- import re
13
- import base64
14
- import copy
15
- from dataclasses import dataclass
16
- #import flash_attn_2_cuda as flash_attn_gpu
17
-
18
- # Vision and ML Libraries
19
- import fitz # PyMuPDF
20
  import gradio as gr
21
- import requests
 
22
  import torch
23
- #import subprocess
24
- from huggingface_hub import snapshot_download
25
  from PIL import Image, ImageDraw, ImageFont
26
  from transformers import AutoModelForCausalLM, AutoProcessor, VisionEncoderDecoderModel
 
27
  from qwen_vl_utils import process_vision_info
28
-
29
- # Image Processing Libraries
30
- import cv2
31
- import numpy as np
32
- import albumentations as alb
33
- from albumentations.pytorch import ToTensorV2
34
- from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
35
-
36
-
37
-
38
- # --- Constants & Global State ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  MIN_PIXELS = 3136
40
  MAX_PIXELS = 11289600
41
  IMAGE_FACTOR = 28
42
- DOT_OCR_PROMPT = """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
 
 
 
43
  1. Bbox format: [x1, y1, x2, y2]
 
44
  2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title'].
 
45
  3. Text Extraction & Formatting Rules:
46
  - Picture: For the 'Picture' category, the text field should be omitted.
47
  - Formula: Format its text as LaTeX.
48
  - Table: Format its text as HTML.
49
  - All Others (Text, Title, etc.): Format their text as Markdown.
 
50
  4. Constraints:
51
  - The output text must be the original text from the image, with no translation.
52
  - All layout elements must be sorted according to human reading order.
 
53
  5. Final Output: The entire output must be a single JSON object.
54
  """
55
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
56
- PDF_CACHE = {
57
- "images": [],
58
- "current_page": 0,
59
- "total_pages": 0,
60
- "file_type": None,
61
- "is_parsed": False,
62
- "results": [],
63
- "model_used": None,
64
- }
65
- MODELS = {}
66
 
67
- # =================================================================================
68
- # --- UTILITY FUNCTIONS (from markdown_utils.py and utils.py) ---
69
- # =================================================================================
70
-
71
- # --- Markdown Conversion Utilities ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- def extract_table_from_html(html_string):
74
- """Extract and clean table tags from HTML string"""
 
75
  try:
76
- table_pattern = re.compile(r'<table.*?>.*?</table>', re.DOTALL)
77
- tables = table_pattern.findall(html_string)
78
- tables = [re.sub(r'<table[^>]*>', '<table>', table) for table in tables]
79
- return '\n'.join(tables)
 
 
 
 
 
80
  except Exception as e:
81
- print(f"extract_table_from_html error: {str(e)}")
82
- return f"<table><tr><td>Error extracting table: {str(e)}</td></tr></table>"
83
-
84
-
85
- class MarkdownConverter:
86
- """Convert structured recognition results to Markdown format"""
87
- def __init__(self):
88
- self.heading_levels = {'title': '#', 'sec': '##', 'sub_sec': '###'}
89
- self.special_labels = {'tab', 'fig', 'title', 'sec', 'sub_sec', 'list', 'formula', 'reference', 'alg'}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
- def try_remove_newline(self, text: str) -> str:
92
- try:
93
- text = text.strip().replace('-\n', '')
94
- def is_chinese(char): return '\u4e00' <= char <= '\u9fff'
95
- lines, processed_lines = text.split('\n'), []
96
- for i in range(len(lines)-1):
97
- current_line, next_line = lines[i].strip(), lines[i+1].strip()
98
- if current_line:
99
- if next_line:
100
- if is_chinese(current_line[-1]) and is_chinese(next_line[0]):
101
- processed_lines.append(current_line)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  else:
103
- processed_lines.append(current_line + ' ')
104
- else:
105
- processed_lines.append(current_line + '\n')
 
106
  else:
107
- processed_lines.append('\n')
108
- if lines and lines[-1].strip():
109
- processed_lines.append(lines[-1].strip())
110
- return ''.join(processed_lines)
111
- except Exception as e:
112
- print(f"try_remove_newline error: {str(e)}")
113
- return text
114
-
115
- def _handle_text(self, text: str) -> str:
116
- try:
117
- if not text: return ""
118
- if text.strip().startswith("\\begin{array}") and text.strip().endswith("\\end{array}"):
119
- text = "$$" + text + "$$"
120
- elif ("_{" in text or "^{" in text or "\\" in text or "_ {" in text or "^ {" in text) and ("$" not in text) and ("\\begin" not in text):
121
- text = "$" + text + "$"
122
- text = self._process_formulas_in_text(text)
123
- text = self.try_remove_newline(text)
124
- return text
125
- except Exception as e:
126
- print(f"_handle_text error: {str(e)}")
127
- return text
128
-
129
- def _process_formulas_in_text(self, text: str) -> str:
130
- try:
131
- delimiters = [('$$', '$$'), ('\\[', '\\]'), ('$', '$'), ('\\(', '\\)')]
132
- result = text
133
- for start_delim, end_delim in delimiters:
134
- current_pos, processed_parts = 0, []
135
- while current_pos < len(result):
136
- start_pos = result.find(start_delim, current_pos)
137
- if start_pos == -1:
138
- processed_parts.append(result[current_pos:])
139
- break
140
- processed_parts.append(result[current_pos:start_pos])
141
- end_pos = result.find(end_delim, start_pos + len(start_delim))
142
- if end_pos == -1:
143
- processed_parts.append(result[start_pos:])
144
- break
145
- formula_content = result[start_pos + len(start_delim):end_pos]
146
- processed_formula = formula_content.replace('\n', ' \\\\ ')
147
- processed_parts.append(f"{start_delim}{processed_formula}{end_delim}")
148
- current_pos = end_pos + len(end_delim)
149
- result = ''.join(processed_parts)
150
- return result
151
- except Exception as e:
152
- print(f"_process_formulas_in_text error: {str(e)}")
153
- return text
154
-
155
- def _remove_newline_in_heading(self, text: str) -> str:
156
- try:
157
- def is_chinese(char): return '\u4e00' <= char <= '\u9fff'
158
- return text.replace('\n', '') if any(is_chinese(char) for char in text) else text.replace('\n', ' ')
159
- except Exception as e:
160
- print(f"_remove_newline_in_heading error: {str(e)}")
161
- return text
162
-
163
- def _handle_heading(self, text: str, label: str) -> str:
164
- try:
165
- level = self.heading_levels.get(label, '#')
166
- text = self._remove_newline_in_heading(text.strip())
167
- text = self._handle_text(text)
168
- return f"{level} {text}\n\n"
169
- except Exception as e:
170
- print(f"_handle_heading error: {str(e)}")
171
- return f"# Error processing heading: {text}\n\n"
172
-
173
- def _handle_list_item(self, text: str) -> str:
174
- try:
175
- return f"- {text.strip()}\n"
176
- except Exception as e:
177
- print(f"_handle_list_item error: {str(e)}")
178
- return f"- Error processing list item: {text}\n"
179
-
180
- def _handle_figure(self, text: str, section_count: int) -> str:
181
- try:
182
- if not text.strip():
183
- return f"![Figure {section_count}](data:image/png;base64,)\n\n"
184
- if text.startswith("data:image/"):
185
- return f"![Figure {section_count}]({text})\n\n"
186
- else:
187
- return f"![Figure {section_count}](data:image/png;base64,{text})\n\n"
188
- except Exception as e:
189
- print(f"_handle_figure error: {str(e)}")
190
- return f"*[Error processing figure: {str(e)}]*\n\n"
191
-
192
- def _handle_table(self, text: str) -> str:
193
- try:
194
- if '<table' in text.lower() or '<tr' in text.lower():
195
- return extract_table_from_html(text) + "\n\n"
196
- else:
197
- table_lines = text.split('\n')
198
- if not table_lines: return "\n\n"
199
- col_count = len(table_lines[0].split()) if table_lines[0] else 1
200
- header = '| ' + ' | '.join(table_lines[0].split()) + ' |'
201
- separator = '| ' + ' | '.join(['---'] * col_count) + ' |'
202
- rows = [f"| {' | '.join(line.split())} |" for line in table_lines[1:]]
203
- return '\n'.join([header, separator] + rows) + '\n\n'
204
- except Exception as e:
205
- print(f"_handle_table error: {str(e)}")
206
- return f"*[Error processing table: {str(e)}]*\n\n"
207
-
208
- def _handle_algorithm(self, text: str) -> str:
209
- try:
210
- text = re.sub(r'\\begin\{algorithm\}(.*?)\\end\{algorithm\}', r'\1', text, flags=re.DOTALL)
211
- text = text.replace('\\begin{algorithmic}', '').replace('\\end{algorithmic}', '')
212
- caption_match = re.search(r'\\caption\{(.*?)\}', text)
213
- caption = f"**{caption_match.group(1)}**\n\n" if caption_match else ""
214
- algorithm_text = re.sub(r'\\caption\{.*?\}', '', text).strip()
215
- return f"{caption}```\n{algorithm_text}\n```\n\n"
216
- except Exception as e:
217
- print(f"_handle_algorithm error: {str(e)}")
218
- return f"*[Error processing algorithm: {str(e)}]*\n\n{text}\n\n"
219
-
220
- def _handle_formula(self, text: str) -> str:
221
- try:
222
- processed_text = self._process_formulas_in_text(text)
223
- if '$$' not in processed_text and '\\[' not in processed_text:
224
- processed_text = f'$${processed_text}$$'
225
- return f"{processed_text}\n\n"
226
- except Exception as e:
227
- print(f"_handle_formula error: {str(e)}")
228
- return f"*[Error processing formula: {str(e)}]*\n\n"
229
-
230
- def convert(self, recognition_results: List[Dict[str, Any]]) -> str:
231
- markdown_content = []
232
- for i, result in enumerate(recognition_results):
233
- try:
234
- label, text = result.get('label', ''), result.get('text', '').strip()
235
- if label == 'fig':
236
- markdown_content.append(self._handle_figure(text, i))
237
- continue
238
- if not text: continue
239
-
240
- if label in {'title', 'sec', 'sub_sec'}:
241
- markdown_content.append(self._handle_heading(text, label))
242
- elif label == 'list':
243
- markdown_content.append(self._handle_list_item(text))
244
- elif label == 'tab':
245
- markdown_content.append(self._handle_table(text))
246
- elif label == 'alg':
247
- markdown_content.append(self._handle_algorithm(text))
248
- elif label == 'formula':
249
- markdown_content.append(self._handle_formula(text))
250
- elif label not in self.special_labels:
251
- markdown_content.append(f"{self._handle_text(text)}\n\n")
252
- except Exception as e:
253
- print(f"Error processing item {i}: {str(e)}")
254
- markdown_content.append(f"*[Error processing content]*\n\n")
255
- return self._post_process(''.join(markdown_content))
256
-
257
- def _post_process(self, md: str) -> str:
258
- try:
259
- md = re.sub(r'\\author\{(.*?)\}', lambda m: self._handle_text(m.group(1)), md, flags=re.DOTALL)
260
- md = re.sub(r'\$(\\author\{.*?\})\$', lambda m: self._handle_text(re.search(r'\\author\{(.*?)\}', m.group(1), re.DOTALL).group(1)), md, flags=re.DOTALL)
261
- md = re.sub(r'\\begin\{abstract\}(.*?)\\end\{abstract\}', r'**Abstract** \1', md, flags=re.DOTALL)
262
- md = re.sub(r'\\begin\{abstract\}', r'**Abstract**', md)
263
- md = re.sub(r'\\eqno\{\((.*?)\)\}', r'\\tag{\1}', md)
264
- md = md.replace("\[ \\\\", "$$ \\\\").replace("\\\\ \]", "\\\\ $$")
265
- md = re.sub(r'_ {', r'_{', md)
266
- md = re.sub(r'^ {', r'^{', md)
267
- md = re.sub(r'\n{3,}', r'\n\n', md)
268
- return md
269
- except Exception as e:
270
- print(f"_post_process error: {str(e)}")
271
- return md
272
-
273
- # --- General Processing Utilities ---
274
- @dataclass
275
- class ImageDimensions:
276
- original_w: int
277
- original_h: int
278
- padded_w: int
279
- padded_h: int
280
-
281
- def adjust_box_edges(image, boxes: List[List[float]], max_pixels=15, threshold=0.2):
282
- if isinstance(image, str):
283
- image = cv2.imread(image)
284
- img_h, img_w = image.shape[:2]
285
- new_boxes = []
286
- for box in boxes:
287
- best_box = copy.deepcopy(box)
288
-
289
- def check_edge(img, current_box, i, is_vertical):
290
- edge = current_box[i]
291
- gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
292
- _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
293
- if is_vertical:
294
- line = binary[current_box[1] : current_box[3] + 1, edge]
295
  else:
296
- line = binary[edge, current_box[0] : current_box[2] + 1]
297
- transitions = np.abs(np.diff(line))
298
- return np.sum(transitions) / len(transitions)
299
-
300
- edges = [(0, -1, True), (2, 1, True), (1, -1, False), (3, 1, False)]
301
- current_box = copy.deepcopy(box)
302
- current_box = [min(max(c, 0), d - 1) for c, d in zip(current_box, [img_w, img_h, img_w, img_h])]
303
-
304
- for i, direction, is_vertical in edges:
305
- best_score = check_edge(image, current_box, i, is_vertical)
306
- if best_score <= threshold: continue
307
- for _ in range(max_pixels):
308
- current_box[i] += direction
309
- dim = img_w if i in [0, 2] else img_h
310
- current_box[i] = min(max(current_box[i], 0), dim - 1)
311
- score = check_edge(image, current_box, i, is_vertical)
312
- if score < best_score:
313
- best_score, best_box = score, copy.deepcopy(current_box)
314
- if score <= threshold: break
315
- new_boxes.append(best_box)
316
- return new_boxes
317
-
318
- def parse_layout_string(bbox_str):
319
- pattern = r"\[(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+)\]\s*(\w+)"
320
- matches = re.finditer(pattern, bbox_str)
321
- return [([float(m.group(i)) for i in range(1, 5)], m.group(5).strip()) for m in matches]
322
-
323
- def map_to_original_coordinates(x1, y1, x2, y2, dims: ImageDimensions) -> Tuple[int, int, int, int]:
324
- try:
325
- top, left = (dims.padded_h - dims.original_h) // 2, (dims.padded_w - dims.original_w) // 2
326
- orig_x1, orig_y1 = max(0, x1 - left), max(0, y1 - top)
327
- orig_x2, orig_y2 = min(dims.original_w, x2 - left), min(dims.original_h, y2 - top)
328
- if orig_x2 <= orig_x1: orig_x2 = min(orig_x1 + 1, dims.original_w)
329
- if orig_y2 <= orig_y1: orig_y2 = min(orig_y1 + 1, dims.original_h)
330
- return int(orig_x1), int(orig_y1), int(orig_x2), int(orig_y2)
331
  except Exception as e:
332
- print(f"map_to_original_coordinates error: {str(e)}")
333
- return 0, 0, min(100, dims.original_w), min(100, dims.original_h)
 
 
 
 
 
 
 
 
 
 
 
334
 
335
- def process_coordinates(coords, padded_image, dims: ImageDimensions, previous_box=None):
 
 
336
  try:
337
- x1, y1 = int(coords[0] * dims.padded_w), int(coords[1] * dims.padded_h)
338
- x2, y2 = int(coords[2] * dims.padded_w), int(coords[3] * dims.padded_h)
339
-
340
- x1, y1, x2, y2 = max(0, x1), max(0, y1), min(dims.padded_w, x2), min(dims.padded_h, y2)
341
- if x2 <= x1: x2 = min(x1 + 1, dims.padded_w)
342
- if y2 <= y1: y2 = min(y1 + 1, dims.padded_h)
343
-
344
- x1, y1, x2, y2 = adjust_box_edges(padded_image, [[x1, y1, x2, y2]])[0]
345
-
346
- if previous_box:
347
- prev_x1, prev_y1, prev_x2, prev_y2 = previous_box
348
- if (x1 < prev_x2 and x2 > prev_x1) and (y1 < prev_y2 and y2 > prev_y1):
349
- y1 = min(prev_y2, dims.padded_h - 1)
350
- if y2 <= y1: y2 = min(y1 + 1, dims.padded_h)
351
-
352
- orig_coords = map_to_original_coordinates(x1, y1, x2, y2, dims)
353
- return x1, y1, x2, y2, *orig_coords, [x1, y1, x2, y2]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
354
  except Exception as e:
355
- print(f"process_coordinates error: {str(e)}")
356
- orig_coords = 0, 0, min(100, dims.original_w), min(100, dims.original_h)
357
- return 0, 0, 100, 100, *orig_coords, [0, 0, 100, 100]
358
 
359
- def prepare_image(image) -> Tuple[np.ndarray, ImageDimensions]:
 
360
  try:
361
- image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
362
- original_h, original_w = image_cv.shape[:2]
363
- max_size = max(original_h, original_w)
364
- top, bottom = (max_size - original_h) // 2, max_size - original_h - ((max_size - original_h) // 2)
365
- left, right = (max_size - original_w) // 2, max_size - original_w - ((max_size - original_w) // 2)
366
- padded_image = cv2.copyMakeBorder(image_cv, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(0, 0, 0))
367
- padded_h, padded_w = padded_image.shape[:2]
368
- dims = ImageDimensions(original_w, original_h, padded_w, padded_h)
369
- return padded_image, dims
370
- except Exception as e:
371
- print(f"prepare_image error: {str(e)}")
372
- dims = ImageDimensions(image.width, image.height, image.width, image.height)
373
- return np.zeros((image.height, image.width, 3), dtype=np.uint8), dims
374
-
375
-
376
- # =================================================================================
377
- # --- MODEL WRAPPER CLASSES ---
378
- # =================================================================================
379
-
380
- class DotOcrModel:
381
- def __init__(self, device: str):
382
- self.model, self.processor, self.device = None, None, device
383
- self.model_id, self.model_path = "rednote-hilab/dots.ocr", "./models/dots-ocr-local"
384
-
385
- @spaces.GPU()
386
- def load_model(self):
387
- if self.model is None:
388
- print("Loading dot.ocr model...")
389
- snapshot_download(repo_id=self.model_id, local_dir=self.model_path, local_dir_use_symlinks=False)
390
- self.model = AutoModelForCausalLM.from_pretrained(
391
- self.model_path, attn_implementation="flash_attention_2",
392
- torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True
393
- )
394
- self.processor = AutoProcessor.from_pretrained(self.model_path, trust_remote_code=True)
395
- print("dot.ocr model loaded.")
396
-
397
- @staticmethod
398
- def smart_resize(height, width, factor, min_pixels, max_pixels):
399
- if max(height, width) / min(height, width) > 200: raise ValueError("Aspect ratio too high")
400
- h_bar, w_bar = max(factor, round(height / factor) * factor), max(factor, round(width / factor) * factor)
401
- if h_bar * w_bar > max_pixels:
402
- beta = math.sqrt((height * width) / max_pixels)
403
- h_bar, w_bar = round(height / beta / factor) * factor, round(width / beta / factor) * factor
404
- elif h_bar * w_bar < min_pixels:
405
- beta = math.sqrt(min_pixels / (height * width))
406
- h_bar, w_bar = round(height * beta / factor) * factor, round(width / beta / factor) * factor
407
- return h_bar, w_bar
408
-
409
- def fetch_image(self, image_input, min_pixels, max_pixels):
410
- image = image_input.convert('RGB')
411
- height, width = self.smart_resize(image.height, image.width, IMAGE_FACTOR, min_pixels, max_pixels)
412
- return image.resize((width, height), Image.LANCZOS)
413
-
414
- @spaces.GPU()
415
- def inference(self, image: Image.Image, prompt: str, max_new_tokens: int = 24000) -> str:
416
- self.load_model()
417
- messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt}]}]
418
- text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
419
- image_inputs, _ = process_vision_info(messages)
420
- inputs = self.processor(text=[text], images=image_inputs, padding=True, return_tensors="pt").to(self.device)
421
- with torch.no_grad():
422
- generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False, temperature=0.1)
423
- generated_ids_trimmed = [out[len(ins):] for ins, out in zip(inputs.input_ids, generated_ids)]
424
- return self.processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
425
-
426
- def process_image(self, image: Image.Image, min_pixels: int, max_pixels: int):
427
- resized_image = self.fetch_image(image, min_pixels, max_pixels)
428
- raw_output = self.inference(resized_image, DOT_OCR_PROMPT)
429
- result = {'original_image': image, 'raw_output': raw_output, 'layout_result': None}
430
  try:
431
  layout_data = json.loads(raw_output)
432
  result['layout_result'] = layout_data
433
- result['processed_image'] = self.draw_layout_on_image(image, layout_data)
434
- result['markdown_content'] = self.layoutjson2md(image, layout_data)
435
- except (json.JSONDecodeError, KeyError) as e:
436
- print(f"Failed to parse or process dot.ocr layout: {e}")
437
- result['processed_image'] = image
438
- result['markdown_content'] = f"### Error processing output\nRaw model output:\n```json\n{raw_output}\n```"
 
439
  return result
440
-
441
- def draw_layout_on_image(self, image: Image.Image, layout_data: List[Dict]) -> Image.Image:
442
- img_copy, draw = image.copy(), ImageDraw.Draw(img_copy)
443
- colors = {'Caption': '#FF6B6B', 'Footnote': '#4ECDC4', 'Formula': '#45B7D1', 'List-item': '#96CEB4',
444
- 'Page-footer': '#FFEAA7', 'Page-header': '#DDA0DD', 'Picture': '#FFD93D', 'Section-header': '#6C5CE7',
445
- 'Table': '#FD79A8', 'Text': '#74B9FF', 'Title': '#E17055'}
446
- try: font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 15)
447
- except: font = ImageFont.load_default()
448
- for item in layout_data:
449
- if 'bbox' in item and 'category' in item:
450
- bbox, category, color = item['bbox'], item['category'], colors.get(category, '#000000')
451
- draw.rectangle(bbox, outline=color, width=3)
452
- label_bbox = draw.textbbox((0, 0), category, font=font)
453
- label_width, label_height = label_bbox[2] - label_bbox[0], label_bbox[3] - label_bbox[1]
454
- label_x, label_y = bbox[0], max(0, bbox[1] - label_height - 5)
455
- draw.rectangle([label_x, label_y, label_x + label_width + 4, label_y + label_height + 4], fill=color)
456
- draw.text((label_x + 2, label_y + 2), category, fill='white', font=font)
457
- return img_copy
458
-
459
- def layoutjson2md(self, image: Image.Image, layout_data: List[Dict]) -> str:
460
- md_lines, sorted_items = [], sorted(layout_data, key=lambda x: (x.get('bbox', [0]*4)[1], x.get('bbox', [0]*4)[0]))
461
- for item in sorted_items:
462
- cat, txt, bbox = item.get('category'), item.get('text'), item.get('bbox')
463
- if cat == 'Picture' and bbox:
464
- try:
465
- x1, y1, x2, y2 = max(0, int(bbox[0])), max(0, int(bbox[1])), min(image.width, int(bbox[2])), min(image.height, int(bbox[3]))
466
- if x2 > x1 and y2 > y1:
467
- cropped = image.crop((x1, y1, x2, y2))
468
- buffer = BytesIO()
469
- cropped.save(buffer, format='PNG')
470
- img_data = base64.b64encode(buffer.getvalue()).decode()
471
- md_lines.append(f"![Image](data:image/png;base64,{img_data})\n")
472
- except Exception: md_lines.append("![Image](Image region detected)\n")
473
- elif not txt: continue
474
- elif cat == 'Title': md_lines.append(f"# {txt}\n")
475
- elif cat == 'Section-header': md_lines.append(f"## {txt}\n")
476
- elif cat == 'List-item': md_lines.append(f"- {txt}\n")
477
- elif cat == 'Formula': md_lines.append(f"$$\n{txt}\n$$\n")
478
- elif cat == 'Caption': md_lines.append(f"*{txt}*\n")
479
- elif cat == 'Footnote': md_lines.append(f"^{txt}^\n")
480
- elif cat in ['Text', 'Table']: md_lines.append(f"{txt}\n")
481
- return "\n".join(md_lines)
482
-
483
- class DolphinModel:
484
- def __init__(self, device: str):
485
- self.model, self.processor, self.tokenizer, self.device = None, None, None, device
486
- self.model_id = "ByteDance/Dolphin"
487
-
488
- @spaces.GPU()
489
- def load_model(self):
490
- if self.model is None:
491
- print("Loading Dolphin model...")
492
- self.processor = AutoProcessor.from_pretrained(self.model_id)
493
- self.model = VisionEncoderDecoderModel.from_pretrained(self.model_id).eval().to(self.device).half()
494
- self.tokenizer = self.processor.tokenizer
495
- print("Dolphin model loaded.")
496
-
497
- @spaces.GPU()
498
- def model_chat(self, prompt, image):
499
- self.load_model()
500
- images = image if isinstance(image, list) else [image]
501
- prompts = prompt if isinstance(prompt, list) else [prompt] * len(images)
502
- batch_inputs = self.processor(images, return_tensors="pt", padding=True)
503
- batch_pixel_values = batch_inputs.pixel_values.half().to(self.device)
504
- prompts = [f"<s>{p} <Answer/>" for p in prompts]
505
- batch_prompt_inputs = self.tokenizer(prompts, add_special_tokens=False, return_tensors="pt")
506
- batch_prompt_ids = batch_prompt_inputs.input_ids.to(self.device)
507
- batch_attention_mask = batch_prompt_inputs.attention_mask.to(self.device)
508
- outputs = self.model.generate(
509
- pixel_values=batch_pixel_values, decoder_input_ids=batch_prompt_ids,
510
- decoder_attention_mask=batch_attention_mask, max_length=4096,
511
- pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id,
512
- use_cache=True, bad_words_ids=[[self.tokenizer.unk_token_id]],
513
- return_dict_in_generate=True, do_sample=False, num_beams=1, repetition_penalty=1.1
514
- )
515
- sequences = self.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)
516
- results = [seq.replace(p, "").replace("<pad>", "").replace("</s>", "").strip() for p, seq in zip(prompts, sequences)]
517
- return results if isinstance(image, list) else results[0]
518
-
519
- def process_elements(self, layout_str: str, image: Image.Image, max_batch_size: int = 16):
520
- padded_image, dims = prepare_image(image)
521
- layout_results = parse_layout_string(layout_str)
522
- elements, reading_order = [], 0
523
- for bbox, label in layout_results:
524
- try:
525
- coords = process_coordinates(bbox, padded_image, dims)
526
- x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2 = coords[:8]
527
- cropped = padded_image[y1:y2, x1:x2]
528
- if cropped.size > 0 and cropped.shape[0] > 3 and cropped.shape[1] > 3:
529
- pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
530
- elements.append({"crop": pil_crop, "label": label, "bbox": [orig_x1, orig_y1, orig_x2, orig_y2], "reading_order": reading_order})
531
- reading_order += 1
532
- except Exception as e:
533
- print(f"Error processing Dolphin element bbox {bbox}: {e}")
534
-
535
- text_elems = self.process_element_batch([e for e in elements if e['label'] != 'tab' and e['label'] != 'fig'], "Read text in the image.", max_batch_size)
536
- table_elems = self.process_element_batch([e for e in elements if e['label'] == 'tab'], "Parse the table in the image.", max_batch_size)
537
- fig_elems = [{"label": e['label'], "bbox": e['bbox'], "text": "", "reading_order": e['reading_order']} for e in elements if e['label'] == 'fig']
538
-
539
- all_results = sorted(text_elems + table_elems + fig_elems, key=lambda x: x['reading_order'])
540
- return all_results
541
-
542
- def process_element_batch(self, elements, prompt, max_batch_size=16):
543
- results = []
544
- for i in range(0, len(elements), max_batch_size):
545
- batch = elements[i:i+max_batch_size]
546
- crops = [elem["crop"] for elem in batch]
547
- prompts = [prompt] * len(crops)
548
- batch_results = self.model_chat(prompts, crops)
549
- for j, res_text in enumerate(batch_results):
550
- elem = batch[j]
551
- results.append({"label": elem["label"], "bbox": elem["bbox"], "text": res_text.strip(), "reading_order": elem["reading_order"]})
552
- return results
553
-
554
- def process_image(self, image: Image.Image):
555
- layout_output = self.model_chat("Parse the reading order of this document.", image)
556
- recognition_results = self.process_elements(layout_output, image)
557
- markdown_content = MarkdownConverter().convert(recognition_results)
558
  return {
559
- 'original_image': image, 'processed_image': image, 'markdown_content': markdown_content,
560
- 'layout_result': recognition_results, 'raw_output': layout_output
 
 
 
561
  }
562
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
563
 
564
- # =================================================================================
565
- # --- GRADIO UI AND EVENT HANDLERS ---
566
- # =================================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
567
 
568
  def create_gradio_interface():
569
- """Create the main Gradio interface and define all event handlers"""
570
-
571
  css = """
572
  .main-container { max-width: 1400px; margin: 0 auto; }
573
  .header-text { text-align: center; color: #2c3e50; margin-bottom: 20px; }
574
- .process-button { border: none !important; color: white !important; font-weight: bold !important; background-color: blue !important;}
575
- .process-button:hover { background-color: darkblue !important; transform: translateY(-2px) !important; box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; }
 
 
 
 
 
 
 
576
  .info-box { border: 1px solid #dee2e6; border-radius: 8px; padding: 15px; margin: 10px 0; }
577
  .page-info { text-align: center; padding: 8px 16px; border-radius: 20px; font-weight: bold; margin: 10px 0; }
 
 
578
  """
579
-
580
- with gr.Blocks(theme="bethecloud/storj_theme", css=css, title="Dot.OCR Comparator") as demo:
581
  gr.HTML("""
582
  <div class="title" style="text-align: center">
583
  <h1>Dot<span style="color: red;">●</span><strong></strong>OCR Comparator</h1>
@@ -586,142 +727,110 @@ def create_gradio_interface():
586
  </p>
587
  </div>
588
  """)
589
-
590
- with gr.Row(elem_classes=["main-container"]):
591
  with gr.Column(scale=1):
592
- file_input = gr.File(label="Upload Image or PDF", file_types=[".jpg", ".jpeg", ".png", ".pdf"], type="filepath")
593
-
 
 
 
 
 
 
 
 
594
  with gr.Row():
595
  examples = gr.Examples(
596
  examples=["examples/sample_image1.png", "examples/sample_image2.png", "examples/sample_pdf.pdf"],
597
  inputs=file_input,
598
  label="Example Documents"
599
  )
600
-
601
- model_choice = gr.Radio(choices=["dot.ocr", "Dolphin"], label="Select Model", value="dot.ocr")
602
- image_preview = gr.Image(label="Preview", type="pil", interactive=False, height=400)
603
-
604
- with gr.Row():
605
- prev_page_btn = gr.Button("◀ Previous")
606
- page_info = gr.HTML('<div class="page-info">No file loaded</div>')
607
- next_page_btn = gr.Button("Next ▶")
608
-
609
- with gr.Accordion("Advanced Settings (dot.ocr only)", open=False):
610
- min_pixels = gr.Number(value=MIN_PIXELS, label="Min Pixels", step=1)
611
- max_pixels = gr.Number(value=MAX_PIXELS, label="Max Pixels", step=1)
612
-
613
  with gr.Row():
614
- process_btn = gr.Button("🚀 Process Document", variant="primary", elem_classes=["process-button"], scale=2)
615
- clear_btn = gr.Button("🗑️ Clear", variant="secondary")
616
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
617
  with gr.Column(scale=2):
618
  with gr.Tabs():
619
- with gr.Tab("📝 Extracted Content"):
620
- markdown_output = gr.Markdown(value="Click 'Process Document' to see extracted content...", elem_id="markdown_output")
621
  with gr.Tab("🖼️ Processed Image"):
622
- processed_image_output = gr.Image(label="Image with Layout Detection", type="pil", interactive=False)
 
 
 
 
 
 
 
 
 
 
623
  with gr.Tab("📋 Layout JSON"):
624
- json_output = gr.JSON(label="Layout Analysis Results")
625
-
626
- def load_file_for_preview(file_path: str) -> Tuple[List[Image.Image], str]:
627
- images = []
628
- if not file_path or not os.path.exists(file_path): return [], "No file selected"
629
- try:
630
- ext = os.path.splitext(file_path)[1].lower()
631
- if ext == '.pdf':
632
- doc = fitz.open(file_path)
633
- for page in doc:
634
- pix = page.get_pixmap(matrix=fitz.Matrix(2.0, 2.0))
635
- images.append(Image.open(BytesIO(pix.tobytes("ppm"))).convert('RGB'))
636
- doc.close()
637
- elif ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']:
638
- images.append(Image.open(file_path).convert('RGB'))
639
- return images, f"Page 1 / {len(images)}"
640
- except Exception as e:
641
- print(f"Error loading file for preview: {e}")
642
- return [], f"Error loading file: {e}"
643
-
644
- def handle_file_upload(file_path):
645
- global PDF_CACHE
646
- images, page_info_str = load_file_for_preview(file_path)
647
- if not images:
648
- return None, page_info_str
649
- PDF_CACHE = {
650
- "images": images, "current_page": 0, "total_pages": len(images),
651
- "is_parsed": False, "results": [], "model_used": None
652
- }
653
- return images[0], f'<div class="page-info">{page_info_str}</div>'
654
-
655
- def process_document(file_path, model_name, min_pix, max_pix):
656
- global PDF_CACHE
657
- if not file_path or not PDF_CACHE["images"]:
658
- return "Please upload a file first.", None, None
659
-
660
- if model_name not in MODELS:
661
- if model_name == 'dot.ocr': MODELS[model_name] = DotOcrModel(DEVICE)
662
- elif model_name == 'Dolphin': MODELS[model_name] = DolphinModel(DEVICE)
663
- model = MODELS[model_name]
664
-
665
- all_results, all_markdown = [], []
666
- for i, img in enumerate(PDF_CACHE["images"]):
667
- gr.Info(f"Processing page {i+1}/{len(PDF_CACHE['images'])} with {model_name}...")
668
- if model_name == 'dot.ocr':
669
- result = model.process_image(img, int(min_pix), int(max_pix))
670
- else: # Dolphin
671
- result = model.process_image(img)
672
- all_results.append(result)
673
- if result.get('markdown_content'):
674
- all_markdown.append(f"### Page {i+1}\n\n{result['markdown_content']}")
675
-
676
- PDF_CACHE.update({"results": all_results, "is_parsed": True, "model_used": model_name})
677
- if not all_results: return "Processing failed.", None, None
678
-
679
- first_result = all_results[0]
680
- combined_md = "\n\n---\n\n".join(all_markdown)
681
-
682
- return combined_md, first_result.get('processed_image'), first_result.get('layout_result')
683
-
684
- def turn_page(direction):
685
- global PDF_CACHE
686
- if not PDF_CACHE["images"] or not PDF_CACHE["is_parsed"]:
687
- return None, '<div class="page-info">No file parsed</div>', "No results yet", None, None
688
 
689
- if direction == "prev": PDF_CACHE["current_page"] = max(0, PDF_CACHE["current_page"] - 1)
690
- else: PDF_CACHE["current_page"] = min(PDF_CACHE["total_pages"] - 1, PDF_CACHE["current_page"] + 1)
691
-
692
- idx = PDF_CACHE["current_page"]
693
- page_info_html = f'<div class="page-info">Page {idx + 1} / {PDF_CACHE["total_pages"]}</div>'
694
- preview_img = PDF_CACHE["images"][idx]
695
- result = PDF_CACHE["results"][idx]
696
-
697
- all_md = [f"### Page {i+1}\n\n{res.get('markdown_content', '')}" for i, res in enumerate(PDF_CACHE["results"])]
698
- md_content = "\n\n---\n\n".join(all_md) if PDF_CACHE["total_pages"] > 1 else result.get('markdown_content', 'No content')
699
-
700
- return preview_img, page_info_html, md_content, result.get('processed_image'), result.get('layout_result')
701
 
702
- def clear_all():
703
- global PDF_CACHE
704
- PDF_CACHE = {"images": [], "current_page": 0, "total_pages": 0, "is_parsed": False, "results": [], "model_used": None}
705
- return None, None, '<div class="page-info">No file loaded</div>', "Click 'Process Document' to see extracted content...", None, None
706
 
707
- # --- Wire UI components ---
708
- file_input.change(handle_file_upload, inputs=file_input, outputs=[image_preview, page_info])
709
  process_btn.click(
710
  process_document,
711
- inputs=[file_input, model_choice, min_pixels, max_pixels],
712
- outputs=[markdown_output, processed_image_output, json_output]
713
  )
714
- prev_page_btn.click(lambda: turn_page("prev"), outputs=[image_preview, page_info, markdown_output, processed_image_output, json_output])
715
- next_page_btn.click(lambda: turn_page("next"), outputs=[image_preview, page_info, markdown_output, processed_image_output, json_output])
716
- clear_btn.click(clear_all, outputs=[file_input, image_preview, page_info, markdown_output, processed_image_output, json_output])
717
-
 
 
718
  return demo
719
 
720
  if __name__ == "__main__":
721
- # Create example directory if it doesn't exist
722
- if not os.path.exists("examples"):
723
- os.makedirs("examples")
724
- print("Created 'examples' directory. Please add sample images/PDFs there.")
725
-
726
- app = create_gradio_interface()
727
- app.queue().launch(debug=True, show_error=True)
 
1
+ import io
 
 
 
 
 
 
2
  import os
3
+ import tempfile
4
+ import time
5
+ import uuid
6
+ import cv2
 
 
 
 
 
 
 
7
  import gradio as gr
8
+ import pymupdf
9
+ import spaces
10
  import torch
 
 
11
  from PIL import Image, ImageDraw, ImageFont
12
  from transformers import AutoModelForCausalLM, AutoProcessor, VisionEncoderDecoderModel
13
+ from huggingface_hub import snapshot_download
14
  from qwen_vl_utils import process_vision_info
15
+ from utils.utils import prepare_image, parse_layout_string, process_coordinates, ImageDimensions
16
+ from utils.markdown_utils import MarkdownConverter
17
+
18
+ # Define device
19
+ device = "cuda" if torch.cuda.is_available() else "cpu"
20
+
21
+ # Load dot.ocr model
22
+ dot_ocr_model_id = "rednote-hilab/dots.ocr"
23
+ dot_ocr_model = AutoModelForCausalLM.from_pretrained(
24
+ dot_ocr_model_id,
25
+ attn_implementation="flash_attention_2",
26
+ torch_dtype=torch.bfloat16,
27
+ device_map="auto",
28
+ trust_remote_code=True
29
+ )
30
+ dot_ocr_processor = AutoProcessor.from_pretrained(
31
+ dot_ocr_model_id,
32
+ trust_remote_code=True
33
+ )
34
+
35
+ # Load Dolphin model
36
+ dolphin_model_id = "ByteDance/Dolphin"
37
+ dolphin_processor = AutoProcessor.from_pretrained(dolphin_model_id)
38
+ dolphin_model = VisionEncoderDecoderModel.from_pretrained(dolphin_model_id)
39
+ dolphin_model.eval()
40
+ dolphin_model.to(device)
41
+ dolphin_model = dolphin_model.half()
42
+ dolphin_tokenizer = dolphin_processor.tokenizer
43
+
44
+ # Constants
45
  MIN_PIXELS = 3136
46
  MAX_PIXELS = 11289600
47
  IMAGE_FACTOR = 28
48
+
49
+ # Prompts
50
+ prompt = """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
51
+
52
  1. Bbox format: [x1, y1, x2, y2]
53
+
54
  2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title'].
55
+
56
  3. Text Extraction & Formatting Rules:
57
  - Picture: For the 'Picture' category, the text field should be omitted.
58
  - Formula: Format its text as LaTeX.
59
  - Table: Format its text as HTML.
60
  - All Others (Text, Title, etc.): Format their text as Markdown.
61
+
62
  4. Constraints:
63
  - The output text must be the original text from the image, with no translation.
64
  - All layout elements must be sorted according to human reading order.
65
+
66
  5. Final Output: The entire output must be a single JSON object.
67
  """
 
 
 
 
 
 
 
 
 
 
 
68
 
69
+ # Utility functions
70
+ def round_by_factor(number: int, factor: int) -> int:
71
+ """Returns the closest integer to 'number' that is divisible by 'factor'."""
72
+ return round(number / factor) * factor
73
+
74
+ def smart_resize(
75
+ height: int,
76
+ width: int,
77
+ factor: int = 28,
78
+ min_pixels: int = 3136,
79
+ max_pixels: int = 11289600,
80
+ ):
81
+ """Rescales the image so that the following conditions are met:
82
+ 1. Both dimensions (height and width) are divisible by 'factor'.
83
+ 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
84
+ 3. The aspect ratio of the image is maintained as closely as possible.
85
+ """
86
+ if max(height, width) / min(height, width) > 200:
87
+ raise ValueError(
88
+ f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
89
+ )
90
+ h_bar = max(factor, round_by_factor(height, factor))
91
+ w_bar = max(factor, round_by_factor(width, factor))
92
+
93
+ if h_bar * w_bar > max_pixels:
94
+ beta = math.sqrt((height * width) / max_pixels)
95
+ h_bar = round_by_factor(height / beta, factor)
96
+ w_bar = round_by_factor(width / beta, factor)
97
+ elif h_bar * w_bar < min_pixels:
98
+ beta = math.sqrt(min_pixels / (height * width))
99
+ h_bar = round_by_factor(height * beta, factor)
100
+ w_bar = round_by_factor(width * beta, factor)
101
+ return h_bar, w_bar
102
+
103
+ def fetch_image(image_input, min_pixels: int = None, max_pixels: int = None):
104
+ """Fetch and process an image"""
105
+ if isinstance(image_input, str):
106
+ if image_input.startswith(("http://", "https://")):
107
+ response = requests.get(image_input)
108
+ image = Image.open(BytesIO(response.content)).convert('RGB')
109
+ else:
110
+ image = Image.open(image_input).convert('RGB')
111
+ elif isinstance(image_input, Image.Image):
112
+ image = image_input.convert('RGB')
113
+ else:
114
+ raise ValueError(f"Invalid image input type: {type(image_input)}")
115
+
116
+ if min_pixels is not None or max_pixels is not None:
117
+ min_pixels = min_pixels or MIN_PIXELS
118
+ max_pixels = max_pixels or MAX_PIXELS
119
+ height, width = smart_resize(
120
+ image.height,
121
+ image.width,
122
+ factor=IMAGE_FACTOR,
123
+ min_pixels=min_pixels,
124
+ max_pixels=max_pixels
125
+ )
126
+ image = image.resize((width, height), Image.LANCZOS)
127
+
128
+ return image
129
 
130
+ def load_images_from_pdf(pdf_path: str) -> List[Image.Image]:
131
+ """Load images from PDF file"""
132
+ images = []
133
  try:
134
+ pdf_document = pymupdf.open(pdf_path)
135
+ for page_num in range(len(pdf_document)):
136
+ page = pdf_document.load_page(page_num)
137
+ mat = pymupdf.Matrix(2.0, 2.0) # Increase resolution
138
+ pix = page.get_pixmap(matrix=mat)
139
+ img_data = pix.tobytes("ppm")
140
+ image = Image.open(BytesIO(img_data)).convert('RGB')
141
+ images.append(image)
142
+ pdf_document.close()
143
  except Exception as e:
144
+ print(f"Error loading PDF: {e}")
145
+ return []
146
+ return images
147
+
148
+ def draw_layout_on_image(image: Image.Image, layout_data: List[Dict]) -> Image.Image:
149
+ """Draw layout bounding boxes on image"""
150
+ img_copy = image.copy()
151
+ draw = ImageDraw.Draw(img_copy)
152
+
153
+ colors = {
154
+ 'Caption': '#FF6B6B',
155
+ 'Footnote': '#4ECDC4',
156
+ 'Formula': '#45B7D1',
157
+ 'List-item': '#96CEB4',
158
+ 'Page-footer': '#FFEAA7',
159
+ 'Page-header': '#DDA0DD',
160
+ 'Picture': '#FFD93D',
161
+ 'Section-header': '#6C5CE7',
162
+ 'Table': '#FD79A8',
163
+ 'Text': '#74B9FF',
164
+ 'Title': '#E17055'
165
+ }
166
+
167
+ try:
168
+ font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 12)
169
+ except Exception:
170
+ font = ImageFont.load_default()
171
+
172
+ for item in layout_data:
173
+ if 'bbox' in item and 'category' in item:
174
+ bbox = item['bbox']
175
+ category = item['category']
176
+ color = colors.get(category, '#000000')
177
+ draw.rectangle(bbox, outline=color, width=2)
178
+ label = category
179
+ label_bbox = draw.textbbox((0, 0), label, font=font)
180
+ label_width = label_bbox[2] - label_bbox[0]
181
+ label_height = label_bbox[3] - label_bbox[1]
182
+ label_x = bbox[0]
183
+ label_y = max(0, bbox[1] - label_height - 2)
184
+ draw.rectangle(
185
+ [label_x, label_y, label_x + label_width + 4, label_y + label_height + 2],
186
+ fill=color
187
+ )
188
+ draw.text((label_x + 2, label_y + 1), label, fill='white', font=font)
189
+ return img_copy
190
 
191
+ def layoutjson2md(image: Image.Image, layout_data: List[Dict], text_key: str = 'text') -> str:
192
+ """Convert layout JSON to markdown format"""
193
+ import base64
194
+ from io import BytesIO
195
+
196
+ markdown_lines = []
197
+
198
+ try:
199
+ sorted_items = sorted(layout_data, key=lambda x: (x.get('bbox', [0, 0, 0, 0])[1], x.get('bbox', [0, 0, 0, 0])[0]))
200
+
201
+ for item in sorted_items:
202
+ category = item.get('category', '')
203
+ text = item.get(text_key, '')
204
+ bbox = item.get('bbox', [])
205
+
206
+ if category == 'Picture':
207
+ if bbox and len(bbox) == 4:
208
+ try:
209
+ x1, y1, x2, y2 = bbox
210
+ x1, y1 = max(0, int(x1)), max(0, int(y1))
211
+ x2, y2 = min(image.width, int(x2)), min(image.height, int(y2))
212
+
213
+ if x2 > x1 and y2 > y1:
214
+ cropped_img = image.crop((x1, y1, x2, y2))
215
+ buffer = BytesIO()
216
+ cropped_img.save(buffer, format='PNG')
217
+ img_data = base64.b64encode(buffer.getvalue()).decode()
218
+ markdown_lines.append(f"![Image](data:image/png;base64,{img_data})\n")
219
  else:
220
+ markdown_lines.append("![Image](Image region detected)\n")
221
+ except Exception as e:
222
+ print(f"Error processing image region: {e}")
223
+ markdown_lines.append("![Image](Image detected)\n")
224
  else:
225
+ markdown_lines.append("![Image](Image detected)\n")
226
+ elif not text:
227
+ continue
228
+ elif category == 'Title':
229
+ markdown_lines.append(f"# {text}\n")
230
+ elif category == 'Section-header':
231
+ markdown_lines.append(f"## {text}\n")
232
+ elif category == 'Text':
233
+ markdown_lines.append(f"{text}\n")
234
+ elif category == 'List-item':
235
+ markdown_lines.append(f"- {text}\n")
236
+ elif category == 'Table':
237
+ if text.strip().startswith('<'):
238
+ markdown_lines.append(f"{text}\n")
239
+ else:
240
+ markdown_lines.append(f"**Table:** {text}\n")
241
+ elif category == 'Formula':
242
+ if text.strip().startswith('$') or '\\' in text:
243
+ markdown_lines.append(f"$$\n{text}\n$$\n")
244
+ else:
245
+ markdown_lines.append(f"**Formula:** {text}\n")
246
+ elif category == 'Caption':
247
+ markdown_lines.append(f"*{text}*\n")
248
+ elif category == 'Footnote':
249
+ markdown_lines.append(f"^{text}^\n")
250
+ elif category in ['Page-header', 'Page-footer']:
251
+ continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  else:
253
+ markdown_lines.append(f"{text}\n")
254
+ markdown_lines.append("")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
  except Exception as e:
256
+ print(f"Error converting to markdown: {e}")
257
+ return str(layout_data)
258
+ return "\n".join(markdown_lines)
259
+
260
+ # Global state variables
261
+ pdf_cache = {
262
+ "images": [],
263
+ "current_page": 0,
264
+ "total_pages": 0,
265
+ "file_type": None,
266
+ "is_parsed": False,
267
+ "results": []
268
+ }
269
 
270
+ @spaces.GPU()
271
+ def dot_ocr_inference(image: Image.Image, prompt: str, max_new_tokens: int = 24000) -> str:
272
+ """Run inference on an image with the given prompt using dot.ocr model"""
273
  try:
274
+ messages = [
275
+ {
276
+ "role": "user",
277
+ "content": [
278
+ {"type": "image", "image": image},
279
+ {"type": "text", "text": prompt}
280
+ ]
281
+ }
282
+ ]
283
+ text = dot_ocr_processor.apply_chat_template(
284
+ messages,
285
+ tokenize=False,
286
+ add_generation_prompt=True
287
+ )
288
+ image_inputs, video_inputs = process_vision_info(messages)
289
+ inputs = dot_ocr_processor(
290
+ text=[text],
291
+ images=image_inputs,
292
+ videos=video_inputs,
293
+ padding=True,
294
+ return_tensors="pt",
295
+ )
296
+ inputs = inputs.to(device)
297
+ with torch.no_grad():
298
+ generated_ids = dot_ocr_model.generate(
299
+ **inputs,
300
+ max_new_tokens=max_new_tokens,
301
+ do_sample=False,
302
+ temperature=0.1
303
+ )
304
+ generated_ids_trimmed = [
305
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
306
+ ]
307
+ output_text = dot_ocr_processor.batch_decode(
308
+ generated_ids_trimmed,
309
+ skip_special_tokens=True,
310
+ clean_up_tokenization_spaces=False
311
+ )
312
+ return output_text[0] if output_text else ""
313
  except Exception as e:
314
+ print(f"Error during dot.ocr inference: {e}")
315
+ return f"Error during inference: {str(e)}"
 
316
 
317
+ def process_image_dot_ocr(image: Image.Image, min_pixels: Optional[int] = None, max_pixels: Optional[int] = None) -> Dict[str, Any]:
318
+ """Process a single image with the dot.ocr model"""
319
  try:
320
+ if min_pixels is not None or max_pixels is not None:
321
+ image = fetch_image(image, min_pixels=min_pixels, max_pixels=max_pixels)
322
+ raw_output = dot_ocr_inference(image, prompt)
323
+ result = {
324
+ 'original_image': image,
325
+ 'raw_output': raw_output,
326
+ 'processed_image': image,
327
+ 'layout_result': None,
328
+ 'markdown_content': None
329
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
  try:
331
  layout_data = json.loads(raw_output)
332
  result['layout_result'] = layout_data
333
+ processed_image = draw_layout_on_image(image, layout_data)
334
+ result['processed_image'] = processed_image
335
+ markdown_content = layoutjson2md(image, layout_data, text_key='text')
336
+ result['markdown_content'] = markdown_content
337
+ except json.JSONDecodeError:
338
+ print("Failed to parse JSON output, using raw output")
339
+ result['markdown_content'] = raw_output
340
  return result
341
+ except Exception as e:
342
+ print(f"Error processing image with dot.ocr: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
  return {
344
+ 'original_image': image,
345
+ 'raw_output': f"Error processing image: {str(e)}",
346
+ 'processed_image': image,
347
+ 'layout_result': None,
348
+ 'markdown_content': f"Error processing image: {str(e)}"
349
  }
350
 
351
+ def process_all_pages_dot_ocr(file_path, min_pixels, max_pixels):
352
+ """Process all pages of a document with dot.ocr model"""
353
+ if file_path.lower().endswith('.pdf'):
354
+ images = load_images_from_pdf(file_path)
355
+ else:
356
+ images = [Image.open(file_path).convert('RGB')]
357
+ results = []
358
+ for img in images:
359
+ result = process_image_dot_ocr(img, min_pixels, max_pixels)
360
+ results.append(result)
361
+ return results
362
+
363
+ # Dolphin model functions
364
+ @spaces.GPU()
365
+ def dolphin_model_chat(prompt, image):
366
+ """Process an image or batch of images with the given prompt(s) using Dolphin model"""
367
+ is_batch = isinstance(image, list)
368
+ if not is_batch:
369
+ images = [image]
370
+ prompts = [prompt]
371
+ else:
372
+ images = image
373
+ prompts = prompt if isinstance(prompt, list) else [prompt] * len(images)
374
+ batch_inputs = dolphin_processor(images, return_tensors="pt", padding=True)
375
+ batch_pixel_values = batch_inputs.pixel_values.half().to(device)
376
+ prompts = [f"<s>{p} <Answer/>" for p in prompts]
377
+ batch_prompt_inputs = dolphin_tokenizer(
378
+ prompts,
379
+ add_special_tokens=False,
380
+ return_tensors="pt"
381
+ )
382
+ batch_prompt_ids = batch_prompt_inputs.input_ids.to(device)
383
+ batch_attention_mask = batch_prompt_inputs.attention_mask.to(device)
384
+ outputs = dolphin_model.generate(
385
+ pixel_values=batch_pixel_values,
386
+ decoder_input_ids=batch_prompt_ids,
387
+ decoder_attention_mask=batch_attention_mask,
388
+ min_length=1,
389
+ max_length=4096,
390
+ pad_token_id=dolphin_tokenizer.pad_token_id,
391
+ eos_token_id=dolphin_tokenizer.eos_token_id,
392
+ use_cache=True,
393
+ bad_words_ids=[[dolphin_tokenizer.unk_token_id]],
394
+ return_dict_in_generate=True,
395
+ do_sample=False,
396
+ num_beams=1,
397
+ repetition_penalty=1.1
398
+ )
399
+ sequences = dolphin_tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)
400
+ results = []
401
+ for i, sequence in enumerate(sequences):
402
+ cleaned = sequence.replace(prompts[i], "").replace("<pad>", "").replace("</s>", "").strip()
403
+ results.append(cleaned)
404
+ if not is_batch:
405
+ return results[0]
406
+ return results
407
+
408
+ def process_element_batch_dolphin(elements, prompt, max_batch_size=16):
409
+ """Process elements of the same type in batches for Dolphin model"""
410
+ results = []
411
+ batch_size = min(len(elements), max_batch_size)
412
+ for i in range(0, len(elements), batch_size):
413
+ batch_elements = elements[i:i+batch_size]
414
+ crops_list = [elem["crop"] for elem in batch_elements]
415
+ prompts_list = [prompt] * len(crops_list)
416
+ batch_results = dolphin_model_chat(prompts_list, crops_list)
417
+ for j, result in enumerate(batch_results):
418
+ elem = batch_elements[j]
419
+ results.append({
420
+ "label": elem["label"],
421
+ "bbox": elem["bbox"],
422
+ "text": result.strip(),
423
+ "reading_order": elem["reading_order"],
424
+ })
425
+ return results
426
+
427
+ def process_page_dolphin(image_path):
428
+ """Process a single page with Dolphin model"""
429
+ pil_image = Image.open(image_path).convert("RGB")
430
+ layout_output = dolphin_model_chat("Parse the reading order of this document.", pil_image)
431
+ padded_image, dims = prepare_image(pil_image)
432
+ recognition_results = process_elements_dolphin(layout_output, padded_image, dims)
433
+ return recognition_results
434
+
435
+ def process_elements_dolphin(layout_results, padded_image, dims):
436
+ """Parse all document elements for Dolphin model"""
437
+ layout_results = parse_layout_string(layout_results)
438
+ text_elements = []
439
+ table_elements = []
440
+ figure_results = []
441
+ previous_box = None
442
+ reading_order = 0
443
+ for bbox, label in layout_results:
444
+ try:
445
+ x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, previous_box = process_coordinates(
446
+ bbox, padded_image, dims, previous_box
447
+ )
448
+ cropped = padded_image[y1:y2, x1:x2]
449
+ if cropped.size > 0 and (cropped.shape[0] > 3 and cropped.shape[1] > 3):
450
+ if label == "fig":
451
+ try:
452
+ pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
453
+ buffered = io.BytesIO()
454
+ pil_crop.save(buffered, format="PNG")
455
+ img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
456
+ figure_results.append(
457
+ {
458
+ "label": label,
459
+ "bbox": [orig_x1, orig_y1, orig_x2, orig_y2],
460
+ "text": img_base64,
461
+ "reading_order": reading_order,
462
+ }
463
+ )
464
+ except Exception as e:
465
+ print(f"Error encoding figure to base64: {e}")
466
+ figure_results.append(
467
+ {
468
+ "label": label,
469
+ "bbox": [orig_x1, orig_y1, orig_x2, orig_y2],
470
+ "text": "",
471
+ "reading_order": reading_order,
472
+ }
473
+ )
474
+ else:
475
+ pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
476
+ element_info = {
477
+ "crop": pil_crop,
478
+ "label": label,
479
+ "bbox": [orig_x1, orig_y1, orig_x2, orig_y2],
480
+ "reading_order": reading_order,
481
+ }
482
+ if label == "tab":
483
+ table_elements.append(element_info)
484
+ else:
485
+ text_elements.append(element_info)
486
+ reading_order += 1
487
+ except Exception as e:
488
+ print(f"Error processing bbox with label {label}: {str(e)}")
489
+ continue
490
+ recognition_results = figure_results.copy()
491
+ if text_elements:
492
+ text_results = process_element_batch_dolphin(text_elements, "Read text in the image.")
493
+ recognition_results.extend(text_results)
494
+ if table_elements:
495
+ table_results = process_element_batch_dolphin(table_elements, "Parse the table in the image.")
496
+ recognition_results.extend(table_results)
497
+ recognition_results.sort(key=lambda x: x.get("reading_order", 0))
498
+ return recognition_results
499
+
500
+ def generate_markdown(recognition_results):
501
+ """Generate markdown from recognition results for Dolphin model"""
502
+ converter = MarkdownConverter()
503
+ return converter.convert(recognition_results)
504
+
505
+ def convert_all_pdf_pages_to_images(file_path, target_size=896):
506
+ """Convert all pages of a PDF to images for Dolphin model"""
507
+ if file_path is None:
508
+ return []
509
+ try:
510
+ file_ext = os.path.splitext(file_path)[1].lower()
511
+ if file_ext == '.pdf':
512
+ doc = pymupdf.open(file_path)
513
+ image_paths = []
514
+ for page_num in range(len(doc)):
515
+ page = doc[page_num]
516
+ rect = page.rect
517
+ scale = target_size / max(rect.width, rect.height)
518
+ mat = pymupdf.Matrix(scale, scale)
519
+ pix = page.get_pixmap(matrix=mat)
520
+ img_data = pix.tobytes("png")
521
+ pil_image = Image.open(io.BytesIO(img_data))
522
+ with tempfile.NamedTemporaryFile(suffix=f"_page_{page_num}.png", delete=False) as tmp_file:
523
+ pil_image.save(tmp_file.name, "PNG")
524
+ image_paths.append(tmp_file.name)
525
+ doc.close()
526
+ return image_paths
527
+ else:
528
+ converted_path = convert_to_image(file_path, target_size)
529
+ return [converted_path] if converted_path else []
530
+ except Exception as e:
531
+ print(f"Error converting PDF pages to images: {e}")
532
+ return []
533
 
534
+ def convert_to_image(file_path, target_size=896, page_num=0):
535
+ """Convert input file to image format for Dolphin model"""
536
+ if file_path is None:
537
+ return None
538
+ try:
539
+ file_ext = os.path.splitext(file_path)[1].lower()
540
+ if file_ext == '.pdf':
541
+ doc = pymupdf.open(file_path)
542
+ if page_num >= len(doc):
543
+ page_num = 0
544
+ page = doc[page_num]
545
+ rect = page.rect
546
+ scale = target_size / max(rect.width, rect.height)
547
+ mat = pymupdf.Matrix(scale, scale)
548
+ pix = page.get_pixmap(matrix=mat)
549
+ img_data = pix.tobytes("png")
550
+ pil_image = Image.open(io.BytesIO(img_data))
551
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file:
552
+ pil_image.save(tmp_file.name, "PNG")
553
+ doc.close()
554
+ return tmp_file.name
555
+ else:
556
+ pil_image = Image.open(file_path).convert("RGB")
557
+ w, h = pil_image.size
558
+ if max(w, h) > target_size:
559
+ if w > h:
560
+ new_w, new_h = target_size, int(h * target_size / w)
561
+ else:
562
+ new_w, new_h = int(w * target_size / h), target_size
563
+ pil_image = pil_image.resize((new_w, new_h), Image.Resampling.LANCZOS)
564
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file:
565
+ pil_image.save(tmp_file.name, "PNG")
566
+ return tmp_file.name
567
+ except Exception as e:
568
+ print(f"Error converting file to image: {e}")
569
+ return file_path
570
+
571
+ def process_all_pages_dolphin(file_path):
572
+ """Process all pages of a document with Dolphin model"""
573
+ image_paths = convert_all_pdf_pages_to_images(file_path)
574
+ per_page_results = []
575
+ for image_path in image_paths:
576
+ try:
577
+ original_image = Image.open(image_path).convert('RGB')
578
+ recognition_results = process_page_dolphin(image_path)
579
+ markdown_content = generate_markdown(recognition_results)
580
+ placeholder_text = "Layout visualization not available for Dolphin model"
581
+ processed_image = create_placeholder_image(placeholder_text, size=(original_image.width, original_image.height))
582
+ per_page_results.append({
583
+ 'original_image': original_image,
584
+ 'processed_image': processed_image,
585
+ 'markdown_content': markdown_content,
586
+ 'layout_result': recognition_results
587
+ })
588
+ except Exception as e:
589
+ print(f"Error processing page: {e}")
590
+ per_page_results.append({
591
+ 'original_image': Image.new('RGB', (100, 100), color='white'),
592
+ 'processed_image': create_placeholder_image("Error processing page", size=(100, 100)),
593
+ 'markdown_content': f"Error processing page: {str(e)}",
594
+ 'layout_result': None
595
+ })
596
+ finally:
597
+ if os.path.exists(image_path):
598
+ os.remove(image_path)
599
+ return per_page_results
600
+
601
+ def create_placeholder_image(text, size=(400, 200)):
602
+ """Create a placeholder image with text"""
603
+ img = Image.new('RGB', size, color='white')
604
+ draw = ImageDraw.Draw(img)
605
+ try:
606
+ font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 16)
607
+ except Exception:
608
+ font = ImageFont.load_default()
609
+ draw.text((10, 10), text, fill='black', font=font)
610
+ return img
611
+
612
+ # Gradio interface functions
613
+ def load_file_for_preview(file_path: str) -> Tuple[Optional[Image.Image], str]:
614
+ """Load file for preview (supports PDF and images)"""
615
+ global pdf_cache
616
+ if not file_path or not os.path.exists(file_path):
617
+ return None, "No file selected"
618
+ file_ext = os.path.splitext(file_path)[1].lower()
619
+ try:
620
+ if file_ext == '.pdf':
621
+ images = load_images_from_pdf(file_path)
622
+ if not images:
623
+ return None, "Failed to load PDF"
624
+ pdf_cache.update({
625
+ "images": images,
626
+ "current_page": 0,
627
+ "total_pages": len(images),
628
+ "file_type": "pdf",
629
+ "is_parsed": False,
630
+ "results": []
631
+ })
632
+ return images[0], f"Page 1 / {len(images)}"
633
+ elif file_ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']:
634
+ image = Image.open(file_path).convert('RGB')
635
+ pdf_cache.update({
636
+ "images": [image],
637
+ "current_page": 0,
638
+ "total_pages": 1,
639
+ "file_type": "image",
640
+ "is_parsed": False,
641
+ "results": []
642
+ })
643
+ return image, "Page 1 / 1"
644
+ else:
645
+ return None, f"Unsupported file format: {file_ext}"
646
+ except Exception as e:
647
+ print(f"Error loading file: {e}")
648
+ return None, f"Error loading file: {str(e)}"
649
+
650
+ def turn_page(direction: str) -> Tuple[Optional[Image.Image], str, str, Optional[Image.Image], Optional[Dict]]:
651
+ """Navigate through PDF pages and update all relevant outputs."""
652
+ global pdf_cache
653
+ if not pdf_cache["images"]:
654
+ return None, "No file loaded", "No results yet", None, None
655
+ if direction == "prev":
656
+ pdf_cache["current_page"] = max(0, pdf_cache["current_page"] - 1)
657
+ elif direction == "next":
658
+ pdf_cache["current_page"] = min(pdf_cache["total_pages"] - 1, pdf_cache["current_page"] + 1)
659
+ index = pdf_cache["current_page"]
660
+ current_image_preview = pdf_cache["images"][index]
661
+ page_info_html = f"Page {index + 1} / {pdf_cache['total_pages']}"
662
+ if pdf_cache["is_parsed"] and index < len(pdf_cache["results"]):
663
+ result = pdf_cache["results"][index]
664
+ processed_img = result['processed_image']
665
+ markdown_content = result['markdown_content'] or "No content available"
666
+ layout_json = result['layout_result']
667
+ else:
668
+ processed_img = None
669
+ markdown_content = "Page not processed yet"
670
+ layout_json = None
671
+ return current_image_preview, page_info_html, markdown_content, processed_img, layout_json
672
+
673
+ def process_document(model_choice, file_path, max_tokens, min_pix, max_pix):
674
+ """Process the uploaded document with the selected model"""
675
+ global pdf_cache
676
+ try:
677
+ if not file_path:
678
+ return None, "Please upload a file first.", None
679
+ if model_choice == "dot.ocr":
680
+ results = process_all_pages_dot_ocr(file_path, min_pix, max_pix)
681
+ elif model_choice == "Dolphin":
682
+ results = process_all_pages_dolphin(file_path)
683
+ else:
684
+ raise ValueError("Invalid model choice")
685
+ pdf_cache["results"] = results
686
+ pdf_cache["is_parsed"] = True
687
+ first_result = results[0]
688
+ if model_choice == "dot.ocr":
689
+ processed_img = first_result['processed_image']
690
+ markdown_content = first_result['markdown_content']
691
+ layout_json = first_result['layout_result']
692
+ else:
693
+ processed_img = first_result['processed_image']
694
+ markdown_content = first_result['markdown_content']
695
+ layout_json = first_result['layout_result']
696
+ return processed_img, markdown_content, layout_json
697
+ except Exception as e:
698
+ error_msg = f"Error processing document: {str(e)}"
699
+ print(error_msg)
700
+ return None, error_msg, None
701
 
702
  def create_gradio_interface():
703
+ """Create the Gradio interface"""
 
704
  css = """
705
  .main-container { max-width: 1400px; margin: 0 auto; }
706
  .header-text { text-align: center; color: #2c3e50; margin-bottom: 20px; }
707
+ .process-button {
708
+ border: none !important;
709
+ color: white !important;
710
+ font-weight: bold !important;
711
+ background-color: blue !important;}
712
+ .process-button:hover {
713
+ background-color: darkblue !important;
714
+ transform: translateY(-2px) !important;
715
+ box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; }
716
  .info-box { border: 1px solid #dee2e6; border-radius: 8px; padding: 15px; margin: 10px 0; }
717
  .page-info { text-align: center; padding: 8px 16px; border-radius: 20px; font-weight: bold; margin: 10px 0; }
718
+ .model-status { padding: 10px; border-radius: 8px; margin: 10px 0; text-align: center; font-weight: bold; }
719
+ .status-ready { background: #d1edff; color: #0c5460; border: 1px solid #b8daff; }
720
  """
721
+ with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
 
722
  gr.HTML("""
723
  <div class="title" style="text-align: center">
724
  <h1>Dot<span style="color: red;">●</span><strong></strong>OCR Comparator</h1>
 
727
  </p>
728
  </div>
729
  """)
730
+ with gr.Row():
 
731
  with gr.Column(scale=1):
732
+ model_choice = gr.Radio(
733
+ choices=["dot.ocr", "Dolphin"],
734
+ label="Select Model",
735
+ value="dot.ocr"
736
+ )
737
+ file_input = gr.File(
738
+ label="Upload Image or PDF",
739
+ file_types=[".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".pdf"],
740
+ type="filepath"
741
+ )
742
  with gr.Row():
743
  examples = gr.Examples(
744
  examples=["examples/sample_image1.png", "examples/sample_image2.png", "examples/sample_pdf.pdf"],
745
  inputs=file_input,
746
  label="Example Documents"
747
  )
748
+ image_preview = gr.Image(
749
+ label="Preview",
750
+ type="pil",
751
+ interactive=False,
752
+ height=300
753
+ )
 
 
 
 
 
 
 
754
  with gr.Row():
755
+ prev_page_btn = gr.Button(" Previous", size="md")
756
+ page_info = gr.HTML("No file loaded")
757
+ next_page_btn = gr.Button("Next ▶", size="md")
758
+ with gr.Accordion("Advanced Settings", open=False):
759
+ max_new_tokens = gr.Slider(
760
+ minimum=1000,
761
+ maximum=32000,
762
+ value=24000,
763
+ step=1000,
764
+ label="Max New Tokens",
765
+ info="Maximum number of tokens to generate"
766
+ )
767
+ min_pixels = gr.Number(
768
+ value=MIN_PIXELS,
769
+ label="Min Pixels",
770
+ info="Minimum image resolution"
771
+ )
772
+ max_pixels = gr.Number(
773
+ value=MAX_PIXELS,
774
+ label="Max Pixels",
775
+ info="Maximum image resolution"
776
+ )
777
+ process_btn = gr.Button(
778
+ "🚀 Process Document",
779
+ variant="primary",
780
+ elem_classes=["process-button"],
781
+ size="lg"
782
+ )
783
+ clear_btn = gr.Button("🗑️ Clear All", variant="secondary")
784
  with gr.Column(scale=2):
785
  with gr.Tabs():
 
 
786
  with gr.Tab("🖼️ Processed Image"):
787
+ processed_image = gr.Image(
788
+ label="Image with Layout Detection",
789
+ type="pil",
790
+ interactive=False,
791
+ height=500
792
+ )
793
+ with gr.Tab("📝 Extracted Content"):
794
+ markdown_output = gr.Markdown(
795
+ value="Click 'Process Document' to see extracted content...",
796
+ height=500
797
+ )
798
  with gr.Tab("📋 Layout JSON"):
799
+ json_output = gr.JSON(
800
+ label="Layout Analysis Results",
801
+ value=None
802
+ )
803
+
804
+ # Event handlers
805
+ file_input.change(
806
+ lambda file_path: load_file_for_preview(file_path),
807
+ inputs=[file_input],
808
+ outputs=[image_preview, page_info]
809
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
810
 
811
+ prev_page_btn.click(
812
+ lambda: turn_page("prev"),
813
+ outputs=[image_preview, page_info, markdown_output, processed_image, json_output]
814
+ )
 
 
 
 
 
 
 
 
815
 
816
+ next_page_btn.click(
817
+ lambda: turn_page("next"),
818
+ outputs=[image_preview, page_info, markdown_output, processed_image, json_output]
819
+ )
820
 
 
 
821
  process_btn.click(
822
  process_document,
823
+ inputs=[model_choice, file_input, max_new_tokens, min_pixels, max_pixels],
824
+ outputs=[processed_image, markdown_output, json_output]
825
  )
826
+
827
+ clear_btn.click(
828
+ lambda: (None, None, "No file loaded", None, "Click 'Process Document' to see extracted content...", None),
829
+ outputs=[file_input, image_preview, page_info, processed_image, markdown_output, json_output]
830
+ )
831
+
832
  return demo
833
 
834
  if __name__ == "__main__":
835
+ demo = create_gradio_interface()
836
+ demo.queue(max_size=10).launch(share=False, debug=True, show_error=True)