prithivMLmods commited on
Commit
74aa8d3
·
verified ·
1 Parent(s): 5ecf0f4

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -289
app.py DELETED
@@ -1,289 +0,0 @@
1
- import spaces
2
- import json
3
- import math
4
- import os
5
- import traceback
6
- from io import BytesIO
7
- from typing import Any, Dict, List, Optional, Tuple
8
- import re
9
- import time
10
- from threading import Thread
11
-
12
- import gradio as gr
13
- import requests
14
- import torch
15
- from PIL import Image
16
- from transformers import (
17
- Qwen2_5_VLForConditionalGeneration,
18
- AutoProcessor,
19
- TextIteratorStreamer,
20
- AutoModel,
21
- AutoTokenizer
22
- )
23
-
24
- js_func = """
25
- function refresh() {
26
- const url = new URL(window.location);
27
- if (url.searchParams.get('__theme') !== 'dark') {
28
- url.searchParams.set('__theme', 'dark');
29
- window.location.href = url.href;
30
- }
31
- }
32
- """
33
-
34
- # --- Constants and Model Setup ---
35
- MAX_INPUT_TOKEN_LENGTH = 4096
36
- device = "cuda" if torch.cuda.is_available() else "cpu"
37
-
38
- # --- Prompts for Different Tasks ---
39
- layout_prompt = """Please output the layout information from the image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
40
-
41
- 1. Bbox format: [x1, y1, x2, y2]
42
- 2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title'].
43
- 3. Text Extraction & Formatting Rules:
44
- - For tables, provide the content in a structured JSON format.
45
- - For all other elements, provide the plain text.
46
- 4. Constraints:
47
- - The output must be the original text from the image.
48
- - All layout elements must be sorted according to human reading order.
49
- 5. Final Output: The entire output must be a single JSON object wrapped in ```json ... ```.
50
- """
51
-
52
- ocr_prompt = "Perform precise OCR on the image. Extract all text content, maintaining the original structure, paragraphs, and tables as formatted markdown."
53
-
54
- # --- Model Loading ---
55
- MODEL_ID_M = "prithivMLmods/Camel-Doc-OCR-062825"
56
- processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
57
- model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
58
- MODEL_ID_M, trust_remote_code=True, torch_dtype=torch.float16
59
- ).to(device).eval()
60
-
61
- MODEL_ID_T = "prithivMLmods/Megalodon-OCR-Sync-0713"
62
- processor_t = AutoProcessor.from_pretrained(MODEL_ID_T, trust_remote_code=True)
63
- model_t = Qwen2_5_VLForConditionalGeneration.from_pretrained(
64
- MODEL_ID_T, trust_remote_code=True, torch_dtype=torch.float16
65
- ).to(device).eval()
66
-
67
- MODEL_ID_C = "nanonets/Nanonets-OCR-s"
68
- processor_c = AutoProcessor.from_pretrained(MODEL_ID_C, trust_remote_code=True)
69
- model_c = Qwen2_5_VLForConditionalGeneration.from_pretrained(
70
- MODEL_ID_C, trust_remote_code=True, torch_dtype=torch.float16
71
- ).to(device).eval()
72
-
73
- MODEL_ID_G = "echo840/MonkeyOCR"
74
- SUBFOLDER = "Recognition"
75
- processor_g = AutoProcessor.from_pretrained(
76
- MODEL_ID_G, trust_remote_code=True, subfolder=SUBFOLDER
77
- )
78
- model_g = Qwen2_5_VLForConditionalGeneration.from_pretrained(
79
- MODEL_ID_G, trust_remote_code=True, subfolder=SUBFOLDER, torch_dtype=torch.float16
80
- ).to(device).eval()
81
-
82
- # --- New Model ---
83
- MODEL_ID_V4 = 'openbmb/MiniCPM-V-4'
84
- model_v4 = AutoModel.from_pretrained(
85
- MODEL_ID_V4,
86
- trust_remote_code=True,
87
- torch_dtype=torch.bfloat16,
88
- attn_implementation='sdpa' # Use 'flash_attention_2' if available and supported
89
- ).eval().to(device)
90
- tokenizer_v4 = AutoTokenizer.from_pretrained(MODEL_ID_V4, trust_remote_code=True)
91
-
92
-
93
- # --- Utility Functions ---
94
- def layoutjson2md(layout_data: List[Dict]) -> str:
95
- """Converts the structured JSON from Layout Analysis into formatted Markdown."""
96
- markdown_lines = []
97
- try:
98
- # Sort items by reading order (top-to-bottom, left-to-right)
99
- sorted_items = sorted(layout_data, key=lambda x: (x.get('bbox', [0,0,0,0])[1], x.get('bbox', [0,0,0,0])[0]))
100
- for item in sorted_items:
101
- category = item.get('category', '')
102
- text = item.get('text', '')
103
- if not text: continue
104
-
105
- if category == 'Title': markdown_lines.append(f"# {text}\n")
106
- elif category == 'Section-header': markdown_lines.append(f"## {text}\n")
107
- elif category == 'Table':
108
- # Handle structured table JSON
109
- if isinstance(text, dict) and 'header' in text and 'rows' in text:
110
- header = '| ' + ' | '.join(map(str, text['header'])) + ' |'
111
- separator = '| ' + ' | '.join(['---'] * len(text['header'])) + ' |'
112
- rows = ['| ' + ' | '.join(map(str, row)) + ' |' for row in text['rows']]
113
- markdown_lines.extend([header, separator] + rows)
114
- markdown_lines.append("\n")
115
- else: # Fallback for simple text
116
- markdown_lines.append(f"{text}\n")
117
- else:
118
- markdown_lines.append(f"{text}\n")
119
- except Exception as e:
120
- print(f"Error converting to markdown: {e}")
121
- return "### Error converting JSON to Markdown."
122
- return "\n".join(markdown_lines)
123
-
124
- # --- Core Application Logic ---
125
- @spaces.GPU
126
- def process_document_stream(model_name: str, task_choice: str, image: Image.Image, max_new_tokens: int):
127
- """
128
- Main generator function that handles both OCR and Layout Analysis tasks.
129
- """
130
- if image is None:
131
- yield "Please upload an image.", "Please upload an image.", None
132
- return
133
-
134
- # 1. Select prompt based on user's task choice
135
- text_prompt = ocr_prompt if task_choice == "Content Extraction" else layout_prompt
136
-
137
- # --- New Model Handling ---
138
- if model_name == "openbmb/MiniCPM-V-4":
139
- if task_choice == "Layout Analysis(.json)":
140
- yield "This model is not optimized for Layout Analysis.", "Task not supported for this model.", None
141
- return
142
-
143
- question = "What is in this image?"
144
- msgs = [{'role': 'user', 'content': [image, question]}]
145
-
146
- # Since this model's .chat method isn't a generator, we run it in a thread
147
- # and yield the final result. A more advanced implementation could stream it.
148
- try:
149
- answer = model_v4.chat(
150
- image=image.convert('RGB'),
151
- msgs=msgs,
152
- tokenizer=tokenizer_v4
153
- )
154
- yield answer, answer, None
155
- except Exception as e:
156
- yield f"Error: {str(e)}", "An error occurred.", None
157
- return
158
-
159
- # 2. Select model and processor
160
- if model_name == "Camel-Doc-OCR-062825": processor, model = processor_m, model_m
161
- elif model_name == "Megalodon-OCR-Sync-0713": processor, model = processor_t, model_t
162
- elif model_name == "Nanonets-OCR-s": processor, model = processor_c, model_c
163
- elif model_name == "MonkeyOCR-Recognition": processor, model = processor_g, model_g
164
- else:
165
- yield "Invalid model selected.", "Invalid model selected.", None
166
- return
167
-
168
- # 3. Prepare model inputs and streamer
169
- messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": text_prompt}]}]
170
- prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
171
- inputs = processor(text=[prompt_full], images=[image], return_tensors="pt", padding=True, truncation=True, max_length=MAX_INPUT_TOKEN_LENGTH).to(device)
172
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
173
- generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
174
-
175
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
176
- thread.start()
177
-
178
- # 4. Stream raw output to the UI in real-time
179
- buffer = ""
180
- for new_text in streamer:
181
- buffer += new_text
182
- buffer = buffer.replace("<|im_end|>", "")
183
- time.sleep(0.01)
184
- yield buffer, "⏳ Processing...", {"status": "streaming"}
185
-
186
- # 5. Post-process the final buffer based on the selected task
187
- if task_choice == "Content Extraction":
188
- # For OCR, the buffer is the final result.
189
- yield buffer, buffer, None
190
- else: # Layout Analysis
191
- try:
192
- json_match = re.search(r'```json\s*([\s\S]+?)\s*```', buffer)
193
- if not json_match:
194
- raise json.JSONDecodeError("JSON object not found in output.", buffer, 0)
195
-
196
- json_str = json_match.group(1)
197
- layout_data = json.loads(json_str)
198
- markdown_content = layoutjson2md(layout_data)
199
-
200
- yield buffer, markdown_content, layout_data
201
- except Exception as e:
202
- error_md = f"❌ **Error:** Failed to parse Layout JSON.\n\n**Details:**\n`{str(e)}`"
203
- error_json = {"error": "ProcessingError", "details": str(e), "raw_output": buffer}
204
- yield buffer, error_md, error_json
205
-
206
- # --- Gradio UI Definition ---
207
- def create_gradio_interface():
208
- """Builds and returns the Gradio web interface."""
209
- css = """
210
- .main-container { max-width: 1400px; margin: 0 auto; }
211
- .process-button { border: none !important; color: white !important; font-weight: bold !important; background-color: blue !important;}
212
- .process-button:hover { background-color: darkblue !important; transform: translateY(-2px) !important; box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; }
213
- """
214
- with gr.Blocks(theme="bethecloud/storj_theme", css=css, js=js_func) as demo:
215
- gr.HTML("""
216
- <div class="title" style="text-align: center">
217
- <h1>OCR Comparator👨‍🏫</h1>
218
- <p style="font-size: 1.1em; color: #6b7280; margin-bottom: 0.6em;">
219
- Advanced Vision-Language Model for Image Content and Layout Extraction
220
- </p>
221
- </div>
222
- """)
223
-
224
- with gr.Row():
225
- # Left Column (Inputs)
226
- with gr.Column(scale=1):
227
- model_choice = gr.Dropdown(
228
- choices=["Camel-Doc-OCR-062825",
229
- "MonkeyOCR-Recognition",
230
- "Nanonets-OCR-s",
231
- "Megalodon-OCR-Sync-0713",
232
- "openbmb/MiniCPM-V-4"],
233
- label="Select Model", value="Nanonets-OCR-s"
234
- )
235
- task_choice = gr.Dropdown(
236
- choices=["Content Extraction", "Layout Analysis(.json)"],
237
- label="Select Task", value="Content Extraction"
238
- )
239
- image_input = gr.Image(label="Upload Image", type="pil", sources=['upload'])
240
- with gr.Accordion("Advanced Settings", open=False):
241
- max_new_tokens = gr.Slider(minimum=512, maximum=8192, value=4096, step=256, label="Max New Tokens")
242
-
243
- process_btn = gr.Button("🚀 Process Document", variant="primary", elem_classes=["process-button"], size="lg")
244
- clear_btn = gr.Button("🗑️ Clear All", variant="secondary")
245
-
246
- # Right Column (Outputs)
247
- with gr.Column(scale=2):
248
- with gr.Tabs() as tabs:
249
- with gr.Tab("📝 Extracted Content"):
250
- raw_output_stream = gr.Textbox(label="Raw Model Output Stream", interactive=False, lines=13, show_copy_button=True)
251
- with gr.Row():
252
- examples = gr.Examples(
253
- examples=["examples/example_img2.png", "examples/example_img1.png"],
254
- inputs=image_input,
255
- label="Examples"
256
- )
257
- with gr.Tab("📰 README.md"):
258
- with gr.Accordion("(Formatted Result)", open=True):
259
- markdown_output = gr.Markdown(label="Formatted Markdown")
260
-
261
- with gr.Tab("📋 Layout Analysis Results"):
262
- json_output = gr.JSON(label="Structured Layout Data (JSON)")
263
-
264
- # Event Handlers
265
- def clear_all_outputs():
266
- return None, "Raw output will appear here.", "Formatted results will appear here.", None
267
-
268
- process_btn.click(
269
- fn=process_document_stream,
270
- inputs=[model_choice,
271
- task_choice,
272
- image_input,
273
- max_new_tokens],
274
- outputs=[raw_output_stream,
275
- markdown_output,
276
- json_output]
277
- )
278
- clear_btn.click(
279
- clear_all_outputs,
280
- outputs=[image_input,
281
- raw_output_stream,
282
- markdown_output,
283
- json_output]
284
- )
285
- return demo
286
-
287
- if __name__ == "__main__":
288
- demo = create_gradio_interface()
289
- demo.queue().launch(share=True, ssr_mode=False, show_error=True)