import spaces import json import math import os import traceback from io import BytesIO from typing import Any, Dict, List, Optional, Tuple import re import time from threading import Thread import gradio as gr import requests import torch from PIL import Image from transformers import ( Qwen2_5_VLForConditionalGeneration, AutoProcessor, TextIteratorStreamer, ) # --- Constants and Model Setup --- MAX_INPUT_TOKEN_LENGTH = 4096 device = "cuda" if torch.cuda.is_available() else "cpu" # The detailed prompt to instruct the model to generate structured JSON prompt = """Please output the layout information from the image, including each layout element's bbox, its category, and the corresponding text content within the bbox. 1. Bbox format: [x1, y1, x2, y2] 2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title']. 3. Text Extraction & Formatting Rules: - Picture: For the 'Picture' category, the text field should be omitted. - Formula: Format its text as LaTeX. - Table: For tables, provide the content in a structured format within the JSON. - All Others (Text, Title, etc.): Format their text as Markdown. 4. Constraints: - The output text must be the original text from the image, with no translation. - All layout elements must be sorted according to human reading order. 5. Final Output: The entire output must be a single JSON object wrapped in ```json ... ```. """ # Load models MODEL_ID_M = "prithivMLmods/Camel-Doc-OCR-062825" processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True) model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_ID_M, trust_remote_code=True, torch_dtype=torch.float16 ).to(device).eval() MODEL_ID_T = "prithivMLmods/Megalodon-OCR-Sync-0713" processor_t = AutoProcessor.from_pretrained(MODEL_ID_T, trust_remote_code=True) model_t = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_ID_T, trust_remote_code=True, torch_dtype=torch.float16 ).to(device).eval() MODEL_ID_C = "nanonets/Nanonets-OCR-s" processor_c = AutoProcessor.from_pretrained(MODEL_ID_C, trust_remote_code=True) model_c = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_ID_C, trust_remote_code=True, torch_dtype=torch.float16 ).to(device).eval() MODEL_ID_G = "echo840/MonkeyOCR" SUBFOLDER = "Recognition" processor_g = AutoProcessor.from_pretrained( MODEL_ID_G, trust_remote_code=True, subfolder=SUBFOLDER ) model_g = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_ID_G, trust_remote_code=True, subfolder=SUBFOLDER, torch_dtype=torch.float16 ).to(device).eval() # --- Utility Functions --- def layoutjson2md(layout_data: List[Dict]) -> str: """Converts the structured JSON layout data into formatted Markdown.""" markdown_lines = [] try: sorted_items = sorted(layout_data, key=lambda x: (x.get('bbox', [0, 0, 0, 0])[1], x.get('bbox', [0, 0, 0, 0])[0])) for item in sorted_items: category = item.get('category', '') text = item.get('text', '') if not text: continue if category == 'Title': markdown_lines.append(f"# {text}\n") elif category == 'Section-header': markdown_lines.append(f"## {text}\n") elif category == 'Table': # Check if the text is a dictionary representing a structured table if isinstance(text, dict) and 'header' in text and 'rows' in text: header = '| ' + ' | '.join(map(str, text['header'])) + ' |' separator = '| ' + ' | '.join(['---'] * len(text['header'])) + ' |' rows = ['| ' + ' | '.join(map(str, row)) + ' |' for row in text['rows']] markdown_lines.append(header) markdown_lines.append(separator) markdown_lines.extend(rows) markdown_lines.append("\n") else: # Fallback for unstructured table text markdown_lines.append(f"{text}\n") else: markdown_lines.append(f"{text}\n") except Exception as e: print(f"Error converting to markdown: {e}") return "### Error converting JSON to Markdown." return "\n".join(markdown_lines) # --- Core Application Logic --- @spaces.GPU def process_document_stream(model_name: str, image: Image.Image, text_prompt: str, max_new_tokens: int): """ Main generator function that streams raw model output and then processes it into formatted Markdown and structured JSON for the UI. """ if image is None: yield "Please upload an image.", "Please upload an image.", None return # Select the model and processor if model_name == "Camel-Doc-OCR-062825": processor, model = processor_m, model_m elif model_name == "Megalodon-OCR-Sync-0713": processor, model = processor_t, model_t elif model_name == "Nanonets-OCR-s": processor, model = processor_c, model_c elif model_name == "MonkeyOCR-Recognition": processor, model = processor_g, model_g else: yield "Invalid model selected.", "Invalid model selected.", None return # Prepare model inputs messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": text_prompt}]}] prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = processor(text=[prompt_full], images=[image], return_tensors="pt", padding=True, truncation=True, max_length=MAX_INPUT_TOKEN_LENGTH).to(device) streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens} # Start generation in a separate thread thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() # Stream raw output to the UI buffer = "" for new_text in streamer: buffer += new_text buffer = buffer.replace("<|im_end|>", "") time.sleep(0.01) # Yield the raw stream and placeholders for the final results yield buffer, "⏳ Formatting Markdown...", {"status": "processing"} # After streaming is complete, process the final buffer try: # Extract the JSON object from the buffer json_match = re.search(r'```json\s*([\s\S]+?)\s*```', buffer) if not json_match: raise json.JSONDecodeError("JSON object not found in the model's output.", buffer, 0) json_str = json_match.group(1) layout_data = json.loads(json_str) # Convert the parsed JSON to formatted markdown markdown_content = layoutjson2md(layout_data) # Yield the final, complete results yield buffer, markdown_content, layout_data except json.JSONDecodeError as e: print(f"JSON parsing failed: {e}") error_md = f"❌ **Error:** Failed to parse JSON from the model's output.\n\nSee the raw output stream for details." error_json = {"error": "JSONDecodeError", "details": str(e), "raw_output": buffer} yield buffer, error_md, error_json except Exception as e: print(f"An unexpected error occurred: {e}") yield buffer, f"❌ An unexpected error occurred: {e}", None # --- Gradio UI Definition --- def create_gradio_interface(): """Builds and returns the Gradio web interface.""" css = """ .main-container { max-width: 1400px; margin: 0 auto; } .process-button { border: none !important; color: white !important; font-weight: bold !important; background-color: blue !important;} .process-button:hover { background-color: darkblue !important; transform: translateY(-2px) !important; box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; } """ with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo: gr.HTML("""
Advanced Vision-Language Model for Image Layout Analysis