prithivMLmods commited on
Commit
888b5aa
Β·
verified Β·
1 Parent(s): 9ebf911

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +145 -398
app.py CHANGED
@@ -6,8 +6,9 @@ import traceback
6
  from io import BytesIO
7
  from typing import Any, Dict, List, Optional, Tuple
8
  import re
 
 
9
 
10
- import fitz
11
  import gradio as gr
12
  import requests
13
  import torch
@@ -17,15 +18,15 @@ from transformers import (
17
  AutoProcessor,
18
  TextIteratorStreamer,
19
  )
20
- from qwen_vl_utils import process_vision_info
21
 
22
  # Constants
23
  MIN_PIXELS = 3136
24
  MAX_PIXELS = 11289600
25
  IMAGE_FACTOR = 28
 
26
  device = "cuda" if torch.cuda.is_available() else "cpu"
27
 
28
- # Prompts
29
  prompt = """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
30
 
31
  1. Bbox format: [x1, y1, x2, y2]
@@ -45,190 +46,74 @@ prompt = """Please output the layout information from the PDF image, including e
45
  5. Final Output: The entire output must be a single JSON object.
46
  """
47
 
48
- # Load models
49
  MODEL_ID_M = "prithivMLmods/Camel-Doc-OCR-062825"
50
  processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
51
  model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
52
- MODEL_ID_M,
53
- trust_remote_code=True,
54
- torch_dtype=torch.float16
55
  ).to(device).eval()
56
 
57
  MODEL_ID_T = "prithivMLmods/Megalodon-OCR-Sync-0713"
58
  processor_t = AutoProcessor.from_pretrained(MODEL_ID_T, trust_remote_code=True)
59
  model_t = Qwen2_5_VLForConditionalGeneration.from_pretrained(
60
- MODEL_ID_T,
61
- trust_remote_code=True,
62
- torch_dtype=torch.float16
63
  ).to(device).eval()
64
 
65
  MODEL_ID_C = "nanonets/Nanonets-OCR-s"
66
  processor_c = AutoProcessor.from_pretrained(MODEL_ID_C, trust_remote_code=True)
67
  model_c = Qwen2_5_VLForConditionalGeneration.from_pretrained(
68
- MODEL_ID_C,
69
- trust_remote_code=True,
70
- torch_dtype=torch.float16
71
  ).to(device).eval()
72
 
73
  MODEL_ID_G = "echo840/MonkeyOCR"
74
  SUBFOLDER = "Recognition"
75
  processor_g = AutoProcessor.from_pretrained(
76
- MODEL_ID_G,
77
- trust_remote_code=True,
78
- subfolder=SUBFOLDER
79
  )
80
  model_g = Qwen2_5_VLForConditionalGeneration.from_pretrained(
81
- MODEL_ID_G,
82
- trust_remote_code=True,
83
- subfolder=SUBFOLDER,
84
- torch_dtype=torch.float16
85
  ).to(device).eval()
86
 
87
- # Utility functions
88
- def round_by_factor(number: int, factor: int) -> int:
89
- return round(number / factor) * factor
90
-
91
- def smart_resize(
92
- height: int,
93
- width: int,
94
- factor: int = 28,
95
- min_pixels: int = 3136,
96
- max_pixels: int = 11289600,
97
- ):
98
- if max(height, width) / min(height, width) > 200:
99
- raise ValueError(f"Aspect ratio too extreme: {max(height, width) / min(height, width)}")
100
- h_bar = max(factor, round_by_factor(height, factor))
101
- w_bar = max(factor, round_by_factor(width, factor))
102
- if h_bar * w_bar > max_pixels:
103
- beta = math.sqrt((height * width) / max_pixels)
104
- h_bar = round_by_factor(height / beta, factor)
105
- w_bar = round_by_factor(width / beta, factor)
106
- elif h_bar * w_bar < min_pixels:
107
- beta = math.sqrt(min_pixels / (height * width))
108
- h_bar = round_by_factor(height * beta, factor)
109
- w_bar = round_by_factor(width * beta, factor)
110
- return h_bar, w_bar
111
-
112
- def fetch_image(image_input, min_pixels: int = None, max_pixels: int = None):
113
- if isinstance(image_input, str):
114
- if image_input.startswith(("http://", "https://")):
115
- response = requests.get(image_input)
116
- image = Image.open(BytesIO(response.content)).convert('RGB')
117
- else:
118
- image = Image.open(image_input).convert('RGB')
119
- elif isinstance(image_input, Image.Image):
120
- image = image_input.convert('RGB')
121
- else:
122
- raise ValueError(f"Invalid image input type: {type(image_input)}")
123
- if min_pixels or max_pixels:
124
- min_pixels = min_pixels or MIN_PIXELS
125
- max_pixels = max_pixels or MAX_PIXELS
126
- height, width = smart_resize(image.height, image.width, factor=IMAGE_FACTOR, min_pixels=min_pixels, max_pixels=max_pixels)
127
- image = image.resize((width, height), Image.LANCZOS)
128
- return image
129
-
130
- def load_images_from_pdf(pdf_path: str) -> List[Image.Image]:
131
- images = []
132
- try:
133
- pdf_document = fitz.open(pdf_path)
134
- for page_num in range(len(pdf_document)):
135
- page = pdf_document.load_page(page_num)
136
- mat = fitz.Matrix(2.0, 2.0)
137
- pix = page.get_pixmap(matrix=mat)
138
- img_data = pix.tobytes("ppm")
139
- image = Image.open(BytesIO(img_data)).convert('RGB')
140
- images.append(image)
141
- pdf_document.close()
142
- except Exception as e:
143
- print(f"Error loading PDF: {e}")
144
- return images
145
-
146
- def draw_layout_on_image(image: Image.Image, layout_data: List[Dict]) -> Image.Image:
147
- img_copy = image.copy()
148
- draw = ImageDraw.Draw(img_copy)
149
- colors = {
150
- 'Caption': '#FF6B6B', 'Footnote': '#4ECDC4', 'Formula': '#45B7D1',
151
- 'List-item': '#96CEB4', 'Page-footer': '#FFEAA7', 'Page-header': '#DDA0DD',
152
- 'Picture': '#FFD93D', 'Section-header': '#6C5CE7', 'Table': '#FD79A8',
153
- 'Text': '#74B9FF', 'Title': '#E17055'
154
- }
155
- try:
156
- font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 12)
157
- except Exception:
158
- font = ImageFont.load_default()
159
- for item in layout_data:
160
- if 'bbox' in item and 'category' in item:
161
- bbox = item['bbox']
162
- category = item['category']
163
- color = colors.get(category, '#000000')
164
- draw.rectangle(bbox, outline=color, width=2)
165
- label = category
166
- label_bbox = draw.textbbox((0, 0), label, font=font)
167
- label_width = label_bbox[2] - label_bbox[0]
168
- label_height = label_bbox[3] - label_bbox[1]
169
- label_x = bbox[0]
170
- label_y = max(0, bbox[1] - label_height - 2)
171
- draw.rectangle([label_x, label_y, label_x + label_width + 4, label_y + label_height + 2], fill=color)
172
- draw.text((label_x + 2, label_y + 1), label, fill='white', font=font)
173
- return img_copy
174
 
 
175
  def is_arabic_text(text: str) -> bool:
 
176
  if not text:
177
  return False
178
- header_pattern = r'^#{1,6}\s+(.+)$'
179
- paragraph_pattern = r'^(?!#{1,6}\s|!\[|```|\||\s*[-*+]\s|\s*\d+\.\s)(.+)$'
180
- content_text = []
181
- for line in text.split('\n'):
182
- line = line.strip()
183
- if not line:
184
- continue
185
- header_match = re.match(header_pattern, line, re.MULTILINE)
186
- if header_match:
187
- content_text.append(header_match.group(1))
188
- continue
189
- if re.match(paragraph_pattern, line, re.MULTILINE):
190
- content_text.append(line)
191
- if not content_text:
192
- return False
193
- combined_text = ' '.join(content_text)
194
  arabic_chars = 0
195
  total_chars = 0
196
- for char in combined_text:
197
  if char.isalpha():
198
  total_chars += 1
199
- if ('\u0600' <= char <= '\u06FF') or ('\u0750' <= char <= '\u077F') or ('\u08A0' <= char <= '\u08FF'):
200
  arabic_chars += 1
201
  return total_chars > 0 and (arabic_chars / total_chars) > 0.5
202
 
203
  def layoutjson2md(image: Image.Image, layout_data: List[Dict], text_key: str = 'text') -> str:
 
204
  import base64
205
  from io import BytesIO
206
  markdown_lines = []
207
  try:
 
208
  sorted_items = sorted(layout_data, key=lambda x: (x.get('bbox', [0, 0, 0, 0])[1], x.get('bbox', [0, 0, 0, 0])[0]))
209
  for item in sorted_items:
210
  category = item.get('category', '')
211
  text = item.get(text_key, '')
212
  bbox = item.get('bbox', [])
 
213
  if category == 'Picture':
214
  if bbox and len(bbox) == 4:
215
  try:
216
- x1, y1, x2, y2 = bbox
217
- x1, y1 = max(0, int(x1)), max(0, int(y1))
218
- x2, y2 = min(image.width, int(x2)), min(image.height, int(y2))
219
- if x2 > x1 and y2 > y1:
220
- cropped_img = image.crop((x1, y1, x2, y2))
221
- buffer = BytesIO()
222
- cropped_img.save(buffer, format='PNG')
223
- img_data = base64.b64encode(buffer.getvalue()).decode()
224
- markdown_lines.append(f"![Image](data:image/png;base64,{img_data})\n")
225
- else:
226
- markdown_lines.append("![Image](Image region detected)\n")
227
  except Exception as e:
228
- print(f"Error processing image region: {e}")
229
- markdown_lines.append("![Image](Image detected)\n")
230
- else:
231
- markdown_lines.append("![Image](Image detected)\n")
232
  elif not text:
233
  continue
234
  elif category == 'Title':
@@ -239,311 +124,173 @@ def layoutjson2md(image: Image.Image, layout_data: List[Dict], text_key: str = '
239
  markdown_lines.append(f"{text}\n")
240
  elif category == 'List-item':
241
  markdown_lines.append(f"- {text}\n")
242
- elif category == 'Table':
243
- if text.strip().startswith('<'):
244
- markdown_lines.append(f"{text}\n")
245
- else:
246
- markdown_lines.append(f"**Table:** {text}\n")
247
- elif category == 'Formula':
248
- if text.strip().startswith('$') or '\\' in text:
249
- markdown_lines.append(f"$$\n{text}\n$$\n")
250
- else:
251
- markdown_lines.append(f"**Formula:** {text}\n")
252
  elif category == 'Caption':
253
  markdown_lines.append(f"*{text}*\n")
254
  elif category == 'Footnote':
255
- markdown_lines.append(f"^{text}^\n")
256
- elif category in ['Page-header', 'Page-footer']:
257
- continue
258
- else:
259
  markdown_lines.append(f"{text}\n")
260
- markdown_lines.append("")
261
  except Exception as e:
262
  print(f"Error converting to markdown: {e}")
263
- return str(layout_data)
264
  return "\n".join(markdown_lines)
265
 
266
- # PDF handling state
267
- pdf_cache = {
268
- "images": [],
269
- "current_page": 0,
270
- "total_pages": 0,
271
- "file_type": None,
272
- "is_parsed": False,
273
- "results": []
274
- }
275
 
276
  @spaces.GPU
277
- def inference(model_name: str, image: Image.Image, prompt: str, max_new_tokens: int = 1024) -> str:
278
- try:
279
- if model_name == "Camel-Doc-OCR-062825":
280
- processor = processor_m
281
- model = model_m
282
- elif model_name == "Megalodon-OCR-Sync-0713":
283
- processor = processor_t
284
- model = model_t
285
- elif model_name == "Nanonets-OCR-s":
286
- processor = processor_c
287
- model = model_c
288
- elif model_name == "MonkeyOCR-Recognition":
289
- processor = processor_g
290
- model = model_g
291
- else:
292
- raise ValueError(f"Invalid model selected: {model_name}")
293
-
294
- messages = [
295
- {
296
- "role": "user",
297
- "content": [
298
- {"type": "image", "image": image},
299
- {"type": "text", "text": prompt}
300
- ]
301
- }
302
- ]
303
-
304
- text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
305
- inputs = processor(text=[text], images=[image], return_tensors="pt").to(device)
306
-
307
- with torch.no_grad():
308
- generated_ids = model.generate(
309
- **inputs,
310
- max_new_tokens=max_new_tokens,
311
- do_sample=False,
312
- temperature=0.1
313
- )
314
-
315
- generated_ids_trimmed = [
316
- out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
317
- ]
318
- output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)
319
- return output_text[0] if output_text else ""
320
 
321
- except Exception as e:
322
- print(f"Error during inference: {e}")
323
- traceback.print_exc()
324
- return f"Error during inference: {str(e)}"
325
-
326
- def process_image(
327
- model_name: str,
328
- image: Image.Image,
329
- min_pixels: Optional[int] = None,
330
- max_pixels: Optional[int] = None
331
- ) -> Dict[str, Any]:
332
- try:
333
- if min_pixels or max_pixels:
334
- image = fetch_image(image, min_pixels=min_pixels, max_pixels=max_pixels)
335
- raw_output = inference(model_name, image, prompt)
336
- result = {
337
- 'original_image': image,
338
- 'raw_output': raw_output,
339
- 'processed_image': image,
340
- 'layout_result': None,
341
- 'markdown_content': None
342
- }
343
- try:
344
- json_match = re.search(r'```json\s*([\s\S]+?)\s*```', raw_output)
345
- json_str = json_match.group(1) if json_match else raw_output
346
- layout_data = json.loads(json_str)
347
- result['layout_result'] = layout_data
348
- try:
349
- processed_image = draw_layout_on_image(image, layout_data)
350
- result['processed_image'] = processed_image
351
- except Exception as e:
352
- print(f"Error drawing layout: {e}")
353
- try:
354
- markdown_content = layoutjson2md(image, layout_data, text_key='text')
355
- result['markdown_content'] = markdown_content
356
- except Exception as e:
357
- print(f"Error generating markdown: {e}")
358
- result['markdown_content'] = raw_output
359
- except json.JSONDecodeError:
360
- print("Failed to parse JSON output, using raw output")
361
- result['markdown_content'] = raw_output
362
- return result
363
- except Exception as e:
364
- print(f"Error processing image: {e}")
365
- traceback.print_exc()
366
- return {
367
- 'original_image': image,
368
- 'raw_output': f"Error processing image: {str(e)}",
369
- 'processed_image': image,
370
- 'layout_result': None,
371
- 'markdown_content': f"Error processing image: {str(e)}"
372
- }
373
-
374
- def load_file_for_preview(file_path: str) -> Tuple[Optional[Image.Image], str]:
375
- global pdf_cache
376
- if not file_path or not os.path.exists(file_path):
377
- return None, "No file selected"
378
- file_ext = os.path.splitext(file_path)[1].lower()
379
  try:
380
- if file_ext == '.pdf':
381
- images = load_images_from_pdf(file_path)
382
- if not images:
383
- return None, "Failed to load PDF"
384
- pdf_cache.update({
385
- "images": images,
386
- "current_page": 0,
387
- "total_pages": len(images),
388
- "file_type": "pdf",
389
- "is_parsed": False,
390
- "results": []
391
- })
392
- return images[0], f"Page 1 / {len(images)}"
393
- elif file_ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']:
394
- image = Image.open(file_path).convert('RGB')
395
- pdf_cache.update({
396
- "images": [image],
397
- "current_page": 0,
398
- "total_pages": 1,
399
- "file_type": "image",
400
- "is_parsed": False,
401
- "results": []
402
- })
403
- return image, "Page 1 / 1"
404
- else:
405
- return None, f"Unsupported file format: {file_ext}"
406
  except Exception as e:
407
- print(f"Error loading file: {e}")
408
- return None, f"Error loading file: {str(e)}"
409
-
410
- def turn_page(direction: str) -> Tuple[Optional[Image.Image], str, Any, Optional[Image.Image], Optional[Dict]]:
411
- global pdf_cache
412
- if not pdf_cache["images"]:
413
- return None, '<div class="page-info">No file loaded</div>', "No results yet", None, None
414
- if direction == "prev":
415
- pdf_cache["current_page"] = max(0, pdf_cache["current_page"] - 1)
416
- elif direction == "next":
417
- pdf_cache["current_page"] = min(pdf_cache["total_pages"] - 1, pdf_cache["current_page"] + 1)
418
- index = pdf_cache["current_page"]
419
- current_image_preview = pdf_cache["images"][index]
420
- page_info_html = f'<div class="page-info">Page {index + 1} / {pdf_cache["total_pages"]}</div>'
421
- markdown_content = "Page not processed yet"
422
- processed_img = None
423
- layout_json = None
424
- if pdf_cache["is_parsed"] and index < len(pdf_cache["results"]) and pdf_cache["results"][index]:
425
- result = pdf_cache["results"][index]
426
- markdown_content = result.get('markdown_content') or result.get('raw_output', 'No content available')
427
- processed_img = result.get('processed_image', None)
428
- layout_json = result.get('layout_result', None)
429
- if is_arabic_text(markdown_content):
430
- markdown_update = gr.update(value=markdown_content, rtl=True)
431
- else:
432
- markdown_update = markdown_content
433
- return current_image_preview, page_info_html, markdown_update, processed_img, layout_json
434
 
435
  def create_gradio_interface():
 
436
  css = """
437
  .main-container { max-width: 1400px; margin: 0 auto; }
438
  .header-text { text-align: center; color: #2c3e50; margin-bottom: 20px; }
439
  .process-button {
440
- border: none !important;
441
- color: white !important;
442
- font-weight: bold !important;
443
- background-color: blue !important;}
444
  .process-button:hover {
445
- background-color: darkblue !important;
446
- transform: translateY(-2px) !important;
447
- box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; }
448
- .info-box { border: 1px solid #dee2e6; border-radius: 8px; padding: 15px; margin: 10px 0; }
449
- .page-info { text-align: center; padding: 8px 16px; border-radius: 20px; font-weight: bold; margin: 10px 0; }
450
- .model-status { padding: 10px; border-radius: 8px; margin: 10px 0; text-align: center; font-weight: bold; }
451
- .status-ready { background: #d1edff; color: #0c5460; border: 1px solid #b8daff; }
452
  """
453
  with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
454
  gr.HTML("""
455
  <div class="title" style="text-align: center">
456
  <h1>Dot<span style="color: red;">●</span><strong></strong>OCR Comparator</h1>
457
  <p style="font-size: 1.1em; color: #6b7280; margin-bottom: 0.6em;">
458
- Advanced vision-language model for image/PDF to markdown document processing
459
  </p>
460
  </div>
461
  """)
 
 
 
 
462
  with gr.Row():
 
463
  with gr.Column(scale=1):
464
  model_choice = gr.Radio(
465
  choices=["Camel-Doc-OCR-062825", "MonkeyOCR-Recognition", "Nanonets-OCR-s", "Megalodon-OCR-Sync-0713"],
466
  label="Select Model",
467
  value="Camel-Doc-OCR-062825"
468
  )
469
- file_input = gr.File(
470
- label="Upload Image or PDF",
471
- file_types=[".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".pdf"],
472
- type="filepath"
473
  )
474
- image_preview = gr.Image(label="Preview", type="pil", interactive=False, height=300)
475
- with gr.Row():
476
- prev_page_btn = gr.Button("β—€ Previous", size="md")
477
- page_info = gr.HTML('<div class="page-info">No file loaded</div>')
478
- next_page_btn = gr.Button("Next β–Ά", size="md")
479
  with gr.Accordion("Advanced Settings", open=False):
480
  max_new_tokens = gr.Slider(minimum=1000, maximum=32000, value=24000, step=1000, label="Max New Tokens")
481
- min_pixels = gr.Number(value=MIN_PIXELS, label="Min Pixels")
482
- max_pixels = gr.Number(value=MAX_PIXELS, label="Max Pixels")
483
  process_btn = gr.Button("πŸš€ Process Document", variant="primary", elem_classes=["process-button"], size="lg")
484
  clear_btn = gr.Button("πŸ—‘οΈ Clear All", variant="secondary")
 
 
485
  with gr.Column(scale=2):
486
  with gr.Tabs():
487
- with gr.Tab("πŸ–ΌοΈ Processed Image"):
488
- processed_image = gr.Image(label="Image with Layout Detection", type="pil", interactive=False, height=500)
489
  with gr.Tab("πŸ“ Extracted Content"):
490
- markdown_output = gr.Markdown(value="Click 'Process Document' to see extracted content...", height=500)
 
 
 
491
  with gr.Tab("πŸ“‹ Layout JSON"):
492
- json_output = gr.JSON(label="Layout Analysis Results", value=None)
493
- def process_document(model_name, file_path, max_tokens, min_pix, max_pix):
494
- global pdf_cache
495
- try:
496
- if not file_path:
497
- return None, "Please upload a file first.", None
498
- load_file_for_preview(file_path)
499
- if pdf_cache["file_type"] == "pdf":
500
- all_results = []
501
- all_markdown = []
502
- for i, img in enumerate(pdf_cache["images"]):
503
- result = process_image(model_name, img, min_pixels=int(min_pix) if min_pix else None, max_pixels=int(max_pix) if max_pix else None)
504
- all_results.append(result)
505
- if result.get('markdown_content'):
506
- all_markdown.append(f"## Page {i+1}\n\n{result['markdown_content']}")
507
- pdf_cache["results"] = all_results
508
- pdf_cache["is_parsed"] = True
509
- first_result = all_results[0]
510
- combined_markdown = "\n\n---\n\n".join(all_markdown)
511
- if is_arabic_text(combined_markdown):
512
- markdown_update = gr.update(value=combined_markdown, rtl=True)
513
- else:
514
- markdown_update = combined_markdown
515
- return first_result['processed_image'], markdown_update, first_result['layout_result']
516
- else:
517
- result = process_image(model_name, pdf_cache["images"][0], min_pixels=int(min_pix) if min_pix else None, max_pixels=int(max_pix) if max_pix else None)
518
- pdf_cache["results"] = [result]
519
- pdf_cache["is_parsed"] = True
520
- content = result['markdown_content'] or "No content extracted"
521
- if is_arabic_text(content):
522
- markdown_update = gr.update(value=content, rtl=True)
523
- else:
524
- markdown_update = content
525
- return result['processed_image'], markdown_update, result['layout_result']
526
- except Exception as e:
527
- error_msg = f"Error processing document: {str(e)}"
528
- print(error_msg)
529
- traceback.print_exc()
530
- return None, error_msg, None
531
- def handle_file_upload(file_path):
532
- if not file_path:
533
- return None, "No file loaded"
534
- image, page_info = load_file_for_preview(file_path)
535
- return image, page_info
536
  def clear_all():
537
- global pdf_cache
538
- pdf_cache = {"images": [], "current_page": 0, "total_pages": 0, "file_type": None, "is_parsed": False, "results": []}
539
- return None, None, '<div class="page-info">No file loaded</div>', None, "Click 'Process Document' to see extracted content...", None
540
- file_input.change(handle_file_upload, inputs=[file_input], outputs=[image_preview, page_info])
541
- prev_page_btn.click(lambda: turn_page("prev"), outputs=[image_preview, page_info, markdown_output, processed_image, json_output])
542
- next_page_btn.click(lambda: turn_page("next"), outputs=[image_preview, page_info, markdown_output, processed_image, json_output])
543
- process_btn.click(process_document, inputs=[model_choice, file_input, max_new_tokens, min_pixels, max_pixels], outputs=[processed_image, markdown_output, json_output])
544
- clear_btn.click(clear_all, outputs=[file_input, image_preview, page_info, processed_image, markdown_output, json_output])
 
 
 
 
 
 
 
 
545
  return demo
546
 
547
  if __name__ == "__main__":
548
  demo = create_gradio_interface()
549
- demo.queue(max_size=10).launch(server_name="0.0.0.0", server_port=7860, share=False, debug=True, show_error=True)
 
6
  from io import BytesIO
7
  from typing import Any, Dict, List, Optional, Tuple
8
  import re
9
+ from threading import Thread
10
+ import time
11
 
 
12
  import gradio as gr
13
  import requests
14
  import torch
 
18
  AutoProcessor,
19
  TextIteratorStreamer,
20
  )
 
21
 
22
  # Constants
23
  MIN_PIXELS = 3136
24
  MAX_PIXELS = 11289600
25
  IMAGE_FACTOR = 28
26
+ MAX_INPUT_TOKEN_LENGTH = 4096
27
  device = "cuda" if torch.cuda.is_available() else "cpu"
28
 
29
+ # Prompt for Layout Analysis
30
  prompt = """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
31
 
32
  1. Bbox format: [x1, y1, x2, y2]
 
46
  5. Final Output: The entire output must be a single JSON object.
47
  """
48
 
49
+ # Load Models
50
  MODEL_ID_M = "prithivMLmods/Camel-Doc-OCR-062825"
51
  processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
52
  model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
53
+ MODEL_ID_M, trust_remote_code=True, torch_dtype=torch.float16
 
 
54
  ).to(device).eval()
55
 
56
  MODEL_ID_T = "prithivMLmods/Megalodon-OCR-Sync-0713"
57
  processor_t = AutoProcessor.from_pretrained(MODEL_ID_T, trust_remote_code=True)
58
  model_t = Qwen2_5_VLForConditionalGeneration.from_pretrained(
59
+ MODEL_ID_T, trust_remote_code=True, torch_dtype=torch.float16
 
 
60
  ).to(device).eval()
61
 
62
  MODEL_ID_C = "nanonets/Nanonets-OCR-s"
63
  processor_c = AutoProcessor.from_pretrained(MODEL_ID_C, trust_remote_code=True)
64
  model_c = Qwen2_5_VLForConditionalGeneration.from_pretrained(
65
+ MODEL_ID_C, trust_remote_code=True, torch_dtype=torch.float16
 
 
66
  ).to(device).eval()
67
 
68
  MODEL_ID_G = "echo840/MonkeyOCR"
69
  SUBFOLDER = "Recognition"
70
  processor_g = AutoProcessor.from_pretrained(
71
+ MODEL_ID_G, trust_remote_code=True, subfolder=SUBFOLDER
 
 
72
  )
73
  model_g = Qwen2_5_VLForConditionalGeneration.from_pretrained(
74
+ MODEL_ID_G, trust_remote_code=True, subfolder=SUBFOLDER, torch_dtype=torch.float16
 
 
 
75
  ).to(device).eval()
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
+ # Utility functions
79
  def is_arabic_text(text: str) -> bool:
80
+ """Check if text contains mostly Arabic characters."""
81
  if not text:
82
  return False
83
+ # Simplified check for Arabic characters in the given text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  arabic_chars = 0
85
  total_chars = 0
86
+ for char in text:
87
  if char.isalpha():
88
  total_chars += 1
89
+ if '\u0600' <= char <= '\u06FF':
90
  arabic_chars += 1
91
  return total_chars > 0 and (arabic_chars / total_chars) > 0.5
92
 
93
  def layoutjson2md(image: Image.Image, layout_data: List[Dict], text_key: str = 'text') -> str:
94
+ """Convert layout JSON to markdown format."""
95
  import base64
96
  from io import BytesIO
97
  markdown_lines = []
98
  try:
99
+ # Sort items by reading order (top to bottom, left to right)
100
  sorted_items = sorted(layout_data, key=lambda x: (x.get('bbox', [0, 0, 0, 0])[1], x.get('bbox', [0, 0, 0, 0])[0]))
101
  for item in sorted_items:
102
  category = item.get('category', '')
103
  text = item.get(text_key, '')
104
  bbox = item.get('bbox', [])
105
+
106
  if category == 'Picture':
107
  if bbox and len(bbox) == 4:
108
  try:
109
+ x1, y1, x2, y2 = [int(coord) for coord in bbox]
110
+ cropped_img = image.crop((x1, y1, x2, y2))
111
+ buffer = BytesIO()
112
+ cropped_img.save(buffer, format='PNG')
113
+ img_data = base64.b64encode(buffer.getvalue()).decode()
114
+ markdown_lines.append(f"![Image](data:image/png;base64,{img_data})\n")
 
 
 
 
 
115
  except Exception as e:
116
+ markdown_lines.append("![Image](Image region detected)\n")
 
 
 
117
  elif not text:
118
  continue
119
  elif category == 'Title':
 
124
  markdown_lines.append(f"{text}\n")
125
  elif category == 'List-item':
126
  markdown_lines.append(f"- {text}\n")
127
+ elif category == 'Table' and text.strip().startswith('<'):
128
+ markdown_lines.append(f"{text}\n")
129
+ elif category == 'Formula' and (text.strip().startswith('$') or '\\' in text):
130
+ markdown_lines.append(f"$$\n{text}\n$$\n")
 
 
 
 
 
 
131
  elif category == 'Caption':
132
  markdown_lines.append(f"*{text}*\n")
133
  elif category == 'Footnote':
134
+ markdown_lines.append(f"^{text}^\n")
135
+ elif category not in ['Page-header', 'Page-footer']:
 
 
136
  markdown_lines.append(f"{text}\n")
 
137
  except Exception as e:
138
  print(f"Error converting to markdown: {e}")
139
+ return f"### Error converting to Markdown\n\n```\n{str(layout_data)}\n```"
140
  return "\n".join(markdown_lines)
141
 
 
 
 
 
 
 
 
 
 
142
 
143
  @spaces.GPU
144
+ def generate_and_process(model_name: str, image: Image.Image, max_new_tokens: int):
145
+ """
146
+ Generates a response using streaming, then processes the final output.
147
+ Yields updates for the raw stream, final markdown, and JSON output.
148
+ """
149
+ if image is None:
150
+ yield "Please upload an image.", "Please upload an image.", None
151
+ return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
+ # 1. Select Model and Processor
154
+ if model_name == "Camel-Doc-OCR-062825":
155
+ processor, model = processor_m, model_m
156
+ elif model_name == "Megalodon-OCR-Sync-0713":
157
+ processor, model = processor_t, model_t
158
+ elif model_name == "Nanonets-OCR-s":
159
+ processor, model = processor_c, model_c
160
+ elif model_name == "MonkeyOCR-Recognition":
161
+ processor, model = processor_g, model_g
162
+ else:
163
+ yield "Invalid model selected.", "Invalid model selected.", None
164
+ return
165
+
166
+ # 2. Prepare inputs for the model
167
+ messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt}]}]
168
+ prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
169
+ inputs = processor(
170
+ text=[prompt_full],
171
+ images=[image],
172
+ return_tensors="pt",
173
+ padding=True,
174
+ truncation=True,
175
+ max_length=MAX_INPUT_TOKEN_LENGTH
176
+ ).to(device)
177
+
178
+ # 3. Stream the generation
179
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
180
+ generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
181
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
182
+ thread.start()
183
+
184
+ buffer = ""
185
+ # Initial placeholder yield
186
+ yield buffer, "⏳ Generating response...", None
187
+
188
+ for new_text in streamer:
189
+ buffer += new_text
190
+ buffer = buffer.replace("<|im_end|>", "")
191
+ time.sleep(0.01) # Small delay for smoother streaming
192
+ yield buffer, "⏳ Generating response...", None
193
+
194
+ # 4. Process the final buffer content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  try:
196
+ json_match = re.search(r'```json\s*([\s\S]+?)\s*```', buffer)
197
+ json_str = json_match.group(1) if json_match else buffer
198
+ layout_data = json.loads(json_str)
199
+
200
+ markdown_content = layoutjson2md(image, layout_data)
201
+
202
+ # Final yield with all processed content
203
+ yield buffer, markdown_content, layout_data
204
+
205
+ except json.JSONDecodeError:
206
+ error_msg = "❌ Failed to parse JSON from model output."
207
+ yield buffer, error_msg, {"error": "JSONDecodeError", "raw_output": buffer}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  except Exception as e:
209
+ error_msg = f"❌ An error occurred during post-processing: {e}"
210
+ yield buffer, error_msg, {"error": str(e), "raw_output": buffer}
211
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
  def create_gradio_interface():
214
+ """Create the Gradio interface."""
215
  css = """
216
  .main-container { max-width: 1400px; margin: 0 auto; }
217
  .header-text { text-align: center; color: #2c3e50; margin-bottom: 20px; }
218
  .process-button {
219
+ border: none !important; color: white !important; font-weight: bold !important;
220
+ background-color: blue !important;
221
+ }
 
222
  .process-button:hover {
223
+ background-color: darkblue !important; transform: translateY(-2px) !important;
224
+ box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important;
225
+ }
 
 
 
 
226
  """
227
  with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
228
  gr.HTML("""
229
  <div class="title" style="text-align: center">
230
  <h1>Dot<span style="color: red;">●</span><strong></strong>OCR Comparator</h1>
231
  <p style="font-size: 1.1em; color: #6b7280; margin-bottom: 0.6em;">
232
+ Advanced vision-language model for image to markdown document processing
233
  </p>
234
  </div>
235
  """)
236
+
237
+ # Keep track of the uploaded image
238
+ image_state = gr.State(None)
239
+
240
  with gr.Row():
241
+ # Left column - Input and controls
242
  with gr.Column(scale=1):
243
  model_choice = gr.Radio(
244
  choices=["Camel-Doc-OCR-062825", "MonkeyOCR-Recognition", "Nanonets-OCR-s", "Megalodon-OCR-Sync-0713"],
245
  label="Select Model",
246
  value="Camel-Doc-OCR-062825"
247
  )
248
+ file_input = gr.Image(
249
+ label="Upload Image",
250
+ type="pil",
251
+ sources=['upload']
252
  )
 
 
 
 
 
253
  with gr.Accordion("Advanced Settings", open=False):
254
  max_new_tokens = gr.Slider(minimum=1000, maximum=32000, value=24000, step=1000, label="Max New Tokens")
255
+
 
256
  process_btn = gr.Button("πŸš€ Process Document", variant="primary", elem_classes=["process-button"], size="lg")
257
  clear_btn = gr.Button("πŸ—‘οΈ Clear All", variant="secondary")
258
+
259
+ # Right column - Results
260
  with gr.Column(scale=2):
261
  with gr.Tabs():
 
 
262
  with gr.Tab("πŸ“ Extracted Content"):
263
+ output_stream = gr.Textbox(label="Raw Output Stream", interactive=False, lines=10, show_copy_button=True)
264
+ with gr.Accordion("(Formatted Result)", open=True):
265
+ markdown_output = gr.Markdown(label="Formatted Result (Result.md)")
266
+
267
  with gr.Tab("πŸ“‹ Layout JSON"):
268
+ json_output = gr.JSON(label="Layout Analysis Results (JSON)", value=None)
269
+
270
+ # Event Handlers
271
+ def handle_file_upload(image):
272
+ """Store the uploaded image in the state."""
273
+ return image
274
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
  def clear_all():
276
+ """Clear all data and reset the interface."""
277
+ return None, None, "Click 'Process Document' to see extracted content...", None, None
278
+
279
+ file_input.upload(handle_file_upload, inputs=[file_input], outputs=[image_state])
280
+
281
+ process_btn.click(
282
+ generate_and_process,
283
+ inputs=[model_choice, image_state, max_new_tokens],
284
+ outputs=[output_stream, markdown_output, json_output]
285
+ )
286
+
287
+ clear_btn.click(
288
+ clear_all,
289
+ outputs=[file_input, image_state, markdown_output, json_output, output_stream]
290
+ )
291
+
292
  return demo
293
 
294
  if __name__ == "__main__":
295
  demo = create_gradio_interface()
296
+ demo.queue().launch(server_name="0.0.0.0", server_port=7860, share=True, show_error=True)