File size: 38,290 Bytes
f17f462
 
 
 
c152910
 
 
 
 
 
 
 
f17f462
 
 
3c4fefe
c152910
f17f462
c152910
 
 
9180057
278dfd1
9180057
4148e9b
f17f462
9180057
60f59d6
f17f462
 
 
 
 
 
 
278dfd1
 
 
 
 
 
 
 
 
f17f462
c152910
 
 
f17f462
c152910
4148e9b
c152910
4148e9b
 
 
 
 
 
 
 
 
f17f462
 
 
 
 
 
 
 
 
 
 
c152910
f17f462
 
 
c152910
f17f462
c152910
f17f462
 
c152910
f17f462
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4148e9b
f17f462
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c152910
f17f462
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4148e9b
f17f462
 
 
 
 
 
4148e9b
f17f462
 
 
 
4148e9b
f17f462
 
4148e9b
f17f462
 
 
4148e9b
f17f462
4148e9b
f17f462
 
 
 
 
 
 
 
4148e9b
f17f462
 
 
 
 
c152910
f17f462
 
 
 
 
 
 
 
 
c152910
f17f462
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4148e9b
f17f462
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4148e9b
f17f462
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4148e9b
f17f462
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c152910
 
f17f462
 
c152910
 
 
f17f462
 
c152910
 
 
f17f462
 
c152910
 
f17f462
c152910
 
 
 
 
4148e9b
f17f462
c152910
f17f462
 
4148e9b
 
 
 
 
 
f17f462
 
 
 
c152910
f17f462
c152910
f17f462
 
 
 
 
 
 
 
 
4148e9b
c152910
 
4148e9b
f17f462
 
 
4148e9b
f17f462
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4148e9b
f17f462
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4148e9b
f17f462
 
4148e9b
f17f462
4148e9b
f17f462
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4148e9b
c152910
f17f462
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c152910
 
 
f17f462
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
# app.py
# All code combined into a single file for convenience.

# --- Imports ---
import spaces
import json
import math
import os
import traceback
from io import BytesIO
from typing import Any, Dict, List, Optional, Tuple
import re
import base64
import copy
from dataclasses import dataclass
#import flash_attn_2_cuda as flash_attn_gpu

# Vision and ML Libraries
import fitz  # PyMuPDF
import gradio as gr
import requests
import torch
import subprocess
from huggingface_hub import snapshot_download
from PIL import Image, ImageDraw, ImageFont
from transformers import AutoModelForCausalLM, AutoProcessor, VisionEncoderDecoderModel
from qwen_vl_utils import process_vision_info

# Image Processing Libraries
import cv2
import numpy as np
import albumentations as alb
from albumentations.pytorch import ToTensorV2
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD


# Attempt to install flash-attn
try:
    subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, check=True, shell=True)
except subprocess.CalledProcessError as e:
    print(f"Error installing flash-attn: {e}")
    print("Continuing without flash-attn.")


# --- Constants & Global State ---
MIN_PIXELS = 3136
MAX_PIXELS = 11289600
IMAGE_FACTOR = 28
DOT_OCR_PROMPT = """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
1. Bbox format: [x1, y1, x2, y2]
2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title'].
3. Text Extraction & Formatting Rules:
    - Picture: For the 'Picture' category, the text field should be omitted.
    - Formula: Format its text as LaTeX.
    - Table: Format its text as HTML.
    - All Others (Text, Title, etc.): Format their text as Markdown.
4. Constraints:
    - The output text must be the original text from the image, with no translation.
    - All layout elements must be sorted according to human reading order.
5. Final Output: The entire output must be a single JSON object.
"""
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
PDF_CACHE = {
    "images": [],
    "current_page": 0,
    "total_pages": 0,
    "file_type": None,
    "is_parsed": False,
    "results": [],
    "model_used": None,
}
MODELS = {}

# =================================================================================
# --- UTILITY FUNCTIONS (from markdown_utils.py and utils.py) ---
# =================================================================================

# --- Markdown Conversion Utilities ---

def extract_table_from_html(html_string):
    """Extract and clean table tags from HTML string"""
    try:
        table_pattern = re.compile(r'<table.*?>.*?</table>', re.DOTALL)
        tables = table_pattern.findall(html_string)
        tables = [re.sub(r'<table[^>]*>', '<table>', table) for table in tables]
        return '\n'.join(tables)
    except Exception as e:
        print(f"extract_table_from_html error: {str(e)}")
        return f"<table><tr><td>Error extracting table: {str(e)}</td></tr></table>"


class MarkdownConverter:
    """Convert structured recognition results to Markdown format"""
    def __init__(self):
        self.heading_levels = {'title': '#', 'sec': '##', 'sub_sec': '###'}
        self.special_labels = {'tab', 'fig', 'title', 'sec', 'sub_sec', 'list', 'formula', 'reference', 'alg'}

    def try_remove_newline(self, text: str) -> str:
        try:
            text = text.strip().replace('-\n', '')
            def is_chinese(char): return '\u4e00' <= char <= '\u9fff'
            lines, processed_lines = text.split('\n'), []
            for i in range(len(lines)-1):
                current_line, next_line = lines[i].strip(), lines[i+1].strip()
                if current_line:
                    if next_line:
                        if is_chinese(current_line[-1]) and is_chinese(next_line[0]):
                            processed_lines.append(current_line)
                        else:
                            processed_lines.append(current_line + ' ')
                    else:
                        processed_lines.append(current_line + '\n')
                else:
                    processed_lines.append('\n')
            if lines and lines[-1].strip():
                processed_lines.append(lines[-1].strip())
            return ''.join(processed_lines)
        except Exception as e:
            print(f"try_remove_newline error: {str(e)}")
            return text

    def _handle_text(self, text: str) -> str:
        try:
            if not text: return ""
            if text.strip().startswith("\\begin{array}") and text.strip().endswith("\\end{array}"):
                text = "$$" + text + "$$"
            elif ("_{" in text or "^{" in text or "\\" in text or "_ {" in text or "^ {" in text) and ("$" not in text) and ("\\begin" not in text):
                text = "$" + text + "$"
            text = self._process_formulas_in_text(text)
            text = self.try_remove_newline(text)
            return text
        except Exception as e:
            print(f"_handle_text error: {str(e)}")
            return text

    def _process_formulas_in_text(self, text: str) -> str:
        try:
            delimiters = [('$$', '$$'), ('\\[', '\\]'), ('$', '$'), ('\\(', '\\)')]
            result = text
            for start_delim, end_delim in delimiters:
                current_pos, processed_parts = 0, []
                while current_pos < len(result):
                    start_pos = result.find(start_delim, current_pos)
                    if start_pos == -1:
                        processed_parts.append(result[current_pos:])
                        break
                    processed_parts.append(result[current_pos:start_pos])
                    end_pos = result.find(end_delim, start_pos + len(start_delim))
                    if end_pos == -1:
                        processed_parts.append(result[start_pos:])
                        break
                    formula_content = result[start_pos + len(start_delim):end_pos]
                    processed_formula = formula_content.replace('\n', ' \\\\ ')
                    processed_parts.append(f"{start_delim}{processed_formula}{end_delim}")
                    current_pos = end_pos + len(end_delim)
                result = ''.join(processed_parts)
            return result
        except Exception as e:
            print(f"_process_formulas_in_text error: {str(e)}")
            return text

    def _remove_newline_in_heading(self, text: str) -> str:
        try:
            def is_chinese(char): return '\u4e00' <= char <= '\u9fff'
            return text.replace('\n', '') if any(is_chinese(char) for char in text) else text.replace('\n', ' ')
        except Exception as e:
            print(f"_remove_newline_in_heading error: {str(e)}")
            return text

    def _handle_heading(self, text: str, label: str) -> str:
        try:
            level = self.heading_levels.get(label, '#')
            text = self._remove_newline_in_heading(text.strip())
            text = self._handle_text(text)
            return f"{level} {text}\n\n"
        except Exception as e:
            print(f"_handle_heading error: {str(e)}")
            return f"# Error processing heading: {text}\n\n"

    def _handle_list_item(self, text: str) -> str:
        try:
            return f"- {text.strip()}\n"
        except Exception as e:
            print(f"_handle_list_item error: {str(e)}")
            return f"- Error processing list item: {text}\n"

    def _handle_figure(self, text: str, section_count: int) -> str:
        try:
            if not text.strip():
                return f"![Figure {section_count}](data:image/png;base64,)\n\n"
            if text.startswith("data:image/"):
                return f"![Figure {section_count}]({text})\n\n"
            else:
                return f"![Figure {section_count}](data:image/png;base64,{text})\n\n"
        except Exception as e:
            print(f"_handle_figure error: {str(e)}")
            return f"*[Error processing figure: {str(e)}]*\n\n"

    def _handle_table(self, text: str) -> str:
        try:
            if '<table' in text.lower() or '<tr' in text.lower():
                return extract_table_from_html(text) + "\n\n"
            else:
                table_lines = text.split('\n')
                if not table_lines: return "\n\n"
                col_count = len(table_lines[0].split()) if table_lines[0] else 1
                header = '| ' + ' | '.join(table_lines[0].split()) + ' |'
                separator = '| ' + ' | '.join(['---'] * col_count) + ' |'
                rows = [f"| {' | '.join(line.split())} |" for line in table_lines[1:]]
                return '\n'.join([header, separator] + rows) + '\n\n'
        except Exception as e:
            print(f"_handle_table error: {str(e)}")
            return f"*[Error processing table: {str(e)}]*\n\n"

    def _handle_algorithm(self, text: str) -> str:
        try:
            text = re.sub(r'\\begin\{algorithm\}(.*?)\\end\{algorithm\}', r'\1', text, flags=re.DOTALL)
            text = text.replace('\\begin{algorithmic}', '').replace('\\end{algorithmic}', '')
            caption_match = re.search(r'\\caption\{(.*?)\}', text)
            caption = f"**{caption_match.group(1)}**\n\n" if caption_match else ""
            algorithm_text = re.sub(r'\\caption\{.*?\}', '', text).strip()
            return f"{caption}```\n{algorithm_text}\n```\n\n"
        except Exception as e:
            print(f"_handle_algorithm error: {str(e)}")
            return f"*[Error processing algorithm: {str(e)}]*\n\n{text}\n\n"

    def _handle_formula(self, text: str) -> str:
        try:
            processed_text = self._process_formulas_in_text(text)
            if '$$' not in processed_text and '\\[' not in processed_text:
                processed_text = f'$${processed_text}$$'
            return f"{processed_text}\n\n"
        except Exception as e:
            print(f"_handle_formula error: {str(e)}")
            return f"*[Error processing formula: {str(e)}]*\n\n"

    def convert(self, recognition_results: List[Dict[str, Any]]) -> str:
        markdown_content = []
        for i, result in enumerate(recognition_results):
            try:
                label, text = result.get('label', ''), result.get('text', '').strip()
                if label == 'fig':
                    markdown_content.append(self._handle_figure(text, i))
                    continue
                if not text: continue
                
                if label in {'title', 'sec', 'sub_sec'}:
                    markdown_content.append(self._handle_heading(text, label))
                elif label == 'list':
                    markdown_content.append(self._handle_list_item(text))
                elif label == 'tab':
                    markdown_content.append(self._handle_table(text))
                elif label == 'alg':
                    markdown_content.append(self._handle_algorithm(text))
                elif label == 'formula':
                    markdown_content.append(self._handle_formula(text))
                elif label not in self.special_labels:
                    markdown_content.append(f"{self._handle_text(text)}\n\n")
            except Exception as e:
                print(f"Error processing item {i}: {str(e)}")
                markdown_content.append(f"*[Error processing content]*\n\n")
        return self._post_process(''.join(markdown_content))

    def _post_process(self, md: str) -> str:
        try:
            md = re.sub(r'\\author\{(.*?)\}', lambda m: self._handle_text(m.group(1)), md, flags=re.DOTALL)
            md = re.sub(r'\$(\\author\{.*?\})\$', lambda m: self._handle_text(re.search(r'\\author\{(.*?)\}', m.group(1), re.DOTALL).group(1)), md, flags=re.DOTALL)
            md = re.sub(r'\\begin\{abstract\}(.*?)\\end\{abstract\}', r'**Abstract** \1', md, flags=re.DOTALL)
            md = re.sub(r'\\begin\{abstract\}', r'**Abstract**', md)
            md = re.sub(r'\\eqno\{\((.*?)\)\}', r'\\tag{\1}', md)
            md = md.replace("\[ \\\\", "$$ \\\\").replace("\\\\ \]", "\\\\ $$")
            md = re.sub(r'_ {', r'_{', md)
            md = re.sub(r'^ {', r'^{', md)
            md = re.sub(r'\n{3,}', r'\n\n', md)
            return md
        except Exception as e:
            print(f"_post_process error: {str(e)}")
            return md

# --- General Processing Utilities ---
@dataclass
class ImageDimensions:
    original_w: int
    original_h: int
    padded_w: int
    padded_h: int

def adjust_box_edges(image, boxes: List[List[float]], max_pixels=15, threshold=0.2):
    if isinstance(image, str):
        image = cv2.imread(image)
    img_h, img_w = image.shape[:2]
    new_boxes = []
    for box in boxes:
        best_box = copy.deepcopy(box)

        def check_edge(img, current_box, i, is_vertical):
            edge = current_box[i]
            gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
            _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
            if is_vertical:
                line = binary[current_box[1] : current_box[3] + 1, edge]
            else:
                line = binary[edge, current_box[0] : current_box[2] + 1]
            transitions = np.abs(np.diff(line))
            return np.sum(transitions) / len(transitions)

        edges = [(0, -1, True), (2, 1, True), (1, -1, False), (3, 1, False)]
        current_box = copy.deepcopy(box)
        current_box = [min(max(c, 0), d - 1) for c, d in zip(current_box, [img_w, img_h, img_w, img_h])]

        for i, direction, is_vertical in edges:
            best_score = check_edge(image, current_box, i, is_vertical)
            if best_score <= threshold: continue
            for _ in range(max_pixels):
                current_box[i] += direction
                dim = img_w if i in [0, 2] else img_h
                current_box[i] = min(max(current_box[i], 0), dim - 1)
                score = check_edge(image, current_box, i, is_vertical)
                if score < best_score:
                    best_score, best_box = score, copy.deepcopy(current_box)
                if score <= threshold: break
        new_boxes.append(best_box)
    return new_boxes

def parse_layout_string(bbox_str):
    pattern = r"\[(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+)\]\s*(\w+)"
    matches = re.finditer(pattern, bbox_str)
    return [([float(m.group(i)) for i in range(1, 5)], m.group(5).strip()) for m in matches]

def map_to_original_coordinates(x1, y1, x2, y2, dims: ImageDimensions) -> Tuple[int, int, int, int]:
    try:
        top, left = (dims.padded_h - dims.original_h) // 2, (dims.padded_w - dims.original_w) // 2
        orig_x1, orig_y1 = max(0, x1 - left), max(0, y1 - top)
        orig_x2, orig_y2 = min(dims.original_w, x2 - left), min(dims.original_h, y2 - top)
        if orig_x2 <= orig_x1: orig_x2 = min(orig_x1 + 1, dims.original_w)
        if orig_y2 <= orig_y1: orig_y2 = min(orig_y1 + 1, dims.original_h)
        return int(orig_x1), int(orig_y1), int(orig_x2), int(orig_y2)
    except Exception as e:
        print(f"map_to_original_coordinates error: {str(e)}")
        return 0, 0, min(100, dims.original_w), min(100, dims.original_h)

def process_coordinates(coords, padded_image, dims: ImageDimensions, previous_box=None):
    try:
        x1, y1 = int(coords[0] * dims.padded_w), int(coords[1] * dims.padded_h)
        x2, y2 = int(coords[2] * dims.padded_w), int(coords[3] * dims.padded_h)
        
        x1, y1, x2, y2 = max(0, x1), max(0, y1), min(dims.padded_w, x2), min(dims.padded_h, y2)
        if x2 <= x1: x2 = min(x1 + 1, dims.padded_w)
        if y2 <= y1: y2 = min(y1 + 1, dims.padded_h)
        
        x1, y1, x2, y2 = adjust_box_edges(padded_image, [[x1, y1, x2, y2]])[0]
        
        if previous_box:
            prev_x1, prev_y1, prev_x2, prev_y2 = previous_box
            if (x1 < prev_x2 and x2 > prev_x1) and (y1 < prev_y2 and y2 > prev_y1):
                y1 = min(prev_y2, dims.padded_h - 1)
                if y2 <= y1: y2 = min(y1 + 1, dims.padded_h)

        orig_coords = map_to_original_coordinates(x1, y1, x2, y2, dims)
        return x1, y1, x2, y2, *orig_coords, [x1, y1, x2, y2]
    except Exception as e:
        print(f"process_coordinates error: {str(e)}")
        orig_coords = 0, 0, min(100, dims.original_w), min(100, dims.original_h)
        return 0, 0, 100, 100, *orig_coords, [0, 0, 100, 100]

def prepare_image(image) -> Tuple[np.ndarray, ImageDimensions]:
    try:
        image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
        original_h, original_w = image_cv.shape[:2]
        max_size = max(original_h, original_w)
        top, bottom = (max_size - original_h) // 2, max_size - original_h - ((max_size - original_h) // 2)
        left, right = (max_size - original_w) // 2, max_size - original_w - ((max_size - original_w) // 2)
        padded_image = cv2.copyMakeBorder(image_cv, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(0, 0, 0))
        padded_h, padded_w = padded_image.shape[:2]
        dims = ImageDimensions(original_w, original_h, padded_w, padded_h)
        return padded_image, dims
    except Exception as e:
        print(f"prepare_image error: {str(e)}")
        dims = ImageDimensions(image.width, image.height, image.width, image.height)
        return np.zeros((image.height, image.width, 3), dtype=np.uint8), dims


# =================================================================================
# --- MODEL WRAPPER CLASSES ---
# =================================================================================

class DotOcrModel:
    def __init__(self, device: str):
        self.model, self.processor, self.device = None, None, device
        self.model_id, self.model_path = "rednote-hilab/dots.ocr", "./models/dots-ocr-local"

    @spaces.GPU()
    def load_model(self):
        if self.model is None:
            print("Loading dot.ocr model...")
            snapshot_download(repo_id=self.model_id, local_dir=self.model_path, local_dir_use_symlinks=False)
            self.model = AutoModelForCausalLM.from_pretrained(
                self.model_path, attn_implementation="flash_attention_2",
                torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True
            )
            self.processor = AutoProcessor.from_pretrained(self.model_path, trust_remote_code=True)
            print("dot.ocr model loaded.")

    @staticmethod
    def smart_resize(height, width, factor, min_pixels, max_pixels):
        if max(height, width) / min(height, width) > 200: raise ValueError("Aspect ratio too high")
        h_bar, w_bar = max(factor, round(height / factor) * factor), max(factor, round(width / factor) * factor)
        if h_bar * w_bar > max_pixels:
            beta = math.sqrt((height * width) / max_pixels)
            h_bar, w_bar = round(height / beta / factor) * factor, round(width / beta / factor) * factor
        elif h_bar * w_bar < min_pixels:
            beta = math.sqrt(min_pixels / (height * width))
            h_bar, w_bar = round(height * beta / factor) * factor, round(width / beta / factor) * factor
        return h_bar, w_bar

    def fetch_image(self, image_input, min_pixels, max_pixels):
        image = image_input.convert('RGB')
        height, width = self.smart_resize(image.height, image.width, IMAGE_FACTOR, min_pixels, max_pixels)
        return image.resize((width, height), Image.LANCZOS)

    @spaces.GPU()
    def inference(self, image: Image.Image, prompt: str, max_new_tokens: int = 24000) -> str:
        self.load_model()
        messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt}]}]
        text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        image_inputs, _ = process_vision_info(messages)
        inputs = self.processor(text=[text], images=image_inputs, padding=True, return_tensors="pt").to(self.device)
        with torch.no_grad():
            generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False, temperature=0.1)
        generated_ids_trimmed = [out[len(ins):] for ins, out in zip(inputs.input_ids, generated_ids)]
        return self.processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

    def process_image(self, image: Image.Image, min_pixels: int, max_pixels: int):
        resized_image = self.fetch_image(image, min_pixels, max_pixels)
        raw_output = self.inference(resized_image, DOT_OCR_PROMPT)
        result = {'original_image': image, 'raw_output': raw_output, 'layout_result': None}
        try:
            layout_data = json.loads(raw_output)
            result['layout_result'] = layout_data
            result['processed_image'] = self.draw_layout_on_image(image, layout_data)
            result['markdown_content'] = self.layoutjson2md(image, layout_data)
        except (json.JSONDecodeError, KeyError) as e:
            print(f"Failed to parse or process dot.ocr layout: {e}")
            result['processed_image'] = image
            result['markdown_content'] = f"### Error processing output\nRaw model output:\n```json\n{raw_output}\n```"
        return result

    def draw_layout_on_image(self, image: Image.Image, layout_data: List[Dict]) -> Image.Image:
        img_copy, draw = image.copy(), ImageDraw.Draw(img_copy)
        colors = {'Caption': '#FF6B6B', 'Footnote': '#4ECDC4', 'Formula': '#45B7D1', 'List-item': '#96CEB4',
                  'Page-footer': '#FFEAA7', 'Page-header': '#DDA0DD', 'Picture': '#FFD93D', 'Section-header': '#6C5CE7',
                  'Table': '#FD79A8', 'Text': '#74B9FF', 'Title': '#E17055'}
        try: font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 15)
        except: font = ImageFont.load_default()
        for item in layout_data:
            if 'bbox' in item and 'category' in item:
                bbox, category, color = item['bbox'], item['category'], colors.get(category, '#000000')
                draw.rectangle(bbox, outline=color, width=3)
                label_bbox = draw.textbbox((0, 0), category, font=font)
                label_width, label_height = label_bbox[2] - label_bbox[0], label_bbox[3] - label_bbox[1]
                label_x, label_y = bbox[0], max(0, bbox[1] - label_height - 5)
                draw.rectangle([label_x, label_y, label_x + label_width + 4, label_y + label_height + 4], fill=color)
                draw.text((label_x + 2, label_y + 2), category, fill='white', font=font)
        return img_copy

    def layoutjson2md(self, image: Image.Image, layout_data: List[Dict]) -> str:
        md_lines, sorted_items = [], sorted(layout_data, key=lambda x: (x.get('bbox', [0]*4)[1], x.get('bbox', [0]*4)[0]))
        for item in sorted_items:
            cat, txt, bbox = item.get('category'), item.get('text'), item.get('bbox')
            if cat == 'Picture' and bbox:
                try:
                    x1, y1, x2, y2 = max(0, int(bbox[0])), max(0, int(bbox[1])), min(image.width, int(bbox[2])), min(image.height, int(bbox[3]))
                    if x2 > x1 and y2 > y1:
                        cropped = image.crop((x1, y1, x2, y2))
                        buffer = BytesIO()
                        cropped.save(buffer, format='PNG')
                        img_data = base64.b64encode(buffer.getvalue()).decode()
                        md_lines.append(f"![Image](data:image/png;base64,{img_data})\n")
                except Exception: md_lines.append("![Image](Image region detected)\n")
            elif not txt: continue
            elif cat == 'Title': md_lines.append(f"# {txt}\n")
            elif cat == 'Section-header': md_lines.append(f"## {txt}\n")
            elif cat == 'List-item': md_lines.append(f"- {txt}\n")
            elif cat == 'Formula': md_lines.append(f"$$\n{txt}\n$$\n")
            elif cat == 'Caption': md_lines.append(f"*{txt}*\n")
            elif cat == 'Footnote': md_lines.append(f"^{txt}^\n")
            elif cat in ['Text', 'Table']: md_lines.append(f"{txt}\n")
        return "\n".join(md_lines)

class DolphinModel:
    def __init__(self, device: str):
        self.model, self.processor, self.tokenizer, self.device = None, None, None, device
        self.model_id = "ByteDance/Dolphin"

    @spaces.GPU()
    def load_model(self):
        if self.model is None:
            print("Loading Dolphin model...")
            self.processor = AutoProcessor.from_pretrained(self.model_id)
            self.model = VisionEncoderDecoderModel.from_pretrained(self.model_id).eval().to(self.device).half()
            self.tokenizer = self.processor.tokenizer
            print("Dolphin model loaded.")

    @spaces.GPU()
    def model_chat(self, prompt, image):
        self.load_model()
        images = image if isinstance(image, list) else [image]
        prompts = prompt if isinstance(prompt, list) else [prompt] * len(images)
        batch_inputs = self.processor(images, return_tensors="pt", padding=True)
        batch_pixel_values = batch_inputs.pixel_values.half().to(self.device)
        prompts = [f"<s>{p} <Answer/>" for p in prompts]
        batch_prompt_inputs = self.tokenizer(prompts, add_special_tokens=False, return_tensors="pt")
        batch_prompt_ids = batch_prompt_inputs.input_ids.to(self.device)
        batch_attention_mask = batch_prompt_inputs.attention_mask.to(self.device)
        outputs = self.model.generate(
            pixel_values=batch_pixel_values, decoder_input_ids=batch_prompt_ids,
            decoder_attention_mask=batch_attention_mask, max_length=4096,
            pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id,
            use_cache=True, bad_words_ids=[[self.tokenizer.unk_token_id]],
            return_dict_in_generate=True, do_sample=False, num_beams=1, repetition_penalty=1.1
        )
        sequences = self.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)
        results = [seq.replace(p, "").replace("<pad>", "").replace("</s>", "").strip() for p, seq in zip(prompts, sequences)]
        return results if isinstance(image, list) else results[0]

    def process_elements(self, layout_str: str, image: Image.Image, max_batch_size: int = 16):
        padded_image, dims = prepare_image(image)
        layout_results = parse_layout_string(layout_str)
        elements, reading_order = [], 0
        for bbox, label in layout_results:
            try:
                coords = process_coordinates(bbox, padded_image, dims)
                x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2 = coords[:8]
                cropped = padded_image[y1:y2, x1:x2]
                if cropped.size > 0 and cropped.shape[0] > 3 and cropped.shape[1] > 3:
                    pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
                    elements.append({"crop": pil_crop, "label": label, "bbox": [orig_x1, orig_y1, orig_x2, orig_y2], "reading_order": reading_order})
                reading_order += 1
            except Exception as e:
                print(f"Error processing Dolphin element bbox {bbox}: {e}")
        
        text_elems = self.process_element_batch([e for e in elements if e['label'] != 'tab' and e['label'] != 'fig'], "Read text in the image.", max_batch_size)
        table_elems = self.process_element_batch([e for e in elements if e['label'] == 'tab'], "Parse the table in the image.", max_batch_size)
        fig_elems = [{"label": e['label'], "bbox": e['bbox'], "text": "", "reading_order": e['reading_order']} for e in elements if e['label'] == 'fig']
        
        all_results = sorted(text_elems + table_elems + fig_elems, key=lambda x: x['reading_order'])
        return all_results

    def process_element_batch(self, elements, prompt, max_batch_size=16):
        results = []
        for i in range(0, len(elements), max_batch_size):
            batch = elements[i:i+max_batch_size]
            crops = [elem["crop"] for elem in batch]
            prompts = [prompt] * len(crops)
            batch_results = self.model_chat(prompts, crops)
            for j, res_text in enumerate(batch_results):
                elem = batch[j]
                results.append({"label": elem["label"], "bbox": elem["bbox"], "text": res_text.strip(), "reading_order": elem["reading_order"]})
        return results

    def process_image(self, image: Image.Image):
        layout_output = self.model_chat("Parse the reading order of this document.", image)
        recognition_results = self.process_elements(layout_output, image)
        markdown_content = MarkdownConverter().convert(recognition_results)
        return {
            'original_image': image, 'processed_image': image, 'markdown_content': markdown_content,
            'layout_result': recognition_results, 'raw_output': layout_output
        }


# =================================================================================
# --- GRADIO UI AND EVENT HANDLERS ---
# =================================================================================

def create_gradio_interface():
    """Create the main Gradio interface and define all event handlers"""
    
    css = """
    .main-container { max-width: 1400px; margin: 0 auto; }
    .header-text { text-align: center; color: #2c3e50; margin-bottom: 20px; }
    .process-button { border: none !important; color: white !important; font-weight: bold !important; background-color: blue !important;}
    .process-button:hover { background-color: darkblue !important; transform: translateY(-2px) !important; box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; }
    .info-box { border: 1px solid #dee2e6; border-radius: 8px; padding: 15px; margin: 10px 0; }
    .page-info { text-align: center; padding: 8px 16px; border-radius: 20px; font-weight: bold; margin: 10px 0; }
    """
    
    with gr.Blocks(theme="bethecloud/storj_theme", css=css, title="Dot.OCR Comparator") as demo:
        gr.HTML("""
        <div class="title" style="text-align: center">
            <h1>Dot<span style="color: red;">●</span><strong></strong>OCR Comparator</h1>
            <p style="font-size: 1.1em; color: #6b7280; margin-bottom: 0.6em;">
                Advanced vision-language model for image/PDF to markdown document processing
            </p>
        </div>
        """)
        
        with gr.Row(elem_classes=["main-container"]):
            with gr.Column(scale=1):
                file_input = gr.File(label="Upload Image or PDF", file_types=[".jpg", ".jpeg", ".png", ".pdf"], type="filepath")
                
                with gr.Row():
                    examples = gr.Examples(
                        examples=["examples/sample_image1.png", "examples/sample_image2.png", "examples/sample_pdf.pdf"],
                        inputs=file_input,
                        label="Example Documents"
                    )

                model_choice = gr.Radio(choices=["dot.ocr", "Dolphin"], label="Select Model", value="dot.ocr")
                image_preview = gr.Image(label="Preview", type="pil", interactive=False, height=400)
                
                with gr.Row():
                    prev_page_btn = gr.Button("β—€ Previous")
                    page_info = gr.HTML('<div class="page-info">No file loaded</div>')
                    next_page_btn = gr.Button("Next β–Ά")
                
                with gr.Accordion("Advanced Settings (dot.ocr only)", open=False):
                    min_pixels = gr.Number(value=MIN_PIXELS, label="Min Pixels", step=1)
                    max_pixels = gr.Number(value=MAX_PIXELS, label="Max Pixels", step=1)
                
                with gr.Row():
                    process_btn = gr.Button("πŸš€ Process Document", variant="primary", elem_classes=["process-button"], scale=2)
                    clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary")
            
            with gr.Column(scale=2):
                with gr.Tabs():
                    with gr.Tab("πŸ“ Extracted Content"):
                        markdown_output = gr.Markdown(value="Click 'Process Document' to see extracted content...", elem_id="markdown_output")
                    with gr.Tab("πŸ–ΌοΈ Processed Image"):
                        processed_image_output = gr.Image(label="Image with Layout Detection", type="pil", interactive=False)
                    with gr.Tab("πŸ“‹ Layout JSON"):
                        json_output = gr.JSON(label="Layout Analysis Results")

        def load_file_for_preview(file_path: str) -> Tuple[List[Image.Image], str]:
            images = []
            if not file_path or not os.path.exists(file_path): return [], "No file selected"
            try:
                ext = os.path.splitext(file_path)[1].lower()
                if ext == '.pdf':
                    doc = fitz.open(file_path)
                    for page in doc:
                        pix = page.get_pixmap(matrix=fitz.Matrix(2.0, 2.0))
                        images.append(Image.open(BytesIO(pix.tobytes("ppm"))).convert('RGB'))
                    doc.close()
                elif ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']:
                    images.append(Image.open(file_path).convert('RGB'))
                return images, f"Page 1 / {len(images)}"
            except Exception as e:
                print(f"Error loading file for preview: {e}")
                return [], f"Error loading file: {e}"

        def handle_file_upload(file_path):
            global PDF_CACHE
            images, page_info_str = load_file_for_preview(file_path)
            if not images:
                return None, page_info_str
            PDF_CACHE = {
                "images": images, "current_page": 0, "total_pages": len(images),
                "is_parsed": False, "results": [], "model_used": None
            }
            return images[0], f'<div class="page-info">{page_info_str}</div>'

        def process_document(file_path, model_name, min_pix, max_pix):
            global PDF_CACHE
            if not file_path or not PDF_CACHE["images"]:
                return "Please upload a file first.", None, None
            
            if model_name not in MODELS:
                if model_name == 'dot.ocr': MODELS[model_name] = DotOcrModel(DEVICE)
                elif model_name == 'Dolphin': MODELS[model_name] = DolphinModel(DEVICE)
            model = MODELS[model_name]

            all_results, all_markdown = [], []
            for i, img in enumerate(PDF_CACHE["images"]):
                gr.Info(f"Processing page {i+1}/{len(PDF_CACHE['images'])} with {model_name}...")
                if model_name == 'dot.ocr':
                    result = model.process_image(img, int(min_pix), int(max_pix))
                else: # Dolphin
                    result = model.process_image(img)
                all_results.append(result)
                if result.get('markdown_content'):
                    all_markdown.append(f"### Page {i+1}\n\n{result['markdown_content']}")

            PDF_CACHE.update({"results": all_results, "is_parsed": True, "model_used": model_name})
            if not all_results: return "Processing failed.", None, None
            
            first_result = all_results[0]
            combined_md = "\n\n---\n\n".join(all_markdown)
            
            return combined_md, first_result.get('processed_image'), first_result.get('layout_result')

        def turn_page(direction):
            global PDF_CACHE
            if not PDF_CACHE["images"] or not PDF_CACHE["is_parsed"]:
                return None, '<div class="page-info">No file parsed</div>', "No results yet", None, None

            if direction == "prev": PDF_CACHE["current_page"] = max(0, PDF_CACHE["current_page"] - 1)
            else: PDF_CACHE["current_page"] = min(PDF_CACHE["total_pages"] - 1, PDF_CACHE["current_page"] + 1)
            
            idx = PDF_CACHE["current_page"]
            page_info_html = f'<div class="page-info">Page {idx + 1} / {PDF_CACHE["total_pages"]}</div>'
            preview_img = PDF_CACHE["images"][idx]
            result = PDF_CACHE["results"][idx]
            
            all_md = [f"### Page {i+1}\n\n{res.get('markdown_content', '')}" for i, res in enumerate(PDF_CACHE["results"])]
            md_content = "\n\n---\n\n".join(all_md) if PDF_CACHE["total_pages"] > 1 else result.get('markdown_content', 'No content')
            
            return preview_img, page_info_html, md_content, result.get('processed_image'), result.get('layout_result')

        def clear_all():
            global PDF_CACHE
            PDF_CACHE = {"images": [], "current_page": 0, "total_pages": 0, "is_parsed": False, "results": [], "model_used": None}
            return None, None, '<div class="page-info">No file loaded</div>', "Click 'Process Document' to see extracted content...", None, None

        # --- Wire UI components ---
        file_input.change(handle_file_upload, inputs=file_input, outputs=[image_preview, page_info])
        process_btn.click(
            process_document,
            inputs=[file_input, model_choice, min_pixels, max_pixels],
            outputs=[markdown_output, processed_image_output, json_output]
        )
        prev_page_btn.click(lambda: turn_page("prev"), outputs=[image_preview, page_info, markdown_output, processed_image_output, json_output])
        next_page_btn.click(lambda: turn_page("next"), outputs=[image_preview, page_info, markdown_output, processed_image_output, json_output])
        clear_btn.click(clear_all, outputs=[file_input, image_preview, page_info, markdown_output, processed_image_output, json_output])
        
    return demo

if __name__ == "__main__":
    # Create example directory if it doesn't exist
    if not os.path.exists("examples"):
        os.makedirs("examples")
        print("Created 'examples' directory. Please add sample images/PDFs there.")

    app = create_gradio_interface()
    app.queue().launch(debug=True, show_error=True)