OmidSakaki commited on
Commit
768d260
·
verified ·
1 Parent(s): 520d7f2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -76
app.py CHANGED
@@ -5,49 +5,60 @@ from PIL import Image
5
  from paddleocr import PaddleOCR
6
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
7
  import easyocr
 
 
8
 
9
- # Initialize models
10
- paddle_ocr = PaddleOCR(lang='en') # PaddleOCR برای انگلیسی
11
- easy_ocr = easyocr.Reader(['en']) # EasyOCR برای انگلیسی
12
- trocr_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
13
- trocr_model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed")
 
 
 
 
 
 
14
 
15
  def run_paddleocr(image):
16
- """Run PaddleOCR on image"""
17
- if isinstance(image, Image.Image):
18
- image = np.array(image)
19
-
20
  try:
21
- result = paddle_ocr.ocr(image)
22
  return ' '.join([line[1][0] for line in result[0]]) if result else ''
23
  except Exception as e:
24
- return f"PaddleOCR Error: {str(e)}"
25
 
26
  def run_easyocr(image):
27
- """Run EasyOCR on image"""
28
- if isinstance(image, Image.Image):
29
- image = np.array(image)
30
-
31
  try:
32
- result = easy_ocr.readtext(image, detail=0)
33
  return ' '.join(result) if result else ''
34
  except Exception as e:
35
- return f"EasyOCR Error: {str(e)}"
36
 
37
  def run_trocr(image):
38
- """Run TrOCR on image"""
39
- if isinstance(image, np.ndarray):
40
- image = Image.fromarray(image)
41
-
42
  try:
43
- pixel_values = trocr_processor(image, return_tensors="pt").pixel_values
44
- generated_ids = trocr_model.generate(pixel_values)
45
- return trocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  except Exception as e:
47
- return f"TrOCR Error: {str(e)}"
48
 
49
  def compare_models(image):
50
- """Compare all three OCR models"""
51
  if isinstance(image, np.ndarray):
52
  image = Image.fromarray(image)
53
  image = image.convert("RGB")
@@ -55,75 +66,64 @@ def compare_models(image):
55
  results = {}
56
  times = {}
57
 
58
- # Run PaddleOCR
59
- start = time.time()
60
- results['PaddleOCR'] = run_paddleocr(image)
61
- times['PaddleOCR'] = time.time() - start
62
-
63
- # Run EasyOCR
64
- start = time.time()
65
- results['EasyOCR'] = run_easyocr(image)
66
- times['EasyOCR'] = time.time() - start
67
-
68
- # Run TrOCR
69
- start = time.time()
70
- results['TrOCR'] = run_trocr(image)
71
- times['TrOCR'] = time.time() - start
72
 
73
  # Create comparison table
 
 
 
 
 
 
 
 
 
 
74
  comparison = f"""
75
- <table style="width:100%; border-collapse: collapse;">
76
  <tr style="background-color: #f2f2f2;">
77
  <th style="padding: 8px; border: 1px solid #ddd; text-align: center;">Model</th>
78
  <th style="padding: 8px; border: 1px solid #ddd; text-align: center;">Extracted Text</th>
79
- <th style="padding: 8px; border: 1px solid #ddd; text-align: center;">Processing Time (s)</th>
80
- </tr>
81
- <tr>
82
- <td style="padding: 8px; border: 1px solid #ddd; text-align: center;">PaddleOCR</td>
83
- <td style="padding: 8px; border: 1px solid #ddd;">{results['PaddleOCR']}</td>
84
- <td style="padding: 8px; border: 1px solid #ddd; text-align: center;">{times['PaddleOCR']:.3f}</td>
85
- </tr>
86
- <tr>
87
- <td style="padding: 8px; border: 1px solid #ddd; text-align: center;">EasyOCR</td>
88
- <td style="padding: 8px; border: 1px solid #ddd;">{results['EasyOCR']}</td>
89
- <td style="padding: 8px; border: 1px solid #ddd; text-align: center;">{times['EasyOCR']:.3f}</td>
90
- </tr>
91
- <tr>
92
- <td style="padding: 8px; border: 1px solid #ddd; text-align: center;">TrOCR</td>
93
- <td style="padding: 8px; border: 1px solid #ddd;">{results['TrOCR']}</td>
94
- <td style="padding: 8px; border: 1px solid #ddd; text-align: center;">{times['TrOCR']:.3f}</td>
95
  </tr>
 
96
  </table>
97
  """
98
 
99
- return comparison, results['PaddleOCR'], results['EasyOCR'], results['TrOCR']
100
 
101
- # Create Gradio interface
102
- with gr.Blocks(title="English OCR Comparison Tool") as demo:
103
- gr.Markdown("""
104
- ## English OCR Models Comparison
105
- This tool compares three OCR models for English text:
106
- 1. PaddleOCR
107
- 2. EasyOCR
108
- 3. TrOCR (Microsoft)
109
- """)
110
 
111
  with gr.Row():
112
  with gr.Column():
113
- image_input = gr.Image(label="Input Image", type="pil")
114
- submit_btn = gr.Button("Compare Models", variant="primary")
 
 
 
 
 
115
 
116
  with gr.Column():
117
- comparison_output = gr.HTML(label="Comparison Results")
118
- with gr.Accordion("Individual Results", open=False):
119
- paddle_output = gr.Textbox(label="PaddleOCR Result")
120
- easy_output = gr.Textbox(label="EasyOCR Result")
121
- trocr_output = gr.Textbox(label="TrOCR Result")
122
 
123
  submit_btn.click(
124
  fn=compare_models,
125
- inputs=image_input,
126
- outputs=[comparison_output, paddle_output, easy_output, trocr_output]
127
  )
128
 
129
  if __name__ == "__main__":
 
5
  from paddleocr import PaddleOCR
6
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
7
  import easyocr
8
+ import pytesseract
9
+ from doctr.models import ocr_predictor
10
 
11
+ # Initialize all models
12
+ models = {
13
+ "PaddleOCR": PaddleOCR(lang='en'),
14
+ "EasyOCR": easyocr.Reader(['en']),
15
+ "TrOCR": {
16
+ "processor": TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed"),
17
+ "model": VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed")
18
+ },
19
+ "Tesseract": None, # Initialized by pytesseract
20
+ "DocTR": ocr_predictor(det_arch='db_resnet50', reco_arch='crnn_vgg16_bn', pretrained=True)
21
+ }
22
 
23
  def run_paddleocr(image):
 
 
 
 
24
  try:
25
+ result = models["PaddleOCR"].ocr(np.array(image))
26
  return ' '.join([line[1][0] for line in result[0]]) if result else ''
27
  except Exception as e:
28
+ return f"Error: {str(e)}"
29
 
30
  def run_easyocr(image):
 
 
 
 
31
  try:
32
+ result = models["EasyOCR"].readtext(np.array(image), detail=0)
33
  return ' '.join(result) if result else ''
34
  except Exception as e:
35
+ return f"Error: {str(e)}"
36
 
37
  def run_trocr(image):
 
 
 
 
38
  try:
39
+ pixel_values = models["TrOCR"]["processor"](image, return_tensors="pt").pixel_values
40
+ generated_ids = models["TrOCR"]["model"].generate(pixel_values)
41
+ return models["TrOCR"]["processor"].batch_decode(generated_ids, skip_special_tokens=True)[0]
42
+ except Exception as e:
43
+ return f"Error: {str(e)}"
44
+
45
+ def run_tesseract(image):
46
+ try:
47
+ return pytesseract.image_to_string(image, lang='eng')
48
+ except Exception as e:
49
+ return f"Error: {str(e)}"
50
+
51
+ def run_doctr(image):
52
+ try:
53
+ if isinstance(image, Image.Image):
54
+ image = np.array(image)
55
+ result = models["DocTR"]([image])
56
+ return ' '.join([word[0] for page in result.pages for block in page.blocks
57
+ for line in block.lines for word in line.words])
58
  except Exception as e:
59
+ return f"Error: {str(e)}"
60
 
61
  def compare_models(image):
 
62
  if isinstance(image, np.ndarray):
63
  image = Image.fromarray(image)
64
  image = image.convert("RGB")
 
66
  results = {}
67
  times = {}
68
 
69
+ # Run all OCR models
70
+ for name, func in [("PaddleOCR", run_paddleocr),
71
+ ("EasyOCR", run_easyocr),
72
+ ("TrOCR", run_trocr),
73
+ ("Tesseract", run_tesseract),
74
+ ("DocTR", run_doctr)]:
75
+ start = time.time()
76
+ results[name] = func(image)
77
+ times[name] = time.time() - start
 
 
 
 
 
78
 
79
  # Create comparison table
80
+ table_rows = []
81
+ for name in results:
82
+ table_rows.append(f"""
83
+ <tr>
84
+ <td style="padding: 8px; border: 1px solid #ddd; text-align: center;">{name}</td>
85
+ <td style="padding: 8px; border: 1px solid #ddd;">{results[name]}</td>
86
+ <td style="padding: 8px; border: 1px solid #ddd; text-align: center;">{times[name]:.3f}</td>
87
+ </tr>
88
+ """)
89
+
90
  comparison = f"""
91
+ <table style="width:100%; border-collapse: collapse; margin-bottom: 20px;">
92
  <tr style="background-color: #f2f2f2;">
93
  <th style="padding: 8px; border: 1px solid #ddd; text-align: center;">Model</th>
94
  <th style="padding: 8px; border: 1px solid #ddd; text-align: center;">Extracted Text</th>
95
+ <th style="padding: 8px; border: 1px solid #ddd; text-align: center;">Time (s)</th>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  </tr>
97
+ {''.join(table_rows)}
98
  </table>
99
  """
100
 
101
+ return comparison, *results.values()
102
 
103
+ # Gradio Interface
104
+ with gr.Blocks(title="Advanced OCR Comparison") as demo:
105
+ gr.Markdown("## 🚀 Advanced English OCR Comparison (5 Models)")
 
 
 
 
 
 
106
 
107
  with gr.Row():
108
  with gr.Column():
109
+ img_input = gr.Image(label="Upload Document", type="pil")
110
+ gr.Examples(
111
+ examples=["sample1.jpg", "sample2.png"],
112
+ inputs=img_input,
113
+ label="Sample Images"
114
+ )
115
+ submit_btn = gr.Button("Run Comparison", variant="primary")
116
 
117
  with gr.Column():
118
+ comparison = gr.HTML(label="Comparison Results")
119
+ with gr.Accordion("Detailed Results", open=False):
120
+ gr.Markdown("### Individual Model Outputs")
121
+ outputs = [gr.Textbox(label=name) for name in models]
 
122
 
123
  submit_btn.click(
124
  fn=compare_models,
125
+ inputs=img_input,
126
+ outputs=[comparison, *outputs]
127
  )
128
 
129
  if __name__ == "__main__":