OCR-Comparator / app.py
prithivMLmods's picture
Update app.py
084ca31 verified
import spaces
import json
import math
import os
import traceback
from io import BytesIO
from typing import Any, Dict, List, Optional, Tuple
import re
import time
from threading import Thread
import gradio as gr
import requests
import torch
from PIL import Image
from transformers import (
Qwen2VLForConditionalGeneration,
Qwen2_5_VLForConditionalGeneration,
AutoModelForImageTextToText,
AutoProcessor,
TextIteratorStreamer,
AutoModel,
AutoTokenizer,
)
from transformers.image_utils import load_image
# --- Constants and Model Setup ---
MAX_INPUT_TOKEN_LENGTH = 4096
# Note: The following line correctly falls back to CPU if CUDA is not available.
# Let the environment (e.g., Hugging Face Spaces) determine the device.
# This avoids conflicts with the CUDA environment setup by the platform.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("CUDA_VISIBLE_DEVICES=", os.environ.get("CUDA_VISIBLE_DEVICES"))
print("torch.__version__ =", torch.__version__)
print("torch.version.cuda =", torch.version.cuda)
print("cuda available:", torch.cuda.is_available())
print("cuda device count:", torch.cuda.device_count())
if torch.cuda.is_available():
print("current device:", torch.cuda.current_device())
print("device name:", torch.cuda.get_device_name(torch.cuda.current_device()))
print("Using device:", device)
# --- Model Loading ---
# --- Prompts for Different Tasks ---
layout_prompt = """Please output the layout information from the image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
1. Bbox format: [x1, y1, x2, y2]
2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title'].
3. Text Extraction & Formatting Rules:
- For tables, provide the content in a structured JSON format.
- For all other elements, provide the plain text.
4. Constraints:
- The output must be the original text from the image.
- All layout elements must be sorted according to human reading order.
5. Final Output: The entire output must be a single JSON object wrapped in ```json ... ```.
"""
ocr_prompt = "Perform precise OCR on the image. Extract all text content, maintaining the original structure, paragraphs, and tables as formatted markdown."
# --- Model Loading ---
MODEL_ID_M = "prithivMLmods/Camel-Doc-OCR-080125"
processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_ID_M, trust_remote_code=True, torch_dtype=torch.float16
).to(device).eval()
MODEL_ID_T = "prithivMLmods/Megalodon-OCR-Sync-0713"
processor_t = AutoProcessor.from_pretrained(MODEL_ID_T, trust_remote_code=True)
model_t = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_ID_T, trust_remote_code=True, torch_dtype=torch.float16
).to(device).eval()
MODEL_ID_C = "nanonets/Nanonets-OCR-s"
processor_c = AutoProcessor.from_pretrained(MODEL_ID_C, trust_remote_code=True)
model_c = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_ID_C, trust_remote_code=True, torch_dtype=torch.float16
).to(device).eval()
MODEL_ID_G = "echo840/MonkeyOCR"
SUBFOLDER = "Recognition"
processor_g = AutoProcessor.from_pretrained(
MODEL_ID_G, trust_remote_code=True, subfolder=SUBFOLDER
)
model_g = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_ID_G, trust_remote_code=True, subfolder=SUBFOLDER, torch_dtype=torch.float16
).to(device).eval()
MODEL_ID_I = "allenai/olmOCR-7B-0725"
processor_i = AutoProcessor.from_pretrained(MODEL_ID_I, trust_remote_code=True)
model_i = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_ID_I, trust_remote_code=True, torch_dtype=torch.float16
).to(device).eval()
# --- Utility Functions ---
def layoutjson2md(layout_data: Any) -> str:
"""
FIXED: Converts the structured JSON from Layout Analysis into formatted Markdown.
This version is robust against malformed JSON from the model.
"""
markdown_lines = []
# If the model wraps the list in a dictionary, find and extract the list.
if isinstance(layout_data, dict):
found_list = None
for value in layout_data.values():
if isinstance(value, list):
found_list = value
break
if found_list is not None:
layout_data = found_list
else:
return "### Error: Could not find a list of layout items in the JSON object."
if not isinstance(layout_data, list):
return f"### Error: Expected a list of layout items, but received type {type(layout_data).__name__}."
try:
# Filter out any non-dictionary items and sort by reading order.
valid_items = [item for item in layout_data if isinstance(item, dict)]
sorted_items = sorted(valid_items, key=lambda x: (x.get('bbox', [0, 0, 0, 0])[1], x.get('bbox', [0, 0, 0, 0])[0]))
for item in sorted_items:
category = item.get('category', 'Text') # Default to 'Text' if no category
text = item.get('text', '')
if not text:
continue
if category == 'Title':
markdown_lines.append(f"# {text}\n")
elif category == 'Section-header':
markdown_lines.append(f"## {text}\n")
elif category == 'Table':
if isinstance(text, dict) and 'header' in text and 'rows' in text:
header = '| ' + ' | '.join(map(str, text['header'])) + ' |'
separator = '| ' + ' | '.join(['---'] * len(text['header'])) + ' |'
rows = ['| ' + ' | '.join(map(str, row)) + ' |' for row in text['rows']]
markdown_lines.extend([header, separator] + rows)
markdown_lines.append("\n")
else: # Fallback for simple text or malformed tables
markdown_lines.append(f"{text}\n")
else:
markdown_lines.append(f"{text}\n")
except Exception as e:
print(f"Error converting to markdown: {e}")
traceback.print_exc()
return "### Error: An unexpected error occurred while converting JSON to Markdown."
return "\n".join(markdown_lines)
# --- Core Application Logic ---
@spaces.GPU(duration=140) #2min:20secs
def process_document_stream(model_name: str, task_choice: str, image: Image.Image, max_new_tokens: int):
"""
Main generator function that handles both OCR and Layout Analysis tasks.
"""
if image is None:
yield "Please upload an image.", "Please upload an image.", None
return
# 1. Select prompt based on user's task choice
text_prompt = ocr_prompt if task_choice == "Content Extraction" else layout_prompt
# 2. Select model and processor
if model_name == "Camel-Doc-OCR-080125": processor, model = processor_m, model_m
elif model_name == "Megalodon-OCR-Sync-0713": processor, model = processor_t, model_t
elif model_name == "Nanonets-OCR-s": processor, model = processor_c, model_c
elif model_name == "MonkeyOCR-Recognition": processor, model = processor_g, model_g
elif model_name == "olmOCR-7B-0725": processor, model = processor_i, model_i
else:
yield "Invalid model selected.", "Invalid model selected.", None
return
# 3. Prepare model inputs and streamer
messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": text_prompt}]}]
prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=[prompt_full], images=[image], return_tensors="pt", padding=True, truncation=True, max_length=MAX_INPUT_TOKEN_LENGTH).to(device)
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# 4. Stream raw output to the UI in real-time
buffer = ""
for new_text in streamer:
buffer += new_text
buffer = buffer.replace("<|im_end|>", "")
time.sleep(0.01)
yield buffer , "⏳ Processing...", {"status": "streaming"}
# 5. Post-process the final buffer based on the selected task
if task_choice == "Content Extraction":
# For OCR, the buffer is the final result.
yield buffer, buffer, None
else: # Layout Analysis
try:
json_match = re.search(r'```json\s*([\s\S]+?)\s*```', buffer)
if not json_match:
# If no JSON block is found, try to parse the whole buffer as a fallback.
try:
layout_data = json.loads(buffer)
markdown_content = layoutjson2md(layout_data)
yield buffer, markdown_content, layout_data
return
except json.JSONDecodeError:
raise ValueError("JSON object not found in the model's output.")
json_str = json_match.group(1)
layout_data = json.loads(json_str)
markdown_content = layoutjson2md(layout_data)
yield buffer, markdown_content, layout_data
except Exception as e:
error_md = f"❌ **Error:** Failed to parse Layout JSON.\n\n**Details:**\n`{str(e)}`\n\n**Raw Output:**\n```\n{buffer}\n```"
error_json = {"error": "ProcessingError", "details": str(e), "raw_output": buffer}
yield buffer, error_md, error_json
# --- Gradio UI Definition ---
def create_gradio_interface():
"""Builds and returns the Gradio web interface."""
css = """
.main-container { max-width: 1400px; margin: 0 auto; }
.process-button { border: none !important; color: white !important; font-weight: bold !important; background-color: blue !important;}
.process-button:hover { background-color: darkblue !important; transform: translateY(-2px) !important; box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; }
"""
with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
gr.HTML("""
<div class="title" style="text-align: center">
<h1>OCR Comparator🥠</h1>
<p style="font-size: 1.1em; color: #6b7280; margin-bottom: 0.6em;">
Advanced Vision-Language Model for Image Content and Layout Extraction
</p>
</div>
""")
with gr.Row():
# Left Column (Inputs)
with gr.Column(scale=1):
model_choice = gr.Dropdown(
choices=["Camel-Doc-OCR-080125",
"MonkeyOCR-Recognition",
"olmOCR-7B-0725",
"Nanonets-OCR-s",
"Megalodon-OCR-Sync-0713"
],
label="Select Model",
value="Nanonets-OCR-s"
)
task_choice = gr.Dropdown(
choices=["Content Extraction",
"Layout Analysis(.json)"],
label="Select Task", value="Content Extraction"
)
image_input = gr.Image(label="Upload Image", type="pil", sources=['upload'])
with gr.Accordion("Advanced Settings", open=False):
max_new_tokens = gr.Slider(minimum=512, maximum=8192, value=4096, step=256, label="Max New Tokens")
process_btn = gr.Button("🚀 Process Document", variant="primary", elem_classes=["process-button"], size="lg")
clear_btn = gr.Button("🗑️ Clear All", variant="secondary")
# Right Column (Outputs)
with gr.Column(scale=2):
with gr.Tabs() as tabs:
with gr.Tab("📝 Extracted Content"):
raw_output_stream = gr.Textbox(label="Raw Model Output Stream", interactive=False, lines=13, show_copy_button=True)
with gr.Row():
examples = gr.Examples(
examples=["examples/1.png", "examples/2.png", "examples/3.png", "examples/4.png", "examples/5.png"],
inputs=image_input,
label="Examples"
)
gr.Markdown("[Report-Bug💻](https://huggingface.co/spaces/prithivMLmods/OCR-Comparator/discussions)")
with gr.Tab("📰 README.md"):
with gr.Accordion("(Formatted Result)", open=True):
markdown_output = gr.Markdown(label="Formatted Markdown")
with gr.Tab("📋 Layout Analysis Results"):
json_output = gr.JSON(label="Structured Layout Data (JSON)")
# Event Handlers
def clear_all_outputs():
return None, "Raw output will appear here.", "Formatted results will appear here.", None
process_btn.click(
fn=process_document_stream,
inputs=[model_choice,
task_choice,
image_input,
max_new_tokens],
outputs=[raw_output_stream,
markdown_output,
json_output]
)
clear_btn.click(
clear_all_outputs,
outputs=[image_input,
raw_output_stream,
markdown_output,
json_output]
)
return demo
if __name__ == "__main__":
demo = create_gradio_interface()
demo.queue(max_size=50).launch(share=True, ssr_mode=False, show_error=True)