prithivMLmods commited on
Commit
9ebf911
·
verified ·
1 Parent(s): a007e7b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -408
app.py CHANGED
@@ -11,7 +11,6 @@ import fitz
11
  import gradio as gr
12
  import requests
13
  import torch
14
- from huggingface_hub import snapshot_download
15
  from PIL import Image, ImageDraw, ImageFont
16
  from transformers import (
17
  Qwen2_5_VLForConditionalGeneration,
@@ -46,7 +45,7 @@ prompt = """Please output the layout information from the PDF image, including e
46
  5. Final Output: The entire output must be a single JSON object.
47
  """
48
 
49
- # Load Camel-Doc-OCR-062825
50
  MODEL_ID_M = "prithivMLmods/Camel-Doc-OCR-062825"
51
  processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
52
  model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
@@ -55,7 +54,6 @@ model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
55
  torch_dtype=torch.float16
56
  ).to(device).eval()
57
 
58
- # Load Megalodon-OCR-Sync-0713
59
  MODEL_ID_T = "prithivMLmods/Megalodon-OCR-Sync-0713"
60
  processor_t = AutoProcessor.from_pretrained(MODEL_ID_T, trust_remote_code=True)
61
  model_t = Qwen2_5_VLForConditionalGeneration.from_pretrained(
@@ -64,7 +62,6 @@ model_t = Qwen2_5_VLForConditionalGeneration.from_pretrained(
64
  torch_dtype=torch.float16
65
  ).to(device).eval()
66
 
67
- # Load Nanonets-OCR-s
68
  MODEL_ID_C = "nanonets/Nanonets-OCR-s"
69
  processor_c = AutoProcessor.from_pretrained(MODEL_ID_C, trust_remote_code=True)
70
  model_c = Qwen2_5_VLForConditionalGeneration.from_pretrained(
@@ -73,7 +70,6 @@ model_c = Qwen2_5_VLForConditionalGeneration.from_pretrained(
73
  torch_dtype=torch.float16
74
  ).to(device).eval()
75
 
76
- # Load MonkeyOCR
77
  MODEL_ID_G = "echo840/MonkeyOCR"
78
  SUBFOLDER = "Recognition"
79
  processor_g = AutoProcessor.from_pretrained(
@@ -88,13 +84,10 @@ model_g = Qwen2_5_VLForConditionalGeneration.from_pretrained(
88
  torch_dtype=torch.float16
89
  ).to(device).eval()
90
 
91
-
92
  # Utility functions
93
  def round_by_factor(number: int, factor: int) -> int:
94
- """Returns the closest integer to 'number' that is divisible by 'factor'."""
95
  return round(number / factor) * factor
96
 
97
-
98
  def smart_resize(
99
  height: int,
100
  width: int,
@@ -102,18 +95,10 @@ def smart_resize(
102
  min_pixels: int = 3136,
103
  max_pixels: int = 11289600,
104
  ):
105
- """Rescales the image so that the following conditions are met:
106
- 1. Both dimensions (height and width) are divisible by 'factor'.
107
- 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
108
- 3. The aspect ratio of the image is maintained as closely as possible.
109
- """
110
  if max(height, width) / min(height, width) > 200:
111
- raise ValueError(
112
- f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
113
- )
114
  h_bar = max(factor, round_by_factor(height, factor))
115
  w_bar = max(factor, round_by_factor(width, factor))
116
-
117
  if h_bar * w_bar > max_pixels:
118
  beta = math.sqrt((height * width) / max_pixels)
119
  h_bar = round_by_factor(height / beta, factor)
@@ -124,9 +109,7 @@ def smart_resize(
124
  w_bar = round_by_factor(width * beta, factor)
125
  return h_bar, w_bar
126
 
127
-
128
  def fetch_image(image_input, min_pixels: int = None, max_pixels: int = None):
129
- """Fetch and process an image"""
130
  if isinstance(image_input, str):
131
  if image_input.startswith(("http://", "https://")):
132
  response = requests.get(image_input)
@@ -137,31 +120,20 @@ def fetch_image(image_input, min_pixels: int = None, max_pixels: int = None):
137
  image = image_input.convert('RGB')
138
  else:
139
  raise ValueError(f"Invalid image input type: {type(image_input)}")
140
-
141
- if min_pixels is not None or max_pixels is not None:
142
  min_pixels = min_pixels or MIN_PIXELS
143
  max_pixels = max_pixels or MAX_PIXELS
144
- height, width = smart_resize(
145
- image.height,
146
- image.width,
147
- factor=IMAGE_FACTOR,
148
- min_pixels=min_pixels,
149
- max_pixels=max_pixels
150
- )
151
  image = image.resize((width, height), Image.LANCZOS)
152
-
153
  return image
154
 
155
-
156
  def load_images_from_pdf(pdf_path: str) -> List[Image.Image]:
157
- """Load images from PDF file"""
158
  images = []
159
  try:
160
  pdf_document = fitz.open(pdf_path)
161
  for page_num in range(len(pdf_document)):
162
  page = pdf_document.load_page(page_num)
163
- # Convert page to image
164
- mat = fitz.Matrix(2.0, 2.0) # Increase resolution
165
  pix = page.get_pixmap(matrix=mat)
166
  img_data = pix.tobytes("ppm")
167
  image = Image.open(BytesIO(img_data)).convert('RGB')
@@ -169,157 +141,86 @@ def load_images_from_pdf(pdf_path: str) -> List[Image.Image]:
169
  pdf_document.close()
170
  except Exception as e:
171
  print(f"Error loading PDF: {e}")
172
- return []
173
  return images
174
 
175
-
176
  def draw_layout_on_image(image: Image.Image, layout_data: List[Dict]) -> Image.Image:
177
- """Draw layout bounding boxes on image"""
178
  img_copy = image.copy()
179
  draw = ImageDraw.Draw(img_copy)
180
-
181
- # Colors for different categories
182
  colors = {
183
- 'Caption': '#FF6B6B',
184
- 'Footnote': '#4ECDC4',
185
- 'Formula': '#45B7D1',
186
- 'List-item': '#96CEB4',
187
- 'Page-footer': '#FFEAA7',
188
- 'Page-header': '#DDA0DD',
189
- 'Picture': '#FFD93D',
190
- 'Section-header': '#6C5CE7',
191
- 'Table': '#FD79A8',
192
- 'Text': '#74B9FF',
193
- 'Title': '#E17055'
194
  }
195
-
196
  try:
197
- # Load a font
198
- try:
199
- font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 12)
200
- except Exception:
201
- font = ImageFont.load_default()
202
-
203
- for item in layout_data:
204
- if 'bbox' in item and 'category' in item:
205
- bbox = item['bbox']
206
- category = item['category']
207
- color = colors.get(category, '#000000')
208
-
209
- # Draw rectangle
210
- draw.rectangle(bbox, outline=color, width=2)
211
-
212
- # Draw label
213
- label = category
214
- label_bbox = draw.textbbox((0, 0), label, font=font)
215
- label_width = label_bbox[2] - label_bbox[0]
216
- label_height = label_bbox[3] - label_bbox[1]
217
-
218
- # Position label above the box
219
- label_x = bbox[0]
220
- label_y = max(0, bbox[1] - label_height - 2)
221
-
222
- # Draw background for label
223
- draw.rectangle(
224
- [label_x, label_y, label_x + label_width + 4, label_y + label_height + 2],
225
- fill=color
226
- )
227
-
228
- # Draw text
229
- draw.text((label_x + 2, label_y + 1), label, fill='white', font=font)
230
-
231
- except Exception as e:
232
- print(f"Error drawing layout: {e}")
233
-
234
  return img_copy
235
 
236
-
237
  def is_arabic_text(text: str) -> bool:
238
- """Check if text in headers and paragraphs contains mostly Arabic characters"""
239
  if not text:
240
  return False
241
-
242
- # Extract text from headers and paragraphs only
243
- # Match markdown headers (# ## ###) and regular paragraph text
244
  header_pattern = r'^#{1,6}\s+(.+)$'
245
  paragraph_pattern = r'^(?!#{1,6}\s|!\[|```|\||\s*[-*+]\s|\s*\d+\.\s)(.+)$'
246
-
247
  content_text = []
248
-
249
  for line in text.split('\n'):
250
  line = line.strip()
251
  if not line:
252
  continue
253
-
254
- # Check for headers
255
  header_match = re.match(header_pattern, line, re.MULTILINE)
256
  if header_match:
257
  content_text.append(header_match.group(1))
258
  continue
259
-
260
- # Check for paragraph text (exclude lists, tables, code blocks, images)
261
  if re.match(paragraph_pattern, line, re.MULTILINE):
262
  content_text.append(line)
263
-
264
  if not content_text:
265
  return False
266
-
267
- # Join all content text and check for Arabic characters
268
  combined_text = ' '.join(content_text)
269
-
270
- # Arabic Unicode ranges
271
  arabic_chars = 0
272
  total_chars = 0
273
-
274
  for char in combined_text:
275
  if char.isalpha():
276
  total_chars += 1
277
- # Arabic script ranges
278
  if ('\u0600' <= char <= '\u06FF') or ('\u0750' <= char <= '\u077F') or ('\u08A0' <= char <= '\u08FF'):
279
  arabic_chars += 1
280
-
281
- if total_chars == 0:
282
- return False
283
-
284
- # Consider text as Arabic if more than 50% of alphabetic characters are Arabic
285
- return (arabic_chars / total_chars) > 0.5
286
-
287
 
288
  def layoutjson2md(image: Image.Image, layout_data: List[Dict], text_key: str = 'text') -> str:
289
- """Convert layout JSON to markdown format"""
290
  import base64
291
  from io import BytesIO
292
-
293
  markdown_lines = []
294
-
295
  try:
296
- # Sort items by reading order (top to bottom, left to right)
297
- sorted_items = sorted(layout_data, key=lambda x: (x.get('bbox',), x.get('bbox',)))
298
-
299
  for item in sorted_items:
300
  category = item.get('category', '')
301
  text = item.get(text_key, '')
302
  bbox = item.get('bbox', [])
303
-
304
  if category == 'Picture':
305
- # Extract image region and embed it
306
  if bbox and len(bbox) == 4:
307
  try:
308
- # Extract the image region
309
  x1, y1, x2, y2 = bbox
310
- # Ensure coordinates are within image bounds
311
  x1, y1 = max(0, int(x1)), max(0, int(y1))
312
  x2, y2 = min(image.width, int(x2)), min(image.height, int(y2))
313
-
314
  if x2 > x1 and y2 > y1:
315
  cropped_img = image.crop((x1, y1, x2, y2))
316
-
317
- # Convert to base64 for embedding
318
  buffer = BytesIO()
319
  cropped_img.save(buffer, format='PNG')
320
  img_data = base64.b64encode(buffer.getvalue()).decode()
321
-
322
- # Add as markdown image
323
  markdown_lines.append(f"![Image](data:image/png;base64,{img_data})\n")
324
  else:
325
  markdown_lines.append("![Image](Image region detected)\n")
@@ -339,13 +240,11 @@ def layoutjson2md(image: Image.Image, layout_data: List[Dict], text_key: str = '
339
  elif category == 'List-item':
340
  markdown_lines.append(f"- {text}\n")
341
  elif category == 'Table':
342
- # If text is already HTML, keep it as is
343
  if text.strip().startswith('<'):
344
  markdown_lines.append(f"{text}\n")
345
  else:
346
  markdown_lines.append(f"**Table:** {text}\n")
347
  elif category == 'Formula':
348
- # If text is LaTeX, format it properly
349
  if text.strip().startswith('$') or '\\' in text:
350
  markdown_lines.append(f"$$\n{text}\n$$\n")
351
  else:
@@ -355,20 +254,15 @@ def layoutjson2md(image: Image.Image, layout_data: List[Dict], text_key: str = '
355
  elif category == 'Footnote':
356
  markdown_lines.append(f"^{text}^\n")
357
  elif category in ['Page-header', 'Page-footer']:
358
- # Skip headers and footers in main content
359
  continue
360
  else:
361
  markdown_lines.append(f"{text}\n")
362
-
363
- markdown_lines.append("") # Add spacing
364
-
365
  except Exception as e:
366
  print(f"Error converting to markdown: {e}")
367
  return str(layout_data)
368
-
369
  return "\n".join(markdown_lines)
370
 
371
-
372
  # PDF handling state
373
  pdf_cache = {
374
  "images": [],
@@ -378,9 +272,9 @@ pdf_cache = {
378
  "is_parsed": False,
379
  "results": []
380
  }
 
381
  @spaces.GPU
382
- def inference(model_name: str, image: Image.Image, prompt: str, max_new_tokens: int = 24000) -> str:
383
- """Run inference on an image with the given prompt using the selected model."""
384
  try:
385
  if model_name == "Camel-Doc-OCR-062825":
386
  processor = processor_m
@@ -401,23 +295,14 @@ def inference(model_name: str, image: Image.Image, prompt: str, max_new_tokens:
401
  {
402
  "role": "user",
403
  "content": [
404
- {"type": "text", "text": prompt},
405
- {"type": "image"}
406
  ]
407
  }
408
  ]
409
-
410
- text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
411
- image_inputs, video_inputs = process_vision_info(messages)
412
-
413
- inputs = processor(
414
- text=[text],
415
- images=[image],
416
- videos=video_inputs,
417
- padding=True,
418
- return_tensors="pt"
419
- ).to(device)
420
 
 
 
421
 
422
  with torch.no_grad():
423
  generated_ids = model.generate(
@@ -427,34 +312,27 @@ def inference(model_name: str, image: Image.Image, prompt: str, max_new_tokens:
427
  temperature=0.1
428
  )
429
 
430
- generated_ids = [
431
  out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
432
  ]
433
- output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
434
- return output_text
435
 
436
  except Exception as e:
437
  print(f"Error during inference: {e}")
438
  traceback.print_exc()
439
  return f"Error during inference: {str(e)}"
440
 
441
-
442
  def process_image(
443
  model_name: str,
444
  image: Image.Image,
445
  min_pixels: Optional[int] = None,
446
  max_pixels: Optional[int] = None
447
  ) -> Dict[str, Any]:
448
- """Process a single image with the specified prompt mode"""
449
  try:
450
- # Resize image if needed
451
- if min_pixels is not None or max_pixels is not None:
452
  image = fetch_image(image, min_pixels=min_pixels, max_pixels=max_pixels)
453
-
454
- # Run inference with the default prompt
455
  raw_output = inference(model_name, image, prompt)
456
-
457
- # Process results based on prompt mode
458
  result = {
459
  'original_image': image,
460
  'raw_output': raw_output,
@@ -462,42 +340,26 @@ def process_image(
462
  'layout_result': None,
463
  'markdown_content': None
464
  }
465
-
466
- # Try to parse JSON and create visualizations (since we're doing layout analysis)
467
  try:
468
- # Clean the output to be valid JSON
469
- # Models sometimes add ```json ... ``` markers
470
  json_match = re.search(r'```json\s*([\s\S]+?)\s*```', raw_output)
471
- if json_match:
472
- json_str = json_match.group(1)
473
- else:
474
- json_str = raw_output
475
-
476
  layout_data = json.loads(json_str)
477
  result['layout_result'] = layout_data
478
-
479
- # Create visualization with bounding boxes
480
  try:
481
  processed_image = draw_layout_on_image(image, layout_data)
482
  result['processed_image'] = processed_image
483
  except Exception as e:
484
  print(f"Error drawing layout: {e}")
485
- result['processed_image'] = image
486
-
487
- # Generate markdown from layout data
488
  try:
489
  markdown_content = layoutjson2md(image, layout_data, text_key='text')
490
  result['markdown_content'] = markdown_content
491
  except Exception as e:
492
  print(f"Error generating markdown: {e}")
493
  result['markdown_content'] = raw_output
494
-
495
  except json.JSONDecodeError:
496
  print("Failed to parse JSON output, using raw output")
497
  result['markdown_content'] = raw_output
498
-
499
  return result
500
-
501
  except Exception as e:
502
  print(f"Error processing image: {e}")
503
  traceback.print_exc()
@@ -509,24 +371,16 @@ def process_image(
509
  'markdown_content': f"Error processing image: {str(e)}"
510
  }
511
 
512
-
513
  def load_file_for_preview(file_path: str) -> Tuple[Optional[Image.Image], str]:
514
- """Load file for preview (supports PDF and images)"""
515
  global pdf_cache
516
-
517
  if not file_path or not os.path.exists(file_path):
518
  return None, "No file selected"
519
-
520
- # FIX 1: Access the second element of the tuple returned by os.path.splitext
521
- file_ext = os.path.splitext(file_path).lower()
522
-
523
  try:
524
  if file_ext == '.pdf':
525
- # Load PDF pages
526
  images = load_images_from_pdf(file_path)
527
  if not images:
528
  return None, "Failed to load PDF"
529
-
530
  pdf_cache.update({
531
  "images": images,
532
  "current_page": 0,
@@ -535,14 +389,9 @@ def load_file_for_preview(file_path: str) -> Tuple[Optional[Image.Image], str]:
535
  "is_parsed": False,
536
  "results": []
537
  })
538
-
539
- # FIX 2: Return only the first image for the preview component
540
- return images, f"Page 1 / {len(images)}"
541
-
542
  elif file_ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']:
543
- # Load single image
544
  image = Image.open(file_path).convert('RGB')
545
-
546
  pdf_cache.update({
547
  "images": [image],
548
  "current_page": 0,
@@ -551,73 +400,50 @@ def load_file_for_preview(file_path: str) -> Tuple[Optional[Image.Image], str]:
551
  "is_parsed": False,
552
  "results": []
553
  })
554
-
555
  return image, "Page 1 / 1"
556
  else:
557
  return None, f"Unsupported file format: {file_ext}"
558
-
559
  except Exception as e:
560
  print(f"Error loading file: {e}")
561
  return None, f"Error loading file: {str(e)}"
562
 
563
-
564
  def turn_page(direction: str) -> Tuple[Optional[Image.Image], str, Any, Optional[Image.Image], Optional[Dict]]:
565
- """Navigate through PDF pages and update all relevant outputs."""
566
  global pdf_cache
567
-
568
  if not pdf_cache["images"]:
569
  return None, '<div class="page-info">No file loaded</div>', "No results yet", None, None
570
-
571
  if direction == "prev":
572
  pdf_cache["current_page"] = max(0, pdf_cache["current_page"] - 1)
573
  elif direction == "next":
574
- pdf_cache["current_page"] = min(
575
- pdf_cache["total_pages"] - 1,
576
- pdf_cache["current_page"] + 1
577
- )
578
-
579
  index = pdf_cache["current_page"]
580
  current_image_preview = pdf_cache["images"][index]
581
  page_info_html = f'<div class="page-info">Page {index + 1} / {pdf_cache["total_pages"]}</div>'
582
-
583
- # Initialize default result values
584
  markdown_content = "Page not processed yet"
585
  processed_img = None
586
  layout_json = None
587
-
588
- # Get results for current page if available
589
- if (pdf_cache["is_parsed"] and
590
- index < len(pdf_cache["results"]) and
591
- pdf_cache["results"][index]):
592
-
593
  result = pdf_cache["results"][index]
594
  markdown_content = result.get('markdown_content') or result.get('raw_output', 'No content available')
595
- processed_img = result.get('processed_image', None) # Get the processed image
596
- layout_json = result.get('layout_result', None) # Get the layout JSON
597
-
598
- # Check for Arabic text to set RTL property
599
  if is_arabic_text(markdown_content):
600
  markdown_update = gr.update(value=markdown_content, rtl=True)
601
  else:
602
  markdown_update = markdown_content
603
-
604
  return current_image_preview, page_info_html, markdown_update, processed_img, layout_json
605
 
606
-
607
  def create_gradio_interface():
608
- """Create the Gradio interface"""
609
-
610
  css = """
611
  .main-container { max-width: 1400px; margin: 0 auto; }
612
  .header-text { text-align: center; color: #2c3e50; margin-bottom: 20px; }
613
- .process-button {
614
- border: none !important;
615
- color: white !important;
616
- font-weight: bold !important;
617
- background-color: blue !important;}
618
- .process-button:hover {
619
  background-color: darkblue !important;
620
- transform: translateY(-2px) !important;
621
  box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; }
622
  .info-box { border: 1px solid #dee2e6; border-radius: 8px; padding: 15px; margin: 10px 0; }
623
  .page-info { text-align: center; padding: 8px 16px; border-radius: 20px; font-weight: bold; margin: 10px 0; }
@@ -633,248 +459,91 @@ def create_gradio_interface():
633
  </p>
634
  </div>
635
  """)
636
-
637
- # Main interface
638
  with gr.Row():
639
- # Left column - Input and controls
640
  with gr.Column(scale=1):
641
-
642
- # Model selection
643
- model_choice = gr.Dropdown(
644
  choices=["Camel-Doc-OCR-062825", "MonkeyOCR-Recognition", "Nanonets-OCR-s", "Megalodon-OCR-Sync-0713"],
645
  label="Select Model",
646
  value="Camel-Doc-OCR-062825"
647
  )
648
-
649
- # File input
650
  file_input = gr.File(
651
  label="Upload Image or PDF",
652
  file_types=[".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".pdf"],
653
  type="filepath"
654
  )
655
-
656
- # Image preview
657
- image_preview = gr.Image(
658
- label="Preview",
659
- type="pil",
660
- interactive=False,
661
- height=300
662
- )
663
-
664
- # Page navigation for PDFs
665
  with gr.Row():
666
  prev_page_btn = gr.Button("◀ Previous", size="md")
667
  page_info = gr.HTML('<div class="page-info">No file loaded</div>')
668
  next_page_btn = gr.Button("Next ▶", size="md")
669
-
670
- # Advanced settings
671
  with gr.Accordion("Advanced Settings", open=False):
672
- max_new_tokens = gr.Slider(
673
- minimum=1000,
674
- maximum=32000,
675
- value=24000,
676
- step=1000,
677
- label="Max New Tokens",
678
- info="Maximum number of tokens to generate"
679
- )
680
-
681
- min_pixels = gr.Number(
682
- value=MIN_PIXELS,
683
- label="Min Pixels",
684
- info="Minimum image resolution"
685
- )
686
-
687
- max_pixels = gr.Number(
688
- value=MAX_PIXELS,
689
- label="Max Pixels",
690
- info="Maximum image resolution"
691
- )
692
-
693
- # Process button
694
- process_btn = gr.Button(
695
- "🚀 Process Document",
696
- variant="primary",
697
- elem_classes=["process-button"],
698
- size="lg"
699
- )
700
-
701
- # Clear button
702
  clear_btn = gr.Button("🗑️ Clear All", variant="secondary")
703
-
704
- # Right column - Results
705
  with gr.Column(scale=2):
706
-
707
- # Results tabs
708
  with gr.Tabs():
709
- # Processed image tab
710
  with gr.Tab("🖼️ Processed Image"):
711
- processed_image = gr.Image(
712
- label="Image with Layout Detection",
713
- type="pil",
714
- interactive=False,
715
- height=500
716
- )
717
- # Markdown output tab
718
  with gr.Tab("📝 Extracted Content"):
719
- markdown_output = gr.Markdown(
720
- value="Click 'Process Document' to see extracted content...",
721
- height=500
722
- )
723
- # JSON layout tab
724
  with gr.Tab("📋 Layout JSON"):
725
- json_output = gr.JSON(
726
- label="Layout Analysis Results",
727
- value=None
728
- )
729
-
730
- # Event handlers
731
  def process_document(model_name, file_path, max_tokens, min_pix, max_pix):
732
- """Process the uploaded document"""
733
  global pdf_cache
734
-
735
  try:
736
  if not file_path:
737
  return None, "Please upload a file first.", None
738
-
739
- # This function now correctly returns a single image for preview
740
- # and populates the cache for multi-page processing.
741
- preview_img, page_info_str = load_file_for_preview(file_path)
742
- if preview_img is None:
743
- return None, page_info_str, None
744
-
745
- # Process the image(s)
746
  if pdf_cache["file_type"] == "pdf":
747
- # Process all pages for PDF from the cache
748
  all_results = []
749
  all_markdown = []
750
-
751
  for i, img in enumerate(pdf_cache["images"]):
752
- result = process_image(
753
- model_name,
754
- img,
755
- min_pixels=int(min_pix) if min_pix else None,
756
- max_pixels=int(max_pix) if max_pix else None
757
- )
758
  all_results.append(result)
759
  if result.get('markdown_content'):
760
  all_markdown.append(f"## Page {i+1}\n\n{result['markdown_content']}")
761
-
762
  pdf_cache["results"] = all_results
763
  pdf_cache["is_parsed"] = True
764
-
765
- # Show results for first page
766
- first_result = all_results
767
  combined_markdown = "\n\n---\n\n".join(all_markdown)
768
-
769
- # Check if the combined markdown contains mostly Arabic text
770
  if is_arabic_text(combined_markdown):
771
  markdown_update = gr.update(value=combined_markdown, rtl=True)
772
  else:
773
  markdown_update = combined_markdown
774
-
775
- return (
776
- first_result['processed_image'],
777
- markdown_update,
778
- first_result['layout_result']
779
- )
780
  else:
781
- # Process single image
782
- result = process_image(
783
- model_name,
784
- preview_img, # Use the single loaded image
785
- min_pixels=int(min_pix) if min_pix else None,
786
- max_pixels=int(max_pix) if max_pix else None
787
- )
788
-
789
  pdf_cache["results"] = [result]
790
  pdf_cache["is_parsed"] = True
791
-
792
- # Check if the content contains mostly Arabic text
793
  content = result['markdown_content'] or "No content extracted"
794
  if is_arabic_text(content):
795
  markdown_update = gr.update(value=content, rtl=True)
796
  else:
797
  markdown_update = content
798
-
799
- return (
800
- result['processed_image'],
801
- markdown_update,
802
- result['layout_result']
803
- )
804
-
805
  except Exception as e:
806
  error_msg = f"Error processing document: {str(e)}"
807
  print(error_msg)
808
  traceback.print_exc()
809
  return None, error_msg, None
810
-
811
  def handle_file_upload(file_path):
812
- """Handle file upload and show preview"""
813
  if not file_path:
814
- return None, '<div class="page-info">No file loaded</div>'
815
-
816
  image, page_info = load_file_for_preview(file_path)
817
  return image, page_info
818
-
819
  def clear_all():
820
- """Clear all data and reset interface"""
821
  global pdf_cache
822
-
823
- pdf_cache = {
824
- "images": [], "current_page": 0, "total_pages": 0,
825
- "file_type": None, "is_parsed": False, "results": []
826
- }
827
-
828
- return (
829
- None, # file_input
830
- None, # image_preview
831
- '<div class="page-info">No file loaded</div>', # page_info
832
- None, # processed_image
833
- "Click 'Process Document' to see extracted content...", # markdown_output
834
- None, # json_output
835
- )
836
-
837
- # Wire up event handlers
838
- file_input.change(
839
- handle_file_upload,
840
- inputs=[file_input],
841
- outputs=[image_preview, page_info]
842
- )
843
-
844
- prev_page_btn.click(
845
- lambda: turn_page("prev"),
846
- outputs=[image_preview, page_info, markdown_output, processed_image, json_output]
847
- )
848
-
849
- next_page_btn.click(
850
- lambda: turn_page("next"),
851
- outputs=[image_preview, page_info, markdown_output, processed_image, json_output]
852
- )
853
-
854
- process_btn.click(
855
- process_document,
856
- inputs=[model_choice, file_input, max_new_tokens, min_pixels, max_pixels],
857
- outputs=[processed_image, markdown_output, json_output]
858
- )
859
-
860
- clear_btn.click(
861
- clear_all,
862
- outputs=[
863
- file_input, image_preview, page_info, processed_image,
864
- markdown_output, json_output
865
- ]
866
- )
867
-
868
  return demo
869
 
870
-
871
  if __name__ == "__main__":
872
- # Create and launch the interface
873
  demo = create_gradio_interface()
874
- demo.queue(max_size=10).launch(
875
- server_name="0.0.0.0",
876
- server_port=7860,
877
- share=False,
878
- debug=True,
879
- show_error=True
880
- )
 
11
  import gradio as gr
12
  import requests
13
  import torch
 
14
  from PIL import Image, ImageDraw, ImageFont
15
  from transformers import (
16
  Qwen2_5_VLForConditionalGeneration,
 
45
  5. Final Output: The entire output must be a single JSON object.
46
  """
47
 
48
+ # Load models
49
  MODEL_ID_M = "prithivMLmods/Camel-Doc-OCR-062825"
50
  processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
51
  model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
 
54
  torch_dtype=torch.float16
55
  ).to(device).eval()
56
 
 
57
  MODEL_ID_T = "prithivMLmods/Megalodon-OCR-Sync-0713"
58
  processor_t = AutoProcessor.from_pretrained(MODEL_ID_T, trust_remote_code=True)
59
  model_t = Qwen2_5_VLForConditionalGeneration.from_pretrained(
 
62
  torch_dtype=torch.float16
63
  ).to(device).eval()
64
 
 
65
  MODEL_ID_C = "nanonets/Nanonets-OCR-s"
66
  processor_c = AutoProcessor.from_pretrained(MODEL_ID_C, trust_remote_code=True)
67
  model_c = Qwen2_5_VLForConditionalGeneration.from_pretrained(
 
70
  torch_dtype=torch.float16
71
  ).to(device).eval()
72
 
 
73
  MODEL_ID_G = "echo840/MonkeyOCR"
74
  SUBFOLDER = "Recognition"
75
  processor_g = AutoProcessor.from_pretrained(
 
84
  torch_dtype=torch.float16
85
  ).to(device).eval()
86
 
 
87
  # Utility functions
88
  def round_by_factor(number: int, factor: int) -> int:
 
89
  return round(number / factor) * factor
90
 
 
91
  def smart_resize(
92
  height: int,
93
  width: int,
 
95
  min_pixels: int = 3136,
96
  max_pixels: int = 11289600,
97
  ):
 
 
 
 
 
98
  if max(height, width) / min(height, width) > 200:
99
+ raise ValueError(f"Aspect ratio too extreme: {max(height, width) / min(height, width)}")
 
 
100
  h_bar = max(factor, round_by_factor(height, factor))
101
  w_bar = max(factor, round_by_factor(width, factor))
 
102
  if h_bar * w_bar > max_pixels:
103
  beta = math.sqrt((height * width) / max_pixels)
104
  h_bar = round_by_factor(height / beta, factor)
 
109
  w_bar = round_by_factor(width * beta, factor)
110
  return h_bar, w_bar
111
 
 
112
  def fetch_image(image_input, min_pixels: int = None, max_pixels: int = None):
 
113
  if isinstance(image_input, str):
114
  if image_input.startswith(("http://", "https://")):
115
  response = requests.get(image_input)
 
120
  image = image_input.convert('RGB')
121
  else:
122
  raise ValueError(f"Invalid image input type: {type(image_input)}")
123
+ if min_pixels or max_pixels:
 
124
  min_pixels = min_pixels or MIN_PIXELS
125
  max_pixels = max_pixels or MAX_PIXELS
126
+ height, width = smart_resize(image.height, image.width, factor=IMAGE_FACTOR, min_pixels=min_pixels, max_pixels=max_pixels)
 
 
 
 
 
 
127
  image = image.resize((width, height), Image.LANCZOS)
 
128
  return image
129
 
 
130
  def load_images_from_pdf(pdf_path: str) -> List[Image.Image]:
 
131
  images = []
132
  try:
133
  pdf_document = fitz.open(pdf_path)
134
  for page_num in range(len(pdf_document)):
135
  page = pdf_document.load_page(page_num)
136
+ mat = fitz.Matrix(2.0, 2.0)
 
137
  pix = page.get_pixmap(matrix=mat)
138
  img_data = pix.tobytes("ppm")
139
  image = Image.open(BytesIO(img_data)).convert('RGB')
 
141
  pdf_document.close()
142
  except Exception as e:
143
  print(f"Error loading PDF: {e}")
 
144
  return images
145
 
 
146
  def draw_layout_on_image(image: Image.Image, layout_data: List[Dict]) -> Image.Image:
 
147
  img_copy = image.copy()
148
  draw = ImageDraw.Draw(img_copy)
 
 
149
  colors = {
150
+ 'Caption': '#FF6B6B', 'Footnote': '#4ECDC4', 'Formula': '#45B7D1',
151
+ 'List-item': '#96CEB4', 'Page-footer': '#FFEAA7', 'Page-header': '#DDA0DD',
152
+ 'Picture': '#FFD93D', 'Section-header': '#6C5CE7', 'Table': '#FD79A8',
153
+ 'Text': '#74B9FF', 'Title': '#E17055'
 
 
 
 
 
 
 
154
  }
 
155
  try:
156
+ font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 12)
157
+ except Exception:
158
+ font = ImageFont.load_default()
159
+ for item in layout_data:
160
+ if 'bbox' in item and 'category' in item:
161
+ bbox = item['bbox']
162
+ category = item['category']
163
+ color = colors.get(category, '#000000')
164
+ draw.rectangle(bbox, outline=color, width=2)
165
+ label = category
166
+ label_bbox = draw.textbbox((0, 0), label, font=font)
167
+ label_width = label_bbox[2] - label_bbox[0]
168
+ label_height = label_bbox[3] - label_bbox[1]
169
+ label_x = bbox[0]
170
+ label_y = max(0, bbox[1] - label_height - 2)
171
+ draw.rectangle([label_x, label_y, label_x + label_width + 4, label_y + label_height + 2], fill=color)
172
+ draw.text((label_x + 2, label_y + 1), label, fill='white', font=font)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  return img_copy
174
 
 
175
  def is_arabic_text(text: str) -> bool:
 
176
  if not text:
177
  return False
 
 
 
178
  header_pattern = r'^#{1,6}\s+(.+)$'
179
  paragraph_pattern = r'^(?!#{1,6}\s|!\[|```|\||\s*[-*+]\s|\s*\d+\.\s)(.+)$'
 
180
  content_text = []
 
181
  for line in text.split('\n'):
182
  line = line.strip()
183
  if not line:
184
  continue
 
 
185
  header_match = re.match(header_pattern, line, re.MULTILINE)
186
  if header_match:
187
  content_text.append(header_match.group(1))
188
  continue
 
 
189
  if re.match(paragraph_pattern, line, re.MULTILINE):
190
  content_text.append(line)
 
191
  if not content_text:
192
  return False
 
 
193
  combined_text = ' '.join(content_text)
 
 
194
  arabic_chars = 0
195
  total_chars = 0
 
196
  for char in combined_text:
197
  if char.isalpha():
198
  total_chars += 1
 
199
  if ('\u0600' <= char <= '\u06FF') or ('\u0750' <= char <= '\u077F') or ('\u08A0' <= char <= '\u08FF'):
200
  arabic_chars += 1
201
+ return total_chars > 0 and (arabic_chars / total_chars) > 0.5
 
 
 
 
 
 
202
 
203
  def layoutjson2md(image: Image.Image, layout_data: List[Dict], text_key: str = 'text') -> str:
 
204
  import base64
205
  from io import BytesIO
 
206
  markdown_lines = []
 
207
  try:
208
+ sorted_items = sorted(layout_data, key=lambda x: (x.get('bbox', [0, 0, 0, 0])[1], x.get('bbox', [0, 0, 0, 0])[0]))
 
 
209
  for item in sorted_items:
210
  category = item.get('category', '')
211
  text = item.get(text_key, '')
212
  bbox = item.get('bbox', [])
 
213
  if category == 'Picture':
 
214
  if bbox and len(bbox) == 4:
215
  try:
 
216
  x1, y1, x2, y2 = bbox
 
217
  x1, y1 = max(0, int(x1)), max(0, int(y1))
218
  x2, y2 = min(image.width, int(x2)), min(image.height, int(y2))
 
219
  if x2 > x1 and y2 > y1:
220
  cropped_img = image.crop((x1, y1, x2, y2))
 
 
221
  buffer = BytesIO()
222
  cropped_img.save(buffer, format='PNG')
223
  img_data = base64.b64encode(buffer.getvalue()).decode()
 
 
224
  markdown_lines.append(f"![Image](data:image/png;base64,{img_data})\n")
225
  else:
226
  markdown_lines.append("![Image](Image region detected)\n")
 
240
  elif category == 'List-item':
241
  markdown_lines.append(f"- {text}\n")
242
  elif category == 'Table':
 
243
  if text.strip().startswith('<'):
244
  markdown_lines.append(f"{text}\n")
245
  else:
246
  markdown_lines.append(f"**Table:** {text}\n")
247
  elif category == 'Formula':
 
248
  if text.strip().startswith('$') or '\\' in text:
249
  markdown_lines.append(f"$$\n{text}\n$$\n")
250
  else:
 
254
  elif category == 'Footnote':
255
  markdown_lines.append(f"^{text}^\n")
256
  elif category in ['Page-header', 'Page-footer']:
 
257
  continue
258
  else:
259
  markdown_lines.append(f"{text}\n")
260
+ markdown_lines.append("")
 
 
261
  except Exception as e:
262
  print(f"Error converting to markdown: {e}")
263
  return str(layout_data)
 
264
  return "\n".join(markdown_lines)
265
 
 
266
  # PDF handling state
267
  pdf_cache = {
268
  "images": [],
 
272
  "is_parsed": False,
273
  "results": []
274
  }
275
+
276
  @spaces.GPU
277
+ def inference(model_name: str, image: Image.Image, prompt: str, max_new_tokens: int = 1024) -> str:
 
278
  try:
279
  if model_name == "Camel-Doc-OCR-062825":
280
  processor = processor_m
 
295
  {
296
  "role": "user",
297
  "content": [
298
+ {"type": "image", "image": image},
299
+ {"type": "text", "text": prompt}
300
  ]
301
  }
302
  ]
 
 
 
 
 
 
 
 
 
 
 
303
 
304
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
305
+ inputs = processor(text=[text], images=[image], return_tensors="pt").to(device)
306
 
307
  with torch.no_grad():
308
  generated_ids = model.generate(
 
312
  temperature=0.1
313
  )
314
 
315
+ generated_ids_trimmed = [
316
  out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
317
  ]
318
+ output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)
319
+ return output_text[0] if output_text else ""
320
 
321
  except Exception as e:
322
  print(f"Error during inference: {e}")
323
  traceback.print_exc()
324
  return f"Error during inference: {str(e)}"
325
 
 
326
  def process_image(
327
  model_name: str,
328
  image: Image.Image,
329
  min_pixels: Optional[int] = None,
330
  max_pixels: Optional[int] = None
331
  ) -> Dict[str, Any]:
 
332
  try:
333
+ if min_pixels or max_pixels:
 
334
  image = fetch_image(image, min_pixels=min_pixels, max_pixels=max_pixels)
 
 
335
  raw_output = inference(model_name, image, prompt)
 
 
336
  result = {
337
  'original_image': image,
338
  'raw_output': raw_output,
 
340
  'layout_result': None,
341
  'markdown_content': None
342
  }
 
 
343
  try:
 
 
344
  json_match = re.search(r'```json\s*([\s\S]+?)\s*```', raw_output)
345
+ json_str = json_match.group(1) if json_match else raw_output
 
 
 
 
346
  layout_data = json.loads(json_str)
347
  result['layout_result'] = layout_data
 
 
348
  try:
349
  processed_image = draw_layout_on_image(image, layout_data)
350
  result['processed_image'] = processed_image
351
  except Exception as e:
352
  print(f"Error drawing layout: {e}")
 
 
 
353
  try:
354
  markdown_content = layoutjson2md(image, layout_data, text_key='text')
355
  result['markdown_content'] = markdown_content
356
  except Exception as e:
357
  print(f"Error generating markdown: {e}")
358
  result['markdown_content'] = raw_output
 
359
  except json.JSONDecodeError:
360
  print("Failed to parse JSON output, using raw output")
361
  result['markdown_content'] = raw_output
 
362
  return result
 
363
  except Exception as e:
364
  print(f"Error processing image: {e}")
365
  traceback.print_exc()
 
371
  'markdown_content': f"Error processing image: {str(e)}"
372
  }
373
 
 
374
  def load_file_for_preview(file_path: str) -> Tuple[Optional[Image.Image], str]:
 
375
  global pdf_cache
 
376
  if not file_path or not os.path.exists(file_path):
377
  return None, "No file selected"
378
+ file_ext = os.path.splitext(file_path)[1].lower()
 
 
 
379
  try:
380
  if file_ext == '.pdf':
 
381
  images = load_images_from_pdf(file_path)
382
  if not images:
383
  return None, "Failed to load PDF"
 
384
  pdf_cache.update({
385
  "images": images,
386
  "current_page": 0,
 
389
  "is_parsed": False,
390
  "results": []
391
  })
392
+ return images[0], f"Page 1 / {len(images)}"
 
 
 
393
  elif file_ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']:
 
394
  image = Image.open(file_path).convert('RGB')
 
395
  pdf_cache.update({
396
  "images": [image],
397
  "current_page": 0,
 
400
  "is_parsed": False,
401
  "results": []
402
  })
 
403
  return image, "Page 1 / 1"
404
  else:
405
  return None, f"Unsupported file format: {file_ext}"
 
406
  except Exception as e:
407
  print(f"Error loading file: {e}")
408
  return None, f"Error loading file: {str(e)}"
409
 
 
410
  def turn_page(direction: str) -> Tuple[Optional[Image.Image], str, Any, Optional[Image.Image], Optional[Dict]]:
 
411
  global pdf_cache
 
412
  if not pdf_cache["images"]:
413
  return None, '<div class="page-info">No file loaded</div>', "No results yet", None, None
 
414
  if direction == "prev":
415
  pdf_cache["current_page"] = max(0, pdf_cache["current_page"] - 1)
416
  elif direction == "next":
417
+ pdf_cache["current_page"] = min(pdf_cache["total_pages"] - 1, pdf_cache["current_page"] + 1)
 
 
 
 
418
  index = pdf_cache["current_page"]
419
  current_image_preview = pdf_cache["images"][index]
420
  page_info_html = f'<div class="page-info">Page {index + 1} / {pdf_cache["total_pages"]}</div>'
 
 
421
  markdown_content = "Page not processed yet"
422
  processed_img = None
423
  layout_json = None
424
+ if pdf_cache["is_parsed"] and index < len(pdf_cache["results"]) and pdf_cache["results"][index]:
 
 
 
 
 
425
  result = pdf_cache["results"][index]
426
  markdown_content = result.get('markdown_content') or result.get('raw_output', 'No content available')
427
+ processed_img = result.get('processed_image', None)
428
+ layout_json = result.get('layout_result', None)
 
 
429
  if is_arabic_text(markdown_content):
430
  markdown_update = gr.update(value=markdown_content, rtl=True)
431
  else:
432
  markdown_update = markdown_content
 
433
  return current_image_preview, page_info_html, markdown_update, processed_img, layout_json
434
 
 
435
  def create_gradio_interface():
 
 
436
  css = """
437
  .main-container { max-width: 1400px; margin: 0 auto; }
438
  .header-text { text-align: center; color: #2c3e50; margin-bottom: 20px; }
439
+ .process-button {
440
+ border: none !important;
441
+ color: white !important;
442
+ font-weight: bold !important;
443
+ background-color: blue !important;}
444
+ .process-button:hover {
445
  background-color: darkblue !important;
446
+ transform: translateY(-2px) !important;
447
  box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; }
448
  .info-box { border: 1px solid #dee2e6; border-radius: 8px; padding: 15px; margin: 10px 0; }
449
  .page-info { text-align: center; padding: 8px 16px; border-radius: 20px; font-weight: bold; margin: 10px 0; }
 
459
  </p>
460
  </div>
461
  """)
 
 
462
  with gr.Row():
 
463
  with gr.Column(scale=1):
464
+ model_choice = gr.Radio(
 
 
465
  choices=["Camel-Doc-OCR-062825", "MonkeyOCR-Recognition", "Nanonets-OCR-s", "Megalodon-OCR-Sync-0713"],
466
  label="Select Model",
467
  value="Camel-Doc-OCR-062825"
468
  )
 
 
469
  file_input = gr.File(
470
  label="Upload Image or PDF",
471
  file_types=[".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".pdf"],
472
  type="filepath"
473
  )
474
+ image_preview = gr.Image(label="Preview", type="pil", interactive=False, height=300)
 
 
 
 
 
 
 
 
 
475
  with gr.Row():
476
  prev_page_btn = gr.Button("◀ Previous", size="md")
477
  page_info = gr.HTML('<div class="page-info">No file loaded</div>')
478
  next_page_btn = gr.Button("Next ▶", size="md")
 
 
479
  with gr.Accordion("Advanced Settings", open=False):
480
+ max_new_tokens = gr.Slider(minimum=1000, maximum=32000, value=24000, step=1000, label="Max New Tokens")
481
+ min_pixels = gr.Number(value=MIN_PIXELS, label="Min Pixels")
482
+ max_pixels = gr.Number(value=MAX_PIXELS, label="Max Pixels")
483
+ process_btn = gr.Button("🚀 Process Document", variant="primary", elem_classes=["process-button"], size="lg")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
484
  clear_btn = gr.Button("🗑️ Clear All", variant="secondary")
 
 
485
  with gr.Column(scale=2):
 
 
486
  with gr.Tabs():
 
487
  with gr.Tab("🖼️ Processed Image"):
488
+ processed_image = gr.Image(label="Image with Layout Detection", type="pil", interactive=False, height=500)
 
 
 
 
 
 
489
  with gr.Tab("📝 Extracted Content"):
490
+ markdown_output = gr.Markdown(value="Click 'Process Document' to see extracted content...", height=500)
 
 
 
 
491
  with gr.Tab("📋 Layout JSON"):
492
+ json_output = gr.JSON(label="Layout Analysis Results", value=None)
 
 
 
 
 
493
  def process_document(model_name, file_path, max_tokens, min_pix, max_pix):
 
494
  global pdf_cache
 
495
  try:
496
  if not file_path:
497
  return None, "Please upload a file first.", None
498
+ load_file_for_preview(file_path)
 
 
 
 
 
 
 
499
  if pdf_cache["file_type"] == "pdf":
 
500
  all_results = []
501
  all_markdown = []
 
502
  for i, img in enumerate(pdf_cache["images"]):
503
+ result = process_image(model_name, img, min_pixels=int(min_pix) if min_pix else None, max_pixels=int(max_pix) if max_pix else None)
 
 
 
 
 
504
  all_results.append(result)
505
  if result.get('markdown_content'):
506
  all_markdown.append(f"## Page {i+1}\n\n{result['markdown_content']}")
 
507
  pdf_cache["results"] = all_results
508
  pdf_cache["is_parsed"] = True
509
+ first_result = all_results[0]
 
 
510
  combined_markdown = "\n\n---\n\n".join(all_markdown)
 
 
511
  if is_arabic_text(combined_markdown):
512
  markdown_update = gr.update(value=combined_markdown, rtl=True)
513
  else:
514
  markdown_update = combined_markdown
515
+ return first_result['processed_image'], markdown_update, first_result['layout_result']
 
 
 
 
 
516
  else:
517
+ result = process_image(model_name, pdf_cache["images"][0], min_pixels=int(min_pix) if min_pix else None, max_pixels=int(max_pix) if max_pix else None)
 
 
 
 
 
 
 
518
  pdf_cache["results"] = [result]
519
  pdf_cache["is_parsed"] = True
 
 
520
  content = result['markdown_content'] or "No content extracted"
521
  if is_arabic_text(content):
522
  markdown_update = gr.update(value=content, rtl=True)
523
  else:
524
  markdown_update = content
525
+ return result['processed_image'], markdown_update, result['layout_result']
 
 
 
 
 
 
526
  except Exception as e:
527
  error_msg = f"Error processing document: {str(e)}"
528
  print(error_msg)
529
  traceback.print_exc()
530
  return None, error_msg, None
 
531
  def handle_file_upload(file_path):
 
532
  if not file_path:
533
+ return None, "No file loaded"
 
534
  image, page_info = load_file_for_preview(file_path)
535
  return image, page_info
 
536
  def clear_all():
 
537
  global pdf_cache
538
+ pdf_cache = {"images": [], "current_page": 0, "total_pages": 0, "file_type": None, "is_parsed": False, "results": []}
539
+ return None, None, '<div class="page-info">No file loaded</div>', None, "Click 'Process Document' to see extracted content...", None
540
+ file_input.change(handle_file_upload, inputs=[file_input], outputs=[image_preview, page_info])
541
+ prev_page_btn.click(lambda: turn_page("prev"), outputs=[image_preview, page_info, markdown_output, processed_image, json_output])
542
+ next_page_btn.click(lambda: turn_page("next"), outputs=[image_preview, page_info, markdown_output, processed_image, json_output])
543
+ process_btn.click(process_document, inputs=[model_choice, file_input, max_new_tokens, min_pixels, max_pixels], outputs=[processed_image, markdown_output, json_output])
544
+ clear_btn.click(clear_all, outputs=[file_input, image_preview, page_info, processed_image, markdown_output, json_output])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
545
  return demo
546
 
 
547
  if __name__ == "__main__":
 
548
  demo = create_gradio_interface()
549
+ demo.queue(max_size=10).launch(server_name="0.0.0.0", server_port=7860, share=False, debug=True, show_error=True)