prithivMLmods commited on
Commit
3960214
·
verified ·
1 Parent(s): 821d99a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +435 -292
app.py CHANGED
@@ -1,315 +1,458 @@
1
- import os
2
- import time
3
- import threading
4
- import gradio as gr
5
  import spaces
6
- import torch
7
- import numpy as np
8
- from PIL import Image
9
- import cv2
10
- from transformers import (
11
- Qwen2_5_VLForConditionalGeneration,
12
- Qwen2VLForConditionalGeneration,
13
- Glm4vForConditionalGeneration,
14
- AutoProcessor,
15
- TextIteratorStreamer,
16
- )
17
- from qwen_vl_utils import process_vision_info
18
-
19
- # Constants for text generation
20
- MAX_MAX_NEW_TOKENS = 4096
21
- DEFAULT_MAX_NEW_TOKENS = 3584
22
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
23
-
24
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
25
 
26
- # Load Camel-Doc-OCR-062825
27
- MODEL_ID_M = "prithivMLmods/Camel-Doc-OCR-062825"
28
- processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
29
- model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
30
- MODEL_ID_M,
31
- trust_remote_code=True,
32
- torch_dtype=torch.float16
33
- ).to(device).eval()
34
 
35
- # Load Qwen2.5-VL-3B-Instruct-abliterated
36
- MODEL_ID_X = "huihui-ai/Qwen2.5-VL-3B-Instruct-abliterated"
37
- processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
38
- model_x = Qwen2_5_VLForConditionalGeneration.from_pretrained(
39
- MODEL_ID_X,
40
- trust_remote_code=True,
41
- torch_dtype=torch.float16
42
- ).to(device).eval()
43
 
44
- # Load Megalodon-OCR-Sync-0713
45
- MODEL_ID_T = "prithivMLmods/Megalodon-OCR-Sync-0713"
46
- processor_t = AutoProcessor.from_pretrained(MODEL_ID_T, trust_remote_code=True)
47
- model_t = Qwen2_5_VLForConditionalGeneration.from_pretrained(
48
- MODEL_ID_T,
49
- trust_remote_code=True,
50
- torch_dtype=torch.float16
51
- ).to(device).eval()
52
 
53
- # Load GLM-4.1V-9B-Thinking
54
- MODEL_ID_S = "zai-org/GLM-4.1V-9B-Thinking"
55
- processor_s = AutoProcessor.from_pretrained(MODEL_ID_S, trust_remote_code=True)
56
- model_s = Glm4vForConditionalGeneration.from_pretrained(
57
- MODEL_ID_S,
58
- trust_remote_code=True,
59
- torch_dtype=torch.float16
60
- ).to(device).eval()
61
 
62
- # Load DeepEyes-7B
63
- MODEL_ID_Y = "ChenShawn/DeepEyes-7B"
64
- processor_y = AutoProcessor.from_pretrained(MODEL_ID_Y, trust_remote_code=True)
65
- model_y = Qwen2_5_VLForConditionalGeneration.from_pretrained(
66
- MODEL_ID_Y,
67
- trust_remote_code=True,
68
- torch_dtype=torch.float16
69
- ).to(device).eval()
 
 
 
 
70
 
71
- def downsample_video(video_path):
72
- """
73
- Downsample a video to evenly spaced frames, returning each as a PIL image with its timestamp.
74
- """
75
- vidcap = cv2.VideoCapture(video_path)
76
- total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
77
- fps = vidcap.get(cv2.CAP_PROP_FPS)
78
- frames = []
79
- frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
80
- for i in frame_indices:
81
- vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
82
- success, image = vidcap.read()
83
- if success:
84
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
85
- pil_image = Image.fromarray(image)
86
- timestamp = round(i / fps, 2)
87
- frames.append((pil_image, timestamp))
88
- vidcap.release()
89
- return frames
90
 
91
- @spaces.GPU
92
- def generate_image(model_name: str, text: str, image: Image.Image,
93
- max_new_tokens: int = 1024,
94
- temperature: float = 0.6,
95
- top_p: float = 0.9,
96
- top_k: int = 50,
97
- repetition_penalty: float = 1.2):
98
- """
99
- Generate responses using the selected model for image input.
100
- """
101
- if model_name == "Camel-Doc-OCR-062825":
102
- processor = processor_m
103
- model = model_m
104
- elif model_name == "Megalodon-OCR-Sync-0713":
105
- processor = processor_t
106
- model = model_t
107
- elif model_name == "GLM-4.1V-9B-Thinking":
108
- processor = processor_s
109
- model = model_s
110
- elif model_name == "DeepEyes-7B-Thinking":
111
- processor = processor_y
112
- model = model_y
113
- elif model_name == "Qwen2.5-VL-3B-Instruct-abliterated":
114
- processor = processor_x
115
- model = model_x
116
- else:
117
- yield "Invalid model selected.", "Invalid model selected."
118
- return
119
 
120
- if image is None:
121
- yield "Please upload an image.", "Please upload an image."
122
- return
123
 
124
- messages = [{
125
- "role": "user",
126
- "content": [
127
- {"type": "image", "image": image},
128
- {"type": "text", "text": text},
129
- ]
130
- }]
131
- prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
132
- inputs = processor(
133
- text=[prompt_full],
134
- images=[image],
135
- return_tensors="pt",
136
- padding=True,
137
- truncation=False,
138
- max_length=MAX_INPUT_TOKEN_LENGTH
139
- ).to(device)
140
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
141
- generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
142
- thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
143
- thread.start()
144
- buffer = ""
145
- for new_text in streamer:
146
- buffer += new_text
147
- time.sleep(0.01)
148
- yield buffer, buffer
149
 
150
- @spaces.GPU
151
- def generate_video(model_name: str, text: str, video_path: str,
152
- max_new_tokens: int = 1024,
153
- temperature: float = 0.6,
154
- top_p: float = 0.9,
155
- top_k: int = 50,
156
- repetition_penalty: float = 1.2):
157
- """
158
- Generate responses using the selected model for video input.
159
- """
160
- if model_name == "Camel-Doc-OCR-062825":
161
- processor = processor_m
162
- model = model_m
163
- elif model_name == "Megalodon-OCR-Sync-0713":
164
- processor = processor_t
165
- model = model_t
166
- elif model_name == "GLM-4.1V-9B-Thinking":
167
- processor = processor_s
168
- model = model_s
169
- elif model_name == "DeepEyes-7B-Thinking":
170
- processor = processor_y
171
- model = model_y
172
- elif model_name == "Qwen2.5-VL-3B-Instruct-abliterated":
173
- processor = processor_x
174
- model = model_x
175
  else:
176
- yield "Invalid model selected.", "Invalid model selected."
177
- return
 
 
 
 
 
178
 
179
- if video_path is None:
180
- yield "Please upload a video.", "Please upload a video."
181
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
- frames = downsample_video(video_path)
184
- messages = [
185
- {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
186
- {"role": "user", "content": [{"type": "text", "text": text}]}
187
- ]
188
- for frame in frames:
189
- image, timestamp = frame
190
- messages[1]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
191
- messages[1]["content"].append({"type": "image", "image": image})
192
- inputs = processor.apply_chat_template(
193
- messages,
194
- tokenize=True,
195
- add_generation_prompt=True,
196
- return_dict=True,
197
- return_tensors="pt",
198
- truncation=False,
199
- max_length=MAX_INPUT_TOKEN_LENGTH
200
- ).to(device)
201
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
202
- generation_kwargs = {
203
- **inputs,
204
- "streamer": streamer,
205
- "max_new_tokens": max_new_tokens,
206
- "do_sample": True,
207
- "temperature": temperature,
208
- "top_p": top_p,
209
- "top_k": top_k,
210
- "repetition_penalty": repetition_penalty,
211
  }
212
- thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
213
- thread.start()
214
- buffer = ""
215
- for new_text in streamer:
216
- buffer += new_text
217
- buffer = buffer.replace("<|im_end|>", "")
218
- time.sleep(0.01)
219
- yield buffer, buffer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
 
221
- # Define examples for image and video inference
222
- image_examples = [
223
- ["explain the movie shot in detail.", "images/5.jpg"],
224
- ["convert this page to doc [text] precisely for markdown.", "images/1.png"],
225
- ["convert this page to doc [table] precisely for markdown.", "images/2.png"],
226
- ["explain the movie shot in detail.", "images/3.png"],
227
- ["fill the correct numbers.", "images/4.png"]
228
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
- video_examples = [
231
- ["explain the video in detail.", "videos/b.mp4"],
232
- ["explain the ad video in detail.", "videos/a.mp4"]
233
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
- # Updated CSS with model choice highlighting
236
- css = """
237
- .submit-btn {
238
- background-color: #2980b9 !important;
239
- color: white !important;
240
- }
241
- .submit-btn:hover {
242
- background-color: #3498db !important;
243
- }
244
- .canvas-output {
245
- border: 2px solid #4682B4;
246
- border-radius: 10px;
247
- padding: 20px;
248
- }
249
- """
250
-
251
- # Create the Gradio Interface
252
- with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
253
- gr.Markdown("# **[Multimodal VLM OCR](https://huggingface.co/collections/prithivMLmods/multimodal-implementations-67c9982ea04b39f0608badb0)**")
254
- with gr.Row():
255
- with gr.Column():
256
- with gr.Tabs():
257
- with gr.TabItem("Image Inference"):
258
- image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
259
- image_upload = gr.Image(type="pil", label="Image")
260
- image_submit = gr.Button("Submit", elem_classes="submit-btn")
261
- gr.Examples(
262
- examples=image_examples,
263
- inputs=[image_query, image_upload]
264
- )
265
- with gr.TabItem("Video Inference"):
266
- video_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
267
- video_upload = gr.Video(label="Video")
268
- video_submit = gr.Button("Submit", elem_classes="submit-btn")
269
- gr.Examples(
270
- examples=video_examples,
271
- inputs=[video_query, video_upload]
272
- )
273
 
274
- with gr.Accordion("Advanced options", open=False):
275
- max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
276
- temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
277
- top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
278
- top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
279
- repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
 
281
- with gr.Column():
282
- with gr.Column(elem_classes="canvas-output"):
283
- gr.Markdown("## Output")
284
- output = gr.Textbox(label="Raw Output Stream", interactive=False, lines=2)
285
- with gr.Accordion("(Result.md)", open=False):
286
- markdown_output = gr.Markdown(label="(Result.md)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
 
288
- model_choice = gr.Radio(
289
- choices=["Camel-Doc-OCR-062825", "GLM-4.1V-9B-Thinking", "Megalodon-OCR-Sync-0713", "DeepEyes-7B-Thinking", "Qwen2.5-VL-3B-Instruct-abliterated"],
290
- label="Select Model",
291
- value="Camel-Doc-OCR-062825"
292
- )
293
- gr.Markdown("**Model Info 💻** | [Report Bug](https://huggingface.co/spaces/prithivMLmods/Multimodal-OCR-Comparator/discussions)")
294
- gr.Markdown("> Camel-Doc-OCR-062825 and Megalodon-OCR-Sync-0713 are both fine-tuned versions of the Qwen2.5-VL series focused on document retrieval, content extraction, analysis recognition, and excelling in OCR and visual document analysis tasks for structured and unstructured content—Camel-Doc-OCR-062825 leveraging the Qwen2.5-VL-7B-Instruct as its base, while Megalodon-OCR-Sync-0713 uses Qwen2.5-VL-3B-Instruct and is especially trained on diverse captioning datasets.")
295
- gr.Markdown("> GLM-4.1V-9B-Thinking is a vision-language model (VLM) based on the GLM-4-9B-0414 foundation, with a strong emphasis on advanced reasoning capabilities, chain-of-thought inference, and robust bilingual (Chinese/English) performance on complex multimodal benchmarks.")
296
- gr.Markdown("> DeepEyes-7B stands out for its agentic reinforcement learning approach, focusing on thinking with images for better visual reasoning, math problem-solving, and mitigating hallucination using Qwen2.5-VL-7B-Instruct as its foundation. Finally, Qwen2.5-VL-3B-Instruct-abliterated is part of the Qwen2.5-VL family, known for its versatile vision-language understanding and generation, serving as the foundational architecture for several of these fine-tuned vision-language and OCR models.")
297
-
298
- # Define the submit button actions
299
- image_submit.click(fn=generate_image,
300
- inputs=[
301
- model_choice, image_query, image_upload,
302
- max_new_tokens, temperature, top_p, top_k,
303
- repetition_penalty
304
- ],
305
- outputs=[output, markdown_output])
306
- video_submit.click(fn=generate_video,
307
- inputs=[
308
- model_choice, video_query, video_upload,
309
- max_new_tokens, temperature, top_p, top_k,
310
- repetition_penalty
311
- ],
312
- outputs=[output, markdown_output])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
 
314
  if __name__ == "__main__":
315
- demo.queue(max_size=30).launch(share=True, mcp_server=True, ssr_mode=False, show_error=True)
 
 
 
 
 
 
 
 
 
 
 
 
1
  import spaces
2
+ import json
3
+ import math
4
+ import os
5
+ import traceback
6
+ from io import BytesIO
7
+ from typing import Any, Dict, List, Optional, Tuple
8
+ import re
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ import fitz # PyMuPDF
11
+ import gradio as gr
12
+ import requests
13
+ from PIL import Image, ImageDraw, ImageFont
 
 
 
 
14
 
15
+ from model import load_model, inference_dots_ocr, inference_dolphin
 
 
 
 
 
 
 
16
 
17
+ # Constants
18
+ MIN_PIXELS = 3136
19
+ MAX_PIXELS = 11289600
20
+ IMAGE_FACTOR = 28
 
 
 
 
21
 
22
+ # Prompts
23
+ prompt = """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
 
 
 
 
 
 
24
 
25
+ 1. Bbox format: [x1, y1, x2, y2]
26
+ 2. Layout Categories: ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title']
27
+ 3. Text Extraction & Formatting Rules:
28
+ - Picture: Omit the text field
29
+ - Formula: format as LaTeX
30
+ - Table: format as HTML
31
+ - Others: format as Markdown
32
+ 4. Constraints:
33
+ - Use original text, no translation
34
+ - Sort elements by human reading order
35
+ 5. Final Output: Single JSON object
36
+ """
37
 
38
+ # Load models at startup
39
+ models = {
40
+ "dots.ocr": load_model("dots.ocr"),
41
+ "Dolphin": load_model("Dolphin")
42
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
+ # Global state for PDF handling
45
+ pdf_cache = {
46
+ "images": [],
47
+ "current_page": 0,
48
+ "total_pages": 0,
49
+ "file_type": None,
50
+ "is_parsed": False,
51
+ "results": []
52
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
+ # Utility functions
55
+ def round_by_factor(number: int, factor: int) -> int:
56
+ return round(number / factor) * factor
57
 
58
+ def smart_resize(height: int, width: int, factor: int = 28, min_pixels: int = 3136, max_pixels: int = 11289600):
59
+ if max(height, width) / min(height, width) > 200:
60
+ raise ValueError(f"Aspect ratio must be < 200, got {max(height, width) / min(height, width)}")
61
+ h_bar = max(factor, round_by_factor(height, factor))
62
+ w_bar = max(factor, round_by_factor(width, factor))
63
+ if h_bar * w_bar > max_pixels:
64
+ beta = math.sqrt((height * width) / max_pixels)
65
+ h_bar = round_by_factor(height / beta, factor)
66
+ w_bar = round_by_factor(width / beta, factor)
67
+ elif h_bar * w_bar < min_pixels:
68
+ beta = math.sqrt(min_pixels / (height * width))
69
+ h_bar = round_by_factor(height * beta, factor)
70
+ w_bar = round_by_factor(width * beta, factor)
71
+ return h_bar, w_bar
 
 
 
 
 
 
 
 
 
 
 
72
 
73
+ def fetch_image(image_input, min_pixels: int = None, max_pixels: int = None):
74
+ if isinstance(image_input, str):
75
+ if image_input.startswith(("http://", "https://")):
76
+ response = requests.get(image_input)
77
+ image = Image.open(BytesIO(response.content)).convert('RGB')
78
+ else:
79
+ image = Image.open(image_input).convert('RGB')
80
+ elif isinstance(image_input, Image.Image):
81
+ image = image_input.convert('RGB')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  else:
83
+ raise ValueError(f"Invalid image input type: {type(image_input)}")
84
+ if min_pixels or max_pixels:
85
+ min_pixels = min_pixels or MIN_PIXELS
86
+ max_pixels = max_pixels or MAX_PIXELS
87
+ height, width = smart_resize(image.height, image.width, factor=IMAGE_FACTOR, min_pixels=min_pixels, max_pixels=max_pixels)
88
+ image = image.resize((width, height), Image.LANCZOS)
89
+ return image
90
 
91
+ def load_images_from_pdf(pdf_path: str) -> List[Image.Image]:
92
+ images = []
93
+ try:
94
+ pdf_document = fitz.open(pdf_path)
95
+ for page_num in range(len(pdf_document)):
96
+ page = pdf_document.load_page(page_num)
97
+ mat = fitz.Matrix(2.0, 2.0)
98
+ pix = page.get_pixmap(matrix=mat)
99
+ img_data = pix.tobytes("ppm")
100
+ image = Image.open(BytesIO(img_data)).convert('RGB')
101
+ images.append(image)
102
+ pdf_document.close()
103
+ except Exception as e:
104
+ print(f"Error loading PDF: {e}")
105
+ return []
106
+ return images
107
 
108
+ def draw_layout_on_image(image: Image.Image, layout_data: List[Dict]) -> Image.Image:
109
+ img_copy = image.copy()
110
+ draw = ImageDraw.Draw(img_copy)
111
+ colors = {
112
+ 'Caption': '#FF6B6B', 'Footnote': '#4ECDC4', 'Formula': '#45B7D1', 'List-item': '#96CEB4',
113
+ 'Page-footer': '#FFEAA7', 'Page-header': '#DDA0DD', 'Picture': '#FFD93D', 'Section-header': '#6C5CE7',
114
+ 'Table': '#FD79A8', 'Text': '#74B9FF', 'Title': '#E17055'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  }
116
+ try:
117
+ font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 12)
118
+ except Exception:
119
+ font = ImageFont.load_default()
120
+ try:
121
+ for item in layout_data:
122
+ if 'bbox' in item and 'category' in item:
123
+ bbox = item['bbox']
124
+ category = item['category']
125
+ color = colors.get(category, '#000000')
126
+ draw.rectangle(bbox, outline=color, width=2)
127
+ label = category
128
+ label_bbox = draw.textbbox((0, 0), label, font=font)
129
+ label_width = label_bbox[2] - label_bbox[0]
130
+ label_height = label_bbox[3] - label_bbox[1]
131
+ label_x = bbox[0]
132
+ label_y = max(0, bbox[1] - label_height - 2)
133
+ draw.rectangle([label_x, label_y, label_x + label_width + 4, label_y + label_height + 2], fill=color)
134
+ draw.text((label_x + 2, label_y + 1), label, fill='white', font=font)
135
+ except Exception as e:
136
+ print(f"Error drawing layout: {e}")
137
+ return img_copy
138
 
139
+ def is_arabic_text(text: str) -> bool:
140
+ if not text:
141
+ return False
142
+ header_pattern = r'^#{1,6}\s+(.+)$'
143
+ paragraph_pattern = r'^(?!#{1,6}\s|!\[|```|\||\s*[-*+]\s|\s*\d+\.\s)(.+)$'
144
+ content_text = []
145
+ for line in text.split('\n'):
146
+ line = line.strip()
147
+ if not line:
148
+ continue
149
+ header_match = re.match(header_pattern, line, re.MULTILINE)
150
+ if header_match:
151
+ content_text.append(header_match.group(1))
152
+ continue
153
+ if re.match(paragraph_pattern, line, re.MULTILINE):
154
+ content_text.append(line)
155
+ if not content_text:
156
+ return False
157
+ combined_text = ' '.join(content_text)
158
+ arabic_chars = 0
159
+ total_chars = 0
160
+ for char in combined_text:
161
+ if char.isalpha():
162
+ total_chars += 1
163
+ if ('\u0600' <= char <= '\u06FF') or ('\u0750' <= char <= '\u077F') or ('\u08A0' <= char <= '\u08FF'):
164
+ arabic_chars += 1
165
+ return total_chars > 0 and (arabic_chars / total_chars) > 0.5
166
 
167
+ def layoutjson2md(image: Image.Image, layout_data: List[Dict], text_key: str = 'text') -> str:
168
+ import base64
169
+ markdown_lines = []
170
+ try:
171
+ sorted_items = sorted(layout_data, key=lambda x: (x.get('bbox', [0, 0, 0, 0])[1], x.get('bbox', [0, 0, 0, 0])[0]))
172
+ for item in sorted_items:
173
+ category = item.get('category', '')
174
+ text = item.get(text_key, '')
175
+ bbox = item.get('bbox', [])
176
+ if category == 'Picture':
177
+ if bbox and len(bbox) == 4:
178
+ try:
179
+ x1, y1, x2, y2 = [max(0, int(x)) if i < 2 else min(image.width if i % 2 == 0 else image.height, int(x)) for i, x in enumerate(bbox)]
180
+ if x2 > x1 and y2 > y1:
181
+ cropped_img = image.crop((x1, y1, x2, y2))
182
+ buffer = BytesIO()
183
+ cropped_img.save(buffer, format='PNG')
184
+ img_data = base64.b64encode(buffer.getvalue()).decode()
185
+ markdown_lines.append(f"<image-card alt="Image" src="data:image/png;base64,{img_data}" ></image-card>\n")
186
+ else:
187
+ markdown_lines.append("<image-card alt="Image" src="Image region detected" ></image-card>\n")
188
+ except Exception as e:
189
+ print(f"Error processing image region: {e}")
190
+ markdown_lines.append("<image-card alt="Image" src="Image detected" ></image-card>\n")
191
+ else:
192
+ markdown_lines.append("<image-card alt="Image" src="Image detected" ></image-card>\n")
193
+ elif not text:
194
+ continue
195
+ elif category == 'Title':
196
+ markdown_lines.append(f"# {text}\n")
197
+ elif category == 'Section-header':
198
+ markdown_lines.append(f"## {text}\n")
199
+ elif category == 'Text':
200
+ markdown_lines.append(f"{text}\n")
201
+ elif category == 'List-item':
202
+ markdown_lines.append(f"- {text}\n")
203
+ elif category == 'Table':
204
+ if text.strip().startswith('<'):
205
+ markdown_lines.append(f"{text}\n")
206
+ else:
207
+ markdown_lines.append(f"**Table:** {text}\n")
208
+ elif category == 'Formula':
209
+ if text.strip().startswith('$') or '\\' in text:
210
+ markdown_lines.append(f"$$ \n{text}\n $$\n")
211
+ else:
212
+ markdown_lines.append(f"**Formula:** {text}\n")
213
+ elif category == 'Caption':
214
+ markdown_lines.append(f"*{text}*\n")
215
+ elif category == 'Footnote':
216
+ markdown_lines.append(f"^{text}^\n")
217
+ elif category in ['Page-header', 'Page-footer']:
218
+ continue
219
+ else:
220
+ markdown_lines.append(f"{text}\n")
221
+ markdown_lines.append("")
222
+ except Exception as e:
223
+ print(f"Error converting to markdown: {e}")
224
+ return str(layout_data)
225
+ return "\n".join(markdown_lines)
226
 
227
+ def load_file_for_preview(file_path: str) -> Tuple[Optional[Image.Image], str]:
228
+ global pdf_cache
229
+ if not file_path or not os.path.exists(file_path):
230
+ return None, "No file selected"
231
+ file_ext = os.path.splitext(file_path)[1].lower()
232
+ try:
233
+ if file_ext == '.pdf':
234
+ images = load_images_from_pdf(file_path)
235
+ if not images:
236
+ return None, "Failed to load PDF"
237
+ pdf_cache.update({
238
+ "images": images,
239
+ "current_page": 0,
240
+ "total_pages": len(images),
241
+ "file_type": "pdf",
242
+ "is_parsed": False,
243
+ "results": []
244
+ })
245
+ return images[0], f"Page 1 / {len(images)}"
246
+ elif file_ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']:
247
+ image = Image.open(file_path).convert('RGB')
248
+ pdf_cache.update({
249
+ "images": [image],
250
+ "current_page": 0,
251
+ "total_pages": 1,
252
+ "file_type": "image",
253
+ "is_parsed": False,
254
+ "results": []
255
+ })
256
+ return image, "Page 1 / 1"
257
+ else:
258
+ return None, f"Unsupported file format: {file_ext}"
259
+ except Exception as e:
260
+ print(f"Error loading file: {e}")
261
+ return None, f"Error loading file: {str(e)}"
 
 
 
262
 
263
+ @spaces.GPU()
264
+ def process_document(file_path, model_choice, max_tokens, min_pix, max_pix):
265
+ global pdf_cache
266
+ if not file_path:
267
+ return None, "Please upload a file first.", None
268
+ model, processor = models[model_choice]
269
+ image, page_info = load_file_for_preview(file_path)
270
+ if image is None:
271
+ return None, page_info, None
272
+ if pdf_cache["file_type"] == "pdf":
273
+ all_results = []
274
+ for i, img in enumerate(pdf_cache["images"]):
275
+ if model_choice == "dots.ocr":
276
+ raw_output = inference_dots_ocr(model, processor, img, prompt, max_tokens)
277
+ try:
278
+ layout_data = json.loads(raw_output)
279
+ processed_image = draw_layout_on_image(img, layout_data)
280
+ markdown_content = layoutjson2md(img, layout_data)
281
+ result = {
282
+ 'processed_image': processed_image,
283
+ 'markdown_content': markdown_content,
284
+ 'layout_result': layout_data
285
+ }
286
+ except Exception:
287
+ result = {
288
+ 'processed_image': img,
289
+ 'markdown_content': raw_output,
290
+ 'layout_result': None
291
+ }
292
+ else: # Dolphin
293
+ text = inference_dolphin(model, processor, img)
294
+ result = f"## Page {i+1}\n\n{text}" if text else "No text extracted"
295
+ all_results.append(result)
296
+ pdf_cache["results"] = all_results
297
+ pdf_cache["is_parsed"] = True
298
+ first_result = all_results[0]
299
+ if model_choice == "dots.ocr":
300
+ markdown_update = gr.update(value=first_result['markdown_content'], rtl=is_arabic_text(first_result['markdown_content']))
301
+ return first_result['processed_image'], markdown_update, first_result['layout_result']
302
+ else:
303
+ markdown_update = gr.update(value=first_result, rtl=is_arabic_text(first_result))
304
+ return None, markdown_update, None
305
+ else:
306
+ if model_choice == "dots.ocr":
307
+ raw_output = inference_dots_ocr(model, processor, image, prompt, max_tokens)
308
+ try:
309
+ layout_data = json.loads(raw_output)
310
+ processed_image = draw_layout_on_image(image, layout_data)
311
+ markdown_content = layoutjson2md(image, layout_data)
312
+ result = {
313
+ 'processed_image': processed_image,
314
+ 'markdown_content': markdown_content,
315
+ 'layout_result': layout_data
316
+ }
317
+ except Exception:
318
+ result = {
319
+ 'processed_image': image,
320
+ 'markdown_content': raw_output,
321
+ 'layout_result': None
322
+ }
323
+ pdf_cache["results"] = [result]
324
+ else: # Dolphin
325
+ text = inference_dolphin(model, processor, image)
326
+ result = text if text else "No text extracted"
327
+ pdf_cache["results"] = [result]
328
+ pdf_cache["is_parsed"] = True
329
+ if model_choice == "dots.ocr":
330
+ markdown_update = gr.update(value=result['markdown_content'], rtl=is_arabic_text(result['markdown_content']))
331
+ return result['processed_image'], markdown_update, result['layout_result']
332
+ else:
333
+ markdown_update = gr.update(value=result, rtl=is_arabic_text(result))
334
+ return None, markdown_update, None
335
 
336
+ def turn_page(direction: str) -> Tuple[Optional[Image.Image], str, Any, Optional[Image.Image], Optional[Dict]]:
337
+ global pdf_cache
338
+ if not pdf_cache["images"]:
339
+ return None, '<div class="page-info">No file loaded</div>', "No results yet", None, None
340
+ if direction == "prev":
341
+ pdf_cache["current_page"] = max(0, pdf_cache["current_page"] - 1)
342
+ elif direction == "next":
343
+ pdf_cache["current_page"] = min(pdf_cache["total_pages"] - 1, pdf_cache["current_page"] + 1)
344
+ index = pdf_cache["current_page"]
345
+ current_image_preview = pdf_cache["images"][index]
346
+ page_info_html = f'<div class="page-info">Page {index + 1} / {pdf_cache["total_pages"]}</div>'
347
+ if pdf_cache["is_parsed"] and index < len(pdf_cache["results"]):
348
+ result = pdf_cache["results"][index]
349
+ if isinstance(result, dict): # dots.ocr
350
+ markdown_content = result.get('markdown_content', 'No content available')
351
+ processed_img = result.get('processed_image', None)
352
+ layout_json = result.get('layout_result', None)
353
+ else: # Dolphin
354
+ markdown_content = result
355
+ processed_img = None
356
+ layout_json = None
357
+ else:
358
+ markdown_content = "Page not processed yet"
359
+ processed_img = None
360
+ layout_json = None
361
+ markdown_update = gr.update(value=markdown_content, rtl=is_arabic_text(markdown_content))
362
+ return current_image_preview, page_info_html, markdown_update, processed_img, layout_json
363
 
364
+ def create_gradio_interface():
365
+ css = """
366
+ .main-container { max-width: 1400px; margin: 0 auto; }
367
+ .header-text { text-align: center; color: #2c3e50; margin-bottom: 20px; }
368
+ .process-button { border: none !important; color: white !important; font-weight: bold !important; }
369
+ .process-button:hover { transform: translateY(-2px) !important; box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; }
370
+ .info-box { border: 1px solid #dee2e6; border-radius: 8px; padding: 15px; margin: 10px 0; }
371
+ .page-info { text-align: center; padding: 8px 16px; border-radius: 20px; font-weight: bold; margin: 10px 0; }
372
+ .model-status { padding: 10px; border-radius: 8px; margin: 10px 0; text-align: center; font-weight: bold; }
373
+ .status-ready { background: #d1edff; color: #0c5460; border: 1px solid #b8daff; }
374
+ """
375
+ with gr.Blocks(theme=gr.themes.Soft(), css=css, title="Dots.OCR Demo") as demo:
376
+ gr.HTML("""
377
+ <div class="title" style="text-align: center">
378
+ <h1>🔍 Dot-OCR - Multilingual Document Text Extraction</h1>
379
+ <p style="font-size: 1.1em; color: #6b7280; margin-bottom: 0.6em;">
380
+ A state-of-the-art image/pdf-to-markdown vision language model for intelligent document processing
381
+ </p>
382
+ <div style="display: flex; justify-content: center; gap: 20px; margin: 15px 0;">
383
+ <a href="https://huggingface.co/rednote-hilab/dots.ocr" target="_blank" style="text-decoration: none; color: #2563eb; font-weight: 500;">
384
+ 📚 Hugging Face Model
385
+ </a>
386
+ <a href="https://github.com/rednote-hilab/dots.ocr/blob/master/assets/blog.md" target="_blank" style="text-decoration: none; color: #2563eb; font-weight: 500;">
387
+ 📝 Release Blog
388
+ </a>
389
+ <a href="https://github.com/rednote-hilab/dots.ocr" target="_blank" style="text-decoration: none; color: #2563eb; font-weight: 500;">
390
+ 💻 GitHub Repository
391
+ </a>
392
+ </div>
393
+ </div>
394
+ """)
395
+ with gr.Row():
396
+ with gr.Column(scale=1):
397
+ file_input = gr.File(
398
+ label="Upload Image or PDF",
399
+ file_types=[".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".pdf"],
400
+ type="filepath"
401
+ )
402
+ image_preview = gr.Image(label="Preview", type="pil", interactive=False, height=300)
403
+ with gr.Row():
404
+ prev_page_btn = gr.Button("◀ Previous", size="md")
405
+ page_info = gr.HTML('<div class="page-info">No file loaded</div>')
406
+ next_page_btn = gr.Button("Next ▶", size="md")
407
+ model_choice = gr.Radio(
408
+ choices=["dots.ocr", "Dolphin"],
409
+ label="Select Model",
410
+ value="dots.ocr"
411
+ )
412
+ with gr.Accordion("Advanced Settings", open=False):
413
+ max_new_tokens = gr.Slider(minimum=1000, maximum=32000, value=24000, step=1000, label="Max New Tokens")
414
+ min_pixels = gr.Number(value=MIN_PIXELS, label="Min Pixels")
415
+ max_pixels = gr.Number(value=MAX_PIXELS, label="Max Pixels")
416
+ process_btn = gr.Button("🚀 Process Document", variant="primary", elem_classes=["process-button"], size="lg")
417
+ clear_btn = gr.Button("🗑️ Clear All", variant="secondary")
418
+ with gr.Column(scale=2):
419
+ with gr.Tabs():
420
+ with gr.Tab("🖼️ Processed Image"):
421
+ processed_image = gr.Image(label="Image with Layout Detection", type="pil", interactive=False, height=500)
422
+ with gr.Tab("📝 Extracted Content"):
423
+ markdown_output = gr.Markdown(value="Click 'Process Document' to see extracted content...", height=500)
424
+ with gr.Tab("📋 Layout JSON"):
425
+ json_output = gr.JSON(label="Layout Analysis Results", value=None)
426
+
427
+ def handle_file_upload(file_path):
428
+ image, page_info = load_file_for_preview(file_path)
429
+ return image, page_info
430
+
431
+ def clear_all():
432
+ global pdf_cache
433
+ pdf_cache = {"images": [], "current_page": 0, "total_pages": 0, "file_type": None, "is_parsed": False, "results": []}
434
+ return None, None, '<div class="page-info">No file loaded</div>', None, "Click 'Process Document' to see extracted content...", None
435
+
436
+ file_input.change(handle_file_upload, inputs=[file_input], outputs=[image_preview, page_info])
437
+ prev_page_btn.click(lambda: turn_page("prev"), outputs=[image_preview, page_info, markdown_output, processed_image, json_output])
438
+ next_page_btn.click(lambda: turn_page("next"), outputs=[image_preview, page_info, markdown_output, processed_image, json_output])
439
+ process_btn.click(
440
+ process_document,
441
+ inputs=[file_input, model_choice, max_new_tokens, min_pixels, max_pixels],
442
+ outputs=[processed_image, markdown_output, json_output]
443
+ )
444
+ clear_btn.click(
445
+ clear_all,
446
+ outputs=[file_input, image_preview, page_info, processed_image, markdown_output, json_output]
447
+ )
448
+ return demo
449
 
450
  if __name__ == "__main__":
451
+ demo = create_gradio_interface()
452
+ demo.queue(max_size=10).launch(
453
+ server_name="0.0.0.0",
454
+ server_port=7860,
455
+ share=False,
456
+ debug=True,
457
+ show_error=True
458
+ )