import spaces
import json
import math
import os
import traceback
from io import BytesIO
from typing import Any, Dict, List, Optional, Tuple
import re
import warnings
import fitz # PyMuPDF
import gradio as gr
import requests
import torch
from PIL import Image, ImageDraw, ImageFont
from transformers import AutoModelForCausalLM, AutoProcessor, VisionEncoderDecoderModel
from huggingface_hub import snapshot_download
from qwen_vl_utils import process_vision_info
# Suppress the FutureWarning for cleaner output (optional)
warnings.filterwarnings(
"ignore",
category=FutureWarning,
message="Both `num_logits_to_keep` and `logits_to_keep` are set"
)
# JavaScript for theme refresh
js_func = """
function refresh() {
const url = new URL(window.location);
if (url.searchParams.get('__theme') !== 'dark') {
url.searchParams.set('__theme', 'dark');
window.location.href = url.href;
}
}
"""
# Constants
MIN_PIXELS = 3136
MAX_PIXELS = 11289600
IMAGE_FACTOR = 28
# Prompt for dots.ocr
prompt = """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
1. Bbox format: [x1, y1, x2, y2]
2. Layout Categories: ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title']
3. Text Extraction & Formatting Rules:
- Picture: omit the text field
- Formula: format as LaTeX
- Table: format as HTML
- Others: format as Markdown
4. Constraints:
- Use original text, no translation
- Sort elements by human reading order
5. Final Output: Single JSON object
"""
# Model loading functions
def load_model(model_name):
if model_name == "dots.ocr":
model_id = "rednote-hilab/dots.ocr"
model_path = "./models/dots-ocr-local"
snapshot_download(
repo_id=model_id,
local_dir=model_path,
local_dir_use_symlinks=False,
)
model = AutoModelForCausalLM.from_pretrained(
model_path,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
)
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
elif model_name == "Dolphin":
model_id = "ByteDance/Dolphin"
processor = AutoProcessor.from_pretrained(model_id)
model = VisionEncoderDecoderModel.from_pretrained(model_id)
model.eval()
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model = model.half() # Use half precision
else:
raise ValueError(f"Unknown model: {model_name}")
return model, processor
# Inference functions
def inference_dots_ocr(model, processor, image, prompt, max_new_tokens):
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt}
]
}
]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(model.device)
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False # Temperature removed previously to fix another warning
)
generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
output_text = processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)
return output_text[0] if output_text else ""
def inference_dolphin(model, processor, image):
pixel_values = processor(image, return_tensors="pt").pixel_values.to(model.device).half()
generated_ids = model.generate(pixel_values)
generated_text = processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return generated_text
# Load models at startup
models = {
"dots.ocr": load_model("dots.ocr"),
"Dolphin": load_model("Dolphin")
}
# Global state for PDF handling
pdf_cache = {
"images": [],
"current_page": 0,
"total_pages": 0,
"file_type": None,
"is_parsed": False,
"results": []
}
# Utility functions
def round_by_factor(number: int, factor: int) -> int:
return round(number / factor) * factor
def smart_resize(height: int, width: int, factor: int = 28, min_pixels: int = 3136, max_pixels: int = 11289600):
if max(height, width) / min(height, width) > 200:
raise ValueError(f"Aspect ratio must be < 200, got {max(height, width) / min(height, width)}")
h_bar = max(factor, round_by_factor(height, factor))
w_bar = max(factor, round_by_factor(width, factor))
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = round_by_factor(height / beta, factor)
w_bar = round_by_factor(width / beta, factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = round_by_factor(height * beta, factor)
w_bar = round_by_factor(width * beta, factor)
return h_bar, w_bar
def fetch_image(image_input, min_pixels: int = None, max_pixels: int = None):
if isinstance(image_input, str):
if image_input.startswith(("http://", "https://")):
response = requests.get(image_input)
image = Image.open(BytesIO(response.content)).convert('RGB')
else:
image = Image.open(image_input).convert('RGB')
elif isinstance(image_input, Image.Image):
image = image_input.convert('RGB')
else:
raise ValueError(f"Invalid image input type: {type(image_input)}")
if min_pixels or max_pixels:
min_pixels = min_pixels or MIN_PIXELS
max_pixels = max_pixels or MAX_PIXELS
height, width = smart_resize(image.height, image.width, factor=IMAGE_FACTOR, min_pixels=min_pixels, max_pixels=max_pixels)
image = image.resize((width, height), Image.LANCZOS)
return image
def load_images_from_pdf(pdf_path: str) -> List[Image.Image]:
images = []
try:
pdf_document = fitz.open(pdf_path)
for page_num in range(len(pdf_document)):
page = pdf_document.load_page(page_num)
mat = fitz.Matrix(2.0, 2.0)
pix = page.get_pixmap(matrix=mat)
img_data = pix.tobytes("ppm")
image = Image.open(BytesIO(img_data)).convert('RGB')
images.append(image)
pdf_document.close()
except Exception as e:
print(f"Error loading PDF: {e}")
return []
return images
def draw_layout_on_image(image: Image.Image, layout_data: List[Dict]) -> Image.Image:
img_copy = image.copy()
draw = ImageDraw.Draw(img_copy)
colors = {
'Caption': '#FF6B6B', 'Footnote': '#4ECDC4', 'Formula': '#45B7D1', 'List-item': '#96CEB4',
'Page-footer': '#FFEAA7', 'Page-header': '#DDA0DD', 'Picture': '#FFD93D', 'Section-header': '#6C5CE7',
'Table': '#FD79A8', 'Text': '#74B9FF', 'Title': '#E17055'
}
try:
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 12)
except Exception:
font = ImageFont.load_default()
try:
for item in layout_data:
if 'bbox' in item and 'category' in item:
bbox = item['bbox']
category = item['category']
color = colors.get(category, '#000000')
draw.rectangle(bbox, outline=color, width=2)
label = category
label_bbox = draw.textbbox((0, 0), label, font=font)
label_width = label_bbox[2] - label_bbox[0]
label_height = label_bbox[3] - label_bbox[1]
label_x = bbox[0]
label_y = max(0, bbox[1] - label_height - 2)
draw.rectangle([label_x, label_y, label_x + label_width + 4, label_y + label_height + 2], fill=color)
draw.text((label_x + 2, label_y + 1), label, fill='white', font=font)
except Exception as e:
print(f"Error drawing layout: {e}")
return img_copy
def is_arabic_text(text: str) -> bool:
if not text:
return False
header_pattern = r'^#{1,6}\s+(.+)$'
paragraph_pattern = r'^(?!#{1,6}\s|!\[|```|\||\s*[-*+]\s|\s*\d+\.\s)(.+)$'
content_text = []
for line in text.split('\n'):
line = line.strip()
if not line:
continue
header_match = re.match(header_pattern, line, re.MULTILINE)
if header_match:
content_text.append(header_match.group(1))
continue
if re.match(paragraph_pattern, line, re.MULTILINE):
content_text.append(line)
if not content_text:
return False
combined_text = ' '.join(content_text)
arabic_chars = 0
total_chars = 0
for char in combined_text:
if char.isalpha():
total_chars += 1
if ('\u0600' <= char <= '\u06FF') or ('\u0750' <= char <= '\u077F') or ('\u08A0' <= char <= '\u08FF'):
arabic_chars += 1
return total_chars > 0 and (arabic_chars / total_chars) > 0.5
def layoutjson2md(image: Image.Image, layout_data: List[Dict], text_key: str = 'text') -> str:
import base64
markdown_lines = []
try:
sorted_items = sorted(layout_data, key=lambda x: (x.get('bbox', [0, 0, 0, 0])[1], x.get('bbox', [0, 0, 0, 0])[0]))
for item in sorted_items:
category = item.get('category', '')
text = item.get(text_key, '')
bbox = item.get('bbox', [])
if category == 'Picture':
if bbox and len(bbox) == 4:
try:
x1, y1, x2, y2 = [max(0, int(x)) if i < 2 else min(image.width if i % 2 == 0 else image.height, int(x)) for i, x in enumerate(bbox)]
if x2 > x1 and y2 > y1:
cropped_img = image.crop((x1, y1, x2, y2))
buffer = BytesIO()
cropped_img.save(buffer, format='PNG')
img_data = base64.b64encode(buffer.getvalue()).decode()
markdown_lines.append(f'
Advanced vision-language model for image/PDF to markdown document processing