Commit
Β·
e6c24fd
1
Parent(s):
89faab5
Refactor app.py: Clean up unused prompts and reorganize imports for clarity
Browse files
app.py
CHANGED
@@ -1,30 +1,17 @@
|
|
1 |
-
#!/usr/bin/env python3
|
2 |
-
"""
|
3 |
-
Dots.OCR Gradio Demo Application
|
4 |
-
|
5 |
-
A Gradio-based web interface for demonstrating the Dots.OCR model using Hugging Face transformers.
|
6 |
-
This application provides OCR and layout analysis capabilities for documents and images.
|
7 |
-
"""
|
8 |
-
|
9 |
-
import os
|
10 |
import json
|
11 |
-
import traceback
|
12 |
import math
|
|
|
|
|
13 |
from io import BytesIO
|
14 |
-
from typing import
|
|
|
|
|
|
|
15 |
import requests
|
16 |
-
|
17 |
-
# Set LOCAL_RANK for transformers
|
18 |
-
if "LOCAL_RANK" not in os.environ:
|
19 |
-
os.environ["LOCAL_RANK"] = "0"
|
20 |
-
|
21 |
import torch
|
22 |
-
import gradio as gr
|
23 |
from PIL import Image, ImageDraw, ImageFont
|
24 |
-
from transformers import AutoModelForCausalLM, AutoProcessor
|
25 |
from qwen_vl_utils import process_vision_info
|
26 |
-
import
|
27 |
-
|
28 |
|
29 |
# Constants
|
30 |
MIN_PIXELS = 3136
|
@@ -32,8 +19,7 @@ MAX_PIXELS = 11289600
|
|
32 |
IMAGE_FACTOR = 28
|
33 |
|
34 |
# Prompts
|
35 |
-
|
36 |
-
"prompt_layout_all_en": """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
|
37 |
|
38 |
1. Bbox format: [x1, y1, x2, y2]
|
39 |
|
@@ -50,15 +36,7 @@ dict_promptmode_to_prompt = {
|
|
50 |
- All layout elements must be sorted according to human reading order.
|
51 |
|
52 |
5. Final Output: The entire output must be a single JSON object.
|
53 |
-
"""
|
54 |
-
|
55 |
-
"prompt_layout_only_en": """Please output the layout information from this PDF image, including each layout's bbox and its category. The bbox should be in the format [x1, y1, x2, y2]. The layout categories for the PDF document include ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title']. Do not output the corresponding text. The layout result should be in JSON format.""",
|
56 |
-
|
57 |
-
"prompt_ocr": """Extract the text content from this image.""",
|
58 |
-
|
59 |
-
"prompt_grounding_ocr": """Extract text from the given bounding box on the image (format: [x1, y1, x2, y2]).\nBounding Box:\n""",
|
60 |
-
}
|
61 |
-
|
62 |
|
63 |
# Utility functions
|
64 |
def round_by_factor(number: int, factor: int) -> int:
|
@@ -263,15 +241,21 @@ def layoutjson2md(image: Image.Image, layout_data: List[Dict], text_key: str = '
|
|
263 |
|
264 |
# Initialize model and processor at script level
|
265 |
model_id = "rednote-hilab/dots.ocr"
|
|
|
|
|
|
|
|
|
|
|
|
|
266 |
model = AutoModelForCausalLM.from_pretrained(
|
267 |
-
|
268 |
attn_implementation="flash_attention_2",
|
269 |
torch_dtype=torch.bfloat16,
|
270 |
device_map="auto",
|
271 |
trust_remote_code=True
|
272 |
)
|
273 |
processor = AutoProcessor.from_pretrained(
|
274 |
-
|
275 |
trust_remote_code=True
|
276 |
)
|
277 |
|
@@ -378,9 +362,6 @@ def process_image(
|
|
378 |
if min_pixels is not None or max_pixels is not None:
|
379 |
image = fetch_image(image, min_pixels=min_pixels, max_pixels=max_pixels)
|
380 |
|
381 |
-
# Get prompt
|
382 |
-
prompt = dict_promptmode_to_prompt[prompt_mode]
|
383 |
-
|
384 |
# Run inference
|
385 |
raw_output = inference(image, prompt)
|
386 |
|
@@ -640,15 +621,7 @@ def create_gradio_interface():
|
|
640 |
next_page_btn = gr.Button("Next βΆ", size="sm")
|
641 |
|
642 |
gr.Markdown("### βοΈ Settings")
|
643 |
-
|
644 |
-
# Prompt mode selection
|
645 |
-
prompt_mode = gr.Dropdown(
|
646 |
-
choices=list(dict_promptmode_to_prompt.keys()),
|
647 |
-
value="prompt_layout_all_en",
|
648 |
-
label="Task Mode",
|
649 |
-
info="Choose the type of analysis to perform"
|
650 |
-
)
|
651 |
-
|
652 |
# Advanced settings
|
653 |
with gr.Accordion("Advanced Settings", open=False):
|
654 |
max_new_tokens = gr.Slider(
|
@@ -721,16 +694,6 @@ def create_gradio_interface():
|
|
721 |
value=None
|
722 |
)
|
723 |
|
724 |
-
# Prompt display
|
725 |
-
gr.Markdown("### π¬ Current Prompt")
|
726 |
-
prompt_display = gr.Textbox(
|
727 |
-
value=dict_promptmode_to_prompt["prompt_layout_all_en"],
|
728 |
-
label="Prompt Text",
|
729 |
-
lines=8,
|
730 |
-
interactive=False,
|
731 |
-
info="This is the prompt that will be sent to the model"
|
732 |
-
)
|
733 |
-
|
734 |
# Event handlers
|
735 |
def load_model_on_startup():
|
736 |
"""Load model when the interface starts"""
|
@@ -839,8 +802,8 @@ def create_gradio_interface():
|
|
839 |
|
840 |
def update_prompt_display(mode):
|
841 |
"""Update the prompt display when mode changes"""
|
842 |
-
return
|
843 |
-
|
844 |
def handle_file_upload(file_path):
|
845 |
"""Handle file upload and show preview"""
|
846 |
if not file_path:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import json
|
|
|
2 |
import math
|
3 |
+
import os
|
4 |
+
import traceback
|
5 |
from io import BytesIO
|
6 |
+
from typing import Any, Dict, List, Optional, Tuple
|
7 |
+
from huggingface_hub import snapshot_download
|
8 |
+
import fitz # PyMuPDF
|
9 |
+
import gradio as gr
|
10 |
import requests
|
|
|
|
|
|
|
|
|
|
|
11 |
import torch
|
|
|
12 |
from PIL import Image, ImageDraw, ImageFont
|
|
|
13 |
from qwen_vl_utils import process_vision_info
|
14 |
+
from transformers import AutoModelForCausalLM, AutoProcessor
|
|
|
15 |
|
16 |
# Constants
|
17 |
MIN_PIXELS = 3136
|
|
|
19 |
IMAGE_FACTOR = 28
|
20 |
|
21 |
# Prompts
|
22 |
+
prompt = """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
|
|
|
23 |
|
24 |
1. Bbox format: [x1, y1, x2, y2]
|
25 |
|
|
|
36 |
- All layout elements must be sorted according to human reading order.
|
37 |
|
38 |
5. Final Output: The entire output must be a single JSON object.
|
39 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
# Utility functions
|
42 |
def round_by_factor(number: int, factor: int) -> int:
|
|
|
241 |
|
242 |
# Initialize model and processor at script level
|
243 |
model_id = "rednote-hilab/dots.ocr"
|
244 |
+
model_path = "./models/dots-ocr-local"
|
245 |
+
snapshot_download(
|
246 |
+
repo_id=model_id,
|
247 |
+
local_dir=model_path,
|
248 |
+
local_dir_use_symlinks=False, # Recommended to set to False to avoid symlink issues
|
249 |
+
)
|
250 |
model = AutoModelForCausalLM.from_pretrained(
|
251 |
+
model_path,
|
252 |
attn_implementation="flash_attention_2",
|
253 |
torch_dtype=torch.bfloat16,
|
254 |
device_map="auto",
|
255 |
trust_remote_code=True
|
256 |
)
|
257 |
processor = AutoProcessor.from_pretrained(
|
258 |
+
model_path,
|
259 |
trust_remote_code=True
|
260 |
)
|
261 |
|
|
|
362 |
if min_pixels is not None or max_pixels is not None:
|
363 |
image = fetch_image(image, min_pixels=min_pixels, max_pixels=max_pixels)
|
364 |
|
|
|
|
|
|
|
365 |
# Run inference
|
366 |
raw_output = inference(image, prompt)
|
367 |
|
|
|
621 |
next_page_btn = gr.Button("Next βΆ", size="sm")
|
622 |
|
623 |
gr.Markdown("### βοΈ Settings")
|
624 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
625 |
# Advanced settings
|
626 |
with gr.Accordion("Advanced Settings", open=False):
|
627 |
max_new_tokens = gr.Slider(
|
|
|
694 |
value=None
|
695 |
)
|
696 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
697 |
# Event handlers
|
698 |
def load_model_on_startup():
|
699 |
"""Load model when the interface starts"""
|
|
|
802 |
|
803 |
def update_prompt_display(mode):
|
804 |
"""Update the prompt display when mode changes"""
|
805 |
+
return prompt
|
806 |
+
|
807 |
def handle_file_upload(file_path):
|
808 |
"""Handle file upload and show preview"""
|
809 |
if not file_path:
|