prithivMLmods commited on
Commit
c152910
·
verified ·
1 Parent(s): 6280dc1
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/sample_image1.png filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import json
3
+ import math
4
+ import os
5
+ import traceback
6
+ from io import BytesIO
7
+ from typing import Any, Dict, List, Optional, Tuple
8
+ import re
9
+
10
+ import fitz # PyMuPDF
11
+ import gradio as gr
12
+ import requests
13
+ from PIL import Image, ImageDraw, ImageFont
14
+
15
+ from model import load_model, inference_dots_ocr, inference_dolphin
16
+
17
+ # Constants
18
+ MIN_PIXELS = 3136
19
+ MAX_PIXELS = 11289600
20
+ IMAGE_FACTOR = 28
21
+
22
+ # Prompts
23
+ prompt = """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
24
+
25
+ 1. Bbox format: [x1, y1, x2, y2]
26
+ 2. Layout Categories: ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title']
27
+ 3. Text Extraction & Formatting Rules:
28
+ - Picture: omit the text field
29
+ - Formula: format as LaTeX
30
+ - Table: format as HTML
31
+ - Others: format as Markdown
32
+ 4. Constraints:
33
+ - Use original text, no translation
34
+ - Sort elements by human reading order
35
+ 5. Final Output: Single JSON object
36
+ """
37
+
38
+ # Load models at startup
39
+ models = {
40
+ "dots.ocr": load_model("dots.ocr"),
41
+ "Dolphin": load_model("Dolphin")
42
+ }
43
+
44
+ # Global state for PDF handling
45
+ pdf_cache = {
46
+ "images": [],
47
+ "current_page": 0,
48
+ "total_pages": 0,
49
+ "file_type": None,
50
+ "is_parsed": False,
51
+ "results": []
52
+ }
53
+
54
+ # Utility functions
55
+ def round_by_factor(number: int, factor: int) -> int:
56
+ return round(number / factor) * factor
57
+
58
+ def smart_resize(height: int, width: int, factor: int = 28, min_pixels: int = 3136, max_pixels: int = 11289600):
59
+ if max(height, width) / min(height, width) > 200:
60
+ raise ValueError(f"Aspect ratio must be < 200, got {max(height, width) / min(height, width)}")
61
+ h_bar = max(factor, round_by_factor(height, factor))
62
+ w_bar = max(factor, round_by_factor(width, factor))
63
+ if h_bar * w_bar > max_pixels:
64
+ beta = math.sqrt((height * width) / max_pixels)
65
+ h_bar = round_by_factor(height / beta, factor)
66
+ w_bar = round_by_factor(width / beta, factor)
67
+ elif h_bar * w_bar < min_pixels:
68
+ beta = math.sqrt(min_pixels / (height * width))
69
+ h_bar = round_by_factor(height * beta, factor)
70
+ w_bar = round_by_factor(width * beta, factor)
71
+ return h_bar, w_bar
72
+
73
+ def fetch_image(image_input, min_pixels: int = None, max_pixels: int = None):
74
+ if isinstance(image_input, str):
75
+ if image_input.startswith(("http://", "https://")):
76
+ response = requests.get(image_input)
77
+ image = Image.open(BytesIO(response.content)).convert('RGB')
78
+ else:
79
+ image = Image.open(image_input).convert('RGB')
80
+ elif isinstance(image_input, Image.Image):
81
+ image = image_input.convert('RGB')
82
+ else:
83
+ raise ValueError(f"Invalid image input type: {type(image_input)}")
84
+ if min_pixels or max_pixels:
85
+ min_pixels = min_pixels or MIN_PIXELS
86
+ max_pixels = max_pixels or MAX_PIXELS
87
+ height, width = smart_resize(image.height, image.width, factor=IMAGE_FACTOR, min_pixels=min_pixels, max_pixels=max_pixels)
88
+ image = image.resize((width, height), Image.LANCZOS)
89
+ return image
90
+
91
+ def load_images_from_pdf(pdf_path: str) -> List[Image.Image]:
92
+ images = []
93
+ try:
94
+ pdf_document = fitz.open(pdf_path)
95
+ for page_num in range(len(pdf_document)):
96
+ page = pdf_document.load_page(page_num)
97
+ mat = fitz.Matrix(2.0, 2.0)
98
+ pix = page.get_pixmap(matrix=mat)
99
+ img_data = pix.tobytes("ppm")
100
+ image = Image.open(BytesIO(img_data)).convert('RGB')
101
+ images.append(image)
102
+ pdf_document.close()
103
+ except Exception as e:
104
+ print(f"Error loading PDF: {e}")
105
+ return []
106
+ return images
107
+
108
+ def draw_layout_on_image(image: Image.Image, layout_data: List[Dict]) -> Image.Image:
109
+ img_copy = image.copy()
110
+ draw = ImageDraw.Draw(img_copy)
111
+ colors = {
112
+ 'Caption': '#FF6B6B', 'Footnote': '#4ECDC4', 'Formula': '#45B7D1', 'List-item': '#96CEB4',
113
+ 'Page-footer': '#FFEAA7', 'Page-header': '#DDA0DD', 'Picture': '#FFD93D', 'Section-header': '#6C5CE7',
114
+ 'Table': '#FD79A8', 'Text': '#74B9FF', 'Title': '#E17055'
115
+ }
116
+ try:
117
+ font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 12)
118
+ except Exception:
119
+ font = ImageFont.load_default()
120
+ try:
121
+ for item in layout_data:
122
+ if 'bbox' in item and 'category' in item:
123
+ bbox = item['bbox']
124
+ category = item['category']
125
+ color = colors.get(category, '#000000')
126
+ draw.rectangle(bbox, outline=color, width=2)
127
+ label = category
128
+ label_bbox = draw.textbbox((0, 0), label, font=font)
129
+ label_width = label_bbox[2] - label_bbox[0]
130
+ label_height = label_bbox[3] - label_bbox[1]
131
+ label_x = bbox[0]
132
+ label_y = max(0, bbox[1] - label_height - 2)
133
+ draw.rectangle([label_x, label_y, label_x + label_width + 4, label_y + label_height + 2], fill=color)
134
+ draw.text((label_x + 2, label_y + 1), label, fill='white', font=font)
135
+ except Exception as e:
136
+ print(f"Error drawing layout: {e}")
137
+ return img_copy
138
+
139
+ def is_arabic_text(text: str) -> bool:
140
+ if not text:
141
+ return False
142
+ header_pattern = r'^#{1,6}\s+(.+)$'
143
+ paragraph_pattern = r'^(?!#{1,6}\s|!\[|```|\||\s*[-*+]\s|\s*\d+\.\s)(.+)$'
144
+ content_text = []
145
+ for line in text.split('\n'):
146
+ line = line.strip()
147
+ if not line:
148
+ continue
149
+ header_match = re.match(header_pattern, line, re.MULTILINE)
150
+ if header_match:
151
+ content_text.append(header_match.group(1))
152
+ continue
153
+ if re.match(paragraph_pattern, line, re.MULTILINE):
154
+ content_text.append(line)
155
+ if not content_text:
156
+ return False
157
+ combined_text = ' '.join(content_text)
158
+ arabic_chars = 0
159
+ total_chars = 0
160
+ for char in combined_text:
161
+ if char.isalpha():
162
+ total_chars += 1
163
+ if ('\u0600' <= char <= '\u06FF') or ('\u0750' <= char <= '\u077F') or ('\u08A0' <= char <= '\u08FF'):
164
+ arabic_chars += 1
165
+ return total_chars > 0 and (arabic_chars / total_chars) > 0.5
166
+
167
+ def layoutjson2md(image: Image.Image, layout_data: List[Dict], text_key: str = 'text') -> str:
168
+ import base64
169
+ markdown_lines = []
170
+ try:
171
+ sorted_items = sorted(layout_data, key=lambda x: (x.get('bbox', [0, 0, 0, 0])[1], x.get('bbox', [0, 0, 0, 0])[0]))
172
+ for item in sorted_items:
173
+ category = item.get('category', '')
174
+ text = item.get(text_key, '')
175
+ bbox = item.get('bbox', [])
176
+ if category == 'Picture':
177
+ if bbox and len(bbox) == 4:
178
+ try:
179
+ x1, y1, x2, y2 = [max(0, int(x)) if i < 2 else min(image.width if i % 2 == 0 else image.height, int(x)) for i, x in enumerate(bbox)]
180
+ if x2 > x1 and y2 > y1:
181
+ cropped_img = image.crop((x1, y1, x2, y2))
182
+ buffer = BytesIO()
183
+ cropped_img.save(buffer, format='PNG')
184
+ img_data = base64.b64encode(buffer.getvalue()).decode()
185
+ markdown_lines.append(f'<image-card alt="Image" src="data:image/png;base64,{img_data}" ></image-card>\n')
186
+ else:
187
+ markdown_lines.append('<image-card alt="Image" src="Image region detected" ></image-card>\n')
188
+ except Exception as e:
189
+ print(f"Error processing image region: {e}")
190
+ markdown_lines.append('<image-card alt="Image" src="Image detected" ></image-card>\n')
191
+ else:
192
+ markdown_lines.append('<image-card alt="Image" src="Image detected" ></image-card>\n')
193
+ elif not text:
194
+ continue
195
+ elif category == 'Title':
196
+ markdown_lines.append(f"# {text}\n")
197
+ elif category == 'Section-header':
198
+ markdown_lines.append(f"## {text}\n")
199
+ elif category == 'Text':
200
+ markdown_lines.append(f"{text}\n")
201
+ elif category == 'List-item':
202
+ markdown_lines.append(f"- {text}\n")
203
+ elif category == 'Table':
204
+ if text.strip().startswith('<'):
205
+ markdown_lines.append(f"{text}\n")
206
+ else:
207
+ markdown_lines.append(f"**Table:** {text}\n")
208
+ elif category == 'Formula':
209
+ if text.strip().startswith('$') or '\\' in text:
210
+ markdown_lines.append(f"$$ \n{text}\n $$\n")
211
+ else:
212
+ markdown_lines.append(f"**Formula:** {text}\n")
213
+ elif category == 'Caption':
214
+ markdown_lines.append(f"*{text}*\n")
215
+ elif category == 'Footnote':
216
+ markdown_lines.append(f"^{text}^\n")
217
+ elif category in ['Page-header', 'Page-footer']:
218
+ continue
219
+ else:
220
+ markdown_lines.append(f"{text}\n")
221
+ markdown_lines.append("")
222
+ except Exception as e:
223
+ print(f"Error converting to markdown: {e}")
224
+ return str(layout_data)
225
+ return "\n".join(markdown_lines)
226
+
227
+ def load_file_for_preview(file_path: str) -> Tuple[Optional[Image.Image], str]:
228
+ global pdf_cache
229
+ if not file_path or not os.path.exists(file_path):
230
+ return None, "No file selected"
231
+ file_ext = os.path.splitext(file_path)[1].lower()
232
+ try:
233
+ if file_ext == '.pdf':
234
+ images = load_images_from_pdf(file_path)
235
+ if not images:
236
+ return None, "Failed to load PDF"
237
+ pdf_cache.update({
238
+ "images": images,
239
+ "current_page": 0,
240
+ "total_pages": len(images),
241
+ "file_type": "pdf",
242
+ "is_parsed": False,
243
+ "results": []
244
+ })
245
+ return images[0], f"Page 1 / {len(images)}"
246
+ elif file_ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']:
247
+ image = Image.open(file_path).convert('RGB')
248
+ pdf_cache.update({
249
+ "images": [image],
250
+ "current_page": 0,
251
+ "total_pages": 1,
252
+ "file_type": "image",
253
+ "is_parsed": False,
254
+ "results": []
255
+ })
256
+ return image, "Page 1 / 1"
257
+ else:
258
+ return None, f"Unsupported file format: {file_ext}"
259
+ except Exception as e:
260
+ print(f"Error loading file: {e}")
261
+ return None, f"Error loading file: {str(e)}"
262
+
263
+ @spaces.GPU()
264
+ def process_document(file_path, model_choice, max_tokens, min_pix, max_pix):
265
+ global pdf_cache
266
+ if not file_path:
267
+ return None, "Please upload a file first.", None
268
+ model, processor = models[model_choice]
269
+ image, page_info = load_file_for_preview(file_path)
270
+ if image is None:
271
+ return None, page_info, None
272
+ if pdf_cache["file_type"] == "pdf":
273
+ all_results = []
274
+ for i, img in enumerate(pdf_cache["images"]):
275
+ if model_choice == "dots.ocr":
276
+ raw_output = inference_dots_ocr(model, processor, img, prompt, max_tokens)
277
+ try:
278
+ layout_data = json.loads(raw_output)
279
+ processed_image = draw_layout_on_image(img, layout_data)
280
+ markdown_content = layoutjson2md(img, layout_data)
281
+ result = {
282
+ 'processed_image': processed_image,
283
+ 'markdown_content': markdown_content,
284
+ 'layout_result': layout_data
285
+ }
286
+ except Exception:
287
+ result = {
288
+ 'processed_image': img,
289
+ 'markdown_content': raw_output,
290
+ 'layout_result': None
291
+ }
292
+ else: # Dolphin
293
+ text = inference_dolphin(model, processor, img)
294
+ result = f"## Page {i+1}\n\n{text}" if text else "No text extracted"
295
+ all_results.append(result)
296
+ pdf_cache["results"] = all_results
297
+ pdf_cache["is_parsed"] = True
298
+ first_result = all_results[0]
299
+ if model_choice == "dots.ocr":
300
+ markdown_update = gr.update(value=first_result['markdown_content'], rtl=is_arabic_text(first_result['markdown_content']))
301
+ return first_result['processed_image'], markdown_update, first_result['layout_result']
302
+ else:
303
+ markdown_update = gr.update(value=first_result, rtl=is_arabic_text(first_result))
304
+ return None, markdown_update, None
305
+ else:
306
+ if model_choice == "dots.ocr":
307
+ raw_output = inference_dots_ocr(model, processor, image, prompt, max_tokens)
308
+ try:
309
+ layout_data = json.loads(raw_output)
310
+ processed_image = draw_layout_on_image(image, layout_data)
311
+ markdown_content = layoutjson2md(image, layout_data)
312
+ result = {
313
+ 'processed_image': processed_image,
314
+ 'markdown_content': markdown_content,
315
+ 'layout_result': layout_data
316
+ }
317
+ except Exception:
318
+ result = {
319
+ 'processed_image': image,
320
+ 'markdown_content': raw_output,
321
+ 'layout_result': None
322
+ }
323
+ pdf_cache["results"] = [result]
324
+ else: # Dolphin
325
+ text = inference_dolphin(model, processor, image)
326
+ result = text if text else "No text extracted"
327
+ pdf_cache["results"] = [result]
328
+ pdf_cache["is_parsed"] = True
329
+ if model_choice == "dots.ocr":
330
+ markdown_update = gr.update(value=result['markdown_content'], rtl=is_arabic_text(result['markdown_content']))
331
+ return result['processed_image'], markdown_update, result['layout_result']
332
+ else:
333
+ markdown_update = gr.update(value=result, rtl=is_arabic_text(result))
334
+ return None, markdown_update, None
335
+
336
+ def turn_page(direction: str) -> Tuple[Optional[Image.Image], str, Any, Optional[Image.Image], Optional[Dict]]:
337
+ global pdf_cache
338
+ if not pdf_cache["images"]:
339
+ return None, '<div class="page-info">No file loaded</div>', "No results yet", None, None
340
+ if direction == "prev":
341
+ pdf_cache["current_page"] = max(0, pdf_cache["current_page"] - 1)
342
+ elif direction == "next":
343
+ pdf_cache["current_page"] = min(pdf_cache["total_pages"] - 1, pdf_cache["current_page"] + 1)
344
+ index = pdf_cache["current_page"]
345
+ current_image_preview = pdf_cache["images"][index]
346
+ page_info_html = f'<div class="page-info">Page {index + 1} / {pdf_cache["total_pages"]}</div>'
347
+ if pdf_cache["is_parsed"] and index < len(pdf_cache["results"]):
348
+ result = pdf_cache["results"][index]
349
+ if isinstance(result, dict): # dots.ocr
350
+ markdown_content = result.get('markdown_content', 'No content available')
351
+ processed_img = result.get('processed_image', None)
352
+ layout_json = result.get('layout_result', None)
353
+ else: # Dolphin
354
+ markdown_content = result
355
+ processed_img = None
356
+ layout_json = None
357
+ else:
358
+ markdown_content = "Page not processed yet"
359
+ processed_img = None
360
+ layout_json = None
361
+ markdown_update = gr.update(value=markdown_content, rtl=is_arabic_text(markdown_content))
362
+ return current_image_preview, page_info_html, markdown_update, processed_img, layout_json
363
+
364
+ def create_gradio_interface():
365
+ css = """
366
+ .main-container { max-width: 1400px; margin: 0 auto; }
367
+ .header-text { text-align: center; color: #2c3e50; margin-bottom: 20px; }
368
+ .process-button {
369
+ border: none !important;
370
+ color: white !important;
371
+ font-weight: bold !important;
372
+ background-color: blue !important;}
373
+ .process-button:hover {
374
+ background-color: darkblue !important;
375
+ transform: translateY(-2px) !important;
376
+ box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; }
377
+ .info-box { border: 1px solid #dee2e6; border-radius: 8px; padding: 15px; margin: 10px 0; }
378
+ .page-info { text-align: center; padding: 8px 16px; border-radius: 20px; font-weight: bold; margin: 10px 0; }
379
+ .model-status { padding: 10px; border-radius: 8px; margin: 10px 0; text-align: center; font-weight: bold; }
380
+ .status-ready { background: #d1edff; color: #0c5460; border: 1px solid #b8daff; }
381
+ """
382
+ with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
383
+ gr.HTML("""
384
+ <div class="title" style="text-align: center">
385
+ <h1>Dot<span style="color: red;">●</span><strong></strong>OCR vs Dolphin🐬</h1>
386
+ <p style="font-size: 1.1em; color: #6b7280; margin-bottom: 0.6em;">
387
+ Advanced vision-language model for image/PDF to markdown document processing
388
+ </p>
389
+ </div>
390
+ """)
391
+ with gr.Row():
392
+ with gr.Column(scale=1):
393
+ file_input = gr.File(
394
+ label="Upload Image or PDF",
395
+ file_types=[".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".pdf"],
396
+ type="filepath"
397
+ )
398
+ image_preview = gr.Image(label="Preview", type="pil", interactive=False, height=300)
399
+ with gr.Row():
400
+ prev_page_btn = gr.Button("⬅ Previous", size="md")
401
+ page_info = gr.HTML('<div class="page-info">No file loaded</div>')
402
+ next_page_btn = gr.Button("Next ➡", size="md")
403
+ model_choice = gr.Radio(
404
+ choices=["dots.ocr", "Dolphin"],
405
+ label="Select Model",
406
+ value="dots.ocr"
407
+ )
408
+ with gr.Accordion("Advanced Settings", open=False):
409
+ max_new_tokens = gr.Slider(minimum=1000, maximum=32000, value=24000, step=1000, label="Max New Tokens")
410
+ min_pixels = gr.Number(value=MIN_PIXELS, label="Min Pixels")
411
+ max_pixels = gr.Number(value=MAX_PIXELS, label="Max Pixels")
412
+ process_btn = gr.Button("🔥 Process Document", variant="primary", elem_classes=["process-button"], size="lg")
413
+ clear_btn = gr.Button("Clear Document", variant="secondary")
414
+
415
+ with gr.Column(scale=2):
416
+ with gr.Tabs():
417
+ with gr.Tab("✦︎ Processed Image"):
418
+ processed_image = gr.Image(label="Image with Layout Detection", type="pil", interactive=False, height=500)
419
+ with gr.Tab("🀥 Extracted Content"):
420
+ markdown_output = gr.Markdown(value="Click 'Process Document' to see extracted content...", height=500)
421
+ with gr.Tab("⏲ Layout JSON"):
422
+ json_output = gr.JSON(label="Layout Analysis Results", value=None)
423
+
424
+ with gr.Row():
425
+ examples = gr.Examples(
426
+ examples=["examples/sample_image1.png", "examples/sample_image2.png", "examples/sample_pdf.pdf"],
427
+ inputs=file_input,
428
+ label="Example Documents"
429
+ )
430
+
431
+ def handle_file_upload(file_path):
432
+ image, page_info = load_file_for_preview(file_path)
433
+ return image, page_info
434
+
435
+ def clear_all():
436
+ global pdf_cache
437
+ pdf_cache = {"images": [], "current_page": 0, "total_pages": 0, "file_type": None, "is_parsed": False, "results": []}
438
+ return None, None, '<div class="page-info">No file loaded</div>', None, "Click 'Process Document' to see extracted content...", None
439
+
440
+ file_input.change(handle_file_upload, inputs=[file_input], outputs=[image_preview, page_info])
441
+ prev_page_btn.click(lambda: turn_page("prev"), outputs=[image_preview, page_info, markdown_output, processed_image, json_output])
442
+ next_page_btn.click(lambda: turn_page("next"), outputs=[image_preview, page_info, markdown_output, processed_image, json_output])
443
+ process_btn.click(
444
+ process_document,
445
+ inputs=[file_input, model_choice, max_new_tokens, min_pixels, max_pixels],
446
+ outputs=[processed_image, markdown_output, json_output]
447
+ )
448
+ clear_btn.click(
449
+ clear_all,
450
+ outputs=[file_input, image_preview, page_info, processed_image, markdown_output, json_output]
451
+ )
452
+ return demo
453
+
454
+ if __name__ == "__main__":
455
+ demo = create_gradio_interface()
456
+ demo.queue(max_size=30).launch(share=False, debug=True, show_error=True)
examples/sample_image1.png ADDED

Git LFS Details

  • SHA256: 337fb445fc7cde5d02fa0cc2e8c1805943cf094ca33511a783c8676912a84d1d
  • Pointer size: 131 Bytes
  • Size of remote file: 414 kB
examples/sample_image2.png ADDED
examples/sample_pdf.pdf ADDED
Binary file (11.9 kB). View file
 
model.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoProcessor, VisionEncoderDecoderModel
3
+ from huggingface_hub import snapshot_download
4
+ from qwen_vl_utils import process_vision_info
5
+
6
+ def load_model(model_name):
7
+ """
8
+ Load the specified model and its processor based on the model name.
9
+
10
+ Args:
11
+ model_name (str): Name of the model ("dots.ocr" or "Dolphin").
12
+
13
+ Returns:
14
+ tuple: (model, processor) for the specified model.
15
+ """
16
+ if model_name == "dots.ocr":
17
+ model_id = "rednote-hilab/dots.ocr"
18
+ model_path = "./models/dots-ocr-local"
19
+ snapshot_download(
20
+ repo_id=model_id,
21
+ local_dir=model_path,
22
+ local_dir_use_symlinks=False,
23
+ )
24
+ model = AutoModelForCausalLM.from_pretrained(
25
+ model_path,
26
+ attn_implementation="flash_attention_2",
27
+ torch_dtype=torch.bfloat16,
28
+ device_map="auto",
29
+ trust_remote_code=True
30
+ )
31
+ processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
32
+ elif model_name == "Dolphin":
33
+ model_id = "ByteDance/Dolphin"
34
+ processor = AutoProcessor.from_pretrained(model_id)
35
+ model = VisionEncoderDecoderModel.from_pretrained(model_id)
36
+ model.eval()
37
+ device = "cuda" if torch.cuda.is_available() else "cpu"
38
+ model.to(device)
39
+ model = model.half() # Use half precision
40
+ else:
41
+ raise ValueError(f"Unknown model: {model_name}")
42
+ return model, processor
43
+
44
+ def inference_dots_ocr(model, processor, image, prompt, max_new_tokens):
45
+ """
46
+ Perform inference using the dots.ocr model.
47
+
48
+ Args:
49
+ model: The loaded dots.ocr model.
50
+ processor: The corresponding processor.
51
+ image (PIL.Image): Input image.
52
+ prompt (str): Prompt for inference.
53
+ max_new_tokens (int): Maximum number of tokens to generate.
54
+
55
+ Returns:
56
+ str: Generated text output.
57
+ """
58
+ messages = [
59
+ {
60
+ "role": "user",
61
+ "content": [
62
+ {"type": "image", "image": image},
63
+ {"type": "text", "text": prompt}
64
+ ]
65
+ }
66
+ ]
67
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
68
+ image_inputs, video_inputs = process_vision_info(messages)
69
+ inputs = processor(
70
+ text=[text],
71
+ images=image_inputs,
72
+ videos=video_inputs,
73
+ padding=True,
74
+ return_tensors="pt",
75
+ )
76
+ inputs = inputs.to(model.device)
77
+ with torch.no_grad():
78
+ generated_ids = model.generate(
79
+ **inputs,
80
+ max_new_tokens=max_new_tokens,
81
+ do_sample=False,
82
+ temperature=0.1
83
+ )
84
+ generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
85
+ output_text = processor.batch_decode(
86
+ generated_ids_trimmed,
87
+ skip_special_tokens=True,
88
+ clean_up_tokenization_spaces=False
89
+ )
90
+ return output_text[0] if output_text else ""
91
+
92
+ def inference_dolphin(model, processor, image):
93
+ """
94
+ Perform inference using the Dolphin model.
95
+
96
+ Args:
97
+ model: The loaded Dolphin model.
98
+ processor: The corresponding processor.
99
+ image (PIL.Image): Input image.
100
+
101
+ Returns:
102
+ str: Generated text output.
103
+ """
104
+ pixel_values = processor(image, return_tensors="pt").pixel_values.to(model.device).half()
105
+ generated_ids = model.generate(pixel_values)
106
+ generated_text = processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
107
+ return generated_text
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ flash-attn==2.8.0.post2
4
+ transformers==4.51.3
5
+ transformers-stream-generator
6
+ qwen-vl-utils
7
+ modelscope
8
+ accelerate
9
+ openai
10
+ huggingface-hub
11
+ spaces
12
+ numpy
13
+ pillow
14
+ opencv-python
15
+ av
16
+ timm
17
+ PyMuPDF
18
+ requests
19
+ gradio
20
+ gradio_image_annotation