prithivMLmods commited on
Commit
4d0a926
·
verified ·
1 Parent(s): 74aa8d3

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +253 -0
app.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import json
3
+ import math
4
+ import os
5
+ import traceback
6
+ from io import BytesIO
7
+ from typing import Any, Dict, List, Optional, Tuple
8
+ import re
9
+ import time
10
+ from threading import Thread
11
+
12
+ import gradio as gr
13
+ import requests
14
+ import torch
15
+ from PIL import Image
16
+ from transformers import (
17
+ Qwen2_5_VLForConditionalGeneration,
18
+ AutoProcessor,
19
+ TextIteratorStreamer,
20
+ )
21
+
22
+ js_func = """
23
+ function refresh() {
24
+ const url = new URL(window.location);
25
+ if (url.searchParams.get('__theme') !== 'dark') {
26
+ url.searchParams.set('__theme', 'dark');
27
+ window.location.href = url.href;
28
+ }
29
+ }
30
+ """
31
+
32
+ # --- Constants and Model Setup ---
33
+ MAX_INPUT_TOKEN_LENGTH = 4096
34
+ device = "cuda" if torch.cuda.is_available() else "cpu"
35
+
36
+ # --- Prompts for Different Tasks ---
37
+ layout_prompt = """Please output the layout information from the image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
38
+
39
+ 1. Bbox format: [x1, y1, x2, y2]
40
+ 2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title'].
41
+ 3. Text Extraction & Formatting Rules:
42
+ - For tables, provide the content in a structured JSON format.
43
+ - For all other elements, provide the plain text.
44
+ 4. Constraints:
45
+ - The output must be the original text from the image.
46
+ - All layout elements must be sorted according to human reading order.
47
+ 5. Final Output: The entire output must be a single JSON object wrapped in ```json ... ```.
48
+ """
49
+
50
+ ocr_prompt = "Perform precise OCR on the image. Extract all text content, maintaining the original structure, paragraphs, and tables as formatted markdown."
51
+
52
+ # --- Model Loading ---
53
+ MODEL_ID_M = "prithivMLmods/Camel-Doc-OCR-062825"
54
+ processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
55
+ model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
56
+ MODEL_ID_M, trust_remote_code=True, torch_dtype=torch.float16
57
+ ).to(device).eval()
58
+
59
+ MODEL_ID_T = "prithivMLmods/Megalodon-OCR-Sync-0713"
60
+ processor_t = AutoProcessor.from_pretrained(MODEL_ID_T, trust_remote_code=True)
61
+ model_t = Qwen2_5_VLForConditionalGeneration.from_pretrained(
62
+ MODEL_ID_T, trust_remote_code=True, torch_dtype=torch.float16
63
+ ).to(device).eval()
64
+
65
+ MODEL_ID_C = "nanonets/Nanonets-OCR-s"
66
+ processor_c = AutoProcessor.from_pretrained(MODEL_ID_C, trust_remote_code=True)
67
+ model_c = Qwen2_5_VLForConditionalGeneration.from_pretrained(
68
+ MODEL_ID_C, trust_remote_code=True, torch_dtype=torch.float16
69
+ ).to(device).eval()
70
+
71
+ MODEL_ID_G = "echo840/MonkeyOCR"
72
+ SUBFOLDER = "Recognition"
73
+ processor_g = AutoProcessor.from_pretrained(
74
+ MODEL_ID_G, trust_remote_code=True, subfolder=SUBFOLDER
75
+ )
76
+ model_g = Qwen2_5_VLForConditionalGeneration.from_pretrained(
77
+ MODEL_ID_G, trust_remote_code=True, subfolder=SUBFOLDER, torch_dtype=torch.float16
78
+ ).to(device).eval()
79
+
80
+ # --- Utility Functions ---
81
+ def layoutjson2md(layout_data: List[Dict]) -> str:
82
+ """Converts the structured JSON from Layout Analysis into formatted Markdown."""
83
+ markdown_lines = []
84
+ try:
85
+ # Sort items by reading order (top-to-bottom, left-to-right)
86
+ sorted_items = sorted(layout_data, key=lambda x: (x.get('bbox', [0,0,0,0])[1], x.get('bbox', [0,0,0,0])[0]))
87
+ for item in sorted_items:
88
+ category = item.get('category', '')
89
+ text = item.get('text', '')
90
+ if not text: continue
91
+
92
+ if category == 'Title': markdown_lines.append(f"# {text}\n")
93
+ elif category == 'Section-header': markdown_lines.append(f"## {text}\n")
94
+ elif category == 'Table':
95
+ # Handle structured table JSON
96
+ if isinstance(text, dict) and 'header' in text and 'rows' in text:
97
+ header = '| ' + ' | '.join(map(str, text['header'])) + ' |'
98
+ separator = '| ' + ' | '.join(['---'] * len(text['header'])) + ' |'
99
+ rows = ['| ' + ' | '.join(map(str, row)) + ' |' for row in text['rows']]
100
+ markdown_lines.extend([header, separator] + rows)
101
+ markdown_lines.append("\n")
102
+ else: # Fallback for simple text
103
+ markdown_lines.append(f"{text}\n")
104
+ else:
105
+ markdown_lines.append(f"{text}\n")
106
+ except Exception as e:
107
+ print(f"Error converting to markdown: {e}")
108
+ return "### Error converting JSON to Markdown."
109
+ return "\n".join(markdown_lines)
110
+
111
+ # --- Core Application Logic ---
112
+ @spaces.GPU
113
+ def process_document_stream(model_name: str, task_choice: str, image: Image.Image, max_new_tokens: int):
114
+ """
115
+ Main generator function that handles both OCR and Layout Analysis tasks.
116
+ """
117
+ if image is None:
118
+ yield "Please upload an image.", "Please upload an image.", None
119
+ return
120
+
121
+ # 1. Select prompt based on user's task choice
122
+ text_prompt = ocr_prompt if task_choice == "Content Extraction" else layout_prompt
123
+
124
+ # 2. Select model and processor
125
+ if model_name == "Camel-Doc-OCR-062825": processor, model = processor_m, model_m
126
+ elif model_name == "Megalodon-OCR-Sync-0713": processor, model = processor_t, model_t
127
+ elif model_name == "Nanonets-OCR-s": processor, model = processor_c, model_c
128
+ elif model_name == "MonkeyOCR-Recognition": processor, model = processor_g, model_g
129
+ else:
130
+ yield "Invalid model selected.", "Invalid model selected.", None
131
+ return
132
+
133
+ # 3. Prepare model inputs and streamer
134
+ messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": text_prompt}]}]
135
+ prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
136
+ inputs = processor(text=[prompt_full], images=[image], return_tensors="pt", padding=True, truncation=True, max_length=MAX_INPUT_TOKEN_LENGTH).to(device)
137
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
138
+ generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
139
+
140
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
141
+ thread.start()
142
+
143
+ # 4. Stream raw output to the UI in real-time
144
+ buffer = ""
145
+ for new_text in streamer:
146
+ buffer += new_text
147
+ buffer = buffer.replace("<|im_end|>", "")
148
+ time.sleep(0.01)
149
+ yield buffer, "⏳ Processing...", {"status": "streaming"}
150
+
151
+ # 5. Post-process the final buffer based on the selected task
152
+ if task_choice == "Content Extraction":
153
+ # For OCR, the buffer is the final result.
154
+ yield buffer, buffer, None
155
+ else: # Layout Analysis
156
+ try:
157
+ json_match = re.search(r'```json\s*([\s\S]+?)\s*```', buffer)
158
+ if not json_match:
159
+ raise json.JSONDecodeError("JSON object not found in output.", buffer, 0)
160
+
161
+ json_str = json_match.group(1)
162
+ layout_data = json.loads(json_str)
163
+ markdown_content = layoutjson2md(layout_data)
164
+
165
+ yield buffer, markdown_content, layout_data
166
+ except Exception as e:
167
+ error_md = f"❌ **Error:** Failed to parse Layout JSON.\n\n**Details:**\n`{str(e)}`"
168
+ error_json = {"error": "ProcessingError", "details": str(e), "raw_output": buffer}
169
+ yield buffer, error_md, error_json
170
+
171
+ # --- Gradio UI Definition ---
172
+ def create_gradio_interface():
173
+ """Builds and returns the Gradio web interface."""
174
+ css = """
175
+ .main-container { max-width: 1400px; margin: 0 auto; }
176
+ .process-button { border: none !important; color: white !important; font-weight: bold !important; background-color: blue !important;}
177
+ .process-button:hover { background-color: darkblue !important; transform: translateY(-2px) !important; box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; }
178
+ """
179
+ with gr.Blocks(theme="bethecloud/storj_theme", css=css, js=js_func) as demo:
180
+ gr.HTML("""
181
+ <div class="title" style="text-align: center">
182
+ <h1>OCR Comparator👨‍🏫</h1>
183
+ <p style="font-size: 1.1em; color: #6b7280; margin-bottom: 0.6em;">
184
+ Advanced Vision-Language Model for Image Content and Layout Extraction
185
+ </p>
186
+ </div>
187
+ """)
188
+
189
+ with gr.Row():
190
+ # Left Column (Inputs)
191
+ with gr.Column(scale=1):
192
+ model_choice = gr.Dropdown(
193
+ choices=["Camel-Doc-OCR-062825",
194
+ "MonkeyOCR-Recognition",
195
+ "Nanonets-OCR-s",
196
+ "Megalodon-OCR-Sync-0713"],
197
+ label="Select Model", value="Nanonets-OCR-s"
198
+ )
199
+ task_choice = gr.Dropdown(
200
+ choices=["Content Extraction", "Layout Analysis(.json)"],
201
+ label="Select Task", value="Content Extraction"
202
+ )
203
+ image_input = gr.Image(label="Upload Image", type="pil", sources=['upload'])
204
+ with gr.Accordion("Advanced Settings", open=False):
205
+ max_new_tokens = gr.Slider(minimum=512, maximum=8192, value=4096, step=256, label="Max New Tokens")
206
+
207
+ process_btn = gr.Button("🚀 Process Document", variant="primary", elem_classes=["process-button"], size="lg")
208
+ clear_btn = gr.Button("🗑️ Clear All", variant="secondary")
209
+
210
+ # Right Column (Outputs)
211
+ with gr.Column(scale=2):
212
+ with gr.Tabs() as tabs:
213
+ with gr.Tab("📝 Extracted Content"):
214
+ raw_output_stream = gr.Textbox(label="Raw Model Output Stream", interactive=False, lines=13, show_copy_button=True)
215
+ with gr.Row():
216
+ examples = gr.Examples(
217
+ examples=["examples/example_img2.png", "examples/example_img1.png"],
218
+ inputs=image_input,
219
+ label="Examples"
220
+ )
221
+ with gr.Tab("📰 README.md"):
222
+ with gr.Accordion("(Formatted Result)", open=True):
223
+ markdown_output = gr.Markdown(label="Formatted Markdown")
224
+
225
+ with gr.Tab("📋 Layout Analysis Results"):
226
+ json_output = gr.JSON(label="Structured Layout Data (JSON)")
227
+
228
+ # Event Handlers
229
+ def clear_all_outputs():
230
+ return None, "Raw output will appear here.", "Formatted results will appear here.", None
231
+
232
+ process_btn.click(
233
+ fn=process_document_stream,
234
+ inputs=[model_choice,
235
+ task_choice,
236
+ image_input,
237
+ max_new_tokens],
238
+ outputs=[raw_output_stream,
239
+ markdown_output,
240
+ json_output]
241
+ )
242
+ clear_btn.click(
243
+ clear_all_outputs,
244
+ outputs=[image_input,
245
+ raw_output_stream,
246
+ markdown_output,
247
+ json_output]
248
+ )
249
+ return demo
250
+
251
+ if __name__ == "__main__":
252
+ demo = create_gradio_interface()
253
+ demo.queue().launch(share=True, mcp_server=True, ssr_mode=False, show_error=True)