import spaces import json import math import os import traceback from io import BytesIO from typing import Any, Dict, List, Optional, Tuple import re import time from threading import Thread import gradio as gr import requests import torch from PIL import Image from transformers import ( Qwen2_5_VLForConditionalGeneration, AutoProcessor, TextIteratorStreamer, ) from qwen_vl_utils import process_vision_info # Constants MIN_PIXELS = 3136 MAX_PIXELS = 11289600 IMAGE_FACTOR = 28 MAX_INPUT_TOKEN_LENGTH = 2048 device = "cuda" if torch.cuda.is_available() else "cpu" # Prompts prompt = """Please output the layout information from the image, including each layout element's bbox, its category, and the corresponding text content within the bbox. 1. Bbox format: [x1, y1, x2, y2] 2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title']. 3. Text Extraction & Formatting Rules: - Picture: For the 'Picture' category, the text field should be omitted. - Formula: Format its text as LaTeX. - Table: Format its text as HTML. - All Others (Text, Title, etc.): Format their text as Markdown. 4. Constraints: - The output text must be the original text from the image, with no translation. - All layout elements must be sorted according to human reading order. 5. Final Output: The entire output must be a single JSON object. """ # Load models MODEL_ID_M = "prithivMLmods/Camel-Doc-OCR-062825" processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True) model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_ID_M, trust_remote_code=True, torch_dtype=torch.float16 ).to(device).eval() MODEL_ID_T = "prithivMLmods/Megalodon-OCR-Sync-0713" processor_t = AutoProcessor.from_pretrained(MODEL_ID_T, trust_remote_code=True) model_t = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_ID_T, trust_remote_code=True, torch_dtype=torch.float16 ).to(device).eval() MODEL_ID_C = "nanonets/Nanonets-OCR-s" processor_c = AutoProcessor.from_pretrained(MODEL_ID_C, trust_remote_code=True) model_c = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_ID_C, trust_remote_code=True, torch_dtype=torch.float16 ).to(device).eval() MODEL_ID_G = "echo840/MonkeyOCR" SUBFOLDER = "Recognition" processor_g = AutoProcessor.from_pretrained( MODEL_ID_G, trust_remote_code=True, subfolder=SUBFOLDER ) model_g = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_ID_G, trust_remote_code=True, subfolder=SUBFOLDER, torch_dtype=torch.float16 ).to(device).eval() # Utility functions def round_by_factor(number: int, factor: int) -> int: return round(number / factor) * factor def smart_resize( height: int, width: int, factor: int = 28, min_pixels: int = 3136, max_pixels: int = 11289600, ): if max(height, width) / min(height, width) > 200: raise ValueError(f"Aspect ratio too extreme: {max(height, width) / min(height, width)}") h_bar = max(factor, round_by_factor(height, factor)) w_bar = max(factor, round_by_factor(width, factor)) if h_bar * w_bar > max_pixels: beta = math.sqrt((height * width) / max_pixels) h_bar = round_by_factor(height / beta, factor) w_bar = round_by_factor(width / beta, factor) elif h_bar * w_bar < min_pixels: beta = math.sqrt(min_pixels / (height * width)) h_bar = round_by_factor(height * beta, factor) w_bar = round_by_factor(width * beta, factor) return h_bar, w_bar def fetch_image(image_input, min_pixels: int = None, max_pixels: int = None): if isinstance(image_input, str): if image_input.startswith(("http://", "https://")): response = requests.get(image_input) image = Image.open(BytesIO(response.content)).convert('RGB') else: image = Image.open(image_input).convert('RGB') elif isinstance(image_input, Image.Image): image = image_input.convert('RGB') else: raise ValueError(f"Invalid image input type: {type(image_input)}") if min_pixels or max_pixels: min_pixels = min_pixels or MIN_PIXELS max_pixels = max_pixels or MAX_PIXELS height, width = smart_resize(image.height, image.width, factor=IMAGE_FACTOR, min_pixels=min_pixels, max_pixels=max_pixels) image = image.resize((width, height), Image.LANCZOS) return image def is_arabic_text(text: str) -> bool: if not text: return False header_pattern = r'^#{1,6}\s+(.+)$' paragraph_pattern = r'^(?!#{1,6}\s|!\[|```|\||\s*[-*+]\s|\s*\d+\.\s)(.+)$' content_text = [] for line in text.split('\n'): line = line.strip() if not line: continue header_match = re.match(header_pattern, line, re.MULTILINE) if header_match: content_text.append(header_match.group(1)) continue if re.match(paragraph_pattern, line, re.MULTILINE): content_text.append(line) if not content_text: return False combined_text = ' '.join(content_text) arabic_chars = 0 total_chars = 0 for char in combined_text: if char.isalpha(): total_chars += 1 if ('\u0600' <= char <= '\u06FF') or ('\u0750' <= char <= '\u077F') or ('\u08A0' <= char <= '\u08FF'): arabic_chars += 1 return total_chars > 0 and (arabic_chars / total_chars) > 0.5 def layoutjson2md(image: Image.Image, layout_data: List[Dict], text_key: str = 'text') -> str: import base64 from io import BytesIO markdown_lines = [] try: sorted_items = sorted(layout_data, key=lambda x: (x.get('bbox', [0, 0, 0, 0])[1], x.get('bbox', [0, 0, 0, 0])[0])) for item in sorted_items: category = item.get('category', '') text = item.get(text_key, '') bbox = item.get('bbox', []) if category == 'Picture': if bbox and len(bbox) == 4: try: x1, y1, x2, y2 = bbox x1, y1 = max(0, int(x1)), max(0, int(y1)) x2, y2 = min(image.width, int(x2)), min(image.height, int(y2)) if x2 > x1 and y2 > y1: cropped_img = image.crop((x1, y1, x2, y2)) buffer = BytesIO() cropped_img.save(buffer, format='PNG') img_data = base64.b64encode(buffer.getvalue()).decode() markdown_lines.append(f"![Image](data:image/png;base64,{img_data})\n") else: markdown_lines.append("![Image](Image region detected)\n") except Exception as e: print(f"Error processing image region: {e}") markdown_lines.append("![Image](Image detected)\n") else: markdown_lines.append("![Image](Image detected)\n") elif not text: continue elif category == 'Title': markdown_lines.append(f"# {text}\n") elif category == 'Section-header': markdown_lines.append(f"## {text}\n") elif category == 'Text': markdown_lines.append(f"{text}\n") elif category == 'List-item': markdown_lines.append(f"- {text}\n") elif category == 'Table': if text.strip().startswith('<'): markdown_lines.append(f"{text}\n") else: markdown_lines.append(f"**Table:** {text}\n") elif category == 'Formula': if text.strip().startswith('$') or '\\' in text: markdown_lines.append(f"$$\n{text}\n$$\n") else: markdown_lines.append(f"**Formula:** {text}\n") elif category == 'Caption': markdown_lines.append(f"*{text}*\n") elif category == 'Footnote': markdown_lines.append(f"^{text}^\n") elif category in ['Page-header', 'Page-footer']: continue else: markdown_lines.append(f"{text}\n") markdown_lines.append("") except Exception as e: print(f"Error converting to markdown: {e}") return str(layout_data) return "\n".join(markdown_lines) @spaces.GPU def inference(model_name: str, image: Image.Image, text: str, max_new_tokens: int = 1024) -> str: try: if model_name == "Camel-Doc-OCR-062825": processor = processor_m model = model_m elif model_name == "Megalodon-OCR-Sync-0713": processor = processor_t model = model_t elif model_name == "Nanonets-OCR-s": processor = processor_c model = model_c elif model_name == "MonkeyOCR-Recognition": processor = processor_g model = model_g else: raise ValueError(f"Invalid model selected: {model_name}") if image is None: yield "Please upload an image.", "Please upload an image." return messages = [{ "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": text}, ] }] prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = processor( text=[prompt_full], images=[image], return_tensors="pt", padding=True, truncation=False, max_length=MAX_INPUT_TOKEN_LENGTH ).to(device) streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens} thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() buffer = "" for new_text in streamer: buffer += new_text buffer = buffer.replace("<|im_end|>", "") time.sleep(0.01) yield buffer, buffer except Exception as e: print(f"Error during inference: {e}") traceback.print_exc() yield f"Error during inference: {str(e)}", f"Error during inference: {str(e)}" def process_image( model_name: str, image: Image.Image, min_pixels: Optional[int] = None, max_pixels: Optional[int] = None, max_new_tokens: int = 1024 ) -> Dict[str, Any]: try: if min_pixels or max_pixels: image = fetch_image(image, min_pixels=min_pixels, max_pixels=max_pixels) result = { 'original_image': image, 'raw_output': "", 'layout_result': None, 'markdown_content': None } buffer = "" for raw_output, _ in inference(model_name, image, prompt, max_new_tokens): buffer = raw_output result['raw_output'] = buffer yield result try: json_match = re.search(r'```json\s*([\s\S]+?)\s*```', buffer) json_str = json_match.group(1) if json_match else buffer layout_data = json.loads(json_str) result['layout_result'] = layout_data try: markdown_content = layoutjson2md(image, layout_data, text_key='text') result['markdown_content'] = markdown_content except Exception as e: print(f"Error generating markdown: {e}") result['markdown_content'] = buffer except json.JSONDecodeError: print("Failed to parse JSON output, using raw output") result['markdown_content'] = buffer yield result except Exception as e: print(f"Error processing image: {e}") traceback.print_exc() result = { 'original_image': image, 'raw_output': f"Error processing image: {str(e)}", 'layout_result': None, 'markdown_content': f"Error processing image: {str(e)}" } yield result def load_file_for_preview(file_path: str) -> Tuple[Optional[Image.Image], str]: if not file_path or not os.path.exists(file_path): return None, "No file selected" file_ext = os.path.splitext(file_path)[1].lower() try: if file_ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']: image = Image.open(file_path).convert('RGB') return image, "Image loaded" else: return None, f"Unsupported file format: {file_ext}" except Exception as e: print(f"Error loading file: {e}") return None, f"Error loading file: {str(e)}" def create_gradio_interface(): css = """ .main-container { max-width: 1400px; margin: 0 auto; } .header-text { text-align: center; color: #2c3e50; margin-bottom: 20px; } .process-button { border: none !important; color: white !important; font-weight: bold !important; background-color: blue !important;} .process-button:hover { background-color: darkblue !important; transform: translateY(-2px) !important; box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; } .info-box { border: 1px solid #dee2e6; border-radius: 8px; padding: 15px; margin: 10px 0; } .page-info { text-align: center; padding: 8px 16px; border-radius: 20px; font-weight: bold; margin: 10px 0; } .model-status { padding: 10px; border-radius: 8px; margin: 10px 0; text-align: center; font-weight: bold; } .status-ready { background: #d1edff; color: #0c5460; border: 1px solid #b8daff; } """ with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo: gr.HTML("""

DotOCR Comparator

Advanced vision-language model for image to markdown document processing

""") with gr.Row(): with gr.Column(scale=1): model_choice = gr.Radio( choices=["Camel-Doc-OCR-062825", "MonkeyOCR-Recognition", "Nanonets-OCR-s", "Megalodon-OCR-Sync-0713"], label="Select Model", value="Camel-Doc-OCR-062825" ) file_input = gr.File( label="Upload Image", file_types=[".jpg", ".jpeg", ".png", ".bmp", ".tiff"], type="filepath" ) image_preview = gr.Image(label="Preview", type="pil", interactive=False, height=300) with gr.Accordion("Advanced Settings", open=False): max_new_tokens = gr.Slider(minimum=1000, maximum=32000, value=24000, step=1000, label="Max New Tokens") min_pixels = gr.Number(value=MIN_PIXELS, label="Min Pixels") max_pixels = gr.Number(value=MAX_PIXELS, label="Max Pixels") process_btn = gr.Button("🚀 Process Document", variant="primary", elem_classes=["process-button"], size="lg") clear_btn = gr.Button("🗑️ Clear All", variant="secondary") with gr.Column(scale=2): with gr.Tabs(): with gr.Tab("📝 Extracted Content"): output = gr.Textbox(label="Raw Output Stream", interactive=False, lines=2, show_copy_button=True) with gr.Accordion("(Result.md)", open=False): markdown_output = gr.Markdown(label="Formatted Result (Result.Md)") with gr.Tab("📋 Layout JSON"): json_output = gr.JSON(label="Layout Analysis Results", value=None) def process_document(model_name, file_path, max_tokens, min_pix, max_pix): try: if not file_path: return "Please upload an image.", "Please upload an image.", None image, status = load_file_for_preview(file_path) if image is None: return status, status, None for result in process_image(model_name, image, min_pixels=int(min_pix) if min_pix else None, max_pixels=int(max_pix) if max_pix else None, max_new_tokens=max_tokens): raw_output = result['raw_output'] markdown_content = result['markdown_content'] or raw_output if is_arabic_text(markdown_content): markdown_update = gr.update(value=markdown_content, rtl=True) else: markdown_update = markdown_content yield raw_output, markdown_update, result['layout_result'] except Exception as e: error_msg = f"Error processing document: {str(e)}" print(error_msg) traceback.print_exc() yield error_msg, error_msg, None def handle_file_upload(file_path): if not file_path: return None, "No file loaded" image, page_info = load_file_for_preview(file_path) return image, page_info def clear_all(): return None, None, "No file loaded", "", "Click 'Process Document' to see extracted content...", None file_input.change(handle_file_upload, inputs=[file_input], outputs=[image_preview, output]) process_btn.click( process_document, inputs=[model_choice, file_input, max_new_tokens, min_pixels, max_pixels], outputs=[output, markdown_output, json_output] ) clear_btn.click( clear_all, outputs=[file_input, image_preview, output, markdown_output, json_output] ) return demo if __name__ == "__main__": demo = create_gradio_interface() demo.queue(max_size=10).launch(server_name="0.0.0.0", server_port=7860, share=False, debug=True, show_error=True)