import spaces import json import math import os import traceback from io import BytesIO from typing import Any, Dict, List, Optional, Tuple import re import fitz # PyMuPDF import gradio as gr import requests import torch from huggingface_hub import snapshot_download from PIL import Image, ImageDraw, ImageFont from qwen_vl_utils import process_vision_info from transformers import AutoModelForCausalLM, AutoProcessor, Qwen2_5_VLForConditionalGeneration # Constants MIN_PIXELS = 3136 MAX_PIXELS = 11289600 IMAGE_FACTOR = 28 # Prompts prompt = """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox. 1. Bbox format: [x1, y1, x2, y2] 2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title']. 3. Text Extraction & Formatting Rules: - Picture: For the 'Picture' category, the text field should be omitted. - Formula: Format its text as LaTeX. - Table: Format its text as HTML. - All Others (Text, Title, etc.): Format their text as Markdown. 4. Constraints: - The output text must be the original text from the image, with no translation. - All layout elements must be sorted according to human reading order. 5. Final Output: The entire output must be a single JSON object. """ # Utility Functions def round_by_factor(number: int, factor: int) -> int: """Returns the closest integer to 'number' that is divisible by 'factor'.""" return round(number / factor) * factor def smart_resize( height: int, width: int, factor: int = 28, min_pixels: int = 3136, max_pixels: int = 11289600, ): """Rescales the image so that dimensions are divisible by 'factor', within pixel range, maintaining aspect ratio.""" if max(height, width) / min(height, width) > 200: raise ValueError(f"Aspect ratio must be < 200, got {max(height, width) / min(height, width)}") h_bar = max(factor, round_by_factor(height, factor)) w_bar = max(factor, round_by_factor(width, factor)) if h_bar * w_bar > max_pixels: beta = math.sqrt((height * width) / max_pixels) h_bar = round_by_factor(height / beta, factor) w_bar = round_by_factor(width / beta, factor) elif h_bar * w_bar < min_pixels: beta = math.sqrt(min_pixels / (height * width)) h_bar = round_by_factor(height * beta, factor) w_bar = round_by_factor(width * beta, factor) return h_bar, w_bar def fetch_image(image_input, min_pixels: int = None, max_pixels: int = None): """Fetch and process an image.""" if isinstance(image_input, str): if image_input.startswith(("http://", "https://")): response = requests.get(image_input) image = Image.open(BytesIO(response.content)).convert('RGB') else: image = Image.open(image_input).convert('RGB') elif isinstance(image_input, Image.Image): image = image_input.convert('RGB') else: raise ValueError(f"Invalid image input type: {type(image_input)}") if min_pixels is not None or max_pixels is not None: min_pixels = min_pixels or MIN_PIXELS max_pixels = max_pixels or MAX_PIXELS height, width = smart_resize( image.height, image.width, factor=IMAGE_FACTOR, min_pixels=min_pixels, max_pixels=max_pixels ) image = image.resize((width, height), Image.LANCZOS) return image def load_images_from_pdf(pdf_path: str) -> List[Image.Image]: """Load images from PDF file.""" images = [] try: pdf_document = fitz.open(pdf_path) for page_num in range(len(pdf_document)): page = pdf_document.load_page(page_num) mat = fitz.Matrix(2.0, 2.0) # Increase resolution pix = page.get_pixmap(matrix=mat) img_data = pix.tobytes("ppm") image = Image.open(BytesIO(img_data)).convert('RGB') images.append(image) pdf_document.close() except Exception as e: print(f"Error loading PDF: {e}") return [] return images def draw_layout_on_image(image: Image.Image, layout_data: List[Dict]) -> Image.Image: """Draw layout bounding boxes on image.""" img_copy = image.copy() draw = ImageDraw.Draw(img_copy) colors = { 'Caption': '#FF6B6B', 'Footnote': '#4ECDC4', 'Formula': '#45B7D1', 'List-item': '#96CEB4', 'Page-footer': '#FFEAA7', 'Page-header': '#DDA0DD', 'Picture': '#FFD93D', 'Section-header': '#6C5CE7', 'Table': '#FD79A8', 'Text': '#74B9FF', 'Title': '#E17055' } try: font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 12) or ImageFont.load_default() for item in layout_data: if 'bbox' in item and 'category' in item: bbox = item['bbox'] category = item['category'] color = colors.get(category, '#000000') draw.rectangle(bbox, outline=color, width=2) label = category label_bbox = draw.textbbox((0, 0), label, font=font) label_width, label_height = label_bbox[2] - label_bbox[0], label_bbox[3] - label_bbox[1] label_x, label_y = bbox[0], max(0, bbox[1] - label_height - 2) draw.rectangle([label_x, label_y, label_x + label_width + 4, label_y + label_height + 2], fill=color) draw.text((label_x + 2, label_y + 1), label, fill='white', font=font) except Exception as e: print(f"Error drawing layout: {e}") return img_copy def layoutjson2md(image: Image.Image, layout_data: List[Dict], text_key: str = 'text') -> str: """Convert layout JSON to markdown format.""" import base64 markdown_lines = [] try: sorted_items = sorted(layout_data, key=lambda x: (x.get('bbox', [0, 0, 0, 0])[1], x.get('bbox', [0, 0, 0, 0])[0])) for item in sorted_items: category = item.get('category', '') text = item.get(text_key, '') bbox = item.get('bbox', []) if category == 'Picture' and bbox and len(bbox) == 4: try: x1, y1, x2, y2 = [max(0, int(x1)), max(0, int(y1)), min(image.width, int(x2)), min(image.height, int(y2))] if x2 > x1 and y2 > y1: cropped_img = image.crop((x1, y1, x2, y2)) buffer = BytesIO() cropped_img.save(buffer, format='PNG') img_data = base64.b64encode(buffer.getvalue()).decode() markdown_lines.append(f"\n") else: markdown_lines.append("\n") except Exception as e: print(f"Error processing image region: {e}") markdown_lines.append("\n") elif not text: continue elif category == 'Title': markdown_lines.append(f"# {text}\n") elif category == 'Section-header': markdown_lines.append(f"## {text}\n") elif category == 'Text': markdown_lines.append(f"{text}\n") elif category == 'List-item': markdown_lines.append(f"- {text}\n") elif category == 'Table': markdown_lines.append(f"{text}\n" if text.strip().startswith('<') else f"**Table:** {text}\n") elif category == 'Formula': markdown_lines.append(f"$$\n{text}\n$$\n" if text.strip().startswith('$') or '\\' in text else f"**Formula:** {text}\n") elif category == 'Caption': markdown_lines.append(f"*{text}*\n") elif category == 'Footnote': markdown_lines.append(f"^{text}^\n") elif category in ['Page-header', 'Page-footer']: continue else: markdown_lines.append(f"{text}\n") markdown_lines.append("") except Exception as e: print(f"Error converting to markdown: {e}") return str(layout_data) return "\n".join(markdown_lines) # Load Models device = "cuda" if torch.cuda.is_available() else "cpu" # Load dot.ocr model_id = "rednote-hilab/dots.ocr" model_path = "./models/dots-ocr-local" snapshot_download(repo_id=model_id, local_dir=model_path, local_dir_use_symlinks=False) model = AutoModelForCausalLM.from_pretrained( model_path, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True ) processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) # Load Camel-Doc-OCR-062825 MODEL_ID_M = "prithivMLmods/Camel-Doc-OCR-062825" processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True) model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_ID_M, trust_remote_code=True, torch_dtype=torch.float16 ).to(device).eval() # Load Megalodon-OCR-Sync-0713 MODEL_ID_T = "prithivMLmods/Megalodon-OCR-Sync-0713" processor_t = AutoProcessor.from_pretrained(MODEL_ID_T, trust_remote_code=True) model_t = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_ID_T, trust_remote_code=True, torch_dtype=torch.float16 ).to(device).eval() # Model Dictionary model_dict = { "dot.ocr": {"model": model, "processor": processor, "process_layout": True}, "Camel-Doc-OCR-062825": {"model": model_m, "processor": processor_m, "process_layout": False}, "Megalodon-OCR-Sync-0713": {"model": model_t, "processor": processor_t, "process_layout": False}, } # Global State pdf_cache = {"images": [], "current_page": 0, "total_pages": 0, "file_type": None, "is_parsed": False, "results": []} @spaces.GPU() def inference(model, processor, image: Image.Image, prompt: str, max_new_tokens: int = 24000) -> str: """Run inference on an image with the given prompt using the specified model and processor.""" try: messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt}]}] text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) image_inputs, video_inputs = process_vision_info(messages) inputs = processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt").to(device) with torch.no_grad(): generated_ids = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False, temperature=0.1) generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)] output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False) return output_text[0] if output_text else "" except Exception as e: print(f"Error during inference: {e}") traceback.print_exc() return f"Error during inference: {str(e)}" def process_image( image: Image.Image, model, processor, process_layout: bool, min_pixels: Optional[int] = None, max_pixels: Optional[int] = None ) -> Dict[str, Any]: """Process a single image with the specified model and processor.""" try: if min_pixels is not None or max_pixels is not None: image = fetch_image(image, min_pixels=min_pixels, max_pixels=max_pixels) raw_output = inference(model, processor, image, prompt) result = {'original_image': image, 'raw_output': raw_output, 'processed_image': image, 'layout_result': None, 'markdown_content': raw_output} if process_layout: try: layout_data = json.loads(raw_output) result['layout_result'] = layout_data result['processed_image'] = draw_layout_on_image(image, layout_data) result['markdown_content'] = layoutjson2md(image, layout_data, text_key='text') except json.JSONDecodeError: print("Failed to parse JSON output, using raw output") except Exception as e: print(f"Error processing layout: {e}") return result except Exception as e: print(f"Error processing image: {e}") traceback.print_exc() return {'original_image': image, 'raw_output': str(e), 'processed_image': image, 'layout_result': None, 'markdown_content': str(e)} def load_file_for_preview(file_path: str) -> Tuple[Optional[Image.Image], str]: """Load file for preview (supports PDF and images).""" global pdf_cache if not file_path or not os.path.exists(file_path): return None, "No file selected" file_ext = os.path.splitext(file_path)[1].lower() try: if file_ext == '.pdf': images = load_images_from_pdf(file_path) if not images: return None, "Failed to load PDF" pdf_cache.update({"images": images, "current_page": 0, "total_pages": len(images), "file_type": "pdf", "is_parsed": False, "results": []}) return images[0], f"Page 1 / {len(images)}" elif file_ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']: image = Image.open(file_path).convert('RGB') pdf_cache.update({"images": [image], "current_page": 0, "total_pages": 1, "file_type": "image", "is_parsed": False, "results": []}) return image, "Page 1 / 1" else: return None, f"Unsupported file format: {file_ext}" except Exception as e: print(f"Error loading file: {e}") return None, f"Error loading file: {str(e)}" def turn_page(direction: str) -> Tuple[Optional[Image.Image], str, Any, Optional[Image.Image], Optional[Dict]]: """Navigate through PDF pages and update outputs.""" global pdf_cache if not pdf_cache["images"]: return None, '
Advanced vision-language model for image/PDF to markdown document processing