prithivMLmods commited on
Commit
34ad363
·
verified ·
1 Parent(s): 0ef351a

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +313 -0
app.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import json
3
+ import math
4
+ import os
5
+ import traceback
6
+ from io import BytesIO
7
+ from typing import Any, Dict, List, Optional, Tuple
8
+ import re
9
+ import time
10
+ from threading import Thread
11
+
12
+ import gradio as gr
13
+ import requests
14
+ import torch
15
+ from PIL import Image
16
+
17
+ from transformers import (
18
+ Qwen2VLForConditionalGeneration,
19
+ Qwen2_5_VLForConditionalGeneration,
20
+ AutoModelForImageTextToText,
21
+ AutoProcessor,
22
+ TextIteratorStreamer,
23
+ AutoModel,
24
+ AutoTokenizer,
25
+ )
26
+
27
+ from transformers.image_utils import load_image
28
+
29
+ # --- Constants and Model Setup ---
30
+ MAX_INPUT_TOKEN_LENGTH = 4096
31
+ # Note: The following line correctly falls back to CPU if CUDA is not available.
32
+ # Let the environment (e.g., Hugging Face Spaces) determine the device.
33
+ # This avoids conflicts with the CUDA environment setup by the platform.
34
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
+
36
+ print("CUDA_VISIBLE_DEVICES=", os.environ.get("CUDA_VISIBLE_DEVICES"))
37
+ print("torch.__version__ =", torch.__version__)
38
+ print("torch.version.cuda =", torch.version.cuda)
39
+ print("cuda available:", torch.cuda.is_available())
40
+ print("cuda device count:", torch.cuda.device_count())
41
+ if torch.cuda.is_available():
42
+ print("current device:", torch.cuda.current_device())
43
+ print("device name:", torch.cuda.get_device_name(torch.cuda.current_device()))
44
+
45
+ print("Using device:", device)
46
+
47
+ # --- Model Loading ---
48
+
49
+ # --- Prompts for Different Tasks ---
50
+ layout_prompt = """Please output the layout information from the image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
51
+
52
+ 1. Bbox format: [x1, y1, x2, y2]
53
+ 2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title'].
54
+ 3. Text Extraction & Formatting Rules:
55
+ - For tables, provide the content in a structured JSON format.
56
+ - For all other elements, provide the plain text.
57
+ 4. Constraints:
58
+ - The output must be the original text from the image.
59
+ - All layout elements must be sorted according to human reading order.
60
+ 5. Final Output: The entire output must be a single JSON object wrapped in ```json ... ```.
61
+ """
62
+
63
+ ocr_prompt = "Perform precise OCR on the image. Extract all text content, maintaining the original structure, paragraphs, and tables as formatted markdown."
64
+
65
+ # --- Model Loading ---
66
+ MODEL_ID_M = "prithivMLmods/Camel-Doc-OCR-080125"
67
+ processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
68
+ model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
69
+ MODEL_ID_M, trust_remote_code=True, torch_dtype=torch.float16
70
+ ).to(device).eval()
71
+
72
+ MODEL_ID_T = "prithivMLmods/Megalodon-OCR-Sync-0713"
73
+ processor_t = AutoProcessor.from_pretrained(MODEL_ID_T, trust_remote_code=True)
74
+ model_t = Qwen2_5_VLForConditionalGeneration.from_pretrained(
75
+ MODEL_ID_T, trust_remote_code=True, torch_dtype=torch.float16
76
+ ).to(device).eval()
77
+
78
+ MODEL_ID_C = "nanonets/Nanonets-OCR-s"
79
+ processor_c = AutoProcessor.from_pretrained(MODEL_ID_C, trust_remote_code=True)
80
+ model_c = Qwen2_5_VLForConditionalGeneration.from_pretrained(
81
+ MODEL_ID_C, trust_remote_code=True, torch_dtype=torch.float16
82
+ ).to(device).eval()
83
+
84
+ MODEL_ID_G = "echo840/MonkeyOCR"
85
+ SUBFOLDER = "Recognition"
86
+ processor_g = AutoProcessor.from_pretrained(
87
+ MODEL_ID_G, trust_remote_code=True, subfolder=SUBFOLDER
88
+ )
89
+ model_g = Qwen2_5_VLForConditionalGeneration.from_pretrained(
90
+ MODEL_ID_G, trust_remote_code=True, subfolder=SUBFOLDER, torch_dtype=torch.float16
91
+ ).to(device).eval()
92
+
93
+ MODEL_ID_I = "allenai/olmOCR-7B-0725"
94
+ processor_i = AutoProcessor.from_pretrained(MODEL_ID_I, trust_remote_code=True)
95
+ model_i = Qwen2_5_VLForConditionalGeneration.from_pretrained(
96
+ MODEL_ID_I, trust_remote_code=True, torch_dtype=torch.float16
97
+ ).to(device).eval()
98
+
99
+ # --- Utility Functions ---
100
+ def layoutjson2md(layout_data: Any) -> str:
101
+ """
102
+ FIXED: Converts the structured JSON from Layout Analysis into formatted Markdown.
103
+ This version is robust against malformed JSON from the model.
104
+ """
105
+ markdown_lines = []
106
+
107
+ # If the model wraps the list in a dictionary, find and extract the list.
108
+ if isinstance(layout_data, dict):
109
+ found_list = None
110
+ for value in layout_data.values():
111
+ if isinstance(value, list):
112
+ found_list = value
113
+ break
114
+ if found_list is not None:
115
+ layout_data = found_list
116
+ else:
117
+ return "### Error: Could not find a list of layout items in the JSON object."
118
+
119
+ if not isinstance(layout_data, list):
120
+ return f"### Error: Expected a list of layout items, but received type {type(layout_data).__name__}."
121
+
122
+ try:
123
+ # Filter out any non-dictionary items and sort by reading order.
124
+ valid_items = [item for item in layout_data if isinstance(item, dict)]
125
+ sorted_items = sorted(valid_items, key=lambda x: (x.get('bbox', [0, 0, 0, 0])[1], x.get('bbox', [0, 0, 0, 0])[0]))
126
+
127
+ for item in sorted_items:
128
+ category = item.get('category', 'Text') # Default to 'Text' if no category
129
+ text = item.get('text', '')
130
+ if not text:
131
+ continue
132
+
133
+ if category == 'Title':
134
+ markdown_lines.append(f"# {text}\n")
135
+ elif category == 'Section-header':
136
+ markdown_lines.append(f"## {text}\n")
137
+ elif category == 'Table':
138
+ if isinstance(text, dict) and 'header' in text and 'rows' in text:
139
+ header = '| ' + ' | '.join(map(str, text['header'])) + ' |'
140
+ separator = '| ' + ' | '.join(['---'] * len(text['header'])) + ' |'
141
+ rows = ['| ' + ' | '.join(map(str, row)) + ' |' for row in text['rows']]
142
+ markdown_lines.extend([header, separator] + rows)
143
+ markdown_lines.append("\n")
144
+ else: # Fallback for simple text or malformed tables
145
+ markdown_lines.append(f"{text}\n")
146
+ else:
147
+ markdown_lines.append(f"{text}\n")
148
+
149
+ except Exception as e:
150
+ print(f"Error converting to markdown: {e}")
151
+ traceback.print_exc()
152
+ return "### Error: An unexpected error occurred while converting JSON to Markdown."
153
+
154
+ return "\n".join(markdown_lines)
155
+
156
+
157
+ # --- Core Application Logic ---
158
+ @spaces.GPU
159
+ def process_document_stream(model_name: str, task_choice: str, image: Image.Image, max_new_tokens: int):
160
+ """
161
+ Main generator function that handles both OCR and Layout Analysis tasks.
162
+ """
163
+ if image is None:
164
+ yield "Please upload an image.", "Please upload an image.", None
165
+ return
166
+
167
+ # 1. Select prompt based on user's task choice
168
+ text_prompt = ocr_prompt if task_choice == "Content Extraction" else layout_prompt
169
+
170
+ # 2. Select model and processor
171
+ if model_name == "Camel-Doc-OCR-080125": processor, model = processor_m, model_m
172
+ elif model_name == "Megalodon-OCR-Sync-0713": processor, model = processor_t, model_t
173
+ elif model_name == "Nanonets-OCR-s": processor, model = processor_c, model_c
174
+ elif model_name == "MonkeyOCR-Recognition": processor, model = processor_g, model_g
175
+ elif model_name == "olmOCR-7B-0725": processor, model = processor_i, model_i
176
+ else:
177
+ yield "Invalid model selected.", "Invalid model selected.", None
178
+ return
179
+
180
+ # 3. Prepare model inputs and streamer
181
+ messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": text_prompt}]}]
182
+ prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
183
+ inputs = processor(text=[prompt_full], images=[image], return_tensors="pt", padding=True, truncation=True, max_length=MAX_INPUT_TOKEN_LENGTH).to(device)
184
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
185
+ generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
186
+
187
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
188
+ thread.start()
189
+
190
+ # 4. Stream raw output to the UI in real-time
191
+ buffer = ""
192
+ for new_text in streamer:
193
+ buffer += new_text
194
+ buffer = buffer.replace("<|im_end|>", "")
195
+ time.sleep(0.01)
196
+ yield buffer , "⏳ Processing...", {"status": "streaming"}
197
+
198
+ # 5. Post-process the final buffer based on the selected task
199
+ if task_choice == "Content Extraction":
200
+ # For OCR, the buffer is the final result.
201
+ yield buffer, buffer, None
202
+ else: # Layout Analysis
203
+ try:
204
+ json_match = re.search(r'```json\s*([\s\S]+?)\s*```', buffer)
205
+ if not json_match:
206
+ # If no JSON block is found, try to parse the whole buffer as a fallback.
207
+ try:
208
+ layout_data = json.loads(buffer)
209
+ markdown_content = layoutjson2md(layout_data)
210
+ yield buffer, markdown_content, layout_data
211
+ return
212
+ except json.JSONDecodeError:
213
+ raise ValueError("JSON object not found in the model's output.")
214
+
215
+ json_str = json_match.group(1)
216
+ layout_data = json.loads(json_str)
217
+ markdown_content = layoutjson2md(layout_data)
218
+
219
+ yield buffer, markdown_content, layout_data
220
+ except Exception as e:
221
+ error_md = f"❌ **Error:** Failed to parse Layout JSON.\n\n**Details:**\n`{str(e)}`\n\n**Raw Output:**\n```\n{buffer}\n```"
222
+ error_json = {"error": "ProcessingError", "details": str(e), "raw_output": buffer}
223
+ yield buffer, error_md, error_json
224
+
225
+
226
+ # --- Gradio UI Definition ---
227
+ def create_gradio_interface():
228
+ """Builds and returns the Gradio web interface."""
229
+ css = """
230
+ .main-container { max-width: 1400px; margin: 0 auto; }
231
+ .process-button { border: none !important; color: white !important; font-weight: bold !important; background-color: blue !important;}
232
+ .process-button:hover { background-color: darkblue !important; transform: translateY(-2px) !important; box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; }
233
+ """
234
+ with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
235
+ gr.HTML("""
236
+ <div class="title" style="text-align: center">
237
+ <h1>Tiny VLMs Lab🧪</h1>
238
+ <p style="font-size: 1.1em; color: #6b7280; margin-bottom: 0.6em;">
239
+ Advanced Vision-Language Model for Image Content and Layout Extraction
240
+ </p>
241
+ </div>
242
+ """)
243
+
244
+ with gr.Row():
245
+ # Left Column (Inputs)
246
+ with gr.Column(scale=1):
247
+ model_choice = gr.Dropdown(
248
+ choices=["Camel-Doc-OCR-080125",
249
+ "MonkeyOCR-Recognition",
250
+ "olmOCR-7B-0725",
251
+ "Nanonets-OCR-s",
252
+ "Megalodon-OCR-Sync-0713"
253
+ ],
254
+ label="Select Model",
255
+ value="Nanonets-OCR-s"
256
+ )
257
+ task_choice = gr.Dropdown(
258
+ choices=["Content Extraction",
259
+ "Layout Analysis(.json)"],
260
+ label="Select Task", value="Content Extraction"
261
+ )
262
+ image_input = gr.Image(label="Upload Image", type="pil", sources=['upload'])
263
+ with gr.Accordion("Advanced Settings", open=False):
264
+ max_new_tokens = gr.Slider(minimum=512, maximum=8192, value=4096, step=256, label="Max New Tokens")
265
+
266
+ process_btn = gr.Button("🚀 Process Document", variant="primary", elem_classes=["process-button"], size="lg")
267
+ clear_btn = gr.Button("🗑️ Clear All", variant="secondary")
268
+
269
+ # Right Column (Outputs)
270
+ with gr.Column(scale=2):
271
+ with gr.Tabs() as tabs:
272
+ with gr.Tab("📝 Extracted Content"):
273
+ raw_output_stream = gr.Textbox(label="Raw Model Output Stream", interactive=False, lines=13, show_copy_button=True)
274
+ with gr.Row():
275
+ examples = gr.Examples(
276
+ examples=["examples/1.png", "examples/2.png", "examples/3.png", "examples/4.png", "examples/5.png"],
277
+ inputs=image_input,
278
+ label="Examples"
279
+ )
280
+ gr.Markdown("[Report-Bug💻](https://huggingface.co/spaces/prithivMLmods/OCR-Comparator/discussions)")
281
+ with gr.Tab("📰 README.md"):
282
+ with gr.Accordion("(Formatted Result)", open=True):
283
+ markdown_output = gr.Markdown(label="Formatted Markdown")
284
+
285
+ with gr.Tab("📋 Layout Analysis Results"):
286
+ json_output = gr.JSON(label="Structured Layout Data (JSON)")
287
+
288
+ # Event Handlers
289
+ def clear_all_outputs():
290
+ return None, "Raw output will appear here.", "Formatted results will appear here.", None
291
+
292
+ process_btn.click(
293
+ fn=process_document_stream,
294
+ inputs=[model_choice,
295
+ task_choice,
296
+ image_input,
297
+ max_new_tokens],
298
+ outputs=[raw_output_stream,
299
+ markdown_output,
300
+ json_output]
301
+ )
302
+ clear_btn.click(
303
+ clear_all_outputs,
304
+ outputs=[image_input,
305
+ raw_output_stream,
306
+ markdown_output,
307
+ json_output]
308
+ )
309
+ return demo
310
+
311
+ if __name__ == "__main__":
312
+ demo = create_gradio_interface()
313
+ demo.queue(max_size=50).launch(share=True, ssr_mode=False, show_error=True)