MohamedRashad commited on
Commit
02c7af0
Β·
1 Parent(s): 60c056c

Add initial implementation of Dots.OCR Gradio demo application and requirements

Browse files
Files changed (2) hide show
  1. app.py +939 -0
  2. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,939 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Dots.OCR Gradio Demo Application
4
+
5
+ A Gradio-based web interface for demonstrating the Dots.OCR model using Hugging Face transformers.
6
+ This application provides OCR and layout analysis capabilities for documents and images.
7
+ """
8
+
9
+ import os
10
+ import json
11
+ import traceback
12
+ import math
13
+ from io import BytesIO
14
+ from typing import Optional, Dict, Any, Tuple, List
15
+ import requests
16
+
17
+ # Set LOCAL_RANK for transformers
18
+ if "LOCAL_RANK" not in os.environ:
19
+ os.environ["LOCAL_RANK"] = "0"
20
+
21
+ import torch
22
+ import gradio as gr
23
+ from PIL import Image, ImageDraw, ImageFont
24
+ from transformers import AutoModelForCausalLM, AutoProcessor
25
+ from qwen_vl_utils import process_vision_info
26
+ import fitz # PyMuPDF
27
+
28
+
29
+ # Constants
30
+ MIN_PIXELS = 3136
31
+ MAX_PIXELS = 11289600
32
+ IMAGE_FACTOR = 28
33
+
34
+ # Prompts
35
+ dict_promptmode_to_prompt = {
36
+ "prompt_layout_all_en": """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
37
+
38
+ 1. Bbox format: [x1, y1, x2, y2]
39
+
40
+ 2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title'].
41
+
42
+ 3. Text Extraction & Formatting Rules:
43
+ - Picture: For the 'Picture' category, the text field should be omitted.
44
+ - Formula: Format its text as LaTeX.
45
+ - Table: Format its text as HTML.
46
+ - All Others (Text, Title, etc.): Format their text as Markdown.
47
+
48
+ 4. Constraints:
49
+ - The output text must be the original text from the image, with no translation.
50
+ - All layout elements must be sorted according to human reading order.
51
+
52
+ 5. Final Output: The entire output must be a single JSON object.
53
+ """,
54
+
55
+ "prompt_layout_only_en": """Please output the layout information from this PDF image, including each layout's bbox and its category. The bbox should be in the format [x1, y1, x2, y2]. The layout categories for the PDF document include ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title']. Do not output the corresponding text. The layout result should be in JSON format.""",
56
+
57
+ "prompt_ocr": """Extract the text content from this image.""",
58
+
59
+ "prompt_grounding_ocr": """Extract text from the given bounding box on the image (format: [x1, y1, x2, y2]).\nBounding Box:\n""",
60
+ }
61
+
62
+
63
+ # Utility functions
64
+ def round_by_factor(number: int, factor: int) -> int:
65
+ """Returns the closest integer to 'number' that is divisible by 'factor'."""
66
+ return round(number / factor) * factor
67
+
68
+
69
+ def smart_resize(
70
+ height: int,
71
+ width: int,
72
+ factor: int = 28,
73
+ min_pixels: int = 3136,
74
+ max_pixels: int = 11289600,
75
+ ):
76
+ """Rescales the image so that the following conditions are met:
77
+ 1. Both dimensions (height and width) are divisible by 'factor'.
78
+ 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
79
+ 3. The aspect ratio of the image is maintained as closely as possible.
80
+ """
81
+ if max(height, width) / min(height, width) > 200:
82
+ raise ValueError(
83
+ f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
84
+ )
85
+ h_bar = max(factor, round_by_factor(height, factor))
86
+ w_bar = max(factor, round_by_factor(width, factor))
87
+
88
+ if h_bar * w_bar > max_pixels:
89
+ beta = math.sqrt((height * width) / max_pixels)
90
+ h_bar = round_by_factor(height / beta, factor)
91
+ w_bar = round_by_factor(width / beta, factor)
92
+ elif h_bar * w_bar < min_pixels:
93
+ beta = math.sqrt(min_pixels / (height * width))
94
+ h_bar = round_by_factor(height * beta, factor)
95
+ w_bar = round_by_factor(width * beta, factor)
96
+ return h_bar, w_bar
97
+
98
+
99
+ def fetch_image(image_input, min_pixels: int = None, max_pixels: int = None):
100
+ """Fetch and process an image"""
101
+ if isinstance(image_input, str):
102
+ if image_input.startswith(("http://", "https://")):
103
+ response = requests.get(image_input)
104
+ image = Image.open(BytesIO(response.content)).convert('RGB')
105
+ else:
106
+ image = Image.open(image_input).convert('RGB')
107
+ elif isinstance(image_input, Image.Image):
108
+ image = image_input.convert('RGB')
109
+ else:
110
+ raise ValueError(f"Invalid image input type: {type(image_input)}")
111
+
112
+ if min_pixels is not None or max_pixels is not None:
113
+ min_pixels = min_pixels or MIN_PIXELS
114
+ max_pixels = max_pixels or MAX_PIXELS
115
+ height, width = smart_resize(
116
+ image.height,
117
+ image.width,
118
+ factor=IMAGE_FACTOR,
119
+ min_pixels=min_pixels,
120
+ max_pixels=max_pixels
121
+ )
122
+ image = image.resize((width, height), Image.LANCZOS)
123
+
124
+ return image
125
+
126
+
127
+ def load_images_from_pdf(pdf_path: str) -> List[Image.Image]:
128
+ """Load images from PDF file"""
129
+ images = []
130
+ try:
131
+ pdf_document = fitz.open(pdf_path)
132
+ for page_num in range(len(pdf_document)):
133
+ page = pdf_document.load_page(page_num)
134
+ # Convert page to image
135
+ mat = fitz.Matrix(2.0, 2.0) # Increase resolution
136
+ pix = page.get_pixmap(matrix=mat)
137
+ img_data = pix.tobytes("ppm")
138
+ image = Image.open(BytesIO(img_data)).convert('RGB')
139
+ images.append(image)
140
+ pdf_document.close()
141
+ except Exception as e:
142
+ print(f"Error loading PDF: {e}")
143
+ return []
144
+ return images
145
+
146
+
147
+ def draw_layout_on_image(image: Image.Image, layout_data: List[Dict]) -> Image.Image:
148
+ """Draw layout bounding boxes on image"""
149
+ img_copy = image.copy()
150
+ draw = ImageDraw.Draw(img_copy)
151
+
152
+ # Colors for different categories
153
+ colors = {
154
+ 'Caption': '#FF6B6B',
155
+ 'Footnote': '#4ECDC4',
156
+ 'Formula': '#45B7D1',
157
+ 'List-item': '#96CEB4',
158
+ 'Page-footer': '#FFEAA7',
159
+ 'Page-header': '#DDA0DD',
160
+ 'Picture': '#FFD93D',
161
+ 'Section-header': '#6C5CE7',
162
+ 'Table': '#FD79A8',
163
+ 'Text': '#74B9FF',
164
+ 'Title': '#E17055'
165
+ }
166
+
167
+ try:
168
+ # Load a font
169
+ try:
170
+ font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 12)
171
+ except Exception:
172
+ font = ImageFont.load_default()
173
+
174
+ for item in layout_data:
175
+ if 'bbox' in item and 'category' in item:
176
+ bbox = item['bbox']
177
+ category = item['category']
178
+ color = colors.get(category, '#000000')
179
+
180
+ # Draw rectangle
181
+ draw.rectangle(bbox, outline=color, width=2)
182
+
183
+ # Draw label
184
+ label = category
185
+ label_bbox = draw.textbbox((0, 0), label, font=font)
186
+ label_width = label_bbox[2] - label_bbox[0]
187
+ label_height = label_bbox[3] - label_bbox[1]
188
+
189
+ # Position label above the box
190
+ label_x = bbox[0]
191
+ label_y = max(0, bbox[1] - label_height - 2)
192
+
193
+ # Draw background for label
194
+ draw.rectangle(
195
+ [label_x, label_y, label_x + label_width + 4, label_y + label_height + 2],
196
+ fill=color
197
+ )
198
+
199
+ # Draw text
200
+ draw.text((label_x + 2, label_y + 1), label, fill='white', font=font)
201
+
202
+ except Exception as e:
203
+ print(f"Error drawing layout: {e}")
204
+
205
+ return img_copy
206
+
207
+
208
+ def layoutjson2md(image: Image.Image, layout_data: List[Dict], text_key: str = 'text', no_page_hf: bool = False) -> str:
209
+ """Convert layout JSON to markdown format"""
210
+ markdown_lines = []
211
+
212
+ if not no_page_hf:
213
+ markdown_lines.append("# Document Content\n")
214
+
215
+ try:
216
+ # Sort items by reading order (top to bottom, left to right)
217
+ sorted_items = sorted(layout_data, key=lambda x: (x.get('bbox', [0, 0, 0, 0])[1], x.get('bbox', [0, 0, 0, 0])[0]))
218
+
219
+ for item in sorted_items:
220
+ category = item.get('category', '')
221
+ text = item.get(text_key, '')
222
+
223
+ if not text:
224
+ continue
225
+
226
+ if category == 'Title':
227
+ markdown_lines.append(f"# {text}\n")
228
+ elif category == 'Section-header':
229
+ markdown_lines.append(f"## {text}\n")
230
+ elif category == 'Text':
231
+ markdown_lines.append(f"{text}\n")
232
+ elif category == 'List-item':
233
+ markdown_lines.append(f"- {text}\n")
234
+ elif category == 'Table':
235
+ # If text is already HTML, keep it as is
236
+ if text.strip().startswith('<'):
237
+ markdown_lines.append(f"{text}\n")
238
+ else:
239
+ markdown_lines.append(f"**Table:** {text}\n")
240
+ elif category == 'Formula':
241
+ # If text is LaTeX, format it properly
242
+ if text.strip().startswith('$') or '\\' in text:
243
+ markdown_lines.append(f"$$\n{text}\n$$\n")
244
+ else:
245
+ markdown_lines.append(f"**Formula:** {text}\n")
246
+ elif category == 'Caption':
247
+ markdown_lines.append(f"*{text}*\n")
248
+ elif category == 'Footnote':
249
+ markdown_lines.append(f"^{text}^\n")
250
+ elif category in ['Page-header', 'Page-footer']:
251
+ # Skip headers and footers in main content
252
+ continue
253
+ else:
254
+ markdown_lines.append(f"{text}\n")
255
+
256
+ markdown_lines.append("") # Add spacing
257
+
258
+ except Exception as e:
259
+ print(f"Error converting to markdown: {e}")
260
+ return str(layout_data)
261
+
262
+ return "\n".join(markdown_lines)
263
+
264
+ # Initialize model and processor at script level
265
+ model_id = "rednote-hilab/dots.ocr"
266
+ model = AutoModelForCausalLM.from_pretrained(
267
+ model_id,
268
+ attn_implementation="flash_attention_2",
269
+ torch_dtype=torch.bfloat16,
270
+ device_map="auto",
271
+ trust_remote_code=True
272
+ )
273
+ processor = AutoProcessor.from_pretrained(
274
+ model_id,
275
+ trust_remote_code=True
276
+ )
277
+
278
+ # Global state variables
279
+ device = "cuda" if torch.cuda.is_available() else "cpu"
280
+
281
+ # PDF handling state
282
+ pdf_cache = {
283
+ "images": [],
284
+ "current_page": 0,
285
+ "total_pages": 0,
286
+ "file_type": None,
287
+ "is_parsed": False,
288
+ "results": []
289
+ }
290
+
291
+ # Processing state
292
+ processing_results = {
293
+ 'original_image': None,
294
+ 'processed_image': None,
295
+ 'layout_result': None,
296
+ 'markdown_content': None,
297
+ 'raw_output': None,
298
+ }
299
+ def inference(image: Image.Image, prompt: str, max_new_tokens: int = 24000) -> str:
300
+ """Run inference on an image with the given prompt"""
301
+ try:
302
+ if model is None or processor is None:
303
+ raise RuntimeError("Model not loaded. Please check model initialization.")
304
+
305
+ # Prepare messages in the expected format
306
+ messages = [
307
+ {
308
+ "role": "user",
309
+ "content": [
310
+ {
311
+ "type": "image",
312
+ "image": image
313
+ },
314
+ {"type": "text", "text": prompt}
315
+ ]
316
+ }
317
+ ]
318
+
319
+ # Apply chat template
320
+ text = processor.apply_chat_template(
321
+ messages,
322
+ tokenize=False,
323
+ add_generation_prompt=True
324
+ )
325
+
326
+ # Process vision information
327
+ image_inputs, video_inputs = process_vision_info(messages)
328
+
329
+ # Prepare inputs
330
+ inputs = processor(
331
+ text=[text],
332
+ images=image_inputs,
333
+ videos=video_inputs,
334
+ padding=True,
335
+ return_tensors="pt",
336
+ )
337
+
338
+ # Move to device
339
+ inputs = inputs.to(device)
340
+
341
+ # Generate output
342
+ with torch.no_grad():
343
+ generated_ids = model.generate(
344
+ **inputs,
345
+ max_new_tokens=max_new_tokens,
346
+ do_sample=False,
347
+ temperature=0.1
348
+ )
349
+
350
+ # Decode output
351
+ generated_ids_trimmed = [
352
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
353
+ ]
354
+
355
+ output_text = processor.batch_decode(
356
+ generated_ids_trimmed,
357
+ skip_special_tokens=True,
358
+ clean_up_tokenization_spaces=False
359
+ )
360
+
361
+ return output_text[0] if output_text else ""
362
+
363
+ except Exception as e:
364
+ print(f"Error during inference: {e}")
365
+ traceback.print_exc()
366
+ return f"Error during inference: {str(e)}"
367
+
368
+
369
+ def process_image(
370
+ image: Image.Image,
371
+ prompt_mode: str,
372
+ min_pixels: Optional[int] = None,
373
+ max_pixels: Optional[int] = None
374
+ ) -> Dict[str, Any]:
375
+ """Process a single image with the specified prompt mode"""
376
+ try:
377
+ # Resize image if needed
378
+ if min_pixels is not None or max_pixels is not None:
379
+ image = fetch_image(image, min_pixels=min_pixels, max_pixels=max_pixels)
380
+
381
+ # Get prompt
382
+ prompt = dict_promptmode_to_prompt[prompt_mode]
383
+
384
+ # Run inference
385
+ raw_output = inference(image, prompt)
386
+
387
+ # Process results based on prompt mode
388
+ result = {
389
+ 'original_image': image,
390
+ 'raw_output': raw_output,
391
+ 'prompt_mode': prompt_mode,
392
+ 'processed_image': image,
393
+ 'layout_result': None,
394
+ 'markdown_content': None
395
+ }
396
+
397
+ # For layout analysis prompts, try to parse JSON and create visualizations
398
+ if prompt_mode in ['prompt_layout_all_en', 'prompt_layout_only_en']:
399
+ try:
400
+ # Try to parse JSON output
401
+ layout_data = json.loads(raw_output)
402
+ result['layout_result'] = layout_data
403
+
404
+ # Create visualization with bounding boxes
405
+ try:
406
+ processed_image = draw_layout_on_image(image, layout_data)
407
+ result['processed_image'] = processed_image
408
+ except Exception as e:
409
+ print(f"Error drawing layout: {e}")
410
+ result['processed_image'] = image
411
+
412
+ # Generate markdown if text is available
413
+ if prompt_mode == 'prompt_layout_all_en':
414
+ try:
415
+ markdown_content = layoutjson2md(image, layout_data, text_key='text')
416
+ result['markdown_content'] = markdown_content
417
+ except Exception as e:
418
+ print(f"Error generating markdown: {e}")
419
+ result['markdown_content'] = raw_output
420
+
421
+ except json.JSONDecodeError:
422
+ print("Failed to parse JSON output, using raw output")
423
+ result['markdown_content'] = raw_output
424
+ else:
425
+ # For OCR prompts, use raw output as markdown
426
+ result['markdown_content'] = raw_output
427
+
428
+ return result
429
+
430
+ except Exception as e:
431
+ print(f"Error processing image: {e}")
432
+ traceback.print_exc()
433
+ return {
434
+ 'original_image': image,
435
+ 'raw_output': f"Error processing image: {str(e)}",
436
+ 'prompt_mode': prompt_mode,
437
+ 'processed_image': image,
438
+ 'layout_result': None,
439
+ 'markdown_content': f"Error processing image: {str(e)}"
440
+ }
441
+
442
+
443
+ def load_file_for_preview(file_path: str) -> Tuple[Optional[Image.Image], str]:
444
+ """Load file for preview (supports PDF and images)"""
445
+ global pdf_cache
446
+
447
+ if not file_path or not os.path.exists(file_path):
448
+ return None, "No file selected"
449
+
450
+ file_ext = os.path.splitext(file_path)[1].lower()
451
+
452
+ try:
453
+ if file_ext == '.pdf':
454
+ # Load PDF pages
455
+ images = load_images_from_pdf(file_path)
456
+ if not images:
457
+ return None, "Failed to load PDF"
458
+
459
+ pdf_cache.update({
460
+ "images": images,
461
+ "current_page": 0,
462
+ "total_pages": len(images),
463
+ "file_type": "pdf",
464
+ "is_parsed": False,
465
+ "results": []
466
+ })
467
+
468
+ return images[0], f"Page 1 / {len(images)}"
469
+
470
+ elif file_ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']:
471
+ # Load single image
472
+ image = Image.open(file_path).convert('RGB')
473
+
474
+ pdf_cache.update({
475
+ "images": [image],
476
+ "current_page": 0,
477
+ "total_pages": 1,
478
+ "file_type": "image",
479
+ "is_parsed": False,
480
+ "results": []
481
+ })
482
+
483
+ return image, "Page 1 / 1"
484
+ else:
485
+ return None, f"Unsupported file format: {file_ext}"
486
+
487
+ except Exception as e:
488
+ print(f"Error loading file: {e}")
489
+ return None, f"Error loading file: {str(e)}"
490
+
491
+
492
+ def turn_page(direction: str) -> Tuple[Optional[Image.Image], str, str]:
493
+ """Navigate through PDF pages"""
494
+ global pdf_cache
495
+
496
+ if not pdf_cache["images"]:
497
+ return None, "No file loaded", "No results yet"
498
+
499
+ if direction == "prev":
500
+ pdf_cache["current_page"] = max(0, pdf_cache["current_page"] - 1)
501
+ elif direction == "next":
502
+ pdf_cache["current_page"] = min(
503
+ pdf_cache["total_pages"] - 1,
504
+ pdf_cache["current_page"] + 1
505
+ )
506
+
507
+ index = pdf_cache["current_page"]
508
+ current_image = pdf_cache["images"][index]
509
+ page_info = f"Page {index + 1} / {pdf_cache['total_pages']}"
510
+
511
+ # Get results for current page if available
512
+ current_result = ""
513
+ if (pdf_cache["is_parsed"] and
514
+ index < len(pdf_cache["results"]) and
515
+ pdf_cache["results"][index]):
516
+ result = pdf_cache["results"][index]
517
+ if result.get('markdown_content'):
518
+ current_result = result['markdown_content']
519
+ else:
520
+ current_result = result.get('raw_output', 'No content available')
521
+ else:
522
+ current_result = "Page not processed yet"
523
+
524
+ return current_image, page_info, current_result
525
+
526
+
527
+ def create_gradio_interface():
528
+ """Create the Gradio interface"""
529
+
530
+ # Custom CSS
531
+ css = """
532
+ .main-container {
533
+ max-width: 1400px;
534
+ margin: 0 auto;
535
+ }
536
+
537
+ .header-text {
538
+ text-align: center;
539
+ color: #2c3e50;
540
+ margin-bottom: 20px;
541
+ }
542
+
543
+ .process-button {
544
+ background: linear-gradient(45deg, #667eea 0%, #764ba2 100%) !important;
545
+ border: none !important;
546
+ color: white !important;
547
+ font-weight: bold !important;
548
+ }
549
+
550
+ .process-button:hover {
551
+ transform: translateY(-2px) !important;
552
+ box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important;
553
+ }
554
+
555
+ .info-box {
556
+ background: #f8f9fa;
557
+ border: 1px solid #dee2e6;
558
+ border-radius: 8px;
559
+ padding: 15px;
560
+ margin: 10px 0;
561
+ }
562
+
563
+ .page-info {
564
+ text-align: center;
565
+ padding: 8px 16px;
566
+ background: #e9ecef;
567
+ border-radius: 20px;
568
+ font-weight: bold;
569
+ margin: 10px 0;
570
+ }
571
+
572
+ .model-status {
573
+ padding: 10px;
574
+ border-radius: 8px;
575
+ margin: 10px 0;
576
+ text-align: center;
577
+ font-weight: bold;
578
+ }
579
+
580
+ .status-loading {
581
+ background: #fff3cd;
582
+ color: #856404;
583
+ border: 1px solid #ffeaa7;
584
+ }
585
+
586
+ .status-ready {
587
+ background: #d1edff;
588
+ color: #0c5460;
589
+ border: 1px solid #b8daff;
590
+ }
591
+
592
+ .status-error {
593
+ background: #f8d7da;
594
+ color: #721c24;
595
+ border: 1px solid #f5c6cb;
596
+ }
597
+ """
598
+
599
+ with gr.Blocks(theme=gr.themes.Soft(), css=css, title="Dots.OCR Demo") as demo:
600
+
601
+ # Header
602
+ gr.HTML("""
603
+ <div class="header-text">
604
+ <h1>πŸ” Dots.OCR Hugging Face Demo</h1>
605
+ <p>Advanced OCR and Document Layout Analysis powered by Hugging Face Transformers</p>
606
+ </div>
607
+ """)
608
+
609
+ # Model status
610
+ model_status = gr.HTML(
611
+ '<div class="model-status status-loading">πŸ”„ Initializing model...</div>',
612
+ elem_id="model_status"
613
+ )
614
+
615
+ # Main interface
616
+ with gr.Row():
617
+ # Left column - Input and controls
618
+ with gr.Column(scale=1):
619
+ gr.Markdown("### πŸ“ Input")
620
+
621
+ # File input
622
+ file_input = gr.File(
623
+ label="Upload Image or PDF",
624
+ file_types=[".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".pdf"],
625
+ type="filepath"
626
+ )
627
+
628
+ # Image preview
629
+ image_preview = gr.Image(
630
+ label="Preview",
631
+ type="pil",
632
+ interactive=False,
633
+ height=300
634
+ )
635
+
636
+ # Page navigation for PDFs
637
+ with gr.Row():
638
+ prev_page_btn = gr.Button("β—€ Previous", size="sm")
639
+ page_info = gr.HTML('<div class="page-info">No file loaded</div>')
640
+ next_page_btn = gr.Button("Next β–Ά", size="sm")
641
+
642
+ gr.Markdown("### βš™οΈ Settings")
643
+
644
+ # Prompt mode selection
645
+ prompt_mode = gr.Dropdown(
646
+ choices=list(dict_promptmode_to_prompt.keys()),
647
+ value="prompt_layout_all_en",
648
+ label="Task Mode",
649
+ info="Choose the type of analysis to perform"
650
+ )
651
+
652
+ # Advanced settings
653
+ with gr.Accordion("Advanced Settings", open=False):
654
+ max_new_tokens = gr.Slider(
655
+ minimum=1000,
656
+ maximum=32000,
657
+ value=24000,
658
+ step=1000,
659
+ label="Max New Tokens",
660
+ info="Maximum number of tokens to generate"
661
+ )
662
+
663
+ min_pixels = gr.Number(
664
+ value=MIN_PIXELS,
665
+ label="Min Pixels",
666
+ info="Minimum image resolution"
667
+ )
668
+
669
+ max_pixels = gr.Number(
670
+ value=MAX_PIXELS,
671
+ label="Max Pixels",
672
+ info="Maximum image resolution"
673
+ )
674
+
675
+ # Process button
676
+ process_btn = gr.Button(
677
+ "πŸš€ Process Document",
678
+ variant="primary",
679
+ elem_classes=["process-button"],
680
+ size="lg"
681
+ )
682
+
683
+ # Clear button
684
+ clear_btn = gr.Button("πŸ—‘οΈ Clear All", variant="secondary")
685
+
686
+ # Right column - Results
687
+ with gr.Column(scale=2):
688
+ gr.Markdown("### πŸ“Š Results")
689
+
690
+ # Results tabs
691
+ with gr.Tabs():
692
+ # Processed image tab
693
+ with gr.Tab("πŸ–ΌοΈ Processed Image"):
694
+ processed_image = gr.Image(
695
+ label="Image with Layout Detection",
696
+ type="pil",
697
+ interactive=False,
698
+ height=500
699
+ )
700
+
701
+ # Markdown output tab
702
+ with gr.Tab("πŸ“ Extracted Content"):
703
+ markdown_output = gr.Markdown(
704
+ value="Click 'Process Document' to see extracted content...",
705
+ height=500
706
+ )
707
+
708
+ # Raw output tab
709
+ with gr.Tab("πŸ”§ Raw Output"):
710
+ raw_output = gr.Textbox(
711
+ label="Raw Model Output",
712
+ lines=20,
713
+ max_lines=30,
714
+ interactive=False
715
+ )
716
+
717
+ # JSON layout tab
718
+ with gr.Tab("πŸ“‹ Layout JSON"):
719
+ json_output = gr.JSON(
720
+ label="Layout Analysis Results",
721
+ value=None
722
+ )
723
+
724
+ # Prompt display
725
+ gr.Markdown("### πŸ’¬ Current Prompt")
726
+ prompt_display = gr.Textbox(
727
+ value=dict_promptmode_to_prompt["prompt_layout_all_en"],
728
+ label="Prompt Text",
729
+ lines=8,
730
+ interactive=False,
731
+ info="This is the prompt that will be sent to the model"
732
+ )
733
+
734
+ # Event handlers
735
+ def load_model_on_startup():
736
+ """Load model when the interface starts"""
737
+ try:
738
+ # Model is already loaded at script level
739
+ return '<div class="model-status status-ready">βœ… Model loaded successfully!</div>'
740
+ except Exception as e:
741
+ return f'<div class="model-status status-error">❌ Error: {str(e)}</div>'
742
+
743
+ def process_document(file_path, prompt_mode_val, max_tokens, min_pix, max_pix):
744
+ """Process the uploaded document"""
745
+ global pdf_cache
746
+
747
+ try:
748
+ if not file_path:
749
+ return (
750
+ None,
751
+ "Please upload a file first.",
752
+ "No file uploaded",
753
+ None,
754
+ '<div class="model-status status-error">❌ No file uploaded</div>'
755
+ )
756
+
757
+ if model is None:
758
+ return (
759
+ None,
760
+ "Model not loaded. Please refresh the page and try again.",
761
+ "Model not loaded",
762
+ None,
763
+ '<div class="model-status status-error">❌ Model not loaded</div>'
764
+ )
765
+
766
+ # Load and preview file
767
+ image, page_info = load_file_for_preview(file_path)
768
+ if image is None:
769
+ return (
770
+ None,
771
+ page_info,
772
+ "Failed to load file",
773
+ None,
774
+ '<div class="model-status status-error">❌ Failed to load file</div>'
775
+ )
776
+
777
+ # Process the image(s)
778
+ if pdf_cache["file_type"] == "pdf":
779
+ # Process all pages for PDF
780
+ all_results = []
781
+ all_markdown = []
782
+
783
+ for i, img in enumerate(pdf_cache["images"]):
784
+ result = process_image(
785
+ img,
786
+ prompt_mode_val,
787
+ min_pixels=int(min_pix) if min_pix else None,
788
+ max_pixels=int(max_pix) if max_pix else None
789
+ )
790
+ all_results.append(result)
791
+ if result.get('markdown_content'):
792
+ all_markdown.append(f"## Page {i+1}\n\n{result['markdown_content']}")
793
+
794
+ pdf_cache["results"] = all_results
795
+ pdf_cache["is_parsed"] = True
796
+
797
+ # Show results for first page
798
+ first_result = all_results[0]
799
+ combined_markdown = "\n\n---\n\n".join(all_markdown)
800
+
801
+ return (
802
+ first_result['processed_image'],
803
+ combined_markdown,
804
+ first_result['raw_output'],
805
+ first_result['layout_result'],
806
+ '<div class="model-status status-ready">βœ… Processing completed!</div>'
807
+ )
808
+ else:
809
+ # Process single image
810
+ result = process_image(
811
+ image,
812
+ prompt_mode_val,
813
+ min_pixels=int(min_pix) if min_pix else None,
814
+ max_pixels=int(max_pix) if max_pix else None
815
+ )
816
+
817
+ pdf_cache["results"] = [result]
818
+ pdf_cache["is_parsed"] = True
819
+
820
+ return (
821
+ result['processed_image'],
822
+ result['markdown_content'] or "No content extracted",
823
+ result['raw_output'],
824
+ result['layout_result'],
825
+ '<div class="model-status status-ready">βœ… Processing completed!</div>'
826
+ )
827
+
828
+ except Exception as e:
829
+ error_msg = f"Error processing document: {str(e)}"
830
+ print(error_msg)
831
+ traceback.print_exc()
832
+ return (
833
+ None,
834
+ error_msg,
835
+ error_msg,
836
+ None,
837
+ f'<div class="model-status status-error">❌ {error_msg}</div>'
838
+ )
839
+
840
+ def update_prompt_display(mode):
841
+ """Update the prompt display when mode changes"""
842
+ return dict_promptmode_to_prompt[mode]
843
+
844
+ def handle_file_upload(file_path):
845
+ """Handle file upload and show preview"""
846
+ if not file_path:
847
+ return None, "No file loaded"
848
+
849
+ image, page_info = load_file_for_preview(file_path)
850
+ return image, page_info
851
+
852
+ def handle_page_turn(direction):
853
+ """Handle page navigation"""
854
+ image, page_info, result = turn_page(direction)
855
+ return image, page_info, result
856
+
857
+ def clear_all():
858
+ """Clear all data and reset interface"""
859
+ global pdf_cache, processing_results
860
+
861
+ pdf_cache = {
862
+ "images": [],
863
+ "current_page": 0,
864
+ "total_pages": 0,
865
+ "file_type": None,
866
+ "is_parsed": False,
867
+ "results": []
868
+ }
869
+ processing_results = {
870
+ 'original_image': None,
871
+ 'processed_image': None,
872
+ 'layout_result': None,
873
+ 'markdown_content': None,
874
+ 'raw_output': None,
875
+ }
876
+
877
+ return (
878
+ None, # file_input
879
+ None, # image_preview
880
+ "No file loaded", # page_info
881
+ None, # processed_image
882
+ "Click 'Process Document' to see extracted content...", # markdown_output
883
+ "", # raw_output
884
+ None, # json_output
885
+ '<div class="model-status status-ready">βœ… Interface cleared</div>' # model_status
886
+ )
887
+
888
+ # Wire up event handlers
889
+ demo.load(load_model_on_startup, outputs=[model_status])
890
+
891
+ file_input.change(
892
+ handle_file_upload,
893
+ inputs=[file_input],
894
+ outputs=[image_preview, page_info]
895
+ )
896
+
897
+ prev_page_btn.click(
898
+ lambda: handle_page_turn("prev"),
899
+ outputs=[image_preview, page_info, markdown_output]
900
+ )
901
+
902
+ next_page_btn.click(
903
+ lambda: handle_page_turn("next"),
904
+ outputs=[image_preview, page_info, markdown_output]
905
+ )
906
+
907
+ prompt_mode.change(
908
+ update_prompt_display,
909
+ inputs=[prompt_mode],
910
+ outputs=[prompt_display]
911
+ )
912
+
913
+ process_btn.click(
914
+ process_document,
915
+ inputs=[file_input, prompt_mode, max_new_tokens, min_pixels, max_pixels],
916
+ outputs=[processed_image, markdown_output, raw_output, json_output, model_status]
917
+ )
918
+
919
+ clear_btn.click(
920
+ clear_all,
921
+ outputs=[
922
+ file_input, image_preview, page_info, processed_image,
923
+ markdown_output, raw_output, json_output, model_status
924
+ ]
925
+ )
926
+
927
+ return demo
928
+
929
+
930
+ if __name__ == "__main__":
931
+ # Create and launch the interface
932
+ demo = create_gradio_interface()
933
+ demo.queue(max_size=10).launch(
934
+ server_name="0.0.0.0",
935
+ server_port=7860,
936
+ share=False,
937
+ debug=True,
938
+ show_error=True
939
+ )
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ huggingface_hub
2
+ torch
3
+ transformers
4
+ qwen_vl_utils
5
+ Pillow
6
+ PyMuPDF
7
+ accelerate
8
+ https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.0.8/flash_attn-2.7.4.post1+cu126torch2.7-cp310-cp310-linux_x86_64.whl