|
import spaces |
|
import json |
|
import math |
|
import os |
|
import traceback |
|
from io import BytesIO |
|
from typing import Any, Dict, List, Optional, Tuple |
|
import re |
|
|
|
import fitz |
|
import gradio as gr |
|
import requests |
|
import torch |
|
from huggingface_hub import snapshot_download |
|
from PIL import Image, ImageDraw, ImageFont |
|
from qwen_vl_utils import process_vision_info |
|
from transformers import AutoModelForCausalLM, AutoProcessor |
|
|
|
|
|
MIN_PIXELS = 3136 |
|
MAX_PIXELS = 11289600 |
|
IMAGE_FACTOR = 28 |
|
|
|
|
|
prompt = """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox. |
|
|
|
1. Bbox format: [x1, y1, x2, y2] |
|
|
|
2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title']. |
|
|
|
3. Text Extraction & Formatting Rules: |
|
- Picture: For the 'Picture' category, the text field should be omitted. |
|
- Formula: Format its text as LaTeX. |
|
- Table: Format its text as HTML. |
|
- All Others (Text, Title, etc.): Format their text as Markdown. |
|
|
|
4. Constraints: |
|
- The output text must be the original text from the image, with no translation. |
|
- All layout elements must be sorted according to human reading order. |
|
|
|
5. Final Output: The entire output must be a single JSON object. |
|
""" |
|
|
|
|
|
def round_by_factor(number: int, factor: int) -> int: |
|
"""Returns the closest integer to 'number' that is divisible by 'factor'.""" |
|
return round(number / factor) * factor |
|
|
|
|
|
def smart_resize( |
|
height: int, |
|
width: int, |
|
factor: int = 28, |
|
min_pixels: int = 3136, |
|
max_pixels: int = 11289600, |
|
): |
|
"""Rescales the image so that the following conditions are met: |
|
1. Both dimensions (height and width) are divisible by 'factor'. |
|
2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. |
|
3. The aspect ratio of the image is maintained as closely as possible. |
|
""" |
|
if max(height, width) / min(height, width) > 200: |
|
raise ValueError( |
|
f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}" |
|
) |
|
h_bar = max(factor, round_by_factor(height, factor)) |
|
w_bar = max(factor, round_by_factor(width, factor)) |
|
|
|
if h_bar * w_bar > max_pixels: |
|
beta = math.sqrt((height * width) / max_pixels) |
|
h_bar = round_by_factor(height / beta, factor) |
|
w_bar = round_by_factor(width / beta, factor) |
|
elif h_bar * w_bar < min_pixels: |
|
beta = math.sqrt(min_pixels / (height * width)) |
|
h_bar = round_by_factor(height * beta, factor) |
|
w_bar = round_by_factor(width * beta, factor) |
|
return h_bar, w_bar |
|
|
|
|
|
def fetch_image(image_input, min_pixels: int = None, max_pixels: int = None): |
|
"""Fetch and process an image""" |
|
if isinstance(image_input, str): |
|
if image_input.startswith(("http://", "https://")): |
|
response = requests.get(image_input) |
|
image = Image.open(BytesIO(response.content)).convert('RGB') |
|
else: |
|
image = Image.open(image_input).convert('RGB') |
|
elif isinstance(image_input, Image.Image): |
|
image = image_input.convert('RGB') |
|
else: |
|
raise ValueError(f"Invalid image input type: {type(image_input)}") |
|
|
|
if min_pixels is not None or max_pixels is not None: |
|
min_pixels = min_pixels or MIN_PIXELS |
|
max_pixels = max_pixels or MAX_PIXELS |
|
height, width = smart_resize( |
|
image.height, |
|
image.width, |
|
factor=IMAGE_FACTOR, |
|
min_pixels=min_pixels, |
|
max_pixels=max_pixels |
|
) |
|
image = image.resize((width, height), Image.LANCZOS) |
|
|
|
return image |
|
|
|
|
|
def load_images_from_pdf(pdf_path: str) -> List[Image.Image]: |
|
"""Load images from PDF file""" |
|
images = [] |
|
try: |
|
pdf_document = fitz.open(pdf_path) |
|
for page_num in range(len(pdf_document)): |
|
page = pdf_document.load_page(page_num) |
|
|
|
mat = fitz.Matrix(2.0, 2.0) |
|
pix = page.get_pixmap(matrix=mat) |
|
img_data = pix.tobytes("ppm") |
|
image = Image.open(BytesIO(img_data)).convert('RGB') |
|
images.append(image) |
|
pdf_document.close() |
|
except Exception as e: |
|
print(f"Error loading PDF: {e}") |
|
return [] |
|
return images |
|
|
|
|
|
def draw_layout_on_image(image: Image.Image, layout_data: List[Dict]) -> Image.Image: |
|
"""Draw layout bounding boxes on image""" |
|
img_copy = image.copy() |
|
draw = ImageDraw.Draw(img_copy) |
|
|
|
|
|
colors = { |
|
'Caption': '#FF6B6B', |
|
'Footnote': '#4ECDC4', |
|
'Formula': '#45B7D1', |
|
'List-item': '#96CEB4', |
|
'Page-footer': '#FFEAA7', |
|
'Page-header': '#DDA0DD', |
|
'Picture': '#FFD93D', |
|
'Section-header': '#6C5CE7', |
|
'Table': '#FD79A8', |
|
'Text': '#74B9FF', |
|
'Title': '#E17055' |
|
} |
|
|
|
try: |
|
|
|
try: |
|
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 12) |
|
except Exception: |
|
font = ImageFont.load_default() |
|
|
|
for item in layout_data: |
|
if 'bbox' in item and 'category' in item: |
|
bbox = item['bbox'] |
|
category = item['category'] |
|
color = colors.get(category, '#000000') |
|
|
|
|
|
draw.rectangle(bbox, outline=color, width=2) |
|
|
|
|
|
label = category |
|
label_bbox = draw.textbbox((0, 0), label, font=font) |
|
label_width = label_bbox[2] - label_bbox[0] |
|
label_height = label_bbox[3] - label_bbox[1] |
|
|
|
|
|
label_x = bbox[0] |
|
label_y = max(0, bbox[1] - label_height - 2) |
|
|
|
|
|
draw.rectangle( |
|
[label_x, label_y, label_x + label_width + 4, label_y + label_height + 2], |
|
fill=color |
|
) |
|
|
|
|
|
draw.text((label_x + 2, label_y + 1), label, fill='white', font=font) |
|
|
|
except Exception as e: |
|
print(f"Error drawing layout: {e}") |
|
|
|
return img_copy |
|
|
|
|
|
def is_arabic_text(text: str) -> bool: |
|
"""Check if text in headers and paragraphs contains mostly Arabic characters""" |
|
if not text: |
|
return False |
|
|
|
|
|
|
|
header_pattern = r'^#{1,6}\s+(.+)$' |
|
paragraph_pattern = r'^(?!#{1,6}\s|!\[|```|\||\s*[-*+]\s|\s*\d+\.\s)(.+)$' |
|
|
|
content_text = [] |
|
|
|
for line in text.split('\n'): |
|
line = line.strip() |
|
if not line: |
|
continue |
|
|
|
|
|
header_match = re.match(header_pattern, line, re.MULTILINE) |
|
if header_match: |
|
content_text.append(header_match.group(1)) |
|
continue |
|
|
|
|
|
if re.match(paragraph_pattern, line, re.MULTILINE): |
|
content_text.append(line) |
|
|
|
if not content_text: |
|
return False |
|
|
|
|
|
combined_text = ' '.join(content_text) |
|
|
|
|
|
arabic_chars = 0 |
|
total_chars = 0 |
|
|
|
for char in combined_text: |
|
if char.isalpha(): |
|
total_chars += 1 |
|
|
|
if ('\u0600' <= char <= '\u06FF') or ('\u0750' <= char <= '\u077F') or ('\u08A0' <= char <= '\u08FF'): |
|
arabic_chars += 1 |
|
|
|
if total_chars == 0: |
|
return False |
|
|
|
|
|
return (arabic_chars / total_chars) > 0.5 |
|
|
|
|
|
def layoutjson2md(image: Image.Image, layout_data: List[Dict], text_key: str = 'text') -> str: |
|
"""Convert layout JSON to markdown format""" |
|
import base64 |
|
from io import BytesIO |
|
|
|
markdown_lines = [] |
|
|
|
try: |
|
|
|
sorted_items = sorted(layout_data, key=lambda x: (x.get('bbox', [0, 0, 0, 0])[1], x.get('bbox', [0, 0, 0, 0])[0])) |
|
|
|
for item in sorted_items: |
|
category = item.get('category', '') |
|
text = item.get(text_key, '') |
|
bbox = item.get('bbox', []) |
|
|
|
if category == 'Picture': |
|
|
|
if bbox and len(bbox) == 4: |
|
try: |
|
|
|
x1, y1, x2, y2 = bbox |
|
|
|
x1, y1 = max(0, int(x1)), max(0, int(y1)) |
|
x2, y2 = min(image.width, int(x2)), min(image.height, int(y2)) |
|
|
|
if x2 > x1 and y2 > y1: |
|
cropped_img = image.crop((x1, y1, x2, y2)) |
|
|
|
|
|
buffer = BytesIO() |
|
cropped_img.save(buffer, format='PNG') |
|
img_data = base64.b64encode(buffer.getvalue()).decode() |
|
|
|
|
|
markdown_lines.append(f"\n") |
|
else: |
|
markdown_lines.append("\n") |
|
except Exception as e: |
|
print(f"Error processing image region: {e}") |
|
markdown_lines.append("\n") |
|
else: |
|
markdown_lines.append("\n") |
|
elif not text: |
|
continue |
|
elif category == 'Title': |
|
markdown_lines.append(f"# {text}\n") |
|
elif category == 'Section-header': |
|
markdown_lines.append(f"## {text}\n") |
|
elif category == 'Text': |
|
markdown_lines.append(f"{text}\n") |
|
elif category == 'List-item': |
|
markdown_lines.append(f"- {text}\n") |
|
elif category == 'Table': |
|
|
|
if text.strip().startswith('<'): |
|
markdown_lines.append(f"{text}\n") |
|
else: |
|
markdown_lines.append(f"**Table:** {text}\n") |
|
elif category == 'Formula': |
|
|
|
if text.strip().startswith('$') or '\\' in text: |
|
markdown_lines.append(f"$$\n{text}\n$$\n") |
|
else: |
|
markdown_lines.append(f"**Formula:** {text}\n") |
|
elif category == 'Caption': |
|
markdown_lines.append(f"*{text}*\n") |
|
elif category == 'Footnote': |
|
markdown_lines.append(f"^{text}^\n") |
|
elif category in ['Page-header', 'Page-footer']: |
|
|
|
continue |
|
else: |
|
markdown_lines.append(f"{text}\n") |
|
|
|
markdown_lines.append("") |
|
|
|
except Exception as e: |
|
print(f"Error converting to markdown: {e}") |
|
return str(layout_data) |
|
|
|
return "\n".join(markdown_lines) |
|
|
|
|
|
|
|
model_id = "helizac/dots.ocr-4bit" |
|
|
|
model_path = "./models/dots-ocr-local" |
|
model_path = snapshot_download( |
|
repo_id=model_id, |
|
local_dir=model_path, |
|
local_dir_use_symlinks=False, |
|
) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_path, |
|
attn_implementation="flash_attention_2", |
|
torch_dtype=torch.bfloat16, |
|
device_map="auto", |
|
trust_remote_code=True |
|
) |
|
processor = AutoProcessor.from_pretrained( |
|
model_path, |
|
trust_remote_code=True |
|
) |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
pdf_cache = { |
|
"images": [], |
|
"current_page": 0, |
|
"total_pages": 0, |
|
"file_type": None, |
|
"is_parsed": False, |
|
"results": [] |
|
} |
|
@spaces.GPU() |
|
def inference(image: Image.Image, prompt: str, max_new_tokens: int = 24000) -> str: |
|
"""Run inference on an image with the given prompt""" |
|
try: |
|
if model is None or processor is None: |
|
raise RuntimeError("Model not loaded. Please check model initialization.") |
|
|
|
|
|
messages = [ |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{ |
|
"type": "image", |
|
"image": image |
|
}, |
|
{"type": "text", "text": prompt} |
|
] |
|
} |
|
] |
|
|
|
|
|
text = processor.apply_chat_template( |
|
messages, |
|
tokenize=False, |
|
add_generation_prompt=True |
|
) |
|
|
|
|
|
image_inputs, video_inputs = process_vision_info(messages) |
|
|
|
|
|
inputs = processor( |
|
text=[text], |
|
images=image_inputs, |
|
videos=video_inputs, |
|
padding=True, |
|
return_tensors="pt", |
|
) |
|
|
|
|
|
inputs = inputs.to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
generated_ids = model.generate( |
|
**inputs, |
|
max_new_tokens=max_new_tokens, |
|
do_sample=False, |
|
|
|
temperature=0.6, top_p=0.9, repetition_penalty=1.15 |
|
) |
|
|
|
|
|
generated_ids_trimmed = [ |
|
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) |
|
] |
|
|
|
output_text = processor.batch_decode( |
|
generated_ids_trimmed, |
|
skip_special_tokens=True, |
|
clean_up_tokenization_spaces=False |
|
) |
|
|
|
return output_text[0] if output_text else "" |
|
|
|
except Exception as e: |
|
print(f"Error during inference: {e}") |
|
traceback.print_exc() |
|
return f"Error during inference: {str(e)}" |
|
|
|
|
|
def process_image( |
|
image: Image.Image, |
|
min_pixels: Optional[int] = None, |
|
max_pixels: Optional[int] = None |
|
) -> Dict[str, Any]: |
|
"""Process a single image with the specified prompt mode""" |
|
try: |
|
|
|
if min_pixels is not None or max_pixels is not None: |
|
image = fetch_image(image, min_pixels=min_pixels, max_pixels=max_pixels) |
|
|
|
|
|
raw_output = inference(image, prompt) |
|
|
|
|
|
result = { |
|
'original_image': image, |
|
'raw_output': raw_output, |
|
'processed_image': image, |
|
'layout_result': None, |
|
'markdown_content': None |
|
} |
|
|
|
|
|
try: |
|
|
|
layout_data = json.loads(raw_output) |
|
result['layout_result'] = layout_data |
|
|
|
|
|
try: |
|
processed_image = draw_layout_on_image(image, layout_data) |
|
result['processed_image'] = processed_image |
|
except Exception as e: |
|
print(f"Error drawing layout: {e}") |
|
result['processed_image'] = image |
|
|
|
|
|
try: |
|
markdown_content = layoutjson2md(image, layout_data, text_key='text') |
|
result['markdown_content'] = markdown_content |
|
except Exception as e: |
|
print(f"Error generating markdown: {e}") |
|
result['markdown_content'] = raw_output |
|
|
|
except json.JSONDecodeError: |
|
print("Failed to parse JSON output, using raw output") |
|
result['markdown_content'] = raw_output |
|
|
|
return result |
|
|
|
except Exception as e: |
|
print(f"Error processing image: {e}") |
|
traceback.print_exc() |
|
return { |
|
'original_image': image, |
|
'raw_output': f"Error processing image: {str(e)}", |
|
'processed_image': image, |
|
'layout_result': None, |
|
'markdown_content': f"Error processing image: {str(e)}" |
|
} |
|
|
|
|
|
def load_file_for_preview(file_path: str) -> Tuple[Optional[Image.Image], str]: |
|
"""Load file for preview (supports PDF and images)""" |
|
global pdf_cache |
|
|
|
if not file_path or not os.path.exists(file_path): |
|
return None, "No file selected" |
|
|
|
file_ext = os.path.splitext(file_path)[1].lower() |
|
|
|
try: |
|
if file_ext == '.pdf': |
|
|
|
images = load_images_from_pdf(file_path) |
|
if not images: |
|
return None, "Failed to load PDF" |
|
|
|
pdf_cache.update({ |
|
"images": images, |
|
"current_page": 0, |
|
"total_pages": len(images), |
|
"file_type": "pdf", |
|
"is_parsed": False, |
|
"results": [] |
|
}) |
|
|
|
return images[0], f"Page 1 / {len(images)}" |
|
|
|
elif file_ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']: |
|
|
|
image = Image.open(file_path).convert('RGB') |
|
|
|
pdf_cache.update({ |
|
"images": [image], |
|
"current_page": 0, |
|
"total_pages": 1, |
|
"file_type": "image", |
|
"is_parsed": False, |
|
"results": [] |
|
}) |
|
|
|
return image, "Page 1 / 1" |
|
else: |
|
return None, f"Unsupported file format: {file_ext}" |
|
|
|
except Exception as e: |
|
print(f"Error loading file: {e}") |
|
return None, f"Error loading file: {str(e)}" |
|
|
|
|
|
def turn_page(direction: str) -> Tuple[Optional[Image.Image], str, Any, Optional[Image.Image], Optional[Dict]]: |
|
"""Navigate through PDF pages and update all relevant outputs.""" |
|
global pdf_cache |
|
|
|
if not pdf_cache["images"]: |
|
return None, '<div class="page-info">No file loaded</div>', "No results yet", None, None |
|
|
|
if direction == "prev": |
|
pdf_cache["current_page"] = max(0, pdf_cache["current_page"] - 1) |
|
elif direction == "next": |
|
pdf_cache["current_page"] = min( |
|
pdf_cache["total_pages"] - 1, |
|
pdf_cache["current_page"] + 1 |
|
) |
|
|
|
index = pdf_cache["current_page"] |
|
current_image_preview = pdf_cache["images"][index] |
|
page_info_html = f'<div class="page-info">Page {index + 1} / {pdf_cache["total_pages"]}</div>' |
|
|
|
|
|
markdown_content = "Page not processed yet" |
|
processed_img = None |
|
layout_json = None |
|
|
|
|
|
if (pdf_cache["is_parsed"] and |
|
index < len(pdf_cache["results"]) and |
|
pdf_cache["results"][index]): |
|
|
|
result = pdf_cache["results"][index] |
|
markdown_content = result.get('markdown_content') or result.get('raw_output', 'No content available') |
|
processed_img = result.get('processed_image', None) |
|
layout_json = result.get('layout_result', None) |
|
|
|
|
|
if is_arabic_text(markdown_content): |
|
markdown_update = gr.update(value=markdown_content, rtl=True) |
|
else: |
|
markdown_update = markdown_content |
|
|
|
return current_image_preview, page_info_html, markdown_update, processed_img, layout_json |
|
|
|
|
|
def create_gradio_interface(): |
|
"""Create the Gradio interface""" |
|
|
|
|
|
css = """ |
|
.main-container { |
|
max-width: 1400px; |
|
margin: 0 auto; |
|
} |
|
|
|
.header-text { |
|
text-align: center; |
|
color: #2c3e50; |
|
margin-bottom: 20px; |
|
} |
|
|
|
.process-button { |
|
border: none !important; |
|
color: white !important; |
|
font-weight: bold !important; |
|
} |
|
|
|
.process-button:hover { |
|
transform: translateY(-2px) !important; |
|
box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; |
|
} |
|
|
|
.info-box { |
|
border: 1px solid #dee2e6; |
|
border-radius: 8px; |
|
padding: 15px; |
|
margin: 10px 0; |
|
} |
|
|
|
.page-info { |
|
text-align: center; |
|
padding: 8px 16px; |
|
border-radius: 20px; |
|
font-weight: bold; |
|
margin: 10px 0; |
|
} |
|
|
|
.model-status { |
|
padding: 10px; |
|
border-radius: 8px; |
|
margin: 10px 0; |
|
text-align: center; |
|
font-weight: bold; |
|
} |
|
|
|
.status-ready { |
|
background: #d1edff; |
|
color: #0c5460; |
|
border: 1px solid #b8daff; |
|
} |
|
""" |
|
|
|
with gr.Blocks(theme=gr.themes.Soft(), css=css, title="Dots.OCR Demo") as demo: |
|
|
|
|
|
gr.HTML(""" |
|
<div class="title" style="text-align: center"> |
|
<h1>π Dot-OCR - Multilingual Document Text Extraction</h1> |
|
<p style="font-size: 1.1em; color: #6b7280; margin-bottom: 0.6em;"> |
|
A state-of-the-art image/pdf-to-markdown vision language model for intelligent document processing |
|
</p> |
|
<div style="display: flex; justify-content: center; gap: 20px; margin: 15px 0;"> |
|
<a href="https://huggingface.co/helizac/dots.ocr-4bit" target="_blank" style="text-decoration: none; color: #2563eb; font-weight: 500;"> |
|
π Hugging Face Model |
|
</a> |
|
<a href="https://github.com/rednote-hilab/dots.ocr/blob/master/assets/blog.md" target="_blank" style="text-decoration: none; color: #2563eb; font-weight: 500;"> |
|
π Release Blog |
|
</a> |
|
<a href="https://github.com/rednote-hilab/dots.ocr" target="_blank" style="text-decoration: none; color: #2563eb; font-weight: 500;"> |
|
π» GitHub Repository |
|
</a> |
|
</div> |
|
</div> |
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
|
with gr.Column(scale=1): |
|
|
|
|
|
file_input = gr.File( |
|
label="Upload Image or PDF", |
|
file_types=[".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".pdf"], |
|
type="filepath" |
|
) |
|
|
|
|
|
image_preview = gr.Image( |
|
label="Preview", |
|
type="pil", |
|
interactive=False, |
|
height=300 |
|
) |
|
|
|
|
|
with gr.Row(): |
|
prev_page_btn = gr.Button("β Previous", size="md") |
|
page_info = gr.HTML('<div class="page-info">No file loaded</div>') |
|
next_page_btn = gr.Button("Next βΆ", size="md") |
|
|
|
|
|
with gr.Accordion("Advanced Settings", open=False): |
|
max_new_tokens = gr.Slider( |
|
minimum=1000, |
|
maximum=32000, |
|
value=24000, |
|
step=1000, |
|
label="Max New Tokens", |
|
info="Maximum number of tokens to generate" |
|
) |
|
|
|
min_pixels = gr.Number( |
|
value=MIN_PIXELS, |
|
label="Min Pixels", |
|
info="Minimum image resolution" |
|
) |
|
|
|
max_pixels = gr.Number( |
|
value=MAX_PIXELS, |
|
label="Max Pixels", |
|
info="Maximum image resolution" |
|
) |
|
|
|
|
|
process_btn = gr.Button( |
|
"π Process Document", |
|
variant="primary", |
|
elem_classes=["process-button"], |
|
size="lg" |
|
) |
|
|
|
|
|
clear_btn = gr.Button("ποΈ Clear All", variant="secondary") |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
|
|
|
with gr.Tabs(): |
|
|
|
with gr.Tab("πΌοΈ Processed Image"): |
|
processed_image = gr.Image( |
|
label="Image with Layout Detection", |
|
type="pil", |
|
interactive=False, |
|
height=500 |
|
) |
|
|
|
with gr.Tab("π Extracted Content"): |
|
markdown_output = gr.Markdown( |
|
value="Click 'Process Document' to see extracted content...", |
|
height=500 |
|
) |
|
|
|
with gr.Tab("π Layout JSON"): |
|
json_output = gr.JSON( |
|
label="Layout Analysis Results", |
|
value=None |
|
) |
|
|
|
|
|
def process_document(file_path, max_tokens, min_pix, max_pix): |
|
"""Process the uploaded document""" |
|
global pdf_cache |
|
|
|
try: |
|
if not file_path: |
|
return None, "Please upload a file first.", None |
|
|
|
if model is None: |
|
return None, "Model not loaded. Please refresh the page and try again.", None |
|
|
|
|
|
image, page_info = load_file_for_preview(file_path) |
|
if image is None: |
|
return None, page_info, None |
|
|
|
|
|
if pdf_cache["file_type"] == "pdf": |
|
|
|
all_results = [] |
|
all_markdown = [] |
|
|
|
for i, img in enumerate(pdf_cache["images"]): |
|
result = process_image( |
|
img, |
|
min_pixels=int(min_pix) if min_pix else None, |
|
max_pixels=int(max_pix) if max_pix else None |
|
) |
|
all_results.append(result) |
|
if result.get('markdown_content'): |
|
all_markdown.append(f"## Page {i+1}\n\n{result['markdown_content']}") |
|
|
|
pdf_cache["results"] = all_results |
|
pdf_cache["is_parsed"] = True |
|
|
|
|
|
first_result = all_results[0] |
|
combined_markdown = "\n\n---\n\n".join(all_markdown) |
|
|
|
|
|
if is_arabic_text(combined_markdown): |
|
markdown_update = gr.update(value=combined_markdown, rtl=True) |
|
else: |
|
markdown_update = combined_markdown |
|
|
|
return ( |
|
first_result['processed_image'], |
|
markdown_update, |
|
first_result['layout_result'] |
|
) |
|
else: |
|
|
|
result = process_image( |
|
image, |
|
min_pixels=int(min_pix) if min_pix else None, |
|
max_pixels=int(max_pix) if max_pix else None |
|
) |
|
|
|
pdf_cache["results"] = [result] |
|
pdf_cache["is_parsed"] = True |
|
|
|
|
|
content = result['markdown_content'] or "No content extracted" |
|
if is_arabic_text(content): |
|
markdown_update = gr.update(value=content, rtl=True) |
|
else: |
|
markdown_update = content |
|
|
|
return ( |
|
result['processed_image'], |
|
markdown_update, |
|
result['layout_result'] |
|
) |
|
|
|
except Exception as e: |
|
error_msg = f"Error processing document: {str(e)}" |
|
print(error_msg) |
|
traceback.print_exc() |
|
return None, error_msg, None |
|
|
|
def handle_file_upload(file_path): |
|
"""Handle file upload and show preview""" |
|
if not file_path: |
|
return None, "No file loaded" |
|
|
|
image, page_info = load_file_for_preview(file_path) |
|
return image, page_info |
|
|
|
def handle_page_turn(direction): |
|
"""Handle page navigation""" |
|
image, page_info, result = turn_page(direction) |
|
return image, page_info, result |
|
|
|
def clear_all(): |
|
"""Clear all data and reset interface""" |
|
global pdf_cache |
|
|
|
pdf_cache = { |
|
"images": [], "current_page": 0, "total_pages": 0, |
|
"file_type": None, "is_parsed": False, "results": [] |
|
} |
|
|
|
return ( |
|
None, |
|
None, |
|
'<div class="page-info">No file loaded</div>', |
|
None, |
|
"Click 'Process Document' to see extracted content...", |
|
None, |
|
) |
|
|
|
|
|
file_input.change( |
|
handle_file_upload, |
|
inputs=[file_input], |
|
outputs=[image_preview, page_info] |
|
) |
|
|
|
|
|
prev_page_btn.click( |
|
lambda: turn_page("prev"), |
|
outputs=[image_preview, page_info, markdown_output, processed_image, json_output] |
|
) |
|
|
|
next_page_btn.click( |
|
lambda: turn_page("next"), |
|
outputs=[image_preview, page_info, markdown_output, processed_image, json_output] |
|
) |
|
|
|
process_btn.click( |
|
process_document, |
|
inputs=[file_input, max_new_tokens, min_pixels, max_pixels], |
|
outputs=[processed_image, markdown_output, json_output] |
|
) |
|
|
|
|
|
clear_btn.click( |
|
clear_all, |
|
outputs=[ |
|
file_input, image_preview, page_info, processed_image, |
|
markdown_output, json_output |
|
] |
|
) |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
demo = create_gradio_interface() |
|
demo.queue(max_size=10).launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
share=False, |
|
debug=True, |
|
show_error=True |
|
) |
|
|