import spaces import json import math import os import traceback from io import BytesIO from typing import Any, Dict, List, Optional, Tuple import re import time from threading import Thread import uuid import gradio as gr import requests import torch from PIL import Image from transformers import ( Qwen2_5_VLForConditionalGeneration, AutoProcessor, TextIteratorStreamer, ) from reportlab.lib.pagesizes import A4 from reportlab.lib.styles import getSampleStyleSheet from reportlab.platypus import SimpleDocTemplate, Image as RLImage, Paragraph, Spacer from reportlab.lib.units import inch # --- Constants and Model Setup --- MAX_INPUT_TOKEN_LENGTH = 4096 # Note: The following line correctly falls back to CPU if CUDA is not available. device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("CUDA_VISIBLE_DEVICES=", os.environ.get("CUDA_VISIBLE_DEVICES")) print("torch.__version__ =", torch.__version__) print("torch.version.cuda =", torch.version.cuda) print("cuda available:", torch.cuda.is_available()) print("cuda device count:", torch.cuda.device_count()) if torch.cuda.is_available(): print("current device:", torch.cuda.current_device()) print("device name:", torch.cuda.get_device_name(torch.cuda.current_device())) print("Using device:", device) # --- Prompts for Different Tasks --- ocr_prompt = "Perform precise OCR on the image. Extract all text content, maintaining the original structure, paragraphs, and tables as formatted markdown." # --- Model Loading --- MODEL_ID_M = "prithivMLmods/Camel-Doc-OCR-080125" processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True) model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_ID_M, trust_remote_code=True, torch_dtype=torch.float16 ).to(device).eval() MODEL_ID_T = "prithivMLmods/Megalodon-OCR-Sync-0713" processor_t = AutoProcessor.from_pretrained(MODEL_ID_T, trust_remote_code=True) model_t = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_ID_T, trust_remote_code=True, torch_dtype=torch.float16 ).to(device).eval() MODEL_ID_C = "nanonets/Nanonets-OCR-s" processor_c = AutoProcessor.from_pretrained(MODEL_ID_C, trust_remote_code=True) model_c = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_ID_C, trust_remote_code=True, torch_dtype=torch.float16 ).to(device).eval() MODEL_ID_G = "echo840/MonkeyOCR" SUBFOLDER = "Recognition" processor_g = AutoProcessor.from_pretrained( MODEL_ID_G, trust_remote_code=True, subfolder=SUBFOLDER ) model_g = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_ID_G, trust_remote_code=True, subfolder=SUBFOLDER, torch_dtype=torch.float16 ).to(device).eval() MODEL_ID_I = "allenai/olmOCR-7B-0725" processor_i = AutoProcessor.from_pretrained(MODEL_ID_I, trust_remote_code=True) model_i = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_ID_I, trust_remote_code=True, torch_dtype=torch.float16 ).to(device).eval() # --- PDF Generation Utility Function --- def generate_pdf(image: Image.Image, text_content: str, font_size: int, line_spacing: float, alignment: str, image_size: str) -> str: """ Generates a PDF document with the input image and extracted text. """ if image is None or not text_content: raise gr.Error("Cannot generate PDF. Image or text content is missing.") filename = f"/tmp/output_{uuid.uuid4()}.pdf" doc = SimpleDocTemplate( filename, pagesize=A4, rightMargin=inch, leftMargin=inch, topMargin=inch, bottomMargin=inch ) styles = getSampleStyleSheet() style_normal = styles["Normal"] style_normal.fontSize = int(font_size) style_normal.leading = int(font_size) * line_spacing style_normal.alignment = { "Left": 0, "Center": 1, "Right": 2, "Justified": 4 }[alignment] story = [] # Handle Image # Convert PIL image to a format reportlab can use without saving to disk img_buffer = BytesIO() image.save(img_buffer, format='PNG') img_buffer.seek(0) # Image size settings page_width, _ = A4 available_width = page_width - 2 * inch image_widths = { "Small": available_width * 0.3, "Medium": available_width * 0.6, "Large": available_width * 0.9, } img = RLImage(img_buffer, width=image_widths[image_size], height=image.height * (image_widths[image_size]/image.width)) story.append(img) story.append(Spacer(1, 12)) # Handle Text - Replace markdown with spaces for PDF # A simple replacement for basic markdown, for more complex cases a proper parser would be needed cleaned_text = text_content.replace("# ", "").replace("## ", "").replace("*", "") text_paragraphs = cleaned_text.split('\n') for para in text_paragraphs: if para.strip(): story.append(Paragraph(para, style_normal)) doc.build(story) return filename # --- Core Application Logic --- @spaces.GPU def process_document_stream(model_name: str, image: Image.Image, max_new_tokens: int): """ Main generator function that handles OCR tasks. """ if image is None: yield "Please upload an image.", "Please upload an image." return # 1. Set prompt for OCR text_prompt = ocr_prompt # 2. Select model and processor if model_name == "Camel-Doc-OCR-080125": processor, model = processor_m, model_m elif model_name == "Megalodon-OCR-Sync-0713": processor, model = processor_t, model_t elif model_name == "Nanonets-OCR-s": processor, model = processor_c, model_c elif model_name == "MonkeyOCR-Recognition": processor, model = processor_g, model_g elif model_name == "olmOCR-7B-0725": processor, model = processor_i, model_i else: yield "Invalid model selected.", "Invalid model selected." return # 3. Prepare model inputs and streamer messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": text_prompt}]}] prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = processor(text=[prompt_full], images=[image], return_tensors="pt", padding=True, truncation=True, max_length=MAX_INPUT_TOKEN_LENGTH).to(device) streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens} thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() # 4. Stream raw output to the UI in real-time buffer = "" for new_text in streamer: buffer += new_text buffer = buffer.replace("<|im_end|>", "") time.sleep(0.01) yield buffer , "β³ Processing..." # 5. Yield the final result for both raw and formatted outputs yield buffer, buffer # --- Gradio UI Definition --- def create_gradio_interface(): """Builds and returns the Gradio web interface.""" css = """ .main-container { max-width: 1400px; margin: 0 auto; } .process-button { border: none !important; color: white !important; font-weight: bold !important; background-color: blue !important;} .process-button:hover { background-color: darkblue !important; transform: translateY(-2px) !important; box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; } """ with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo: gr.HTML("""
Advanced Vision-Language Model for Image Content and Layout Extraction