DocQA_Agent / app.py
OmidSakaki's picture
Update app.py
db9549c verified
raw
history blame
4.43 kB
import gradio as gr
import time
import numpy as np
from PIL import Image
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
import easyocr
from doctr.models import ocr_predictor
# Initialize models
models = {
"EasyOCR": easyocr.Reader(['en']),
"TrOCR": {
"processor": TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed"),
"model": VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed")
},
"DocTR": ocr_predictor(det_arch='db_resnet50', reco_arch='crnn_vgg16_bn', pretrained=True)
}
def run_easyocr(image):
"""Run EasyOCR on image"""
try:
result = models["EasyOCR"].readtext(np.array(image), detail=0)
return ' '.join(result) if result else ''
except Exception as e:
return f"Error: {str(e)}"
def run_trocr(image):
"""Run TrOCR on image"""
try:
pixel_values = models["TrOCR"]["processor"](image, return_tensors="pt").pixel_values
generated_ids = models["TrOCR"]["model"].generate(pixel_values)
return models["TrOCR"]["processor"].batch_decode(generated_ids, skip_special_tokens=True)[0]
except Exception as e:
return f"Error: {str(e)}"
def run_doctr(image):
"""Run DocTR on image"""
try:
if isinstance(image, Image.Image):
image = np.array(image)
result = models["DocTR"]([image])
return ' '.join([word[0] for page in result.pages for block in page.blocks
for line in block.lines for word in line.words])
except Exception as e:
return f"Error: {str(e)}"
def compare_models(image):
"""Compare all OCR models"""
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
image = image.convert("RGB")
results = {}
times = {}
# Run all OCR models
for name, func in [("EasyOCR", run_easyocr),
("TrOCR", run_trocr),
("DocTR", run_doctr)]:
start = time.time()
results[name] = func(image)
times[name] = time.time() - start
# Create comparison table
table_rows = []
for name in results:
table_rows.append(f"""
<tr>
<td style="padding: 8px; border: 1px solid #ddd; text-align: center; font-weight: bold;">{name}</td>
<td style="padding: 8px; border: 1px solid #ddd;">{results[name]}</td>
<td style="padding: 8px; border: 1px solid #ddd; text-align: center;">{times[name]:.3f}s</td>
</tr>
""")
comparison = f"""
<div style="overflow-x: auto;">
<table style="width:100%; border-collapse: collapse; margin: 15px 0; font-family: Arial, sans-serif;">
<tr style="background-color: #4CAF50; color: white;">
<th style="padding: 12px; border: 1px solid #ddd; text-align: center;">Model</th>
<th style="padding: 12px; border: 1px solid #ddd; text-align: center;">Extracted Text</th>
<th style="padding: 12px; border: 1px solid #ddd; text-align: center;">Processing Time</th>
</tr>
{''.join(table_rows)}
</table>
</div>
"""
return comparison, results['EasyOCR'], results['TrOCR'], results['DocTR']
# Create Gradio interface
with gr.Blocks(title="English OCR Comparison", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# πŸš€ English OCR Model Comparison
Compare the performance of top OCR models for English text extraction
""")
with gr.Row():
with gr.Column():
img_input = gr.Image(label="Upload Image", type="pil")
gr.Examples(
examples=["sample1.jpg", "sample2.png"],
inputs=img_input,
label="Try these sample images"
)
submit_btn = gr.Button("Compare Models", variant="primary")
with gr.Column():
comparison = gr.HTML(label="Comparison Results")
with gr.Accordion("Detailed Results", open=False):
gr.Markdown("### Individual Model Outputs")
easy_output = gr.Textbox(label="EasyOCR")
trocr_output = gr.Textbox(label="TrOCR")
doctr_output = gr.Textbox(label="DocTR")
submit_btn.click(
fn=compare_models,
inputs=img_input,
outputs=[comparison, easy_output, trocr_output, doctr_output]
)
if __name__ == "__main__":
demo.launch()