File size: 4,431 Bytes
9453eac
2bf547d
24f0403
9453eac
2bf547d
5c3f634
768d260
9453eac
db9549c
768d260
 
 
 
 
 
 
 
9453eac
5c3f634
db9549c
5c3f634
768d260
5c3f634
 
768d260
9453eac
2bf547d
db9549c
dd4c7df
768d260
 
 
 
 
 
 
db9549c
768d260
 
 
 
 
 
dd4c7df
768d260
2bf547d
 
db9549c
2bf547d
 
 
 
 
24f0403
2bf547d
768d260
db9549c
768d260
 
 
 
 
2bf547d
5c3f634
768d260
 
 
 
db9549c
768d260
db9549c
768d260
 
 
2bf547d
db9549c
 
 
 
 
 
2bf547d
768d260
2bf547d
db9549c
2bf547d
 
db9549c
9453eac
db9549c
 
 
 
 
 
9453eac
 
 
db9549c
768d260
 
 
db9549c
768d260
db9549c
279ab91
 
768d260
 
 
db9549c
 
 
2bf547d
 
 
768d260
db9549c
9453eac
 
 
2bf547d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import gradio as gr
import time
import numpy as np
from PIL import Image
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
import easyocr
from doctr.models import ocr_predictor

# Initialize models
models = {
    "EasyOCR": easyocr.Reader(['en']),
    "TrOCR": {
        "processor": TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed"),
        "model": VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed")
    },
    "DocTR": ocr_predictor(det_arch='db_resnet50', reco_arch='crnn_vgg16_bn', pretrained=True)
}

def run_easyocr(image):
    """Run EasyOCR on image"""
    try:
        result = models["EasyOCR"].readtext(np.array(image), detail=0)
        return ' '.join(result) if result else ''
    except Exception as e:
        return f"Error: {str(e)}"

def run_trocr(image):
    """Run TrOCR on image"""
    try:
        pixel_values = models["TrOCR"]["processor"](image, return_tensors="pt").pixel_values
        generated_ids = models["TrOCR"]["model"].generate(pixel_values)
        return models["TrOCR"]["processor"].batch_decode(generated_ids, skip_special_tokens=True)[0]
    except Exception as e:
        return f"Error: {str(e)}"

def run_doctr(image):
    """Run DocTR on image"""
    try:
        if isinstance(image, Image.Image):
            image = np.array(image)
        result = models["DocTR"]([image])
        return ' '.join([word[0] for page in result.pages for block in page.blocks 
                        for line in block.lines for word in line.words])
    except Exception as e:
        return f"Error: {str(e)}"

def compare_models(image):
    """Compare all OCR models"""
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)
    image = image.convert("RGB")
    
    results = {}
    times = {}
    
    # Run all OCR models
    for name, func in [("EasyOCR", run_easyocr),
                      ("TrOCR", run_trocr),
                      ("DocTR", run_doctr)]:
        start = time.time()
        results[name] = func(image)
        times[name] = time.time() - start
    
    # Create comparison table
    table_rows = []
    for name in results:
        table_rows.append(f"""
        <tr>
            <td style="padding: 8px; border: 1px solid #ddd; text-align: center; font-weight: bold;">{name}</td>
            <td style="padding: 8px; border: 1px solid #ddd;">{results[name]}</td>
            <td style="padding: 8px; border: 1px solid #ddd; text-align: center;">{times[name]:.3f}s</td>
        </tr>
        """)
    
    comparison = f"""
    <div style="overflow-x: auto;">
    <table style="width:100%; border-collapse: collapse; margin: 15px 0; font-family: Arial, sans-serif;">
        <tr style="background-color: #4CAF50; color: white;">
            <th style="padding: 12px; border: 1px solid #ddd; text-align: center;">Model</th>
            <th style="padding: 12px; border: 1px solid #ddd; text-align: center;">Extracted Text</th>
            <th style="padding: 12px; border: 1px solid #ddd; text-align: center;">Processing Time</th>
        </tr>
        {''.join(table_rows)}
    </table>
    </div>
    """
    
    return comparison, results['EasyOCR'], results['TrOCR'], results['DocTR']

# Create Gradio interface
with gr.Blocks(title="English OCR Comparison", theme=gr.themes.Soft()) as demo:
    gr.Markdown("""
    # 🚀 English OCR Model Comparison
    Compare the performance of top OCR models for English text extraction
    """)
    
    with gr.Row():
        with gr.Column():
            img_input = gr.Image(label="Upload Image", type="pil")
            gr.Examples(
                examples=["sample1.jpg", "sample2.png"],
                inputs=img_input,
                label="Try these sample images"
            )
            submit_btn = gr.Button("Compare Models", variant="primary")
        
        with gr.Column():
            comparison = gr.HTML(label="Comparison Results")
            with gr.Accordion("Detailed Results", open=False):
                gr.Markdown("### Individual Model Outputs")
                easy_output = gr.Textbox(label="EasyOCR")
                trocr_output = gr.Textbox(label="TrOCR")
                doctr_output = gr.Textbox(label="DocTR")
    
    submit_btn.click(
        fn=compare_models,
        inputs=img_input,
        outputs=[comparison, easy_output, trocr_output, doctr_output]
    )

if __name__ == "__main__":
    demo.launch()