import spaces import json import math import os import traceback from io import BytesIO from typing import Any, Dict, List, Optional, Tuple import re from threading import Thread import time import gradio as gr import requests import torch from PIL import Image, ImageDraw, ImageFont from transformers import ( Qwen2_5_VLForConditionalGeneration, AutoProcessor, TextIteratorStreamer, ) # Constants MIN_PIXELS = 3136 MAX_PIXELS = 11289600 IMAGE_FACTOR = 28 MAX_INPUT_TOKEN_LENGTH = 4096 device = "cuda" if torch.cuda.is_available() else "cpu" # Prompt for Layout Analysis prompt = """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox. 1. Bbox format: [x1, y1, x2, y2] 2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title']. 3. Text Extraction & Formatting Rules: - Picture: For the 'Picture' category, the text field should be omitted. - Formula: Format its text as LaTeX. - Table: Format its text as HTML. - All Others (Text, Title, etc.): Format their text as Markdown. 4. Constraints: - The output text must be the original text from the image, with no translation. - All layout elements must be sorted according to human reading order. 5. Final Output: The entire output must be a single JSON object. """ # Load Models MODEL_ID_M = "prithivMLmods/Camel-Doc-OCR-062825" processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True) model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_ID_M, trust_remote_code=True, torch_dtype=torch.float16 ).to(device).eval() MODEL_ID_T = "prithivMLmods/Megalodon-OCR-Sync-0713" processor_t = AutoProcessor.from_pretrained(MODEL_ID_T, trust_remote_code=True) model_t = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_ID_T, trust_remote_code=True, torch_dtype=torch.float16 ).to(device).eval() MODEL_ID_C = "nanonets/Nanonets-OCR-s" processor_c = AutoProcessor.from_pretrained(MODEL_ID_C, trust_remote_code=True) model_c = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_ID_C, trust_remote_code=True, torch_dtype=torch.float16 ).to(device).eval() MODEL_ID_G = "echo840/MonkeyOCR" SUBFOLDER = "Recognition" processor_g = AutoProcessor.from_pretrained( MODEL_ID_G, trust_remote_code=True, subfolder=SUBFOLDER ) model_g = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_ID_G, trust_remote_code=True, subfolder=SUBFOLDER, torch_dtype=torch.float16 ).to(device).eval() # Utility functions def is_arabic_text(text: str) -> bool: """Check if text contains mostly Arabic characters.""" if not text: return False # Simplified check for Arabic characters in the given text arabic_chars = 0 total_chars = 0 for char in text: if char.isalpha(): total_chars += 1 if '\u0600' <= char <= '\u06FF': arabic_chars += 1 return total_chars > 0 and (arabic_chars / total_chars) > 0.5 def layoutjson2md(image: Image.Image, layout_data: List[Dict], text_key: str = 'text') -> str: """Convert layout JSON to markdown format.""" import base64 from io import BytesIO markdown_lines = [] try: # Sort items by reading order (top to bottom, left to right) sorted_items = sorted(layout_data, key=lambda x: (x.get('bbox', [0, 0, 0, 0])[1], x.get('bbox', [0, 0, 0, 0])[0])) for item in sorted_items: category = item.get('category', '') text = item.get(text_key, '') bbox = item.get('bbox', []) if category == 'Picture': if bbox and len(bbox) == 4: try: x1, y1, x2, y2 = [int(coord) for coord in bbox] cropped_img = image.crop((x1, y1, x2, y2)) buffer = BytesIO() cropped_img.save(buffer, format='PNG') img_data = base64.b64encode(buffer.getvalue()).decode() markdown_lines.append(f"![Image](data:image/png;base64,{img_data})\n") except Exception as e: markdown_lines.append("![Image](Image region detected)\n") elif not text: continue elif category == 'Title': markdown_lines.append(f"# {text}\n") elif category == 'Section-header': markdown_lines.append(f"## {text}\n") elif category == 'Text': markdown_lines.append(f"{text}\n") elif category == 'List-item': markdown_lines.append(f"- {text}\n") elif category == 'Table' and text.strip().startswith('<'): markdown_lines.append(f"{text}\n") elif category == 'Formula' and (text.strip().startswith('$') or '\\' in text): markdown_lines.append(f"$$\n{text}\n$$\n") elif category == 'Caption': markdown_lines.append(f"*{text}*\n") elif category == 'Footnote': markdown_lines.append(f"^{text}^\n") elif category not in ['Page-header', 'Page-footer']: markdown_lines.append(f"{text}\n") except Exception as e: print(f"Error converting to markdown: {e}") return f"### Error converting to Markdown\n\n```\n{str(layout_data)}\n```" return "\n".join(markdown_lines) @spaces.GPU def generate_and_process(model_name: str, image: Image.Image, max_new_tokens: int): """ Generates a response using streaming, then processes the final output. Yields updates for the raw stream, final markdown, and JSON output. """ if image is None: yield "Please upload an image.", "Please upload an image.", None return # 1. Select Model and Processor if model_name == "Camel-Doc-OCR-062825": processor, model = processor_m, model_m elif model_name == "Megalodon-OCR-Sync-0713": processor, model = processor_t, model_t elif model_name == "Nanonets-OCR-s": processor, model = processor_c, model_c elif model_name == "MonkeyOCR-Recognition": processor, model = processor_g, model_g else: yield "Invalid model selected.", "Invalid model selected.", None return # 2. Prepare inputs for the model messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt}]}] prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = processor( text=[prompt_full], images=[image], return_tensors="pt", padding=True, truncation=True, max_length=MAX_INPUT_TOKEN_LENGTH ).to(device) # 3. Stream the generation streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens} thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() buffer = "" # Initial placeholder yield yield buffer, "⏳ Generating response...", None for new_text in streamer: buffer += new_text buffer = buffer.replace("<|im_end|>", "") time.sleep(0.01) # Small delay for smoother streaming yield buffer, "⏳ Generating response...", None # 4. Process the final buffer content try: json_match = re.search(r'```json\s*([\s\S]+?)\s*```', buffer) json_str = json_match.group(1) if json_match else buffer layout_data = json.loads(json_str) markdown_content = layoutjson2md(image, layout_data) # Final yield with all processed content yield buffer, markdown_content, layout_data except json.JSONDecodeError: error_msg = "❌ Failed to parse JSON from model output." yield buffer, error_msg, {"error": "JSONDecodeError", "raw_output": buffer} except Exception as e: error_msg = f"❌ An error occurred during post-processing: {e}" yield buffer, error_msg, {"error": str(e), "raw_output": buffer} def create_gradio_interface(): """Create the Gradio interface.""" css = """ .main-container { max-width: 1400px; margin: 0 auto; } .header-text { text-align: center; color: #2c3e50; margin-bottom: 20px; } .process-button { border: none !important; color: white !important; font-weight: bold !important; background-color: blue !important; } .process-button:hover { background-color: darkblue !important; transform: translateY(-2px) !important; box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; } """ with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo: gr.HTML("""

DotOCR Comparator

Advanced vision-language model for image to markdown document processing

""") # Keep track of the uploaded image image_state = gr.State(None) with gr.Row(): # Left column - Input and controls with gr.Column(scale=1): model_choice = gr.Radio( choices=["Camel-Doc-OCR-062825", "MonkeyOCR-Recognition", "Nanonets-OCR-s", "Megalodon-OCR-Sync-0713"], label="Select Model", value="Camel-Doc-OCR-062825" ) file_input = gr.Image( label="Upload Image", type="pil", sources=['upload'] ) with gr.Accordion("Advanced Settings", open=False): max_new_tokens = gr.Slider(minimum=1000, maximum=32000, value=24000, step=1000, label="Max New Tokens") process_btn = gr.Button("🚀 Process Document", variant="primary", elem_classes=["process-button"], size="lg") clear_btn = gr.Button("🗑️ Clear All", variant="secondary") # Right column - Results with gr.Column(scale=2): with gr.Tabs(): with gr.Tab("📝 Extracted Content"): output_stream = gr.Textbox(label="Raw Output Stream", interactive=False, lines=10, show_copy_button=True) with gr.Accordion("(Formatted Result)", open=True): markdown_output = gr.Markdown(label="Formatted Result (Result.md)") with gr.Tab("📋 Layout JSON"): json_output = gr.JSON(label="Layout Analysis Results (JSON)", value=None) # Event Handlers def handle_file_upload(image): """Store the uploaded image in the state.""" return image def clear_all(): """Clear all data and reset the interface.""" return None, None, "Click 'Process Document' to see extracted content...", None, None file_input.upload(handle_file_upload, inputs=[file_input], outputs=[image_state]) process_btn.click( generate_and_process, inputs=[model_choice, image_state, max_new_tokens], outputs=[output_stream, markdown_output, json_output] ) clear_btn.click( clear_all, outputs=[file_input, image_state, markdown_output, json_output, output_stream] ) return demo if __name__ == "__main__": demo = create_gradio_interface() demo.queue().launch(server_name="0.0.0.0", server_port=7860, share=True, show_error=True)