OmidSakaki commited on
Commit
db9549c
·
verified ·
1 Parent(s): 0d931e0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -37
app.py CHANGED
@@ -2,32 +2,22 @@ import gradio as gr
2
  import time
3
  import numpy as np
4
  from PIL import Image
5
- from paddleocr import PaddleOCR
6
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
7
  import easyocr
8
- import pytesseract
9
  from doctr.models import ocr_predictor
10
 
11
- # Initialize all models
12
  models = {
13
- "PaddleOCR": PaddleOCR(lang='en'),
14
  "EasyOCR": easyocr.Reader(['en']),
15
  "TrOCR": {
16
  "processor": TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed"),
17
  "model": VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed")
18
  },
19
- "Tesseract": None, # Initialized by pytesseract
20
  "DocTR": ocr_predictor(det_arch='db_resnet50', reco_arch='crnn_vgg16_bn', pretrained=True)
21
  }
22
 
23
- def run_paddleocr(image):
24
- try:
25
- result = models["PaddleOCR"].ocr(np.array(image))
26
- return ' '.join([line[1][0] for line in result[0]]) if result else ''
27
- except Exception as e:
28
- return f"Error: {str(e)}"
29
-
30
  def run_easyocr(image):
 
31
  try:
32
  result = models["EasyOCR"].readtext(np.array(image), detail=0)
33
  return ' '.join(result) if result else ''
@@ -35,6 +25,7 @@ def run_easyocr(image):
35
  return f"Error: {str(e)}"
36
 
37
  def run_trocr(image):
 
38
  try:
39
  pixel_values = models["TrOCR"]["processor"](image, return_tensors="pt").pixel_values
40
  generated_ids = models["TrOCR"]["model"].generate(pixel_values)
@@ -42,13 +33,8 @@ def run_trocr(image):
42
  except Exception as e:
43
  return f"Error: {str(e)}"
44
 
45
- def run_tesseract(image):
46
- try:
47
- return pytesseract.image_to_string(image, lang='eng')
48
- except Exception as e:
49
- return f"Error: {str(e)}"
50
-
51
  def run_doctr(image):
 
52
  try:
53
  if isinstance(image, Image.Image):
54
  image = np.array(image)
@@ -59,6 +45,7 @@ def run_doctr(image):
59
  return f"Error: {str(e)}"
60
 
61
  def compare_models(image):
 
62
  if isinstance(image, np.ndarray):
63
  image = Image.fromarray(image)
64
  image = image.convert("RGB")
@@ -67,10 +54,8 @@ def compare_models(image):
67
  times = {}
68
 
69
  # Run all OCR models
70
- for name, func in [("PaddleOCR", run_paddleocr),
71
- ("EasyOCR", run_easyocr),
72
  ("TrOCR", run_trocr),
73
- ("Tesseract", run_tesseract),
74
  ("DocTR", run_doctr)]:
75
  start = time.time()
76
  results[name] = func(image)
@@ -81,49 +66,56 @@ def compare_models(image):
81
  for name in results:
82
  table_rows.append(f"""
83
  <tr>
84
- <td style="padding: 8px; border: 1px solid #ddd; text-align: center;">{name}</td>
85
  <td style="padding: 8px; border: 1px solid #ddd;">{results[name]}</td>
86
- <td style="padding: 8px; border: 1px solid #ddd; text-align: center;">{times[name]:.3f}</td>
87
  </tr>
88
  """)
89
 
90
  comparison = f"""
91
- <table style="width:100%; border-collapse: collapse; margin-bottom: 20px;">
92
- <tr style="background-color: #f2f2f2;">
93
- <th style="padding: 8px; border: 1px solid #ddd; text-align: center;">Model</th>
94
- <th style="padding: 8px; border: 1px solid #ddd; text-align: center;">Extracted Text</th>
95
- <th style="padding: 8px; border: 1px solid #ddd; text-align: center;">Time (s)</th>
 
96
  </tr>
97
  {''.join(table_rows)}
98
  </table>
 
99
  """
100
 
101
- return comparison, *results.values()
102
 
103
- # Gradio Interface
104
- with gr.Blocks(title="Advanced OCR Comparison") as demo:
105
- gr.Markdown("## 🚀 Advanced English OCR Comparison (5 Models)")
 
 
 
106
 
107
  with gr.Row():
108
  with gr.Column():
109
- img_input = gr.Image(label="Upload Document", type="pil")
110
  gr.Examples(
111
  examples=["sample1.jpg", "sample2.png"],
112
  inputs=img_input,
113
- label="Sample Images"
114
  )
115
- submit_btn = gr.Button("Run Comparison", variant="primary")
116
 
117
  with gr.Column():
118
  comparison = gr.HTML(label="Comparison Results")
119
  with gr.Accordion("Detailed Results", open=False):
120
  gr.Markdown("### Individual Model Outputs")
121
- outputs = [gr.Textbox(label=name) for name in models]
 
 
122
 
123
  submit_btn.click(
124
  fn=compare_models,
125
  inputs=img_input,
126
- outputs=[comparison, *outputs]
127
  )
128
 
129
  if __name__ == "__main__":
 
2
  import time
3
  import numpy as np
4
  from PIL import Image
 
5
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
6
  import easyocr
 
7
  from doctr.models import ocr_predictor
8
 
9
+ # Initialize models
10
  models = {
 
11
  "EasyOCR": easyocr.Reader(['en']),
12
  "TrOCR": {
13
  "processor": TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed"),
14
  "model": VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed")
15
  },
 
16
  "DocTR": ocr_predictor(det_arch='db_resnet50', reco_arch='crnn_vgg16_bn', pretrained=True)
17
  }
18
 
 
 
 
 
 
 
 
19
  def run_easyocr(image):
20
+ """Run EasyOCR on image"""
21
  try:
22
  result = models["EasyOCR"].readtext(np.array(image), detail=0)
23
  return ' '.join(result) if result else ''
 
25
  return f"Error: {str(e)}"
26
 
27
  def run_trocr(image):
28
+ """Run TrOCR on image"""
29
  try:
30
  pixel_values = models["TrOCR"]["processor"](image, return_tensors="pt").pixel_values
31
  generated_ids = models["TrOCR"]["model"].generate(pixel_values)
 
33
  except Exception as e:
34
  return f"Error: {str(e)}"
35
 
 
 
 
 
 
 
36
  def run_doctr(image):
37
+ """Run DocTR on image"""
38
  try:
39
  if isinstance(image, Image.Image):
40
  image = np.array(image)
 
45
  return f"Error: {str(e)}"
46
 
47
  def compare_models(image):
48
+ """Compare all OCR models"""
49
  if isinstance(image, np.ndarray):
50
  image = Image.fromarray(image)
51
  image = image.convert("RGB")
 
54
  times = {}
55
 
56
  # Run all OCR models
57
+ for name, func in [("EasyOCR", run_easyocr),
 
58
  ("TrOCR", run_trocr),
 
59
  ("DocTR", run_doctr)]:
60
  start = time.time()
61
  results[name] = func(image)
 
66
  for name in results:
67
  table_rows.append(f"""
68
  <tr>
69
+ <td style="padding: 8px; border: 1px solid #ddd; text-align: center; font-weight: bold;">{name}</td>
70
  <td style="padding: 8px; border: 1px solid #ddd;">{results[name]}</td>
71
+ <td style="padding: 8px; border: 1px solid #ddd; text-align: center;">{times[name]:.3f}s</td>
72
  </tr>
73
  """)
74
 
75
  comparison = f"""
76
+ <div style="overflow-x: auto;">
77
+ <table style="width:100%; border-collapse: collapse; margin: 15px 0; font-family: Arial, sans-serif;">
78
+ <tr style="background-color: #4CAF50; color: white;">
79
+ <th style="padding: 12px; border: 1px solid #ddd; text-align: center;">Model</th>
80
+ <th style="padding: 12px; border: 1px solid #ddd; text-align: center;">Extracted Text</th>
81
+ <th style="padding: 12px; border: 1px solid #ddd; text-align: center;">Processing Time</th>
82
  </tr>
83
  {''.join(table_rows)}
84
  </table>
85
+ </div>
86
  """
87
 
88
+ return comparison, results['EasyOCR'], results['TrOCR'], results['DocTR']
89
 
90
+ # Create Gradio interface
91
+ with gr.Blocks(title="English OCR Comparison", theme=gr.themes.Soft()) as demo:
92
+ gr.Markdown("""
93
+ # 🚀 English OCR Model Comparison
94
+ Compare the performance of top OCR models for English text extraction
95
+ """)
96
 
97
  with gr.Row():
98
  with gr.Column():
99
+ img_input = gr.Image(label="Upload Image", type="pil")
100
  gr.Examples(
101
  examples=["sample1.jpg", "sample2.png"],
102
  inputs=img_input,
103
+ label="Try these sample images"
104
  )
105
+ submit_btn = gr.Button("Compare Models", variant="primary")
106
 
107
  with gr.Column():
108
  comparison = gr.HTML(label="Comparison Results")
109
  with gr.Accordion("Detailed Results", open=False):
110
  gr.Markdown("### Individual Model Outputs")
111
+ easy_output = gr.Textbox(label="EasyOCR")
112
+ trocr_output = gr.Textbox(label="TrOCR")
113
+ doctr_output = gr.Textbox(label="DocTR")
114
 
115
  submit_btn.click(
116
  fn=compare_models,
117
  inputs=img_input,
118
+ outputs=[comparison, easy_output, trocr_output, doctr_output]
119
  )
120
 
121
  if __name__ == "__main__":