DocQA_Agent / app.py
OmidSakaki's picture
Update app.py
768d260 verified
raw
history blame
4.55 kB
import gradio as gr
import time
import numpy as np
from PIL import Image
from paddleocr import PaddleOCR
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
import easyocr
import pytesseract
from doctr.models import ocr_predictor
# Initialize all models
models = {
"PaddleOCR": PaddleOCR(lang='en'),
"EasyOCR": easyocr.Reader(['en']),
"TrOCR": {
"processor": TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed"),
"model": VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed")
},
"Tesseract": None, # Initialized by pytesseract
"DocTR": ocr_predictor(det_arch='db_resnet50', reco_arch='crnn_vgg16_bn', pretrained=True)
}
def run_paddleocr(image):
try:
result = models["PaddleOCR"].ocr(np.array(image))
return ' '.join([line[1][0] for line in result[0]]) if result else ''
except Exception as e:
return f"Error: {str(e)}"
def run_easyocr(image):
try:
result = models["EasyOCR"].readtext(np.array(image), detail=0)
return ' '.join(result) if result else ''
except Exception as e:
return f"Error: {str(e)}"
def run_trocr(image):
try:
pixel_values = models["TrOCR"]["processor"](image, return_tensors="pt").pixel_values
generated_ids = models["TrOCR"]["model"].generate(pixel_values)
return models["TrOCR"]["processor"].batch_decode(generated_ids, skip_special_tokens=True)[0]
except Exception as e:
return f"Error: {str(e)}"
def run_tesseract(image):
try:
return pytesseract.image_to_string(image, lang='eng')
except Exception as e:
return f"Error: {str(e)}"
def run_doctr(image):
try:
if isinstance(image, Image.Image):
image = np.array(image)
result = models["DocTR"]([image])
return ' '.join([word[0] for page in result.pages for block in page.blocks
for line in block.lines for word in line.words])
except Exception as e:
return f"Error: {str(e)}"
def compare_models(image):
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
image = image.convert("RGB")
results = {}
times = {}
# Run all OCR models
for name, func in [("PaddleOCR", run_paddleocr),
("EasyOCR", run_easyocr),
("TrOCR", run_trocr),
("Tesseract", run_tesseract),
("DocTR", run_doctr)]:
start = time.time()
results[name] = func(image)
times[name] = time.time() - start
# Create comparison table
table_rows = []
for name in results:
table_rows.append(f"""
<tr>
<td style="padding: 8px; border: 1px solid #ddd; text-align: center;">{name}</td>
<td style="padding: 8px; border: 1px solid #ddd;">{results[name]}</td>
<td style="padding: 8px; border: 1px solid #ddd; text-align: center;">{times[name]:.3f}</td>
</tr>
""")
comparison = f"""
<table style="width:100%; border-collapse: collapse; margin-bottom: 20px;">
<tr style="background-color: #f2f2f2;">
<th style="padding: 8px; border: 1px solid #ddd; text-align: center;">Model</th>
<th style="padding: 8px; border: 1px solid #ddd; text-align: center;">Extracted Text</th>
<th style="padding: 8px; border: 1px solid #ddd; text-align: center;">Time (s)</th>
</tr>
{''.join(table_rows)}
</table>
"""
return comparison, *results.values()
# Gradio Interface
with gr.Blocks(title="Advanced OCR Comparison") as demo:
gr.Markdown("## πŸš€ Advanced English OCR Comparison (5 Models)")
with gr.Row():
with gr.Column():
img_input = gr.Image(label="Upload Document", type="pil")
gr.Examples(
examples=["sample1.jpg", "sample2.png"],
inputs=img_input,
label="Sample Images"
)
submit_btn = gr.Button("Run Comparison", variant="primary")
with gr.Column():
comparison = gr.HTML(label="Comparison Results")
with gr.Accordion("Detailed Results", open=False):
gr.Markdown("### Individual Model Outputs")
outputs = [gr.Textbox(label=name) for name in models]
submit_btn.click(
fn=compare_models,
inputs=img_input,
outputs=[comparison, *outputs]
)
if __name__ == "__main__":
demo.launch()