# app.py
# All code combined into a single file for convenience.
# --- Imports ---
import spaces
import json
import math
import os
import traceback
from io import BytesIO
from typing import Any, Dict, List, Optional, Tuple
import re
import base64
import copy
from dataclasses import dataclass
#import flash_attn_2_cuda as flash_attn_gpu
# Vision and ML Libraries
import fitz # PyMuPDF
import gradio as gr
import requests
import torch
import subprocess
from huggingface_hub import snapshot_download
from PIL import Image, ImageDraw, ImageFont
from transformers import AutoModelForCausalLM, AutoProcessor, VisionEncoderDecoderModel
from qwen_vl_utils import process_vision_info
# Image Processing Libraries
import cv2
import numpy as np
import albumentations as alb
from albumentations.pytorch import ToTensorV2
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
# Attempt to install flash-attn
try:
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, check=True, shell=True)
except subprocess.CalledProcessError as e:
print(f"Error installing flash-attn: {e}")
print("Continuing without flash-attn.")
# --- Constants & Global State ---
MIN_PIXELS = 3136
MAX_PIXELS = 11289600
IMAGE_FACTOR = 28
DOT_OCR_PROMPT = """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
1. Bbox format: [x1, y1, x2, y2]
2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title'].
3. Text Extraction & Formatting Rules:
- Picture: For the 'Picture' category, the text field should be omitted.
- Formula: Format its text as LaTeX.
- Table: Format its text as HTML.
- All Others (Text, Title, etc.): Format their text as Markdown.
4. Constraints:
- The output text must be the original text from the image, with no translation.
- All layout elements must be sorted according to human reading order.
5. Final Output: The entire output must be a single JSON object.
"""
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
PDF_CACHE = {
"images": [],
"current_page": 0,
"total_pages": 0,
"file_type": None,
"is_parsed": False,
"results": [],
"model_used": None,
}
MODELS = {}
# =================================================================================
# --- UTILITY FUNCTIONS (from markdown_utils.py and utils.py) ---
# =================================================================================
# --- Markdown Conversion Utilities ---
def extract_table_from_html(html_string):
"""Extract and clean table tags from HTML string"""
try:
table_pattern = re.compile(r'
.*?', re.DOTALL)
tables = table_pattern.findall(html_string)
tables = [re.sub(r']*>', '', table) for table in tables]
return '\n'.join(tables)
except Exception as e:
print(f"extract_table_from_html error: {str(e)}")
return f"Error extracting table: {str(e)} |
"
class MarkdownConverter:
"""Convert structured recognition results to Markdown format"""
def __init__(self):
self.heading_levels = {'title': '#', 'sec': '##', 'sub_sec': '###'}
self.special_labels = {'tab', 'fig', 'title', 'sec', 'sub_sec', 'list', 'formula', 'reference', 'alg'}
def try_remove_newline(self, text: str) -> str:
try:
text = text.strip().replace('-\n', '')
def is_chinese(char): return '\u4e00' <= char <= '\u9fff'
lines, processed_lines = text.split('\n'), []
for i in range(len(lines)-1):
current_line, next_line = lines[i].strip(), lines[i+1].strip()
if current_line:
if next_line:
if is_chinese(current_line[-1]) and is_chinese(next_line[0]):
processed_lines.append(current_line)
else:
processed_lines.append(current_line + ' ')
else:
processed_lines.append(current_line + '\n')
else:
processed_lines.append('\n')
if lines and lines[-1].strip():
processed_lines.append(lines[-1].strip())
return ''.join(processed_lines)
except Exception as e:
print(f"try_remove_newline error: {str(e)}")
return text
def _handle_text(self, text: str) -> str:
try:
if not text: return ""
if text.strip().startswith("\\begin{array}") and text.strip().endswith("\\end{array}"):
text = "$$" + text + "$$"
elif ("_{" in text or "^{" in text or "\\" in text or "_ {" in text or "^ {" in text) and ("$" not in text) and ("\\begin" not in text):
text = "$" + text + "$"
text = self._process_formulas_in_text(text)
text = self.try_remove_newline(text)
return text
except Exception as e:
print(f"_handle_text error: {str(e)}")
return text
def _process_formulas_in_text(self, text: str) -> str:
try:
delimiters = [('$$', '$$'), ('\\[', '\\]'), ('$', '$'), ('\\(', '\\)')]
result = text
for start_delim, end_delim in delimiters:
current_pos, processed_parts = 0, []
while current_pos < len(result):
start_pos = result.find(start_delim, current_pos)
if start_pos == -1:
processed_parts.append(result[current_pos:])
break
processed_parts.append(result[current_pos:start_pos])
end_pos = result.find(end_delim, start_pos + len(start_delim))
if end_pos == -1:
processed_parts.append(result[start_pos:])
break
formula_content = result[start_pos + len(start_delim):end_pos]
processed_formula = formula_content.replace('\n', ' \\\\ ')
processed_parts.append(f"{start_delim}{processed_formula}{end_delim}")
current_pos = end_pos + len(end_delim)
result = ''.join(processed_parts)
return result
except Exception as e:
print(f"_process_formulas_in_text error: {str(e)}")
return text
def _remove_newline_in_heading(self, text: str) -> str:
try:
def is_chinese(char): return '\u4e00' <= char <= '\u9fff'
return text.replace('\n', '') if any(is_chinese(char) for char in text) else text.replace('\n', ' ')
except Exception as e:
print(f"_remove_newline_in_heading error: {str(e)}")
return text
def _handle_heading(self, text: str, label: str) -> str:
try:
level = self.heading_levels.get(label, '#')
text = self._remove_newline_in_heading(text.strip())
text = self._handle_text(text)
return f"{level} {text}\n\n"
except Exception as e:
print(f"_handle_heading error: {str(e)}")
return f"# Error processing heading: {text}\n\n"
def _handle_list_item(self, text: str) -> str:
try:
return f"- {text.strip()}\n"
except Exception as e:
print(f"_handle_list_item error: {str(e)}")
return f"- Error processing list item: {text}\n"
def _handle_figure(self, text: str, section_count: int) -> str:
try:
if not text.strip():
return f"\n\n"
if text.startswith("data:image/"):
return f"\n\n"
else:
return f"\n\n"
except Exception as e:
print(f"_handle_figure error: {str(e)}")
return f"*[Error processing figure: {str(e)}]*\n\n"
def _handle_table(self, text: str) -> str:
try:
if ' str:
try:
text = re.sub(r'\\begin\{algorithm\}(.*?)\\end\{algorithm\}', r'\1', text, flags=re.DOTALL)
text = text.replace('\\begin{algorithmic}', '').replace('\\end{algorithmic}', '')
caption_match = re.search(r'\\caption\{(.*?)\}', text)
caption = f"**{caption_match.group(1)}**\n\n" if caption_match else ""
algorithm_text = re.sub(r'\\caption\{.*?\}', '', text).strip()
return f"{caption}```\n{algorithm_text}\n```\n\n"
except Exception as e:
print(f"_handle_algorithm error: {str(e)}")
return f"*[Error processing algorithm: {str(e)}]*\n\n{text}\n\n"
def _handle_formula(self, text: str) -> str:
try:
processed_text = self._process_formulas_in_text(text)
if '$$' not in processed_text and '\\[' not in processed_text:
processed_text = f'$${processed_text}$$'
return f"{processed_text}\n\n"
except Exception as e:
print(f"_handle_formula error: {str(e)}")
return f"*[Error processing formula: {str(e)}]*\n\n"
def convert(self, recognition_results: List[Dict[str, Any]]) -> str:
markdown_content = []
for i, result in enumerate(recognition_results):
try:
label, text = result.get('label', ''), result.get('text', '').strip()
if label == 'fig':
markdown_content.append(self._handle_figure(text, i))
continue
if not text: continue
if label in {'title', 'sec', 'sub_sec'}:
markdown_content.append(self._handle_heading(text, label))
elif label == 'list':
markdown_content.append(self._handle_list_item(text))
elif label == 'tab':
markdown_content.append(self._handle_table(text))
elif label == 'alg':
markdown_content.append(self._handle_algorithm(text))
elif label == 'formula':
markdown_content.append(self._handle_formula(text))
elif label not in self.special_labels:
markdown_content.append(f"{self._handle_text(text)}\n\n")
except Exception as e:
print(f"Error processing item {i}: {str(e)}")
markdown_content.append(f"*[Error processing content]*\n\n")
return self._post_process(''.join(markdown_content))
def _post_process(self, md: str) -> str:
try:
md = re.sub(r'\\author\{(.*?)\}', lambda m: self._handle_text(m.group(1)), md, flags=re.DOTALL)
md = re.sub(r'\$(\\author\{.*?\})\$', lambda m: self._handle_text(re.search(r'\\author\{(.*?)\}', m.group(1), re.DOTALL).group(1)), md, flags=re.DOTALL)
md = re.sub(r'\\begin\{abstract\}(.*?)\\end\{abstract\}', r'**Abstract** \1', md, flags=re.DOTALL)
md = re.sub(r'\\begin\{abstract\}', r'**Abstract**', md)
md = re.sub(r'\\eqno\{\((.*?)\)\}', r'\\tag{\1}', md)
md = md.replace("\[ \\\\", "$$ \\\\").replace("\\\\ \]", "\\\\ $$")
md = re.sub(r'_ {', r'_{', md)
md = re.sub(r'^ {', r'^{', md)
md = re.sub(r'\n{3,}', r'\n\n', md)
return md
except Exception as e:
print(f"_post_process error: {str(e)}")
return md
# --- General Processing Utilities ---
@dataclass
class ImageDimensions:
original_w: int
original_h: int
padded_w: int
padded_h: int
def adjust_box_edges(image, boxes: List[List[float]], max_pixels=15, threshold=0.2):
if isinstance(image, str):
image = cv2.imread(image)
img_h, img_w = image.shape[:2]
new_boxes = []
for box in boxes:
best_box = copy.deepcopy(box)
def check_edge(img, current_box, i, is_vertical):
edge = current_box[i]
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
_, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
if is_vertical:
line = binary[current_box[1] : current_box[3] + 1, edge]
else:
line = binary[edge, current_box[0] : current_box[2] + 1]
transitions = np.abs(np.diff(line))
return np.sum(transitions) / len(transitions)
edges = [(0, -1, True), (2, 1, True), (1, -1, False), (3, 1, False)]
current_box = copy.deepcopy(box)
current_box = [min(max(c, 0), d - 1) for c, d in zip(current_box, [img_w, img_h, img_w, img_h])]
for i, direction, is_vertical in edges:
best_score = check_edge(image, current_box, i, is_vertical)
if best_score <= threshold: continue
for _ in range(max_pixels):
current_box[i] += direction
dim = img_w if i in [0, 2] else img_h
current_box[i] = min(max(current_box[i], 0), dim - 1)
score = check_edge(image, current_box, i, is_vertical)
if score < best_score:
best_score, best_box = score, copy.deepcopy(current_box)
if score <= threshold: break
new_boxes.append(best_box)
return new_boxes
def parse_layout_string(bbox_str):
pattern = r"\[(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+)\]\s*(\w+)"
matches = re.finditer(pattern, bbox_str)
return [([float(m.group(i)) for i in range(1, 5)], m.group(5).strip()) for m in matches]
def map_to_original_coordinates(x1, y1, x2, y2, dims: ImageDimensions) -> Tuple[int, int, int, int]:
try:
top, left = (dims.padded_h - dims.original_h) // 2, (dims.padded_w - dims.original_w) // 2
orig_x1, orig_y1 = max(0, x1 - left), max(0, y1 - top)
orig_x2, orig_y2 = min(dims.original_w, x2 - left), min(dims.original_h, y2 - top)
if orig_x2 <= orig_x1: orig_x2 = min(orig_x1 + 1, dims.original_w)
if orig_y2 <= orig_y1: orig_y2 = min(orig_y1 + 1, dims.original_h)
return int(orig_x1), int(orig_y1), int(orig_x2), int(orig_y2)
except Exception as e:
print(f"map_to_original_coordinates error: {str(e)}")
return 0, 0, min(100, dims.original_w), min(100, dims.original_h)
def process_coordinates(coords, padded_image, dims: ImageDimensions, previous_box=None):
try:
x1, y1 = int(coords[0] * dims.padded_w), int(coords[1] * dims.padded_h)
x2, y2 = int(coords[2] * dims.padded_w), int(coords[3] * dims.padded_h)
x1, y1, x2, y2 = max(0, x1), max(0, y1), min(dims.padded_w, x2), min(dims.padded_h, y2)
if x2 <= x1: x2 = min(x1 + 1, dims.padded_w)
if y2 <= y1: y2 = min(y1 + 1, dims.padded_h)
x1, y1, x2, y2 = adjust_box_edges(padded_image, [[x1, y1, x2, y2]])[0]
if previous_box:
prev_x1, prev_y1, prev_x2, prev_y2 = previous_box
if (x1 < prev_x2 and x2 > prev_x1) and (y1 < prev_y2 and y2 > prev_y1):
y1 = min(prev_y2, dims.padded_h - 1)
if y2 <= y1: y2 = min(y1 + 1, dims.padded_h)
orig_coords = map_to_original_coordinates(x1, y1, x2, y2, dims)
return x1, y1, x2, y2, *orig_coords, [x1, y1, x2, y2]
except Exception as e:
print(f"process_coordinates error: {str(e)}")
orig_coords = 0, 0, min(100, dims.original_w), min(100, dims.original_h)
return 0, 0, 100, 100, *orig_coords, [0, 0, 100, 100]
def prepare_image(image) -> Tuple[np.ndarray, ImageDimensions]:
try:
image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
original_h, original_w = image_cv.shape[:2]
max_size = max(original_h, original_w)
top, bottom = (max_size - original_h) // 2, max_size - original_h - ((max_size - original_h) // 2)
left, right = (max_size - original_w) // 2, max_size - original_w - ((max_size - original_w) // 2)
padded_image = cv2.copyMakeBorder(image_cv, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(0, 0, 0))
padded_h, padded_w = padded_image.shape[:2]
dims = ImageDimensions(original_w, original_h, padded_w, padded_h)
return padded_image, dims
except Exception as e:
print(f"prepare_image error: {str(e)}")
dims = ImageDimensions(image.width, image.height, image.width, image.height)
return np.zeros((image.height, image.width, 3), dtype=np.uint8), dims
# =================================================================================
# --- MODEL WRAPPER CLASSES ---
# =================================================================================
class DotOcrModel:
def __init__(self, device: str):
self.model, self.processor, self.device = None, None, device
self.model_id, self.model_path = "rednote-hilab/dots.ocr", "./models/dots-ocr-local"
@spaces.GPU()
def load_model(self):
if self.model is None:
print("Loading dot.ocr model...")
snapshot_download(repo_id=self.model_id, local_dir=self.model_path, local_dir_use_symlinks=False)
self.model = AutoModelForCausalLM.from_pretrained(
self.model_path, attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True
)
self.processor = AutoProcessor.from_pretrained(self.model_path, trust_remote_code=True)
print("dot.ocr model loaded.")
@staticmethod
def smart_resize(height, width, factor, min_pixels, max_pixels):
if max(height, width) / min(height, width) > 200: raise ValueError("Aspect ratio too high")
h_bar, w_bar = max(factor, round(height / factor) * factor), max(factor, round(width / factor) * factor)
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar, w_bar = round(height / beta / factor) * factor, round(width / beta / factor) * factor
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar, w_bar = round(height * beta / factor) * factor, round(width / beta / factor) * factor
return h_bar, w_bar
def fetch_image(self, image_input, min_pixels, max_pixels):
image = image_input.convert('RGB')
height, width = self.smart_resize(image.height, image.width, IMAGE_FACTOR, min_pixels, max_pixels)
return image.resize((width, height), Image.LANCZOS)
@spaces.GPU()
def inference(self, image: Image.Image, prompt: str, max_new_tokens: int = 24000) -> str:
self.load_model()
messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt}]}]
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, _ = process_vision_info(messages)
inputs = self.processor(text=[text], images=image_inputs, padding=True, return_tensors="pt").to(self.device)
with torch.no_grad():
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False, temperature=0.1)
generated_ids_trimmed = [out[len(ins):] for ins, out in zip(inputs.input_ids, generated_ids)]
return self.processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
def process_image(self, image: Image.Image, min_pixels: int, max_pixels: int):
resized_image = self.fetch_image(image, min_pixels, max_pixels)
raw_output = self.inference(resized_image, DOT_OCR_PROMPT)
result = {'original_image': image, 'raw_output': raw_output, 'layout_result': None}
try:
layout_data = json.loads(raw_output)
result['layout_result'] = layout_data
result['processed_image'] = self.draw_layout_on_image(image, layout_data)
result['markdown_content'] = self.layoutjson2md(image, layout_data)
except (json.JSONDecodeError, KeyError) as e:
print(f"Failed to parse or process dot.ocr layout: {e}")
result['processed_image'] = image
result['markdown_content'] = f"### Error processing output\nRaw model output:\n```json\n{raw_output}\n```"
return result
def draw_layout_on_image(self, image: Image.Image, layout_data: List[Dict]) -> Image.Image:
img_copy, draw = image.copy(), ImageDraw.Draw(img_copy)
colors = {'Caption': '#FF6B6B', 'Footnote': '#4ECDC4', 'Formula': '#45B7D1', 'List-item': '#96CEB4',
'Page-footer': '#FFEAA7', 'Page-header': '#DDA0DD', 'Picture': '#FFD93D', 'Section-header': '#6C5CE7',
'Table': '#FD79A8', 'Text': '#74B9FF', 'Title': '#E17055'}
try: font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 15)
except: font = ImageFont.load_default()
for item in layout_data:
if 'bbox' in item and 'category' in item:
bbox, category, color = item['bbox'], item['category'], colors.get(category, '#000000')
draw.rectangle(bbox, outline=color, width=3)
label_bbox = draw.textbbox((0, 0), category, font=font)
label_width, label_height = label_bbox[2] - label_bbox[0], label_bbox[3] - label_bbox[1]
label_x, label_y = bbox[0], max(0, bbox[1] - label_height - 5)
draw.rectangle([label_x, label_y, label_x + label_width + 4, label_y + label_height + 4], fill=color)
draw.text((label_x + 2, label_y + 2), category, fill='white', font=font)
return img_copy
def layoutjson2md(self, image: Image.Image, layout_data: List[Dict]) -> str:
md_lines, sorted_items = [], sorted(layout_data, key=lambda x: (x.get('bbox', [0]*4)[1], x.get('bbox', [0]*4)[0]))
for item in sorted_items:
cat, txt, bbox = item.get('category'), item.get('text'), item.get('bbox')
if cat == 'Picture' and bbox:
try:
x1, y1, x2, y2 = max(0, int(bbox[0])), max(0, int(bbox[1])), min(image.width, int(bbox[2])), min(image.height, int(bbox[3]))
if x2 > x1 and y2 > y1:
cropped = image.crop((x1, y1, x2, y2))
buffer = BytesIO()
cropped.save(buffer, format='PNG')
img_data = base64.b64encode(buffer.getvalue()).decode()
md_lines.append(f"\n")
except Exception: md_lines.append("\n")
elif not txt: continue
elif cat == 'Title': md_lines.append(f"# {txt}\n")
elif cat == 'Section-header': md_lines.append(f"## {txt}\n")
elif cat == 'List-item': md_lines.append(f"- {txt}\n")
elif cat == 'Formula': md_lines.append(f"$$\n{txt}\n$$\n")
elif cat == 'Caption': md_lines.append(f"*{txt}*\n")
elif cat == 'Footnote': md_lines.append(f"^{txt}^\n")
elif cat in ['Text', 'Table']: md_lines.append(f"{txt}\n")
return "\n".join(md_lines)
class DolphinModel:
def __init__(self, device: str):
self.model, self.processor, self.tokenizer, self.device = None, None, None, device
self.model_id = "ByteDance/Dolphin"
@spaces.GPU()
def load_model(self):
if self.model is None:
print("Loading Dolphin model...")
self.processor = AutoProcessor.from_pretrained(self.model_id)
self.model = VisionEncoderDecoderModel.from_pretrained(self.model_id).eval().to(self.device).half()
self.tokenizer = self.processor.tokenizer
print("Dolphin model loaded.")
@spaces.GPU()
def model_chat(self, prompt, image):
self.load_model()
images = image if isinstance(image, list) else [image]
prompts = prompt if isinstance(prompt, list) else [prompt] * len(images)
batch_inputs = self.processor(images, return_tensors="pt", padding=True)
batch_pixel_values = batch_inputs.pixel_values.half().to(self.device)
prompts = [f"{p} " for p in prompts]
batch_prompt_inputs = self.tokenizer(prompts, add_special_tokens=False, return_tensors="pt")
batch_prompt_ids = batch_prompt_inputs.input_ids.to(self.device)
batch_attention_mask = batch_prompt_inputs.attention_mask.to(self.device)
outputs = self.model.generate(
pixel_values=batch_pixel_values, decoder_input_ids=batch_prompt_ids,
decoder_attention_mask=batch_attention_mask, max_length=4096,
pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id,
use_cache=True, bad_words_ids=[[self.tokenizer.unk_token_id]],
return_dict_in_generate=True, do_sample=False, num_beams=1, repetition_penalty=1.1
)
sequences = self.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)
results = [seq.replace(p, "").replace("", "").replace("", "").strip() for p, seq in zip(prompts, sequences)]
return results if isinstance(image, list) else results[0]
def process_elements(self, layout_str: str, image: Image.Image, max_batch_size: int = 16):
padded_image, dims = prepare_image(image)
layout_results = parse_layout_string(layout_str)
elements, reading_order = [], 0
for bbox, label in layout_results:
try:
coords = process_coordinates(bbox, padded_image, dims)
x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2 = coords[:8]
cropped = padded_image[y1:y2, x1:x2]
if cropped.size > 0 and cropped.shape[0] > 3 and cropped.shape[1] > 3:
pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
elements.append({"crop": pil_crop, "label": label, "bbox": [orig_x1, orig_y1, orig_x2, orig_y2], "reading_order": reading_order})
reading_order += 1
except Exception as e:
print(f"Error processing Dolphin element bbox {bbox}: {e}")
text_elems = self.process_element_batch([e for e in elements if e['label'] != 'tab' and e['label'] != 'fig'], "Read text in the image.", max_batch_size)
table_elems = self.process_element_batch([e for e in elements if e['label'] == 'tab'], "Parse the table in the image.", max_batch_size)
fig_elems = [{"label": e['label'], "bbox": e['bbox'], "text": "", "reading_order": e['reading_order']} for e in elements if e['label'] == 'fig']
all_results = sorted(text_elems + table_elems + fig_elems, key=lambda x: x['reading_order'])
return all_results
def process_element_batch(self, elements, prompt, max_batch_size=16):
results = []
for i in range(0, len(elements), max_batch_size):
batch = elements[i:i+max_batch_size]
crops = [elem["crop"] for elem in batch]
prompts = [prompt] * len(crops)
batch_results = self.model_chat(prompts, crops)
for j, res_text in enumerate(batch_results):
elem = batch[j]
results.append({"label": elem["label"], "bbox": elem["bbox"], "text": res_text.strip(), "reading_order": elem["reading_order"]})
return results
def process_image(self, image: Image.Image):
layout_output = self.model_chat("Parse the reading order of this document.", image)
recognition_results = self.process_elements(layout_output, image)
markdown_content = MarkdownConverter().convert(recognition_results)
return {
'original_image': image, 'processed_image': image, 'markdown_content': markdown_content,
'layout_result': recognition_results, 'raw_output': layout_output
}
# =================================================================================
# --- GRADIO UI AND EVENT HANDLERS ---
# =================================================================================
def create_gradio_interface():
"""Create the main Gradio interface and define all event handlers"""
css = """
.main-container { max-width: 1400px; margin: 0 auto; }
.header-text { text-align: center; color: #2c3e50; margin-bottom: 20px; }
.process-button { border: none !important; color: white !important; font-weight: bold !important; background-color: blue !important;}
.process-button:hover { background-color: darkblue !important; transform: translateY(-2px) !important; box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; }
.info-box { border: 1px solid #dee2e6; border-radius: 8px; padding: 15px; margin: 10px 0; }
.page-info { text-align: center; padding: 8px 16px; border-radius: 20px; font-weight: bold; margin: 10px 0; }
"""
with gr.Blocks(theme="bethecloud/storj_theme", css=css, title="Dot.OCR Comparator") as demo:
gr.HTML("""
Dot●OCR Comparator
Advanced vision-language model for image/PDF to markdown document processing
""")
with gr.Row(elem_classes=["main-container"]):
with gr.Column(scale=1):
file_input = gr.File(label="Upload Image or PDF", file_types=[".jpg", ".jpeg", ".png", ".pdf"], type="filepath")
with gr.Row():
examples = gr.Examples(
examples=["examples/sample_image1.png", "examples/sample_image2.png", "examples/sample_pdf.pdf"],
inputs=file_input,
label="Example Documents"
)
model_choice = gr.Radio(choices=["dot.ocr", "Dolphin"], label="Select Model", value="dot.ocr")
image_preview = gr.Image(label="Preview", type="pil", interactive=False, height=400)
with gr.Row():
prev_page_btn = gr.Button("◀ Previous")
page_info = gr.HTML('No file loaded
')
next_page_btn = gr.Button("Next ▶")
with gr.Accordion("Advanced Settings (dot.ocr only)", open=False):
min_pixels = gr.Number(value=MIN_PIXELS, label="Min Pixels", step=1)
max_pixels = gr.Number(value=MAX_PIXELS, label="Max Pixels", step=1)
with gr.Row():
process_btn = gr.Button("🚀 Process Document", variant="primary", elem_classes=["process-button"], scale=2)
clear_btn = gr.Button("🗑️ Clear", variant="secondary")
with gr.Column(scale=2):
with gr.Tabs():
with gr.Tab("📝 Extracted Content"):
markdown_output = gr.Markdown(value="Click 'Process Document' to see extracted content...", elem_id="markdown_output")
with gr.Tab("🖼️ Processed Image"):
processed_image_output = gr.Image(label="Image with Layout Detection", type="pil", interactive=False)
with gr.Tab("📋 Layout JSON"):
json_output = gr.JSON(label="Layout Analysis Results")
def load_file_for_preview(file_path: str) -> Tuple[List[Image.Image], str]:
images = []
if not file_path or not os.path.exists(file_path): return [], "No file selected"
try:
ext = os.path.splitext(file_path)[1].lower()
if ext == '.pdf':
doc = fitz.open(file_path)
for page in doc:
pix = page.get_pixmap(matrix=fitz.Matrix(2.0, 2.0))
images.append(Image.open(BytesIO(pix.tobytes("ppm"))).convert('RGB'))
doc.close()
elif ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']:
images.append(Image.open(file_path).convert('RGB'))
return images, f"Page 1 / {len(images)}"
except Exception as e:
print(f"Error loading file for preview: {e}")
return [], f"Error loading file: {e}"
def handle_file_upload(file_path):
global PDF_CACHE
images, page_info_str = load_file_for_preview(file_path)
if not images:
return None, page_info_str
PDF_CACHE = {
"images": images, "current_page": 0, "total_pages": len(images),
"is_parsed": False, "results": [], "model_used": None
}
return images[0], f'{page_info_str}
'
def process_document(file_path, model_name, min_pix, max_pix):
global PDF_CACHE
if not file_path or not PDF_CACHE["images"]:
return "Please upload a file first.", None, None
if model_name not in MODELS:
if model_name == 'dot.ocr': MODELS[model_name] = DotOcrModel(DEVICE)
elif model_name == 'Dolphin': MODELS[model_name] = DolphinModel(DEVICE)
model = MODELS[model_name]
all_results, all_markdown = [], []
for i, img in enumerate(PDF_CACHE["images"]):
gr.Info(f"Processing page {i+1}/{len(PDF_CACHE['images'])} with {model_name}...")
if model_name == 'dot.ocr':
result = model.process_image(img, int(min_pix), int(max_pix))
else: # Dolphin
result = model.process_image(img)
all_results.append(result)
if result.get('markdown_content'):
all_markdown.append(f"### Page {i+1}\n\n{result['markdown_content']}")
PDF_CACHE.update({"results": all_results, "is_parsed": True, "model_used": model_name})
if not all_results: return "Processing failed.", None, None
first_result = all_results[0]
combined_md = "\n\n---\n\n".join(all_markdown)
return combined_md, first_result.get('processed_image'), first_result.get('layout_result')
def turn_page(direction):
global PDF_CACHE
if not PDF_CACHE["images"] or not PDF_CACHE["is_parsed"]:
return None, 'No file parsed
', "No results yet", None, None
if direction == "prev": PDF_CACHE["current_page"] = max(0, PDF_CACHE["current_page"] - 1)
else: PDF_CACHE["current_page"] = min(PDF_CACHE["total_pages"] - 1, PDF_CACHE["current_page"] + 1)
idx = PDF_CACHE["current_page"]
page_info_html = f'Page {idx + 1} / {PDF_CACHE["total_pages"]}
'
preview_img = PDF_CACHE["images"][idx]
result = PDF_CACHE["results"][idx]
all_md = [f"### Page {i+1}\n\n{res.get('markdown_content', '')}" for i, res in enumerate(PDF_CACHE["results"])]
md_content = "\n\n---\n\n".join(all_md) if PDF_CACHE["total_pages"] > 1 else result.get('markdown_content', 'No content')
return preview_img, page_info_html, md_content, result.get('processed_image'), result.get('layout_result')
def clear_all():
global PDF_CACHE
PDF_CACHE = {"images": [], "current_page": 0, "total_pages": 0, "is_parsed": False, "results": [], "model_used": None}
return None, None, 'No file loaded
', "Click 'Process Document' to see extracted content...", None, None
# --- Wire UI components ---
file_input.change(handle_file_upload, inputs=file_input, outputs=[image_preview, page_info])
process_btn.click(
process_document,
inputs=[file_input, model_choice, min_pixels, max_pixels],
outputs=[markdown_output, processed_image_output, json_output]
)
prev_page_btn.click(lambda: turn_page("prev"), outputs=[image_preview, page_info, markdown_output, processed_image_output, json_output])
next_page_btn.click(lambda: turn_page("next"), outputs=[image_preview, page_info, markdown_output, processed_image_output, json_output])
clear_btn.click(clear_all, outputs=[file_input, image_preview, page_info, markdown_output, processed_image_output, json_output])
return demo
if __name__ == "__main__":
# Create example directory if it doesn't exist
if not os.path.exists("examples"):
os.makedirs("examples")
print("Created 'examples' directory. Please add sample images/PDFs there.")
app = create_gradio_interface()
app.queue().launch(debug=True, show_error=True)