File size: 4,554 Bytes
9453eac
2bf547d
24f0403
9453eac
5c3f634
2bf547d
5c3f634
768d260
 
9453eac
768d260
 
 
 
 
 
 
 
 
 
 
9453eac
2bf547d
dd4c7df
768d260
dd4c7df
 
768d260
5c3f634
 
 
768d260
5c3f634
 
768d260
9453eac
2bf547d
dd4c7df
768d260
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd4c7df
768d260
2bf547d
 
 
 
 
 
 
24f0403
2bf547d
768d260
 
 
 
 
 
 
 
 
2bf547d
5c3f634
768d260
 
 
 
 
 
 
 
 
 
2bf547d
768d260
5c3f634
 
 
768d260
2bf547d
768d260
2bf547d
 
 
768d260
9453eac
768d260
 
 
9453eac
 
 
768d260
 
 
 
 
 
 
279ab91
 
768d260
 
 
 
2bf547d
 
 
768d260
 
9453eac
 
 
2bf547d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import gradio as gr
import time
import numpy as np
from PIL import Image
from paddleocr import PaddleOCR
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
import easyocr
import pytesseract
from doctr.models import ocr_predictor

# Initialize all models
models = {
    "PaddleOCR": PaddleOCR(lang='en'),
    "EasyOCR": easyocr.Reader(['en']),
    "TrOCR": {
        "processor": TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed"),
        "model": VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed")
    },
    "Tesseract": None,  # Initialized by pytesseract
    "DocTR": ocr_predictor(det_arch='db_resnet50', reco_arch='crnn_vgg16_bn', pretrained=True)
}

def run_paddleocr(image):
    try:
        result = models["PaddleOCR"].ocr(np.array(image))
        return ' '.join([line[1][0] for line in result[0]]) if result else ''
    except Exception as e:
        return f"Error: {str(e)}"

def run_easyocr(image):
    try:
        result = models["EasyOCR"].readtext(np.array(image), detail=0)
        return ' '.join(result) if result else ''
    except Exception as e:
        return f"Error: {str(e)}"

def run_trocr(image):
    try:
        pixel_values = models["TrOCR"]["processor"](image, return_tensors="pt").pixel_values
        generated_ids = models["TrOCR"]["model"].generate(pixel_values)
        return models["TrOCR"]["processor"].batch_decode(generated_ids, skip_special_tokens=True)[0]
    except Exception as e:
        return f"Error: {str(e)}"

def run_tesseract(image):
    try:
        return pytesseract.image_to_string(image, lang='eng')
    except Exception as e:
        return f"Error: {str(e)}"

def run_doctr(image):
    try:
        if isinstance(image, Image.Image):
            image = np.array(image)
        result = models["DocTR"]([image])
        return ' '.join([word[0] for page in result.pages for block in page.blocks 
                        for line in block.lines for word in line.words])
    except Exception as e:
        return f"Error: {str(e)}"

def compare_models(image):
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)
    image = image.convert("RGB")
    
    results = {}
    times = {}
    
    # Run all OCR models
    for name, func in [("PaddleOCR", run_paddleocr),
                      ("EasyOCR", run_easyocr),
                      ("TrOCR", run_trocr),
                      ("Tesseract", run_tesseract),
                      ("DocTR", run_doctr)]:
        start = time.time()
        results[name] = func(image)
        times[name] = time.time() - start
    
    # Create comparison table
    table_rows = []
    for name in results:
        table_rows.append(f"""
        <tr>
            <td style="padding: 8px; border: 1px solid #ddd; text-align: center;">{name}</td>
            <td style="padding: 8px; border: 1px solid #ddd;">{results[name]}</td>
            <td style="padding: 8px; border: 1px solid #ddd; text-align: center;">{times[name]:.3f}</td>
        </tr>
        """)
    
    comparison = f"""
    <table style="width:100%; border-collapse: collapse; margin-bottom: 20px;">
        <tr style="background-color: #f2f2f2;">
            <th style="padding: 8px; border: 1px solid #ddd; text-align: center;">Model</th>
            <th style="padding: 8px; border: 1px solid #ddd; text-align: center;">Extracted Text</th>
            <th style="padding: 8px; border: 1px solid #ddd; text-align: center;">Time (s)</th>
        </tr>
        {''.join(table_rows)}
    </table>
    """
    
    return comparison, *results.values()

# Gradio Interface
with gr.Blocks(title="Advanced OCR Comparison") as demo:
    gr.Markdown("## 🚀 Advanced English OCR Comparison (5 Models)")
    
    with gr.Row():
        with gr.Column():
            img_input = gr.Image(label="Upload Document", type="pil")
            gr.Examples(
                examples=["sample1.jpg", "sample2.png"],
                inputs=img_input,
                label="Sample Images"
            )
            submit_btn = gr.Button("Run Comparison", variant="primary")
        
        with gr.Column():
            comparison = gr.HTML(label="Comparison Results")
            with gr.Accordion("Detailed Results", open=False):
                gr.Markdown("### Individual Model Outputs")
                outputs = [gr.Textbox(label=name) for name in models]
    
    submit_btn.click(
        fn=compare_models,
        inputs=img_input,
        outputs=[comparison, *outputs]
    )

if __name__ == "__main__":
    demo.launch()