import spaces import json import math import os import traceback from io import BytesIO from typing import Any, Dict, List, Optional, Tuple import re import time from threading import Thread import gradio as gr import requests import torch from PIL import Image from transformers import ( Qwen2_5_VLForConditionalGeneration, AutoProcessor, TextIteratorStreamer, ) from qwen_vl_utils import process_vision_info # Constants MIN_PIXELS = 3136 MAX_PIXELS = 11289600 IMAGE_FACTOR = 28 MAX_INPUT_TOKEN_LENGTH = 2048 device = "cuda" if torch.cuda.is_available() else "cpu" # Prompts prompt = """Please output the layout information from the image, including each layout element's bbox, its category, and the corresponding text content within the bbox. 1. Bbox format: [x1, y1, x2, y2] 2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title']. 3. Text Extraction & Formatting Rules: - Picture: For the 'Picture' category, the text field should be omitted. - Formula: Format its text as LaTeX. - Table: Format its text as HTML. - All Others (Text, Title, etc.): Format their text as Markdown. 4. Constraints: - The output text must be the original text from the image, with no translation. - All layout elements must be sorted according to human reading order. 5. Final Output: The entire output must be a single JSON object. """ # Load models MODEL_ID_M = "prithivMLmods/Camel-Doc-OCR-062825" processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True) model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_ID_M, trust_remote_code=True, torch_dtype=torch.float16 ).to(device).eval() MODEL_ID_T = "prithivMLmods/Megalodon-OCR-Sync-0713" processor_t = AutoProcessor.from_pretrained(MODEL_ID_T, trust_remote_code=True) model_t = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_ID_T, trust_remote_code=True, torch_dtype=torch.float16 ).to(device).eval() MODEL_ID_C = "nanonets/Nanonets-OCR-s" processor_c = AutoProcessor.from_pretrained(MODEL_ID_C, trust_remote_code=True) model_c = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_ID_C, trust_remote_code=True, torch_dtype=torch.float16 ).to(device).eval() MODEL_ID_G = "echo840/MonkeyOCR" SUBFOLDER = "Recognition" processor_g = AutoProcessor.from_pretrained( MODEL_ID_G, trust_remote_code=True, subfolder=SUBFOLDER ) model_g = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_ID_G, trust_remote_code=True, subfolder=SUBFOLDER, torch_dtype=torch.float16 ).to(device).eval() # Utility functions def round_by_factor(number: int, factor: int) -> int: return round(number / factor) * factor def smart_resize( height: int, width: int, factor: int = 28, min_pixels: int = 3136, max_pixels: int = 11289600, ): if max(height, width) / min(height, width) > 200: raise ValueError(f"Aspect ratio too extreme: {max(height, width) / min(height, width)}") h_bar = max(factor, round_by_factor(height, factor)) w_bar = max(factor, round_by_factor(width, factor)) if h_bar * w_bar > max_pixels: beta = math.sqrt((height * width) / max_pixels) h_bar = round_by_factor(height / beta, factor) w_bar = round_by_factor(width / beta, factor) elif h_bar * w_bar < min_pixels: beta = math.sqrt(min_pixels / (height * width)) h_bar = round_by_factor(height * beta, factor) w_bar = round_by_factor(width * beta, factor) return h_bar, w_bar def fetch_image(image_input, min_pixels: int = None, max_pixels: int = None): if isinstance(image_input, str): if image_input.startswith(("http://", "https://")): response = requests.get(image_input) image = Image.open(BytesIO(response.content)).convert('RGB') else: image = Image.open(image_input).convert('RGB') elif isinstance(image_input, Image.Image): image = image_input.convert('RGB') else: raise ValueError(f"Invalid image input type: {type(image_input)}") if min_pixels or max_pixels: min_pixels = min_pixels or MIN_PIXELS max_pixels = max_pixels or MAX_PIXELS height, width = smart_resize(image.height, image.width, factor=IMAGE_FACTOR, min_pixels=min_pixels, max_pixels=max_pixels) image = image.resize((width, height), Image.LANCZOS) return image def is_arabic_text(text: str) -> bool: if not text: return False header_pattern = r'^#{1,6}\s+(.+)$' paragraph_pattern = r'^(?!#{1,6}\s|!\[|```|\||\s*[-*+]\s|\s*\d+\.\s)(.+)$' content_text = [] for line in text.split('\n'): line = line.strip() if not line: continue header_match = re.match(header_pattern, line, re.MULTILINE) if header_match: content_text.append(header_match.group(1)) continue if re.match(paragraph_pattern, line, re.MULTILINE): content_text.append(line) if not content_text: return False combined_text = ' '.join(content_text) arabic_chars = 0 total_chars = 0 for char in combined_text: if char.isalpha(): total_chars += 1 if ('\u0600' <= char <= '\u06FF') or ('\u0750' <= char <= '\u077F') or ('\u08A0' <= char <= '\u08FF'): arabic_chars += 1 return total_chars > 0 and (arabic_chars / total_chars) > 0.5 def layoutjson2md(image: Image.Image, layout_data: List[Dict], text_key: str = 'text') -> str: import base64 from io import BytesIO markdown_lines = [] try: sorted_items = sorted(layout_data, key=lambda x: (x.get('bbox', [0, 0, 0, 0])[1], x.get('bbox', [0, 0, 0, 0])[0])) for item in sorted_items: category = item.get('category', '') text = item.get(text_key, '') bbox = item.get('bbox', []) if category == 'Picture': if bbox and len(bbox) == 4: try: x1, y1, x2, y2 = bbox x1, y1 = max(0, int(x1)), max(0, int(y1)) x2, y2 = min(image.width, int(x2)), min(image.height, int(y2)) if x2 > x1 and y2 > y1: cropped_img = image.crop((x1, y1, x2, y2)) buffer = BytesIO() cropped_img.save(buffer, format='PNG') img_data = base64.b64encode(buffer.getvalue()).decode() markdown_lines.append(f"\n") else: markdown_lines.append("\n") except Exception as e: print(f"Error processing image region: {e}") markdown_lines.append("\n") else: markdown_lines.append("\n") elif not text: continue elif category == 'Title': markdown_lines.append(f"# {text}\n") elif category == 'Section-header': markdown_lines.append(f"## {text}\n") elif category == 'Text': markdown_lines.append(f"{text}\n") elif category == 'List-item': markdown_lines.append(f"- {text}\n") elif category == 'Table': if text.strip().startswith('<'): markdown_lines.append(f"{text}\n") else: markdown_lines.append(f"**Table:** {text}\n") elif category == 'Formula': if text.strip().startswith('$') or '\\' in text: markdown_lines.append(f"$$\n{text}\n$$\n") else: markdown_lines.append(f"**Formula:** {text}\n") elif category == 'Caption': markdown_lines.append(f"*{text}*\n") elif category == 'Footnote': markdown_lines.append(f"^{text}^\n") elif category in ['Page-header', 'Page-footer']: continue else: markdown_lines.append(f"{text}\n") markdown_lines.append("") except Exception as e: print(f"Error converting to markdown: {e}") return str(layout_data) return "\n".join(markdown_lines) @spaces.GPU def inference(model_name: str, image: Image.Image, text: str, max_new_tokens: int = 1024) -> str: try: if model_name == "Camel-Doc-OCR-062825": processor = processor_m model = model_m elif model_name == "Megalodon-OCR-Sync-0713": processor = processor_t model = model_t elif model_name == "Nanonets-OCR-s": processor = processor_c model = model_c elif model_name == "MonkeyOCR-Recognition": processor = processor_g model = model_g else: raise ValueError(f"Invalid model selected: {model_name}") if image is None: yield "Please upload an image.", "Please upload an image." return messages = [{ "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": text}, ] }] prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = processor( text=[prompt_full], images=[image], return_tensors="pt", padding=True, truncation=False, max_length=MAX_INPUT_TOKEN_LENGTH ).to(device) streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens} thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() buffer = "" for new_text in streamer: buffer += new_text buffer = buffer.replace("<|im_end|>", "") time.sleep(0.01) yield buffer, buffer except Exception as e: print(f"Error during inference: {e}") traceback.print_exc() yield f"Error during inference: {str(e)}", f"Error during inference: {str(e)}" def process_image( model_name: str, image: Image.Image, min_pixels: Optional[int] = None, max_pixels: Optional[int] = None, max_new_tokens: int = 1024 ) -> Dict[str, Any]: try: if min_pixels or max_pixels: image = fetch_image(image, min_pixels=min_pixels, max_pixels=max_pixels) result = { 'original_image': image, 'raw_output': "", 'layout_result': None, 'markdown_content': None } buffer = "" for raw_output, _ in inference(model_name, image, prompt, max_new_tokens): buffer = raw_output result['raw_output'] = buffer yield result try: json_match = re.search(r'```json\s*([\s\S]+?)\s*```', buffer) json_str = json_match.group(1) if json_match else buffer layout_data = json.loads(json_str) result['layout_result'] = layout_data try: markdown_content = layoutjson2md(image, layout_data, text_key='text') result['markdown_content'] = markdown_content except Exception as e: print(f"Error generating markdown: {e}") result['markdown_content'] = buffer except json.JSONDecodeError: print("Failed to parse JSON output, using raw output") result['markdown_content'] = buffer yield result except Exception as e: print(f"Error processing image: {e}") traceback.print_exc() result = { 'original_image': image, 'raw_output': f"Error processing image: {str(e)}", 'layout_result': None, 'markdown_content': f"Error processing image: {str(e)}" } yield result def load_file_for_preview(file_path: str) -> Tuple[Optional[Image.Image], str]: if not file_path or not os.path.exists(file_path): return None, "No file selected" file_ext = os.path.splitext(file_path)[1].lower() try: if file_ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']: image = Image.open(file_path).convert('RGB') return image, "Image loaded" else: return None, f"Unsupported file format: {file_ext}" except Exception as e: print(f"Error loading file: {e}") return None, f"Error loading file: {str(e)}" def create_gradio_interface(): css = """ .main-container { max-width: 1400px; margin: 0 auto; } .header-text { text-align: center; color: #2c3e50; margin-bottom: 20px; } .process-button { border: none !important; color: white !important; font-weight: bold !important; background-color: blue !important;} .process-button:hover { background-color: darkblue !important; transform: translateY(-2px) !important; box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; } .info-box { border: 1px solid #dee2e6; border-radius: 8px; padding: 15px; margin: 10px 0; } .page-info { text-align: center; padding: 8px 16px; border-radius: 20px; font-weight: bold; margin: 10px 0; } .model-status { padding: 10px; border-radius: 8px; margin: 10px 0; text-align: center; font-weight: bold; } .status-ready { background: #d1edff; color: #0c5460; border: 1px solid #b8daff; } """ with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo: gr.HTML("""
Advanced vision-language model for image to markdown document processing