import spaces import json import math import os import traceback from io import BytesIO from typing import Any, Dict, List, Optional, Tuple import re import time from threading import Thread import gradio as gr import requests import torch from PIL import Image from transformers import ( Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration, AutoModelForImageTextToText, AutoProcessor, TextIteratorStreamer, AutoModel, AutoTokenizer, ) # --- Constants and Model Setup --- MAX_INPUT_TOKEN_LENGTH = 4096 device = "cuda" if torch.cuda.is_available() else "cpu" # --- Prompts for Different Tasks --- layout_prompt = """Please output the layout information from the image, including each layout element's bbox, its category, and the corresponding text content within the bbox. 1. Bbox format: [x1, y1, x2, y2] 2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title']. 3. Text Extraction & Formatting Rules: - For tables, provide the content in a structured JSON format. - For all other elements, provide the plain text. 4. Constraints: - The output must be the original text from the image. - All layout elements must be sorted according to human reading order. 5. Final Output: The entire output must be a single JSON object wrapped in ```json ... ```. """ ocr_prompt = "Perform precise OCR on the image. Extract all text content, maintaining the original structure, paragraphs, and tables as formatted markdown." # --- Model Loading --- MODEL_ID_M = "prithivMLmods/Camel-Doc-OCR-080125" processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True) model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_ID_M, trust_remote_code=True, torch_dtype=torch.float16 ).to(device).eval() MODEL_ID_T = "prithivMLmods/Megalodon-OCR-Sync-0713" processor_t = AutoProcessor.from_pretrained(MODEL_ID_T, trust_remote_code=True) model_t = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_ID_T, trust_remote_code=True, torch_dtype=torch.float16 ).to(device).eval() MODEL_ID_C = "nanonets/Nanonets-OCR-s" processor_c = AutoProcessor.from_pretrained(MODEL_ID_C, trust_remote_code=True) model_c = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_ID_C, trust_remote_code=True, torch_dtype=torch.float16 ).to(device).eval() MODEL_ID_G = "echo840/MonkeyOCR" SUBFOLDER = "Recognition" processor_g = AutoProcessor.from_pretrained( MODEL_ID_G, trust_remote_code=True, subfolder=SUBFOLDER ) model_g = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_ID_G, trust_remote_code=True, subfolder=SUBFOLDER, torch_dtype=torch.float16 ).to(device).eval() MODEL_ID_I = "allenai/olmOCR-7B-0725" processor_i = AutoProcessor.from_pretrained(MODEL_ID_I, trust_remote_code=True) model_i = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_ID_I, trust_remote_code=True, torch_dtype=torch.float16 ).to(device).eval() # --- Utility Functions --- def layoutjson2md(layout_data: List[Dict]) -> str: """Converts the structured JSON from Layout Analysis into formatted Markdown.""" markdown_lines = [] try: # Sort items by reading order (top-to-bottom, left-to-right) sorted_items = sorted(layout_data, key=lambda x: (x.get('bbox', [0,0,0,0])[1], x.get('bbox', [0,0,0,0])[0])) for item in sorted_items: category = item.get('category', '') text = item.get('text', '') if not text: continue if category == 'Title': markdown_lines.append(f"# {text}\n") elif category == 'Section-header': markdown_lines.append(f"## {text}\n") elif category == 'Table': # Handle structured table JSON if isinstance(text, dict) and 'header' in text and 'rows' in text: header = '| ' + ' | '.join(map(str, text['header'])) + ' |' separator = '| ' + ' | '.join(['---'] * len(text['header'])) + ' |' rows = ['| ' + ' | '.join(map(str, row)) + ' |' for row in text['rows']] markdown_lines.extend([header, separator] + rows) markdown_lines.append("\n") else: # Fallback for simple text markdown_lines.append(f"{text}\n") else: markdown_lines.append(f"{text}\n") except Exception as e: print(f"Error converting to markdown: {e}") return "### Error converting JSON to Markdown." return "\n".join(markdown_lines) # --- Core Application Logic --- @spaces.GPU def process_document_stream(model_name: str, task_choice: str, image: Image.Image, max_new_tokens: int): """ Main generator function that handles both OCR and Layout Analysis tasks. """ if image is None: yield "Please upload an image.", "Please upload an image.", None return # 1. Select prompt based on user's task choice text_prompt = ocr_prompt if task_choice == "Content Extraction" else layout_prompt # 2. Select model and processor if model_name == "Camel-Doc-OCR-080125": processor, model = processor_m, model_m elif model_name == "Megalodon-OCR-Sync-0713": processor, model = processor_t, model_t elif model_name == "Nanonets-OCR-s": processor, model = processor_c, model_c elif model_name == "MonkeyOCR-Recognition": processor, model = processor_g, model_g elif model_name == "olmOCR-7B-0725": processor, model = processor_i, model_i else: yield "Invalid model selected.", "Invalid model selected.", None return # 3. Prepare model inputs and streamer messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": text_prompt}]}] prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = processor(text=[prompt_full], images=[image], return_tensors="pt", padding=True, truncation=True, max_length=MAX_INPUT_TOKEN_LENGTH).to(device) streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens} thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() # 4. Stream raw output to the UI in real-time buffer = "" for new_text in streamer: buffer += new_text buffer = buffer.replace("<|im_end|>", "") time.sleep(0.01) yield buffer, "โณ Processing...", {"status": "streaming"} # 5. Post-process the final buffer based on the selected task if task_choice == "Content Extraction": # For OCR, the buffer is the final result. yield buffer, buffer, None else: # Layout Analysis try: json_match = re.search(r'```json\s*([\s\S]+?)\s*```', buffer) if not json_match: raise json.JSONDecodeError("JSON object not found in output.", buffer, 0) json_str = json_match.group(1) layout_data = json.loads(json_str) markdown_content = layoutjson2md(layout_data) yield buffer, markdown_content, layout_data except Exception as e: error_md = f"โŒ **Error:** Failed to parse Layout JSON.\n\n**Details:**\n`{str(e)}`" error_json = {"error": "ProcessingError", "details": str(e), "raw_output": buffer} yield buffer, error_md, error_json # --- Gradio UI Definition --- def create_gradio_interface(): """Builds and returns the Gradio web interface.""" css = """ .main-container { max-width: 1400px; margin: 0 auto; } .process-button { border: none !important; color: white !important; font-weight: bold !important; background-color: blue !important;} .process-button:hover { background-color: darkblue !important; transform: translateY(-2px) !important; box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; } """ with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo: gr.HTML("""

OCR Comparator๐Ÿฅ 

Advanced Vision-Language Model for Image Content and Layout Extraction

""") with gr.Row(): # Left Column (Inputs) with gr.Column(scale=1): model_choice = gr.Dropdown( choices=["Camel-Doc-OCR-080125", "MonkeyOCR-Recognition", "olmOCR-7B-0725", "Nanonets-OCR-s", "Megalodon-OCR-Sync-0713" ], label="Select Model", value="Nanonets-OCR-s" ) task_choice = gr.Dropdown( choices=["Content Extraction", "Layout Analysis(.json)"], label="Select Task", value="Content Extraction" ) image_input = gr.Image(label="Upload Image", type="pil", sources=['upload']) with gr.Accordion("Advanced Settings", open=False): max_new_tokens = gr.Slider(minimum=512, maximum=8192, value=4096, step=256, label="Max New Tokens") process_btn = gr.Button("๐Ÿš€ Process Document", variant="primary", elem_classes=["process-button"], size="lg") clear_btn = gr.Button("๐Ÿ—‘๏ธ Clear All", variant="secondary") # Right Column (Outputs) with gr.Column(scale=2): with gr.Tabs() as tabs: with gr.Tab("๐Ÿ“ Extracted Content"): raw_output_stream = gr.Textbox(label="Raw Model Output Stream", interactive=False, lines=13, show_copy_button=True) with gr.Row(): examples = gr.Examples( examples=["examples/1.png", "examples/2.png", "examples/3.png", "examples/4.png", "examples/5.png"], inputs=image_input, label="Examples" ) with gr.Tab("๐Ÿ“ฐ README.md"): with gr.Accordion("(Formatted Result)", open=True): markdown_output = gr.Markdown(label="Formatted Markdown") with gr.Tab("๐Ÿ“‹ Layout Analysis Results"): json_output = gr.JSON(label="Structured Layout Data (JSON)") # Event Handlers def clear_all_outputs(): return None, "Raw output will appear here.", "Formatted results will appear here.", None process_btn.click( fn=process_document_stream, inputs=[model_choice, task_choice, image_input, max_new_tokens], outputs=[raw_output_stream, markdown_output, json_output] ) clear_btn.click( clear_all_outputs, outputs=[image_input, raw_output_stream, markdown_output, json_output] ) return demo if __name__ == "__main__": demo = create_gradio_interface() demo.queue(max_size=40).launch(share=True, mcp_server=True, ssr_mode=False, show_error=True)