prithivMLmods commited on
Commit
30d6225
·
verified ·
1 Parent(s): c56b1ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +469 -437
app.py CHANGED
@@ -1,50 +1,30 @@
1
- import io
 
 
2
  import os
3
- import tempfile
4
- import time
5
- import uuid
6
- import cv2
 
 
7
  import gradio as gr
8
- import pymupdf
9
- import spaces
10
  import torch
11
- from PIL import Image, ImageDraw, ImageFont
12
- from transformers import AutoModelForCausalLM, AutoProcessor, VisionEncoderDecoderModel
13
  from huggingface_hub import snapshot_download
14
- from qwen_vl_utils import process_vision_info
15
- from utils.utils import prepare_image, parse_layout_string, process_coordinates, ImageDimensions
16
- from utils.markdown_utils import MarkdownConverter
17
-
18
- # Define device
19
- device = "cuda" if torch.cuda.is_available() else "cpu"
20
-
21
- # Load dot.ocr model
22
- dot_ocr_model_id = "rednote-hilab/dots.ocr"
23
- dot_ocr_model = AutoModelForCausalLM.from_pretrained(
24
- dot_ocr_model_id,
25
- attn_implementation="flash_attention_2",
26
- torch_dtype=torch.bfloat16,
27
- device_map="auto",
28
- trust_remote_code=True
29
- )
30
- dot_ocr_processor = AutoProcessor.from_pretrained(
31
- dot_ocr_model_id,
32
- trust_remote_code=True
33
  )
34
-
35
- # Load Dolphin model
36
- dolphin_model_id = "ByteDance/Dolphin"
37
- dolphin_processor = AutoProcessor.from_pretrained(dolphin_model_id)
38
- dolphin_model = VisionEncoderDecoderModel.from_pretrained(dolphin_model_id)
39
- dolphin_model.eval()
40
- dolphin_model.to(device)
41
- dolphin_model = dolphin_model.half()
42
- dolphin_tokenizer = dolphin_processor.tokenizer
43
 
44
  # Constants
45
  MIN_PIXELS = 3136
46
  MAX_PIXELS = 11289600
47
  IMAGE_FACTOR = 28
 
48
 
49
  # Prompts
50
  prompt = """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
@@ -66,11 +46,55 @@ prompt = """Please output the layout information from the PDF image, including e
66
  5. Final Output: The entire output must be a single JSON object.
67
  """
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  # Utility functions
70
  def round_by_factor(number: int, factor: int) -> int:
71
  """Returns the closest integer to 'number' that is divisible by 'factor'."""
72
  return round(number / factor) * factor
73
 
 
74
  def smart_resize(
75
  height: int,
76
  width: int,
@@ -100,6 +124,7 @@ def smart_resize(
100
  w_bar = round_by_factor(width * beta, factor)
101
  return h_bar, w_bar
102
 
 
103
  def fetch_image(image_input, min_pixels: int = None, max_pixels: int = None):
104
  """Fetch and process an image"""
105
  if isinstance(image_input, str):
@@ -112,29 +137,31 @@ def fetch_image(image_input, min_pixels: int = None, max_pixels: int = None):
112
  image = image_input.convert('RGB')
113
  else:
114
  raise ValueError(f"Invalid image input type: {type(image_input)}")
115
-
116
  if min_pixels is not None or max_pixels is not None:
117
  min_pixels = min_pixels or MIN_PIXELS
118
  max_pixels = max_pixels or MAX_PIXELS
119
  height, width = smart_resize(
120
- image.height,
121
- image.width,
122
  factor=IMAGE_FACTOR,
123
  min_pixels=min_pixels,
124
  max_pixels=max_pixels
125
  )
126
  image = image.resize((width, height), Image.LANCZOS)
127
-
128
  return image
129
 
 
130
  def load_images_from_pdf(pdf_path: str) -> List[Image.Image]:
131
  """Load images from PDF file"""
132
  images = []
133
  try:
134
- pdf_document = pymupdf.open(pdf_path)
135
  for page_num in range(len(pdf_document)):
136
  page = pdf_document.load_page(page_num)
137
- mat = pymupdf.Matrix(2.0, 2.0) # Increase resolution
 
138
  pix = page.get_pixmap(matrix=mat)
139
  img_data = pix.tobytes("ppm")
140
  image = Image.open(BytesIO(img_data)).convert('RGB')
@@ -145,14 +172,16 @@ def load_images_from_pdf(pdf_path: str) -> List[Image.Image]:
145
  return []
146
  return images
147
 
 
148
  def draw_layout_on_image(image: Image.Image, layout_data: List[Dict]) -> Image.Image:
149
  """Draw layout bounding boxes on image"""
150
  img_copy = image.copy()
151
  draw = ImageDraw.Draw(img_copy)
152
-
 
153
  colors = {
154
  'Caption': '#FF6B6B',
155
- 'Footnote': '#4ECDC4',
156
  'Formula': '#45B7D1',
157
  'List-item': '#96CEB4',
158
  'Page-footer': '#FFEAA7',
@@ -163,58 +192,134 @@ def draw_layout_on_image(image: Image.Image, layout_data: List[Dict]) -> Image.I
163
  'Text': '#74B9FF',
164
  'Title': '#E17055'
165
  }
166
-
167
  try:
168
- font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 12)
169
- except Exception:
170
- font = ImageFont.load_default()
171
-
172
- for item in layout_data:
173
- if 'bbox' in item and 'category' in item:
174
- bbox = item['bbox']
175
- category = item['category']
176
- color = colors.get(category, '#000000')
177
- draw.rectangle(bbox, outline=color, width=2)
178
- label = category
179
- label_bbox = draw.textbbox((0, 0), label, font=font)
180
- label_width = label_bbox[2] - label_bbox[0]
181
- label_height = label_bbox[3] - label_bbox[1]
182
- label_x = bbox[0]
183
- label_y = max(0, bbox[1] - label_height - 2)
184
- draw.rectangle(
185
- [label_x, label_y, label_x + label_width + 4, label_y + label_height + 2],
186
- fill=color
187
- )
188
- draw.text((label_x + 2, label_y + 1), label, fill='white', font=font)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  return img_copy
190
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  def layoutjson2md(image: Image.Image, layout_data: List[Dict], text_key: str = 'text') -> str:
192
  """Convert layout JSON to markdown format"""
193
  import base64
194
  from io import BytesIO
195
-
196
  markdown_lines = []
197
-
198
  try:
199
- sorted_items = sorted(layout_data, key=lambda x: (x.get('bbox', [0, 0, 0, 0])[1], x.get('bbox', [0, 0, 0, 0])[0]))
200
-
 
201
  for item in sorted_items:
202
  category = item.get('category', '')
203
  text = item.get(text_key, '')
204
  bbox = item.get('bbox', [])
205
-
206
  if category == 'Picture':
 
207
  if bbox and len(bbox) == 4:
208
  try:
 
209
  x1, y1, x2, y2 = bbox
 
210
  x1, y1 = max(0, int(x1)), max(0, int(y1))
211
  x2, y2 = min(image.width, int(x2)), min(image.height, int(y2))
212
-
213
  if x2 > x1 and y2 > y1:
214
  cropped_img = image.crop((x1, y1, x2, y2))
 
 
215
  buffer = BytesIO()
216
  cropped_img.save(buffer, format='PNG')
217
  img_data = base64.b64encode(buffer.getvalue()).decode()
 
 
218
  markdown_lines.append(f"![Image](data:image/png;base64,{img_data})\n")
219
  else:
220
  markdown_lines.append("![Image](Image region detected)\n")
@@ -234,11 +339,13 @@ def layoutjson2md(image: Image.Image, layout_data: List[Dict], text_key: str = '
234
  elif category == 'List-item':
235
  markdown_lines.append(f"- {text}\n")
236
  elif category == 'Table':
 
237
  if text.strip().startswith('<'):
238
  markdown_lines.append(f"{text}\n")
239
  else:
240
  markdown_lines.append(f"**Table:** {text}\n")
241
  elif category == 'Formula':
 
242
  if text.strip().startswith('$') or '\\' in text:
243
  markdown_lines.append(f"$$\n{text}\n$$\n")
244
  else:
@@ -248,16 +355,21 @@ def layoutjson2md(image: Image.Image, layout_data: List[Dict], text_key: str = '
248
  elif category == 'Footnote':
249
  markdown_lines.append(f"^{text}^\n")
250
  elif category in ['Page-header', 'Page-footer']:
 
251
  continue
252
  else:
253
  markdown_lines.append(f"{text}\n")
254
- markdown_lines.append("")
 
 
255
  except Exception as e:
256
  print(f"Error converting to markdown: {e}")
257
  return str(layout_data)
 
258
  return "\n".join(markdown_lines)
259
 
260
- # Global state variables
 
261
  pdf_cache = {
262
  "images": [],
263
  "current_page": 0,
@@ -266,60 +378,74 @@ pdf_cache = {
266
  "is_parsed": False,
267
  "results": []
268
  }
269
-
270
- @spaces.GPU()
271
- def dot_ocr_inference(image: Image.Image, prompt: str, max_new_tokens: int = 24000) -> str:
272
- """Run inference on an image with the given prompt using dot.ocr model"""
273
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
  messages = [
275
  {
276
  "role": "user",
277
  "content": [
278
- {"type": "image", "image": image},
279
- {"type": "text", "text": prompt}
280
  ]
281
  }
282
  ]
283
- text = dot_ocr_processor.apply_chat_template(
284
- messages,
285
- tokenize=False,
286
- add_generation_prompt=True
287
- )
288
- image_inputs, video_inputs = process_vision_info(messages)
289
- inputs = dot_ocr_processor(
290
- text=[text],
291
- images=image_inputs,
292
- videos=video_inputs,
293
- padding=True,
294
- return_tensors="pt",
295
- )
296
- inputs = inputs.to(device)
297
  with torch.no_grad():
298
- generated_ids = dot_ocr_model.generate(
299
  **inputs,
300
  max_new_tokens=max_new_tokens,
301
  do_sample=False,
302
  temperature=0.1
303
  )
304
- generated_ids_trimmed = [
 
305
  out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
306
  ]
307
- output_text = dot_ocr_processor.batch_decode(
308
- generated_ids_trimmed,
309
- skip_special_tokens=True,
310
- clean_up_tokenization_spaces=False
311
- )
312
- return output_text[0] if output_text else ""
313
  except Exception as e:
314
- print(f"Error during dot.ocr inference: {e}")
 
315
  return f"Error during inference: {str(e)}"
316
 
317
- def process_image_dot_ocr(image: Image.Image, min_pixels: Optional[int] = None, max_pixels: Optional[int] = None) -> Dict[str, Any]:
318
- """Process a single image with the dot.ocr model"""
 
 
 
 
 
 
319
  try:
 
320
  if min_pixels is not None or max_pixels is not None:
321
  image = fetch_image(image, min_pixels=min_pixels, max_pixels=max_pixels)
322
- raw_output = dot_ocr_inference(image, prompt)
 
 
 
 
323
  result = {
324
  'original_image': image,
325
  'raw_output': raw_output,
@@ -327,19 +453,45 @@ def process_image_dot_ocr(image: Image.Image, min_pixels: Optional[int] = None,
327
  'layout_result': None,
328
  'markdown_content': None
329
  }
 
 
330
  try:
331
- layout_data = json.loads(raw_output)
 
 
 
 
 
 
 
 
332
  result['layout_result'] = layout_data
333
- processed_image = draw_layout_on_image(image, layout_data)
334
- result['processed_image'] = processed_image
335
- markdown_content = layoutjson2md(image, layout_data, text_key='text')
336
- result['markdown_content'] = markdown_content
 
 
 
 
 
 
 
 
 
 
 
 
 
337
  except json.JSONDecodeError:
338
  print("Failed to parse JSON output, using raw output")
339
  result['markdown_content'] = raw_output
 
340
  return result
 
341
  except Exception as e:
342
- print(f"Error processing image with dot.ocr: {e}")
 
343
  return {
344
  'original_image': image,
345
  'raw_output': f"Error processing image: {str(e)}",
@@ -348,279 +500,23 @@ def process_image_dot_ocr(image: Image.Image, min_pixels: Optional[int] = None,
348
  'markdown_content': f"Error processing image: {str(e)}"
349
  }
350
 
351
- def process_all_pages_dot_ocr(file_path, min_pixels, max_pixels):
352
- """Process all pages of a document with dot.ocr model"""
353
- if file_path.lower().endswith('.pdf'):
354
- images = load_images_from_pdf(file_path)
355
- else:
356
- images = [Image.open(file_path).convert('RGB')]
357
- results = []
358
- for img in images:
359
- result = process_image_dot_ocr(img, min_pixels, max_pixels)
360
- results.append(result)
361
- return results
362
-
363
- # Dolphin model functions
364
- @spaces.GPU()
365
- def dolphin_model_chat(prompt, image):
366
- """Process an image or batch of images with the given prompt(s) using Dolphin model"""
367
- is_batch = isinstance(image, list)
368
- if not is_batch:
369
- images = [image]
370
- prompts = [prompt]
371
- else:
372
- images = image
373
- prompts = prompt if isinstance(prompt, list) else [prompt] * len(images)
374
- batch_inputs = dolphin_processor(images, return_tensors="pt", padding=True)
375
- batch_pixel_values = batch_inputs.pixel_values.half().to(device)
376
- prompts = [f"<s>{p} <Answer/>" for p in prompts]
377
- batch_prompt_inputs = dolphin_tokenizer(
378
- prompts,
379
- add_special_tokens=False,
380
- return_tensors="pt"
381
- )
382
- batch_prompt_ids = batch_prompt_inputs.input_ids.to(device)
383
- batch_attention_mask = batch_prompt_inputs.attention_mask.to(device)
384
- outputs = dolphin_model.generate(
385
- pixel_values=batch_pixel_values,
386
- decoder_input_ids=batch_prompt_ids,
387
- decoder_attention_mask=batch_attention_mask,
388
- min_length=1,
389
- max_length=4096,
390
- pad_token_id=dolphin_tokenizer.pad_token_id,
391
- eos_token_id=dolphin_tokenizer.eos_token_id,
392
- use_cache=True,
393
- bad_words_ids=[[dolphin_tokenizer.unk_token_id]],
394
- return_dict_in_generate=True,
395
- do_sample=False,
396
- num_beams=1,
397
- repetition_penalty=1.1
398
- )
399
- sequences = dolphin_tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)
400
- results = []
401
- for i, sequence in enumerate(sequences):
402
- cleaned = sequence.replace(prompts[i], "").replace("<pad>", "").replace("</s>", "").strip()
403
- results.append(cleaned)
404
- if not is_batch:
405
- return results[0]
406
- return results
407
-
408
- def process_element_batch_dolphin(elements, prompt, max_batch_size=16):
409
- """Process elements of the same type in batches for Dolphin model"""
410
- results = []
411
- batch_size = min(len(elements), max_batch_size)
412
- for i in range(0, len(elements), batch_size):
413
- batch_elements = elements[i:i+batch_size]
414
- crops_list = [elem["crop"] for elem in batch_elements]
415
- prompts_list = [prompt] * len(crops_list)
416
- batch_results = dolphin_model_chat(prompts_list, crops_list)
417
- for j, result in enumerate(batch_results):
418
- elem = batch_elements[j]
419
- results.append({
420
- "label": elem["label"],
421
- "bbox": elem["bbox"],
422
- "text": result.strip(),
423
- "reading_order": elem["reading_order"],
424
- })
425
- return results
426
-
427
- def process_page_dolphin(image_path):
428
- """Process a single page with Dolphin model"""
429
- pil_image = Image.open(image_path).convert("RGB")
430
- layout_output = dolphin_model_chat("Parse the reading order of this document.", pil_image)
431
- padded_image, dims = prepare_image(pil_image)
432
- recognition_results = process_elements_dolphin(layout_output, padded_image, dims)
433
- return recognition_results
434
-
435
- def process_elements_dolphin(layout_results, padded_image, dims):
436
- """Parse all document elements for Dolphin model"""
437
- layout_results = parse_layout_string(layout_results)
438
- text_elements = []
439
- table_elements = []
440
- figure_results = []
441
- previous_box = None
442
- reading_order = 0
443
- for bbox, label in layout_results:
444
- try:
445
- x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, previous_box = process_coordinates(
446
- bbox, padded_image, dims, previous_box
447
- )
448
- cropped = padded_image[y1:y2, x1:x2]
449
- if cropped.size > 0 and (cropped.shape[0] > 3 and cropped.shape[1] > 3):
450
- if label == "fig":
451
- try:
452
- pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
453
- buffered = io.BytesIO()
454
- pil_crop.save(buffered, format="PNG")
455
- img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
456
- figure_results.append(
457
- {
458
- "label": label,
459
- "bbox": [orig_x1, orig_y1, orig_x2, orig_y2],
460
- "text": img_base64,
461
- "reading_order": reading_order,
462
- }
463
- )
464
- except Exception as e:
465
- print(f"Error encoding figure to base64: {e}")
466
- figure_results.append(
467
- {
468
- "label": label,
469
- "bbox": [orig_x1, orig_y1, orig_x2, orig_y2],
470
- "text": "",
471
- "reading_order": reading_order,
472
- }
473
- )
474
- else:
475
- pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
476
- element_info = {
477
- "crop": pil_crop,
478
- "label": label,
479
- "bbox": [orig_x1, orig_y1, orig_x2, orig_y2],
480
- "reading_order": reading_order,
481
- }
482
- if label == "tab":
483
- table_elements.append(element_info)
484
- else:
485
- text_elements.append(element_info)
486
- reading_order += 1
487
- except Exception as e:
488
- print(f"Error processing bbox with label {label}: {str(e)}")
489
- continue
490
- recognition_results = figure_results.copy()
491
- if text_elements:
492
- text_results = process_element_batch_dolphin(text_elements, "Read text in the image.")
493
- recognition_results.extend(text_results)
494
- if table_elements:
495
- table_results = process_element_batch_dolphin(table_elements, "Parse the table in the image.")
496
- recognition_results.extend(table_results)
497
- recognition_results.sort(key=lambda x: x.get("reading_order", 0))
498
- return recognition_results
499
-
500
- def generate_markdown(recognition_results):
501
- """Generate markdown from recognition results for Dolphin model"""
502
- converter = MarkdownConverter()
503
- return converter.convert(recognition_results)
504
-
505
- def convert_all_pdf_pages_to_images(file_path, target_size=896):
506
- """Convert all pages of a PDF to images for Dolphin model"""
507
- if file_path is None:
508
- return []
509
- try:
510
- file_ext = os.path.splitext(file_path)[1].lower()
511
- if file_ext == '.pdf':
512
- doc = pymupdf.open(file_path)
513
- image_paths = []
514
- for page_num in range(len(doc)):
515
- page = doc[page_num]
516
- rect = page.rect
517
- scale = target_size / max(rect.width, rect.height)
518
- mat = pymupdf.Matrix(scale, scale)
519
- pix = page.get_pixmap(matrix=mat)
520
- img_data = pix.tobytes("png")
521
- pil_image = Image.open(io.BytesIO(img_data))
522
- with tempfile.NamedTemporaryFile(suffix=f"_page_{page_num}.png", delete=False) as tmp_file:
523
- pil_image.save(tmp_file.name, "PNG")
524
- image_paths.append(tmp_file.name)
525
- doc.close()
526
- return image_paths
527
- else:
528
- converted_path = convert_to_image(file_path, target_size)
529
- return [converted_path] if converted_path else []
530
- except Exception as e:
531
- print(f"Error converting PDF pages to images: {e}")
532
- return []
533
-
534
- def convert_to_image(file_path, target_size=896, page_num=0):
535
- """Convert input file to image format for Dolphin model"""
536
- if file_path is None:
537
- return None
538
- try:
539
- file_ext = os.path.splitext(file_path)[1].lower()
540
- if file_ext == '.pdf':
541
- doc = pymupdf.open(file_path)
542
- if page_num >= len(doc):
543
- page_num = 0
544
- page = doc[page_num]
545
- rect = page.rect
546
- scale = target_size / max(rect.width, rect.height)
547
- mat = pymupdf.Matrix(scale, scale)
548
- pix = page.get_pixmap(matrix=mat)
549
- img_data = pix.tobytes("png")
550
- pil_image = Image.open(io.BytesIO(img_data))
551
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file:
552
- pil_image.save(tmp_file.name, "PNG")
553
- doc.close()
554
- return tmp_file.name
555
- else:
556
- pil_image = Image.open(file_path).convert("RGB")
557
- w, h = pil_image.size
558
- if max(w, h) > target_size:
559
- if w > h:
560
- new_w, new_h = target_size, int(h * target_size / w)
561
- else:
562
- new_w, new_h = int(w * target_size / h), target_size
563
- pil_image = pil_image.resize((new_w, new_h), Image.Resampling.LANCZOS)
564
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file:
565
- pil_image.save(tmp_file.name, "PNG")
566
- return tmp_file.name
567
- except Exception as e:
568
- print(f"Error converting file to image: {e}")
569
- return file_path
570
-
571
- def process_all_pages_dolphin(file_path):
572
- """Process all pages of a document with Dolphin model"""
573
- image_paths = convert_all_pdf_pages_to_images(file_path)
574
- per_page_results = []
575
- for image_path in image_paths:
576
- try:
577
- original_image = Image.open(image_path).convert('RGB')
578
- recognition_results = process_page_dolphin(image_path)
579
- markdown_content = generate_markdown(recognition_results)
580
- placeholder_text = "Layout visualization not available for Dolphin model"
581
- processed_image = create_placeholder_image(placeholder_text, size=(original_image.width, original_image.height))
582
- per_page_results.append({
583
- 'original_image': original_image,
584
- 'processed_image': processed_image,
585
- 'markdown_content': markdown_content,
586
- 'layout_result': recognition_results
587
- })
588
- except Exception as e:
589
- print(f"Error processing page: {e}")
590
- per_page_results.append({
591
- 'original_image': Image.new('RGB', (100, 100), color='white'),
592
- 'processed_image': create_placeholder_image("Error processing page", size=(100, 100)),
593
- 'markdown_content': f"Error processing page: {str(e)}",
594
- 'layout_result': None
595
- })
596
- finally:
597
- if os.path.exists(image_path):
598
- os.remove(image_path)
599
- return per_page_results
600
-
601
- def create_placeholder_image(text, size=(400, 200)):
602
- """Create a placeholder image with text"""
603
- img = Image.new('RGB', size, color='white')
604
- draw = ImageDraw.Draw(img)
605
- try:
606
- font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 16)
607
- except Exception:
608
- font = ImageFont.load_default()
609
- draw.text((10, 10), text, fill='black', font=font)
610
- return img
611
 
612
- # Gradio interface functions
613
  def load_file_for_preview(file_path: str) -> Tuple[Optional[Image.Image], str]:
614
  """Load file for preview (supports PDF and images)"""
615
  global pdf_cache
 
616
  if not file_path or not os.path.exists(file_path):
617
  return None, "No file selected"
618
- file_ext = os.path.splitext(file_path)[1].lower()
 
 
619
  try:
620
  if file_ext == '.pdf':
 
621
  images = load_images_from_pdf(file_path)
622
  if not images:
623
  return None, "Failed to load PDF"
 
624
  pdf_cache.update({
625
  "images": images,
626
  "current_page": 0,
@@ -629,9 +525,13 @@ def load_file_for_preview(file_path: str) -> Tuple[Optional[Image.Image], str]:
629
  "is_parsed": False,
630
  "results": []
631
  })
632
- return images[0], f"Page 1 / {len(images)}"
 
 
633
  elif file_ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']:
 
634
  image = Image.open(file_path).convert('RGB')
 
635
  pdf_cache.update({
636
  "images": [image],
637
  "current_page": 0,
@@ -640,78 +540,73 @@ def load_file_for_preview(file_path: str) -> Tuple[Optional[Image.Image], str]:
640
  "is_parsed": False,
641
  "results": []
642
  })
 
643
  return image, "Page 1 / 1"
644
  else:
645
  return None, f"Unsupported file format: {file_ext}"
 
646
  except Exception as e:
647
  print(f"Error loading file: {e}")
648
  return None, f"Error loading file: {str(e)}"
649
 
650
- def turn_page(direction: str) -> Tuple[Optional[Image.Image], str, str, Optional[Image.Image], Optional[Dict]]:
 
651
  """Navigate through PDF pages and update all relevant outputs."""
652
  global pdf_cache
 
653
  if not pdf_cache["images"]:
654
- return None, "No file loaded", "No results yet", None, None
 
655
  if direction == "prev":
656
  pdf_cache["current_page"] = max(0, pdf_cache["current_page"] - 1)
657
  elif direction == "next":
658
- pdf_cache["current_page"] = min(pdf_cache["total_pages"] - 1, pdf_cache["current_page"] + 1)
 
 
 
 
659
  index = pdf_cache["current_page"]
660
  current_image_preview = pdf_cache["images"][index]
661
- page_info_html = f"Page {index + 1} / {pdf_cache['total_pages']}"
662
- if pdf_cache["is_parsed"] and index < len(pdf_cache["results"]):
 
 
 
 
 
 
 
 
 
 
663
  result = pdf_cache["results"][index]
664
- processed_img = result['processed_image']
665
- markdown_content = result['markdown_content'] or "No content available"
666
- layout_json = result['layout_result']
 
 
 
 
667
  else:
668
- processed_img = None
669
- markdown_content = "Page not processed yet"
670
- layout_json = None
671
- return current_image_preview, page_info_html, markdown_content, processed_img, layout_json
672
 
673
- def process_document(model_choice, file_path, max_tokens, min_pix, max_pix):
674
- """Process the uploaded document with the selected model"""
675
- global pdf_cache
676
- try:
677
- if not file_path:
678
- return None, "Please upload a file first.", None
679
- if model_choice == "dot.ocr":
680
- results = process_all_pages_dot_ocr(file_path, min_pix, max_pix)
681
- elif model_choice == "Dolphin":
682
- results = process_all_pages_dolphin(file_path)
683
- else:
684
- raise ValueError("Invalid model choice")
685
- pdf_cache["results"] = results
686
- pdf_cache["is_parsed"] = True
687
- first_result = results[0]
688
- if model_choice == "dot.ocr":
689
- processed_img = first_result['processed_image']
690
- markdown_content = first_result['markdown_content']
691
- layout_json = first_result['layout_result']
692
- else:
693
- processed_img = first_result['processed_image']
694
- markdown_content = first_result['markdown_content']
695
- layout_json = first_result['layout_result']
696
- return processed_img, markdown_content, layout_json
697
- except Exception as e:
698
- error_msg = f"Error processing document: {str(e)}"
699
- print(error_msg)
700
- return None, error_msg, None
701
 
702
  def create_gradio_interface():
703
  """Create the Gradio interface"""
 
704
  css = """
705
  .main-container { max-width: 1400px; margin: 0 auto; }
706
  .header-text { text-align: center; color: #2c3e50; margin-bottom: 20px; }
707
- .process-button {
708
- border: none !important;
709
- color: white !important;
710
- font-weight: bold !important;
711
- background-color: blue !important;}
712
- .process-button:hover {
713
  background-color: darkblue !important;
714
- transform: translateY(-2px) !important;
715
  box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; }
716
  .info-box { border: 1px solid #dee2e6; border-radius: 8px; padding: 15px; margin: 10px 0; }
717
  .page-info { text-align: center; padding: 8px 16px; border-radius: 20px; font-weight: bold; margin: 10px 0; }
@@ -727,34 +622,41 @@ def create_gradio_interface():
727
  </p>
728
  </div>
729
  """)
 
 
730
  with gr.Row():
 
731
  with gr.Column(scale=1):
 
 
732
  model_choice = gr.Radio(
733
- choices=["dot.ocr", "Dolphin"],
734
  label="Select Model",
735
- value="dot.ocr"
736
  )
 
 
737
  file_input = gr.File(
738
  label="Upload Image or PDF",
739
  file_types=[".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".pdf"],
740
  type="filepath"
741
  )
742
- with gr.Row():
743
- examples = gr.Examples(
744
- examples=["examples/sample_image1.png", "examples/sample_image2.png", "examples/sample_pdf.pdf"],
745
- inputs=file_input,
746
- label="Example Documents"
747
- )
748
  image_preview = gr.Image(
749
  label="Preview",
750
  type="pil",
751
  interactive=False,
752
  height=300
753
  )
 
 
754
  with gr.Row():
755
  prev_page_btn = gr.Button("◀ Previous", size="md")
756
- page_info = gr.HTML("No file loaded")
757
  next_page_btn = gr.Button("Next ▶", size="md")
 
 
758
  with gr.Accordion("Advanced Settings", open=False):
759
  max_new_tokens = gr.Slider(
760
  minimum=1000,
@@ -764,25 +666,36 @@ def create_gradio_interface():
764
  label="Max New Tokens",
765
  info="Maximum number of tokens to generate"
766
  )
 
767
  min_pixels = gr.Number(
768
  value=MIN_PIXELS,
769
  label="Min Pixels",
770
  info="Minimum image resolution"
771
  )
 
772
  max_pixels = gr.Number(
773
  value=MAX_PIXELS,
774
- label="Max Pixels",
775
  info="Maximum image resolution"
776
  )
 
 
777
  process_btn = gr.Button(
778
  "🚀 Process Document",
779
  variant="primary",
780
  elem_classes=["process-button"],
781
  size="lg"
782
  )
 
 
783
  clear_btn = gr.Button("🗑️ Clear All", variant="secondary")
 
 
784
  with gr.Column(scale=2):
 
 
785
  with gr.Tabs():
 
786
  with gr.Tab("🖼️ Processed Image"):
787
  processed_image = gr.Image(
788
  label="Image with Layout Detection",
@@ -790,11 +703,13 @@ def create_gradio_interface():
790
  interactive=False,
791
  height=500
792
  )
 
793
  with gr.Tab("📝 Extracted Content"):
794
  markdown_output = gr.Markdown(
795
  value="Click 'Process Document' to see extracted content...",
796
  height=500
797
  )
 
798
  with gr.Tab("📋 Layout JSON"):
799
  json_output = gr.JSON(
800
  label="Layout Analysis Results",
@@ -802,8 +717,114 @@ def create_gradio_interface():
802
  )
803
 
804
  # Event handlers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
805
  file_input.change(
806
- lambda file_path: load_file_for_preview(file_path),
807
  inputs=[file_input],
808
  outputs=[image_preview, page_info]
809
  )
@@ -825,12 +846,23 @@ def create_gradio_interface():
825
  )
826
 
827
  clear_btn.click(
828
- lambda: (None, None, "No file loaded", None, "Click 'Process Document' to see extracted content...", None),
829
- outputs=[file_input, image_preview, page_info, processed_image, markdown_output, json_output]
 
 
 
830
  )
831
 
832
  return demo
833
 
 
834
  if __name__ == "__main__":
 
835
  demo = create_gradio_interface()
836
- demo.queue(max_size=10).launch(share=False, debug=True, show_error=True)
 
 
 
 
 
 
 
1
+ import spaces
2
+ import json
3
+ import math
4
  import os
5
+ import traceback
6
+ from io import BytesIO
7
+ from typing import Any, Dict, List, Optional, Tuple
8
+ import re
9
+
10
+ import fitz
11
  import gradio as gr
12
+ import requests
 
13
  import torch
 
 
14
  from huggingface_hub import snapshot_download
15
+ from PIL import Image, ImageDraw, ImageFont
16
+ from transformers import (
17
+ Qwen2_5_VLForConditionalGeneration,
18
+ AutoProcessor,
19
+ TextIteratorStreamer,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  )
21
+ from qwen_vl_utils import process_vision_info
 
 
 
 
 
 
 
 
22
 
23
  # Constants
24
  MIN_PIXELS = 3136
25
  MAX_PIXELS = 11289600
26
  IMAGE_FACTOR = 28
27
+ device = "cuda" if torch.cuda.is_available() else "cpu"
28
 
29
  # Prompts
30
  prompt = """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
 
46
  5. Final Output: The entire output must be a single JSON object.
47
  """
48
 
49
+ # Load Camel-Doc-OCR-062825
50
+ MODEL_ID_M = "prithivMLmods/Camel-Doc-OCR-062825"
51
+ processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
52
+ model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
53
+ MODEL_ID_M,
54
+ trust_remote_code=True,
55
+ torch_dtype=torch.float16
56
+ ).to(device).eval()
57
+
58
+ # Load Megalodon-OCR-Sync-0713
59
+ MODEL_ID_T = "prithivMLmods/Megalodon-OCR-Sync-0713"
60
+ processor_t = AutoProcessor.from_pretrained(MODEL_ID_T, trust_remote_code=True)
61
+ model_t = Qwen2_5_VLForConditionalGeneration.from_pretrained(
62
+ MODEL_ID_T,
63
+ trust_remote_code=True,
64
+ torch_dtype=torch.float16
65
+ ).to(device).eval()
66
+
67
+ # Load Nanonets-OCR-s
68
+ MODEL_ID_C = "nanonets/Nanonets-OCR-s"
69
+ processor_c = AutoProcessor.from_pretrained(MODEL_ID_C, trust_remote_code=True)
70
+ model_c = Qwen2_5_VLForConditionalGeneration.from_pretrained(
71
+ MODEL_ID_C,
72
+ trust_remote_code=True,
73
+ torch_dtype=torch.float16
74
+ ).to(device).eval()
75
+
76
+ # Load MonkeyOCR
77
+ MODEL_ID_G = "echo840/MonkeyOCR"
78
+ SUBFOLDER = "Recognition"
79
+ processor_g = AutoProcessor.from_pretrained(
80
+ MODEL_ID_G,
81
+ trust_remote_code=True,
82
+ subfolder=SUBFOLDER
83
+ )
84
+ model_g = Qwen2_5_VLForConditionalGeneration.from_pretrained(
85
+ MODEL_ID_G,
86
+ trust_remote_code=True,
87
+ subfolder=SUBFOLDER,
88
+ torch_dtype=torch.float16
89
+ ).to(device).eval()
90
+
91
+
92
  # Utility functions
93
  def round_by_factor(number: int, factor: int) -> int:
94
  """Returns the closest integer to 'number' that is divisible by 'factor'."""
95
  return round(number / factor) * factor
96
 
97
+
98
  def smart_resize(
99
  height: int,
100
  width: int,
 
124
  w_bar = round_by_factor(width * beta, factor)
125
  return h_bar, w_bar
126
 
127
+
128
  def fetch_image(image_input, min_pixels: int = None, max_pixels: int = None):
129
  """Fetch and process an image"""
130
  if isinstance(image_input, str):
 
137
  image = image_input.convert('RGB')
138
  else:
139
  raise ValueError(f"Invalid image input type: {type(image_input)}")
140
+
141
  if min_pixels is not None or max_pixels is not None:
142
  min_pixels = min_pixels or MIN_PIXELS
143
  max_pixels = max_pixels or MAX_PIXELS
144
  height, width = smart_resize(
145
+ image.height,
146
+ image.width,
147
  factor=IMAGE_FACTOR,
148
  min_pixels=min_pixels,
149
  max_pixels=max_pixels
150
  )
151
  image = image.resize((width, height), Image.LANCZOS)
152
+
153
  return image
154
 
155
+
156
  def load_images_from_pdf(pdf_path: str) -> List[Image.Image]:
157
  """Load images from PDF file"""
158
  images = []
159
  try:
160
+ pdf_document = fitz.open(pdf_path)
161
  for page_num in range(len(pdf_document)):
162
  page = pdf_document.load_page(page_num)
163
+ # Convert page to image
164
+ mat = fitz.Matrix(2.0, 2.0) # Increase resolution
165
  pix = page.get_pixmap(matrix=mat)
166
  img_data = pix.tobytes("ppm")
167
  image = Image.open(BytesIO(img_data)).convert('RGB')
 
172
  return []
173
  return images
174
 
175
+
176
  def draw_layout_on_image(image: Image.Image, layout_data: List[Dict]) -> Image.Image:
177
  """Draw layout bounding boxes on image"""
178
  img_copy = image.copy()
179
  draw = ImageDraw.Draw(img_copy)
180
+
181
+ # Colors for different categories
182
  colors = {
183
  'Caption': '#FF6B6B',
184
+ 'Footnote': '#4ECDC4',
185
  'Formula': '#45B7D1',
186
  'List-item': '#96CEB4',
187
  'Page-footer': '#FFEAA7',
 
192
  'Text': '#74B9FF',
193
  'Title': '#E17055'
194
  }
195
+
196
  try:
197
+ # Load a font
198
+ try:
199
+ font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 12)
200
+ except Exception:
201
+ font = ImageFont.load_default()
202
+
203
+ for item in layout_data:
204
+ if 'bbox' in item and 'category' in item:
205
+ bbox = item['bbox']
206
+ category = item['category']
207
+ color = colors.get(category, '#000000')
208
+
209
+ # Draw rectangle
210
+ draw.rectangle(bbox, outline=color, width=2)
211
+
212
+ # Draw label
213
+ label = category
214
+ label_bbox = draw.textbbox((0, 0), label, font=font)
215
+ label_width = label_bbox[2] - label_bbox[0]
216
+ label_height = label_bbox[3] - label_bbox[1]
217
+
218
+ # Position label above the box
219
+ label_x = bbox[0]
220
+ label_y = max(0, bbox[1] - label_height - 2)
221
+
222
+ # Draw background for label
223
+ draw.rectangle(
224
+ [label_x, label_y, label_x + label_width + 4, label_y + label_height + 2],
225
+ fill=color
226
+ )
227
+
228
+ # Draw text
229
+ draw.text((label_x + 2, label_y + 1), label, fill='white', font=font)
230
+
231
+ except Exception as e:
232
+ print(f"Error drawing layout: {e}")
233
+
234
  return img_copy
235
 
236
+
237
+ def is_arabic_text(text: str) -> bool:
238
+ """Check if text in headers and paragraphs contains mostly Arabic characters"""
239
+ if not text:
240
+ return False
241
+
242
+ # Extract text from headers and paragraphs only
243
+ # Match markdown headers (# ## ###) and regular paragraph text
244
+ header_pattern = r'^#{1,6}\s+(.+)$'
245
+ paragraph_pattern = r'^(?!#{1,6}\s|!\[|```|\||\s*[-*+]\s|\s*\d+\.\s)(.+)$'
246
+
247
+ content_text = []
248
+
249
+ for line in text.split('\n'):
250
+ line = line.strip()
251
+ if not line:
252
+ continue
253
+
254
+ # Check for headers
255
+ header_match = re.match(header_pattern, line, re.MULTILINE)
256
+ if header_match:
257
+ content_text.append(header_match.group(1))
258
+ continue
259
+
260
+ # Check for paragraph text (exclude lists, tables, code blocks, images)
261
+ if re.match(paragraph_pattern, line, re.MULTILINE):
262
+ content_text.append(line)
263
+
264
+ if not content_text:
265
+ return False
266
+
267
+ # Join all content text and check for Arabic characters
268
+ combined_text = ' '.join(content_text)
269
+
270
+ # Arabic Unicode ranges
271
+ arabic_chars = 0
272
+ total_chars = 0
273
+
274
+ for char in combined_text:
275
+ if char.isalpha():
276
+ total_chars += 1
277
+ # Arabic script ranges
278
+ if ('\u0600' <= char <= '\u06FF') or ('\u0750' <= char <= '\u077F') or ('\u08A0' <= char <= '\u08FF'):
279
+ arabic_chars += 1
280
+
281
+ if total_chars == 0:
282
+ return False
283
+
284
+ # Consider text as Arabic if more than 50% of alphabetic characters are Arabic
285
+ return (arabic_chars / total_chars) > 0.5
286
+
287
+
288
  def layoutjson2md(image: Image.Image, layout_data: List[Dict], text_key: str = 'text') -> str:
289
  """Convert layout JSON to markdown format"""
290
  import base64
291
  from io import BytesIO
292
+
293
  markdown_lines = []
294
+
295
  try:
296
+ # Sort items by reading order (top to bottom, left to right)
297
+ sorted_items = sorted(layout_data, key=lambda x: (x.get('bbox',), x.get('bbox',)))
298
+
299
  for item in sorted_items:
300
  category = item.get('category', '')
301
  text = item.get(text_key, '')
302
  bbox = item.get('bbox', [])
303
+
304
  if category == 'Picture':
305
+ # Extract image region and embed it
306
  if bbox and len(bbox) == 4:
307
  try:
308
+ # Extract the image region
309
  x1, y1, x2, y2 = bbox
310
+ # Ensure coordinates are within image bounds
311
  x1, y1 = max(0, int(x1)), max(0, int(y1))
312
  x2, y2 = min(image.width, int(x2)), min(image.height, int(y2))
313
+
314
  if x2 > x1 and y2 > y1:
315
  cropped_img = image.crop((x1, y1, x2, y2))
316
+
317
+ # Convert to base64 for embedding
318
  buffer = BytesIO()
319
  cropped_img.save(buffer, format='PNG')
320
  img_data = base64.b64encode(buffer.getvalue()).decode()
321
+
322
+ # Add as markdown image
323
  markdown_lines.append(f"![Image](data:image/png;base64,{img_data})\n")
324
  else:
325
  markdown_lines.append("![Image](Image region detected)\n")
 
339
  elif category == 'List-item':
340
  markdown_lines.append(f"- {text}\n")
341
  elif category == 'Table':
342
+ # If text is already HTML, keep it as is
343
  if text.strip().startswith('<'):
344
  markdown_lines.append(f"{text}\n")
345
  else:
346
  markdown_lines.append(f"**Table:** {text}\n")
347
  elif category == 'Formula':
348
+ # If text is LaTeX, format it properly
349
  if text.strip().startswith('$') or '\\' in text:
350
  markdown_lines.append(f"$$\n{text}\n$$\n")
351
  else:
 
355
  elif category == 'Footnote':
356
  markdown_lines.append(f"^{text}^\n")
357
  elif category in ['Page-header', 'Page-footer']:
358
+ # Skip headers and footers in main content
359
  continue
360
  else:
361
  markdown_lines.append(f"{text}\n")
362
+
363
+ markdown_lines.append("") # Add spacing
364
+
365
  except Exception as e:
366
  print(f"Error converting to markdown: {e}")
367
  return str(layout_data)
368
+
369
  return "\n".join(markdown_lines)
370
 
371
+
372
+ # PDF handling state
373
  pdf_cache = {
374
  "images": [],
375
  "current_page": 0,
 
378
  "is_parsed": False,
379
  "results": []
380
  }
381
+ @spaces.GPU
382
+ def inference(model_name: str, image: Image.Image, prompt: str, max_new_tokens: int = 1024) -> str:
383
+ """Run inference on an image with the given prompt using the selected model."""
 
384
  try:
385
+ if model_name == "Camel-Doc-OCR-062825":
386
+ processor = processor_m
387
+ model = model_m
388
+ elif model_name == "Megalodon-OCR-Sync-0713":
389
+ processor = processor_t
390
+ model = model_t
391
+ elif model_name == "Nanonets-OCR-s":
392
+ processor = processor_c
393
+ model = model_c
394
+ elif model_name == "MonkeyOCR-Recognition":
395
+ processor = processor_g
396
+ model = model_g
397
+ else:
398
+ raise ValueError(f"Invalid model selected: {model_name}")
399
+
400
  messages = [
401
  {
402
  "role": "user",
403
  "content": [
404
+ {"type": "text", "text": prompt},
405
+ {"type": "image"}
406
  ]
407
  }
408
  ]
409
+
410
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
411
+ inputs = processor(text, [image], return_tensors="pt").to(device)
412
+
 
 
 
 
 
 
 
 
 
 
413
  with torch.no_grad():
414
+ generated_ids = model.generate(
415
  **inputs,
416
  max_new_tokens=max_new_tokens,
417
  do_sample=False,
418
  temperature=0.1
419
  )
420
+
421
+ generated_ids = [
422
  out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
423
  ]
424
+ output_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
425
+ return output_text
426
+
 
 
 
427
  except Exception as e:
428
+ print(f"Error during inference: {e}")
429
+ traceback.print_exc()
430
  return f"Error during inference: {str(e)}"
431
 
432
+
433
+ def process_image(
434
+ model_name: str,
435
+ image: Image.Image,
436
+ min_pixels: Optional[int] = None,
437
+ max_pixels: Optional[int] = None
438
+ ) -> Dict[str, Any]:
439
+ """Process a single image with the specified prompt mode"""
440
  try:
441
+ # Resize image if needed
442
  if min_pixels is not None or max_pixels is not None:
443
  image = fetch_image(image, min_pixels=min_pixels, max_pixels=max_pixels)
444
+
445
+ # Run inference with the default prompt
446
+ raw_output = inference(model_name, image, prompt)
447
+
448
+ # Process results based on prompt mode
449
  result = {
450
  'original_image': image,
451
  'raw_output': raw_output,
 
453
  'layout_result': None,
454
  'markdown_content': None
455
  }
456
+
457
+ # Try to parse JSON and create visualizations (since we're doing layout analysis)
458
  try:
459
+ # Clean the output to be valid JSON
460
+ # Models sometimes add ```json ... ``` markers
461
+ json_match = re.search(r'```json\s*([\s\S]+?)\s*```', raw_output)
462
+ if json_match:
463
+ json_str = json_match.group(1)
464
+ else:
465
+ json_str = raw_output
466
+
467
+ layout_data = json.loads(json_str)
468
  result['layout_result'] = layout_data
469
+
470
+ # Create visualization with bounding boxes
471
+ try:
472
+ processed_image = draw_layout_on_image(image, layout_data)
473
+ result['processed_image'] = processed_image
474
+ except Exception as e:
475
+ print(f"Error drawing layout: {e}")
476
+ result['processed_image'] = image
477
+
478
+ # Generate markdown from layout data
479
+ try:
480
+ markdown_content = layoutjson2md(image, layout_data, text_key='text')
481
+ result['markdown_content'] = markdown_content
482
+ except Exception as e:
483
+ print(f"Error generating markdown: {e}")
484
+ result['markdown_content'] = raw_output
485
+
486
  except json.JSONDecodeError:
487
  print("Failed to parse JSON output, using raw output")
488
  result['markdown_content'] = raw_output
489
+
490
  return result
491
+
492
  except Exception as e:
493
+ print(f"Error processing image: {e}")
494
+ traceback.print_exc()
495
  return {
496
  'original_image': image,
497
  'raw_output': f"Error processing image: {str(e)}",
 
500
  'markdown_content': f"Error processing image: {str(e)}"
501
  }
502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503
 
 
504
  def load_file_for_preview(file_path: str) -> Tuple[Optional[Image.Image], str]:
505
  """Load file for preview (supports PDF and images)"""
506
  global pdf_cache
507
+
508
  if not file_path or not os.path.exists(file_path):
509
  return None, "No file selected"
510
+
511
+ file_ext = os.path.splitext(file_path).lower()
512
+
513
  try:
514
  if file_ext == '.pdf':
515
+ # Load PDF pages
516
  images = load_images_from_pdf(file_path)
517
  if not images:
518
  return None, "Failed to load PDF"
519
+
520
  pdf_cache.update({
521
  "images": images,
522
  "current_page": 0,
 
525
  "is_parsed": False,
526
  "results": []
527
  })
528
+
529
+ return images, f"Page 1 / {len(images)}"
530
+
531
  elif file_ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']:
532
+ # Load single image
533
  image = Image.open(file_path).convert('RGB')
534
+
535
  pdf_cache.update({
536
  "images": [image],
537
  "current_page": 0,
 
540
  "is_parsed": False,
541
  "results": []
542
  })
543
+
544
  return image, "Page 1 / 1"
545
  else:
546
  return None, f"Unsupported file format: {file_ext}"
547
+
548
  except Exception as e:
549
  print(f"Error loading file: {e}")
550
  return None, f"Error loading file: {str(e)}"
551
 
552
+
553
+ def turn_page(direction: str) -> Tuple[Optional[Image.Image], str, Any, Optional[Image.Image], Optional[Dict]]:
554
  """Navigate through PDF pages and update all relevant outputs."""
555
  global pdf_cache
556
+
557
  if not pdf_cache["images"]:
558
+ return None, '<div class="page-info">No file loaded</div>', "No results yet", None, None
559
+
560
  if direction == "prev":
561
  pdf_cache["current_page"] = max(0, pdf_cache["current_page"] - 1)
562
  elif direction == "next":
563
+ pdf_cache["current_page"] = min(
564
+ pdf_cache["total_pages"] - 1,
565
+ pdf_cache["current_page"] + 1
566
+ )
567
+
568
  index = pdf_cache["current_page"]
569
  current_image_preview = pdf_cache["images"][index]
570
+ page_info_html = f'<div class="page-info">Page {index + 1} / {pdf_cache["total_pages"]}</div>'
571
+
572
+ # Initialize default result values
573
+ markdown_content = "Page not processed yet"
574
+ processed_img = None
575
+ layout_json = None
576
+
577
+ # Get results for current page if available
578
+ if (pdf_cache["is_parsed"] and
579
+ index < len(pdf_cache["results"]) and
580
+ pdf_cache["results"][index]):
581
+
582
  result = pdf_cache["results"][index]
583
+ markdown_content = result.get('markdown_content') or result.get('raw_output', 'No content available')
584
+ processed_img = result.get('processed_image', None) # Get the processed image
585
+ layout_json = result.get('layout_result', None) # Get the layout JSON
586
+
587
+ # Check for Arabic text to set RTL property
588
+ if is_arabic_text(markdown_content):
589
+ markdown_update = gr.update(value=markdown_content, rtl=True)
590
  else:
591
+ markdown_update = markdown_content
592
+
593
+ return current_image_preview, page_info_html, markdown_update, processed_img, layout_json
 
594
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
595
 
596
  def create_gradio_interface():
597
  """Create the Gradio interface"""
598
+
599
  css = """
600
  .main-container { max-width: 1400px; margin: 0 auto; }
601
  .header-text { text-align: center; color: #2c3e50; margin-bottom: 20px; }
602
+ .process-button {
603
+ border: none !important;
604
+ color: white !important;
605
+ font-weight: bold !important;
606
+ background-color: blue !important;}
607
+ .process-button:hover {
608
  background-color: darkblue !important;
609
+ transform: translateY(-2px) !important;
610
  box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; }
611
  .info-box { border: 1px solid #dee2e6; border-radius: 8px; padding: 15px; margin: 10px 0; }
612
  .page-info { text-align: center; padding: 8px 16px; border-radius: 20px; font-weight: bold; margin: 10px 0; }
 
622
  </p>
623
  </div>
624
  """)
625
+
626
+ # Main interface
627
  with gr.Row():
628
+ # Left column - Input and controls
629
  with gr.Column(scale=1):
630
+
631
+ # Model selection
632
  model_choice = gr.Radio(
633
+ choices=["Camel-Doc-OCR-062825", "MonkeyOCR-Recognition", "Nanonets-OCR-s", "Megalodon-OCR-Sync-0713"],
634
  label="Select Model",
635
+ value="Camel-Doc-OCR-062825"
636
  )
637
+
638
+ # File input
639
  file_input = gr.File(
640
  label="Upload Image or PDF",
641
  file_types=[".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".pdf"],
642
  type="filepath"
643
  )
644
+
645
+ # Image preview
 
 
 
 
646
  image_preview = gr.Image(
647
  label="Preview",
648
  type="pil",
649
  interactive=False,
650
  height=300
651
  )
652
+
653
+ # Page navigation for PDFs
654
  with gr.Row():
655
  prev_page_btn = gr.Button("◀ Previous", size="md")
656
+ page_info = gr.HTML('<div class="page-info">No file loaded</div>')
657
  next_page_btn = gr.Button("Next ▶", size="md")
658
+
659
+ # Advanced settings
660
  with gr.Accordion("Advanced Settings", open=False):
661
  max_new_tokens = gr.Slider(
662
  minimum=1000,
 
666
  label="Max New Tokens",
667
  info="Maximum number of tokens to generate"
668
  )
669
+
670
  min_pixels = gr.Number(
671
  value=MIN_PIXELS,
672
  label="Min Pixels",
673
  info="Minimum image resolution"
674
  )
675
+
676
  max_pixels = gr.Number(
677
  value=MAX_PIXELS,
678
+ label="Max Pixels",
679
  info="Maximum image resolution"
680
  )
681
+
682
+ # Process button
683
  process_btn = gr.Button(
684
  "🚀 Process Document",
685
  variant="primary",
686
  elem_classes=["process-button"],
687
  size="lg"
688
  )
689
+
690
+ # Clear button
691
  clear_btn = gr.Button("🗑️ Clear All", variant="secondary")
692
+
693
+ # Right column - Results
694
  with gr.Column(scale=2):
695
+
696
+ # Results tabs
697
  with gr.Tabs():
698
+ # Processed image tab
699
  with gr.Tab("🖼️ Processed Image"):
700
  processed_image = gr.Image(
701
  label="Image with Layout Detection",
 
703
  interactive=False,
704
  height=500
705
  )
706
+ # Markdown output tab
707
  with gr.Tab("📝 Extracted Content"):
708
  markdown_output = gr.Markdown(
709
  value="Click 'Process Document' to see extracted content...",
710
  height=500
711
  )
712
+ # JSON layout tab
713
  with gr.Tab("📋 Layout JSON"):
714
  json_output = gr.JSON(
715
  label="Layout Analysis Results",
 
717
  )
718
 
719
  # Event handlers
720
+ def process_document(model_name, file_path, max_tokens, min_pix, max_pix):
721
+ """Process the uploaded document"""
722
+ global pdf_cache
723
+
724
+ try:
725
+ if not file_path:
726
+ return None, "Please upload a file first.", None
727
+
728
+ # Load and preview file
729
+ image, page_info = load_file_for_preview(file_path)
730
+ if image is None:
731
+ return None, page_info, None
732
+
733
+ # Process the image(s)
734
+ if pdf_cache["file_type"] == "pdf":
735
+ # Process all pages for PDF
736
+ all_results = []
737
+ all_markdown = []
738
+
739
+ for i, img in enumerate(pdf_cache["images"]):
740
+ result = process_image(
741
+ model_name,
742
+ img,
743
+ min_pixels=int(min_pix) if min_pix else None,
744
+ max_pixels=int(max_pix) if max_pix else None
745
+ )
746
+ all_results.append(result)
747
+ if result.get('markdown_content'):
748
+ all_markdown.append(f"## Page {i+1}\n\n{result['markdown_content']}")
749
+
750
+ pdf_cache["results"] = all_results
751
+ pdf_cache["is_parsed"] = True
752
+
753
+ # Show results for first page
754
+ first_result = all_results
755
+ combined_markdown = "\n\n---\n\n".join(all_markdown)
756
+
757
+ # Check if the combined markdown contains mostly Arabic text
758
+ if is_arabic_text(combined_markdown):
759
+ markdown_update = gr.update(value=combined_markdown, rtl=True)
760
+ else:
761
+ markdown_update = combined_markdown
762
+
763
+ return (
764
+ first_result['processed_image'],
765
+ markdown_update,
766
+ first_result['layout_result']
767
+ )
768
+ else:
769
+ # Process single image
770
+ result = process_image(
771
+ model_name,
772
+ image,
773
+ min_pixels=int(min_pix) if min_pix else None,
774
+ max_pixels=int(max_pix) if max_pix else None
775
+ )
776
+
777
+ pdf_cache["results"] = [result]
778
+ pdf_cache["is_parsed"] = True
779
+
780
+ # Check if the content contains mostly Arabic text
781
+ content = result['markdown_content'] or "No content extracted"
782
+ if is_arabic_text(content):
783
+ markdown_update = gr.update(value=content, rtl=True)
784
+ else:
785
+ markdown_update = content
786
+
787
+ return (
788
+ result['processed_image'],
789
+ markdown_update,
790
+ result['layout_result']
791
+ )
792
+
793
+ except Exception as e:
794
+ error_msg = f"Error processing document: {str(e)}"
795
+ print(error_msg)
796
+ traceback.print_exc()
797
+ return None, error_msg, None
798
+
799
+ def handle_file_upload(file_path):
800
+ """Handle file upload and show preview"""
801
+ if not file_path:
802
+ return None, "No file loaded"
803
+
804
+ image, page_info = load_file_for_preview(file_path)
805
+ return image, page_info
806
+
807
+ def clear_all():
808
+ """Clear all data and reset interface"""
809
+ global pdf_cache
810
+
811
+ pdf_cache = {
812
+ "images": [], "current_page": 0, "total_pages": 0,
813
+ "file_type": None, "is_parsed": False, "results": []
814
+ }
815
+
816
+ return (
817
+ None, # file_input
818
+ None, # image_preview
819
+ '<div class="page-info">No file loaded</div>', # page_info
820
+ None, # processed_image
821
+ "Click 'Process Document' to see extracted content...", # markdown_output
822
+ None, # json_output
823
+ )
824
+
825
+ # Wire up event handlers
826
  file_input.change(
827
+ handle_file_upload,
828
  inputs=[file_input],
829
  outputs=[image_preview, page_info]
830
  )
 
846
  )
847
 
848
  clear_btn.click(
849
+ clear_all,
850
+ outputs=[
851
+ file_input, image_preview, page_info, processed_image,
852
+ markdown_output, json_output
853
+ ]
854
  )
855
 
856
  return demo
857
 
858
+
859
  if __name__ == "__main__":
860
+ # Create and launch the interface
861
  demo = create_gradio_interface()
862
+ demo.queue(max_size=10).launch(
863
+ server_name="0.0.0.0",
864
+ server_port=7860,
865
+ share=False,
866
+ debug=True,
867
+ show_error=True
868
+ )