prithivMLmods commited on
Commit
b64502b
·
verified ·
1 Parent(s): c3a7356

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -456
app.py DELETED
@@ -1,456 +0,0 @@
1
- import spaces
2
- import json
3
- import math
4
- import os
5
- import traceback
6
- from io import BytesIO
7
- from typing import Any, Dict, List, Optional, Tuple
8
- import re
9
-
10
- import fitz # PyMuPDF
11
- import gradio as gr
12
- import requests
13
- from PIL import Image, ImageDraw, ImageFont
14
-
15
- from model import load_model, inference_dots_ocr, inference_dolphin
16
-
17
- # Constants
18
- MIN_PIXELS = 3136
19
- MAX_PIXELS = 11289600
20
- IMAGE_FACTOR = 28
21
-
22
- # Prompts
23
- prompt = """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
24
-
25
- 1. Bbox format: [x1, y1, x2, y2]
26
- 2. Layout Categories: ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title']
27
- 3. Text Extraction & Formatting Rules:
28
- - Picture: omit the text field
29
- - Formula: format as LaTeX
30
- - Table: format as HTML
31
- - Others: format as Markdown
32
- 4. Constraints:
33
- - Use original text, no translation
34
- - Sort elements by human reading order
35
- 5. Final Output: Single JSON object
36
- """
37
-
38
- # Load models at startup
39
- models = {
40
- "dots.ocr": load_model("dots.ocr"),
41
- "Dolphin": load_model("Dolphin")
42
- }
43
-
44
- # Global state for PDF handling
45
- pdf_cache = {
46
- "images": [],
47
- "current_page": 0,
48
- "total_pages": 0,
49
- "file_type": None,
50
- "is_parsed": False,
51
- "results": []
52
- }
53
-
54
- # Utility functions
55
- def round_by_factor(number: int, factor: int) -> int:
56
- return round(number / factor) * factor
57
-
58
- def smart_resize(height: int, width: int, factor: int = 28, min_pixels: int = 3136, max_pixels: int = 11289600):
59
- if max(height, width) / min(height, width) > 200:
60
- raise ValueError(f"Aspect ratio must be < 200, got {max(height, width) / min(height, width)}")
61
- h_bar = max(factor, round_by_factor(height, factor))
62
- w_bar = max(factor, round_by_factor(width, factor))
63
- if h_bar * w_bar > max_pixels:
64
- beta = math.sqrt((height * width) / max_pixels)
65
- h_bar = round_by_factor(height / beta, factor)
66
- w_bar = round_by_factor(width / beta, factor)
67
- elif h_bar * w_bar < min_pixels:
68
- beta = math.sqrt(min_pixels / (height * width))
69
- h_bar = round_by_factor(height * beta, factor)
70
- w_bar = round_by_factor(width * beta, factor)
71
- return h_bar, w_bar
72
-
73
- def fetch_image(image_input, min_pixels: int = None, max_pixels: int = None):
74
- if isinstance(image_input, str):
75
- if image_input.startswith(("http://", "https://")):
76
- response = requests.get(image_input)
77
- image = Image.open(BytesIO(response.content)).convert('RGB')
78
- else:
79
- image = Image.open(image_input).convert('RGB')
80
- elif isinstance(image_input, Image.Image):
81
- image = image_input.convert('RGB')
82
- else:
83
- raise ValueError(f"Invalid image input type: {type(image_input)}")
84
- if min_pixels or max_pixels:
85
- min_pixels = min_pixels or MIN_PIXELS
86
- max_pixels = max_pixels or MAX_PIXELS
87
- height, width = smart_resize(image.height, image.width, factor=IMAGE_FACTOR, min_pixels=min_pixels, max_pixels=max_pixels)
88
- image = image.resize((width, height), Image.LANCZOS)
89
- return image
90
-
91
- def load_images_from_pdf(pdf_path: str) -> List[Image.Image]:
92
- images = []
93
- try:
94
- pdf_document = fitz.open(pdf_path)
95
- for page_num in range(len(pdf_document)):
96
- page = pdf_document.load_page(page_num)
97
- mat = fitz.Matrix(2.0, 2.0)
98
- pix = page.get_pixmap(matrix=mat)
99
- img_data = pix.tobytes("ppm")
100
- image = Image.open(BytesIO(img_data)).convert('RGB')
101
- images.append(image)
102
- pdf_document.close()
103
- except Exception as e:
104
- print(f"Error loading PDF: {e}")
105
- return []
106
- return images
107
-
108
- def draw_layout_on_image(image: Image.Image, layout_data: List[Dict]) -> Image.Image:
109
- img_copy = image.copy()
110
- draw = ImageDraw.Draw(img_copy)
111
- colors = {
112
- 'Caption': '#FF6B6B', 'Footnote': '#4ECDC4', 'Formula': '#45B7D1', 'List-item': '#96CEB4',
113
- 'Page-footer': '#FFEAA7', 'Page-header': '#DDA0DD', 'Picture': '#FFD93D', 'Section-header': '#6C5CE7',
114
- 'Table': '#FD79A8', 'Text': '#74B9FF', 'Title': '#E17055'
115
- }
116
- try:
117
- font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 12)
118
- except Exception:
119
- font = ImageFont.load_default()
120
- try:
121
- for item in layout_data:
122
- if 'bbox' in item and 'category' in item:
123
- bbox = item['bbox']
124
- category = item['category']
125
- color = colors.get(category, '#000000')
126
- draw.rectangle(bbox, outline=color, width=2)
127
- label = category
128
- label_bbox = draw.textbbox((0, 0), label, font=font)
129
- label_width = label_bbox[2] - label_bbox[0]
130
- label_height = label_bbox[3] - label_bbox[1]
131
- label_x = bbox[0]
132
- label_y = max(0, bbox[1] - label_height - 2)
133
- draw.rectangle([label_x, label_y, label_x + label_width + 4, label_y + label_height + 2], fill=color)
134
- draw.text((label_x + 2, label_y + 1), label, fill='white', font=font)
135
- except Exception as e:
136
- print(f"Error drawing layout: {e}")
137
- return img_copy
138
-
139
- def is_arabic_text(text: str) -> bool:
140
- if not text:
141
- return False
142
- header_pattern = r'^#{1,6}\s+(.+)$'
143
- paragraph_pattern = r'^(?!#{1,6}\s|!\[|```|\||\s*[-*+]\s|\s*\d+\.\s)(.+)$'
144
- content_text = []
145
- for line in text.split('\n'):
146
- line = line.strip()
147
- if not line:
148
- continue
149
- header_match = re.match(header_pattern, line, re.MULTILINE)
150
- if header_match:
151
- content_text.append(header_match.group(1))
152
- continue
153
- if re.match(paragraph_pattern, line, re.MULTILINE):
154
- content_text.append(line)
155
- if not content_text:
156
- return False
157
- combined_text = ' '.join(content_text)
158
- arabic_chars = 0
159
- total_chars = 0
160
- for char in combined_text:
161
- if char.isalpha():
162
- total_chars += 1
163
- if ('\u0600' <= char <= '\u06FF') or ('\u0750' <= char <= '\u077F') or ('\u08A0' <= char <= '\u08FF'):
164
- arabic_chars += 1
165
- return total_chars > 0 and (arabic_chars / total_chars) > 0.5
166
-
167
- def layoutjson2md(image: Image.Image, layout_data: List[Dict], text_key: str = 'text') -> str:
168
- import base64
169
- markdown_lines = []
170
- try:
171
- sorted_items = sorted(layout_data, key=lambda x: (x.get('bbox', [0, 0, 0, 0])[1], x.get('bbox', [0, 0, 0, 0])[0]))
172
- for item in sorted_items:
173
- category = item.get('category', '')
174
- text = item.get(text_key, '')
175
- bbox = item.get('bbox', [])
176
- if category == 'Picture':
177
- if bbox and len(bbox) == 4:
178
- try:
179
- x1, y1, x2, y2 = [max(0, int(x)) if i < 2 else min(image.width if i % 2 == 0 else image.height, int(x)) for i, x in enumerate(bbox)]
180
- if x2 > x1 and y2 > y1:
181
- cropped_img = image.crop((x1, y1, x2, y2))
182
- buffer = BytesIO()
183
- cropped_img.save(buffer, format='PNG')
184
- img_data = base64.b64encode(buffer.getvalue()).decode()
185
- markdown_lines.append(f'<image-card alt="Image" src="data:image/png;base64,{img_data}" ></image-card>\n')
186
- else:
187
- markdown_lines.append('<image-card alt="Image" src="Image region detected" ></image-card>\n')
188
- except Exception as e:
189
- print(f"Error processing image region: {e}")
190
- markdown_lines.append('<image-card alt="Image" src="Image detected" ></image-card>\n')
191
- else:
192
- markdown_lines.append('<image-card alt="Image" src="Image detected" ></image-card>\n')
193
- elif not text:
194
- continue
195
- elif category == 'Title':
196
- markdown_lines.append(f"# {text}\n")
197
- elif category == 'Section-header':
198
- markdown_lines.append(f"## {text}\n")
199
- elif category == 'Text':
200
- markdown_lines.append(f"{text}\n")
201
- elif category == 'List-item':
202
- markdown_lines.append(f"- {text}\n")
203
- elif category == 'Table':
204
- if text.strip().startswith('<'):
205
- markdown_lines.append(f"{text}\n")
206
- else:
207
- markdown_lines.append(f"**Table:** {text}\n")
208
- elif category == 'Formula':
209
- if text.strip().startswith('$') or '\\' in text:
210
- markdown_lines.append(f"$$ \n{text}\n $$\n")
211
- else:
212
- markdown_lines.append(f"**Formula:** {text}\n")
213
- elif category == 'Caption':
214
- markdown_lines.append(f"*{text}*\n")
215
- elif category == 'Footnote':
216
- markdown_lines.append(f"^{text}^\n")
217
- elif category in ['Page-header', 'Page-footer']:
218
- continue
219
- else:
220
- markdown_lines.append(f"{text}\n")
221
- markdown_lines.append("")
222
- except Exception as e:
223
- print(f"Error converting to markdown: {e}")
224
- return str(layout_data)
225
- return "\n".join(markdown_lines)
226
-
227
- def load_file_for_preview(file_path: str) -> Tuple[Optional[Image.Image], str]:
228
- global pdf_cache
229
- if not file_path or not os.path.exists(file_path):
230
- return None, "No file selected"
231
- file_ext = os.path.splitext(file_path)[1].lower()
232
- try:
233
- if file_ext == '.pdf':
234
- images = load_images_from_pdf(file_path)
235
- if not images:
236
- return None, "Failed to load PDF"
237
- pdf_cache.update({
238
- "images": images,
239
- "current_page": 0,
240
- "total_pages": len(images),
241
- "file_type": "pdf",
242
- "is_parsed": False,
243
- "results": []
244
- })
245
- return images[0], f"Page 1 / {len(images)}"
246
- elif file_ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']:
247
- image = Image.open(file_path).convert('RGB')
248
- pdf_cache.update({
249
- "images": [image],
250
- "current_page": 0,
251
- "total_pages": 1,
252
- "file_type": "image",
253
- "is_parsed": False,
254
- "results": []
255
- })
256
- return image, "Page 1 / 1"
257
- else:
258
- return None, f"Unsupported file format: {file_ext}"
259
- except Exception as e:
260
- print(f"Error loading file: {e}")
261
- return None, f"Error loading file: {str(e)}"
262
-
263
- @spaces.GPU()
264
- def process_document(file_path, model_choice, max_tokens, min_pix, max_pix):
265
- global pdf_cache
266
- if not file_path:
267
- return None, "Please upload a file first.", None
268
- model, processor = models[model_choice]
269
- image, page_info = load_file_for_preview(file_path)
270
- if image is None:
271
- return None, page_info, None
272
- if pdf_cache["file_type"] == "pdf":
273
- all_results = []
274
- for i, img in enumerate(pdf_cache["images"]):
275
- if model_choice == "dots.ocr":
276
- raw_output = inference_dots_ocr(model, processor, img, prompt, max_tokens)
277
- try:
278
- layout_data = json.loads(raw_output)
279
- processed_image = draw_layout_on_image(img, layout_data)
280
- markdown_content = layoutjson2md(img, layout_data)
281
- result = {
282
- 'processed_image': processed_image,
283
- 'markdown_content': markdown_content,
284
- 'layout_result': layout_data
285
- }
286
- except Exception:
287
- result = {
288
- 'processed_image': img,
289
- 'markdown_content': raw_output,
290
- 'layout_result': None
291
- }
292
- else: # Dolphin
293
- text = inference_dolphin(model, processor, img)
294
- result = f"## Page {i+1}\n\n{text}" if text else "No text extracted"
295
- all_results.append(result)
296
- pdf_cache["results"] = all_results
297
- pdf_cache["is_parsed"] = True
298
- first_result = all_results[0]
299
- if model_choice == "dots.ocr":
300
- markdown_update = gr.update(value=first_result['markdown_content'], rtl=is_arabic_text(first_result['markdown_content']))
301
- return first_result['processed_image'], markdown_update, first_result['layout_result']
302
- else:
303
- markdown_update = gr.update(value=first_result, rtl=is_arabic_text(first_result))
304
- return None, markdown_update, None
305
- else:
306
- if model_choice == "dots.ocr":
307
- raw_output = inference_dots_ocr(model, processor, image, prompt, max_tokens)
308
- try:
309
- layout_data = json.loads(raw_output)
310
- processed_image = draw_layout_on_image(image, layout_data)
311
- markdown_content = layoutjson2md(image, layout_data)
312
- result = {
313
- 'processed_image': processed_image,
314
- 'markdown_content': markdown_content,
315
- 'layout_result': layout_data
316
- }
317
- except Exception:
318
- result = {
319
- 'processed_image': image,
320
- 'markdown_content': raw_output,
321
- 'layout_result': None
322
- }
323
- pdf_cache["results"] = [result]
324
- else: # Dolphin
325
- text = inference_dolphin(model, processor, image)
326
- result = text if text else "No text extracted"
327
- pdf_cache["results"] = [result]
328
- pdf_cache["is_parsed"] = True
329
- if model_choice == "dots.ocr":
330
- markdown_update = gr.update(value=result['markdown_content'], rtl=is_arabic_text(result['markdown_content']))
331
- return result['processed_image'], markdown_update, result['layout_result']
332
- else:
333
- markdown_update = gr.update(value=result, rtl=is_arabic_text(result))
334
- return None, markdown_update, None
335
-
336
- def turn_page(direction: str) -> Tuple[Optional[Image.Image], str, Any, Optional[Image.Image], Optional[Dict]]:
337
- global pdf_cache
338
- if not pdf_cache["images"]:
339
- return None, '<div class="page-info">No file loaded</div>', "No results yet", None, None
340
- if direction == "prev":
341
- pdf_cache["current_page"] = max(0, pdf_cache["current_page"] - 1)
342
- elif direction == "next":
343
- pdf_cache["current_page"] = min(pdf_cache["total_pages"] - 1, pdf_cache["current_page"] + 1)
344
- index = pdf_cache["current_page"]
345
- current_image_preview = pdf_cache["images"][index]
346
- page_info_html = f'<div class="page-info">Page {index + 1} / {pdf_cache["total_pages"]}</div>'
347
- if pdf_cache["is_parsed"] and index < len(pdf_cache["results"]):
348
- result = pdf_cache["results"][index]
349
- if isinstance(result, dict): # dots.ocr
350
- markdown_content = result.get('markdown_content', 'No content available')
351
- processed_img = result.get('processed_image', None)
352
- layout_json = result.get('layout_result', None)
353
- else: # Dolphin
354
- markdown_content = result
355
- processed_img = None
356
- layout_json = None
357
- else:
358
- markdown_content = "Page not processed yet"
359
- processed_img = None
360
- layout_json = None
361
- markdown_update = gr.update(value=markdown_content, rtl=is_arabic_text(markdown_content))
362
- return current_image_preview, page_info_html, markdown_update, processed_img, layout_json
363
-
364
- def create_gradio_interface():
365
- css = """
366
- .main-container { max-width: 1400px; margin: 0 auto; }
367
- .header-text { text-align: center; color: #2c3e50; margin-bottom: 20px; }
368
- .process-button {
369
- border: none !important;
370
- color: white !important;
371
- font-weight: bold !important;
372
- background-color: blue !important;}
373
- .process-button:hover {
374
- background-color: darkblue !important;
375
- transform: translateY(-2px) !important;
376
- box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; }
377
- .info-box { border: 1px solid #dee2e6; border-radius: 8px; padding: 15px; margin: 10px 0; }
378
- .page-info { text-align: center; padding: 8px 16px; border-radius: 20px; font-weight: bold; margin: 10px 0; }
379
- .model-status { padding: 10px; border-radius: 8px; margin: 10px 0; text-align: center; font-weight: bold; }
380
- .status-ready { background: #d1edff; color: #0c5460; border: 1px solid #b8daff; }
381
- """
382
- with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
383
- gr.HTML("""
384
- <div class="title" style="text-align: center">
385
- <h1>Dot<span style="color: red;">●</span><strong></strong>OCR vs Dolphin🐬</h1>
386
- <p style="font-size: 1.1em; color: #6b7280; margin-bottom: 0.6em;">
387
- Advanced vision-language model for image/PDF to markdown document processing
388
- </p>
389
- </div>
390
- """)
391
- with gr.Row():
392
- with gr.Column(scale=1):
393
- file_input = gr.File(
394
- label="Upload Image or PDF",
395
- file_types=[".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".pdf"],
396
- type="filepath"
397
- )
398
- image_preview = gr.Image(label="Preview", type="pil", interactive=False, height=300)
399
- with gr.Row():
400
- prev_page_btn = gr.Button("⬅ Previous", size="md")
401
- page_info = gr.HTML('<div class="page-info">No file loaded</div>')
402
- next_page_btn = gr.Button("Next ➡", size="md")
403
- model_choice = gr.Radio(
404
- choices=["dots.ocr", "Dolphin"],
405
- label="Select Model",
406
- value="dots.ocr"
407
- )
408
- with gr.Accordion("Advanced Settings", open=False):
409
- max_new_tokens = gr.Slider(minimum=1000, maximum=32000, value=24000, step=1000, label="Max New Tokens")
410
- min_pixels = gr.Number(value=MIN_PIXELS, label="Min Pixels")
411
- max_pixels = gr.Number(value=MAX_PIXELS, label="Max Pixels")
412
- process_btn = gr.Button("🔥 Process Document", variant="primary", elem_classes=["process-button"], size="lg")
413
- clear_btn = gr.Button("Clear Document", variant="secondary")
414
-
415
- with gr.Column(scale=2):
416
- with gr.Tabs():
417
- with gr.Tab("✦︎ Processed Image"):
418
- processed_image = gr.Image(label="Image with Layout Detection", type="pil", interactive=False, height=500)
419
- with gr.Tab("🀥 Extracted Content"):
420
- markdown_output = gr.Markdown(value="Click 'Process Document' to see extracted content...", height=500)
421
- with gr.Tab("⏲ Layout JSON"):
422
- json_output = gr.JSON(label="Layout Analysis Results", value=None)
423
-
424
- with gr.Row():
425
- examples = gr.Examples(
426
- examples=["examples/sample_image1.png", "examples/sample_image2.png", "examples/sample_pdf.pdf"],
427
- inputs=file_input,
428
- label="Example Documents"
429
- )
430
-
431
- def handle_file_upload(file_path):
432
- image, page_info = load_file_for_preview(file_path)
433
- return image, page_info
434
-
435
- def clear_all():
436
- global pdf_cache
437
- pdf_cache = {"images": [], "current_page": 0, "total_pages": 0, "file_type": None, "is_parsed": False, "results": []}
438
- return None, None, '<div class="page-info">No file loaded</div>', None, "Click 'Process Document' to see extracted content...", None
439
-
440
- file_input.change(handle_file_upload, inputs=[file_input], outputs=[image_preview, page_info])
441
- prev_page_btn.click(lambda: turn_page("prev"), outputs=[image_preview, page_info, markdown_output, processed_image, json_output])
442
- next_page_btn.click(lambda: turn_page("next"), outputs=[image_preview, page_info, markdown_output, processed_image, json_output])
443
- process_btn.click(
444
- process_document,
445
- inputs=[file_input, model_choice, max_new_tokens, min_pixels, max_pixels],
446
- outputs=[processed_image, markdown_output, json_output]
447
- )
448
- clear_btn.click(
449
- clear_all,
450
- outputs=[file_input, image_preview, page_info, processed_image, markdown_output, json_output]
451
- )
452
- return demo
453
-
454
- if __name__ == "__main__":
455
- demo = create_gradio_interface()
456
- demo.queue(max_size=30).launch(share=False, debug=True, show_error=True)