Spaces:
Running
on
Zero
Running
on
Zero
File size: 10,972 Bytes
30d6225 c152910 30d6225 888b5aa 96f7759 30d6225 c152910 30d6225 9180057 96f7759 30d6225 db537bc 566263b 30d6225 db537bc 566263b 96f7759 db537bc c152910 4148e9b c152910 4148e9b 566263b 4148e9b 566263b 4148e9b c152910 96f7759 30d6225 566263b 30d6225 566263b 30d6225 566263b 30d6225 566263b 30d6225 566263b 30d6225 566263b b789dc3 566263b b789dc3 566263b b789dc3 566263b b789dc3 566263b b789dc3 566263b b789dc3 566263b b789dc3 566263b b789dc3 566263b 96f7759 b789dc3 566263b 96f7759 566263b 96f7759 566263b 96f7759 566263b db537bc c152910 566263b c152910 566263b c152910 db537bc c152910 f17f462 c152910 566263b c152910 566263b db537bc 566263b c152910 9ebf911 30d6225 db537bc 30d6225 db537bc 566263b db537bc 566263b 9ebf911 db537bc 566263b c152910 db537bc 566263b b789dc3 566263b 888b5aa 566263b 888b5aa 566263b 888b5aa 566263b 888b5aa 566263b c152910 db537bc 566263b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 |
import spaces
import json
import math
import os
import traceback
from io import BytesIO
from typing import Any, Dict, List, Optional, Tuple
import re
import time
from threading import Thread
import gradio as gr
import requests
import torch
from PIL import Image
from transformers import (
Qwen2_5_VLForConditionalGeneration,
AutoProcessor,
TextIteratorStreamer,
)
# --- Constants and Model Setup ---
MAX_INPUT_TOKEN_LENGTH = 4096
device = "cuda" if torch.cuda.is_available() else "cpu"
# The detailed prompt to instruct the model to generate structured JSON
prompt = """Please output the layout information from the image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
1. Bbox format: [x1, y1, x2, y2]
2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title'].
3. Text Extraction & Formatting Rules:
- Picture: For the 'Picture' category, the text field should be omitted.
- Formula: Format its text as LaTeX.
- Table: For tables, provide the content in a structured format within the JSON.
- All Others (Text, Title, etc.): Format their text as Markdown.
4. Constraints:
- The output text must be the original text from the image, with no translation.
- All layout elements must be sorted according to human reading order.
5. Final Output: The entire output must be a single JSON object wrapped in ```json ... ```.
"""
# Load models
MODEL_ID_M = "prithivMLmods/Camel-Doc-OCR-062825"
processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_ID_M, trust_remote_code=True, torch_dtype=torch.float16
).to(device).eval()
MODEL_ID_T = "prithivMLmods/Megalodon-OCR-Sync-0713"
processor_t = AutoProcessor.from_pretrained(MODEL_ID_T, trust_remote_code=True)
model_t = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_ID_T, trust_remote_code=True, torch_dtype=torch.float16
).to(device).eval()
MODEL_ID_C = "nanonets/Nanonets-OCR-s"
processor_c = AutoProcessor.from_pretrained(MODEL_ID_C, trust_remote_code=True)
model_c = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_ID_C, trust_remote_code=True, torch_dtype=torch.float16
).to(device).eval()
MODEL_ID_G = "echo840/MonkeyOCR"
SUBFOLDER = "Recognition"
processor_g = AutoProcessor.from_pretrained(
MODEL_ID_G, trust_remote_code=True, subfolder=SUBFOLDER
)
model_g = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_ID_G, trust_remote_code=True, subfolder=SUBFOLDER, torch_dtype=torch.float16
).to(device).eval()
# --- Utility Functions ---
def layoutjson2md(layout_data: List[Dict]) -> str:
"""Converts the structured JSON layout data into formatted Markdown."""
markdown_lines = []
try:
sorted_items = sorted(layout_data, key=lambda x: (x.get('bbox', [0, 0, 0, 0])[1], x.get('bbox', [0, 0, 0, 0])[0]))
for item in sorted_items:
category = item.get('category', '')
text = item.get('text', '')
if not text:
continue
if category == 'Title':
markdown_lines.append(f"# {text}\n")
elif category == 'Section-header':
markdown_lines.append(f"## {text}\n")
elif category == 'Table':
# Check if the text is a dictionary representing a structured table
if isinstance(text, dict) and 'header' in text and 'rows' in text:
header = '| ' + ' | '.join(map(str, text['header'])) + ' |'
separator = '| ' + ' | '.join(['---'] * len(text['header'])) + ' |'
rows = ['| ' + ' | '.join(map(str, row)) + ' |' for row in text['rows']]
markdown_lines.append(header)
markdown_lines.append(separator)
markdown_lines.extend(rows)
markdown_lines.append("\n")
else:
# Fallback for unstructured table text
markdown_lines.append(f"{text}\n")
else:
markdown_lines.append(f"{text}\n")
except Exception as e:
print(f"Error converting to markdown: {e}")
return "### Error converting JSON to Markdown."
return "\n".join(markdown_lines)
# --- Core Application Logic ---
@spaces.GPU
def process_document_stream(model_name: str, image: Image.Image, text_prompt: str, max_new_tokens: int):
"""
Main generator function that streams raw model output and then processes it into
formatted Markdown and structured JSON for the UI.
"""
if image is None:
yield "Please upload an image.", "Please upload an image.", None
return
# Select the model and processor
if model_name == "Camel-Doc-OCR-062825": processor, model = processor_m, model_m
elif model_name == "Megalodon-OCR-Sync-0713": processor, model = processor_t, model_t
elif model_name == "Nanonets-OCR-s": processor, model = processor_c, model_c
elif model_name == "MonkeyOCR-Recognition": processor, model = processor_g, model_g
else:
yield "Invalid model selected.", "Invalid model selected.", None
return
# Prepare model inputs
messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": text_prompt}]}]
prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=[prompt_full], images=[image], return_tensors="pt", padding=True, truncation=True, max_length=MAX_INPUT_TOKEN_LENGTH).to(device)
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
# Start generation in a separate thread
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# Stream raw output to the UI
buffer = ""
for new_text in streamer:
buffer += new_text
buffer = buffer.replace("<|im_end|>", "")
time.sleep(0.01)
# Yield the raw stream and placeholders for the final results
yield buffer, "β³ Formatting Markdown...", {"status": "processing"}
# After streaming is complete, process the final buffer
try:
# Extract the JSON object from the buffer
json_match = re.search(r'```json\s*([\s\S]+?)\s*```', buffer)
if not json_match:
raise json.JSONDecodeError("JSON object not found in the model's output.", buffer, 0)
json_str = json_match.group(1)
layout_data = json.loads(json_str)
# Convert the parsed JSON to formatted markdown
markdown_content = layoutjson2md(layout_data)
# Yield the final, complete results
yield buffer, markdown_content, layout_data
except json.JSONDecodeError as e:
print(f"JSON parsing failed: {e}")
error_md = f"β **Error:** Failed to parse JSON from the model's output.\n\nSee the raw output stream for details."
error_json = {"error": "JSONDecodeError", "details": str(e), "raw_output": buffer}
yield buffer, error_md, error_json
except Exception as e:
print(f"An unexpected error occurred: {e}")
yield buffer, f"β An unexpected error occurred: {e}", None
# --- Gradio UI Definition ---
def create_gradio_interface():
"""Builds and returns the Gradio web interface."""
css = """
.main-container { max-width: 1400px; margin: 0 auto; }
.process-button { border: none !important; color: white !important; font-weight: bold !important; background-color: blue !important;}
.process-button:hover { background-color: darkblue !important; transform: translateY(-2px) !important; box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; }
"""
with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
gr.HTML("""
<div class="title" style="text-align: center">
<h1>Dot<span style="color: red;">β</span><strong></strong>OCR Comparator</h1>
<p style="font-size: 1.1em; color: #6b7280; margin-bottom: 0.6em;">
Advanced Vision-Language Model for Image Layout Analysis
</p>
</div>
""")
with gr.Row():
# --- Left Column (Inputs) ---
with gr.Column(scale=1):
model_choice = gr.Radio(
choices=["Camel-Doc-OCR-062825", "MonkeyOCR-Recognition", "Nanonets-OCR-s", "Megalodon-OCR-Sync-0713"],
label="Select Model",
value="Camel-Doc-OCR-062825"
)
image_input = gr.Image(label="Upload Image", type="pil", sources=['upload'])
with gr.Accordion("Advanced Settings", open=False):
max_new_tokens = gr.Slider(minimum=1000, maximum=8192, value=4096, step=256, label="Max New Tokens")
process_btn = gr.Button("π Process Document", variant="primary", elem_classes=["process-button"], size="lg")
clear_btn = gr.Button("ποΈ Clear All", variant="secondary")
# --- Right Column (Outputs) ---
with gr.Column(scale=2):
with gr.Tabs():
with gr.Tab("π Extracted Content"):
raw_output_stream = gr.Textbox(label="Raw Model Output Stream", interactive=False, lines=15, show_copy_button=True)
with gr.Accordion("(Formatted Result)", open=True):
markdown_output = gr.Markdown(label="Formatted Markdown (from JSON)")
with gr.Tab("π Layout Analysis Results"):
json_output = gr.JSON(label="Structured Layout Data (JSON)", value=None)
# --- Event Handlers ---
def clear_all_outputs():
"""Resets all input and output fields to their default state."""
return None, "Raw output will appear here.", "Formatted results will appear here.", None
# Connect the process button to the main generator function
process_btn.click(
fn=process_document_stream,
inputs=[model_choice, image_input, gr.Textbox(value=prompt, visible=False), max_new_tokens],
outputs=[raw_output_stream, markdown_output, json_output]
)
# Connect the clear button
clear_btn.click(
clear_all_outputs,
outputs=[image_input, raw_output_stream, markdown_output, json_output]
)
return demo
if __name__ == "__main__":
demo = create_gradio_interface()
demo.queue().launch(server_name="0.0.0.0", server_port=7860, show_error=True) |