DocQA_Agent / app.py
OmidSakaki's picture
Update app.py
57fa964 verified
raw
history blame
2.83 kB
import gradio as gr
import easyocr
import numpy as np
from typing import Tuple
from transformers import pipeline
# --- 1. کلاس OCR برای استخراج متن از تصویر ---
class OCRProcessor:
def __init__(self):
self.reader = easyocr.Reader(['fa'])
def extract_text(self, image: np.ndarray) -> str:
try:
results = self.reader.readtext(image, detail=0, paragraph=True)
return "\n".join(results) if results else ""
except Exception as e:
raise RuntimeError(f"خطا در پردازش OCR: {str(e)}")
# --- 2. کلاس تصحیح متن با مدل زبانی ---
class TextCorrector:
def __init__(self):
# استفاده از مدل ParsBERT برای تصحیح متن فارسی
self.corrector = pipeline(
"text2text-generation",
model="persiannlp/parsbert-uncased", # مدل زبانی فارسی
tokenizer="persiannlp/parsbert-uncased"
)
def correct(self, text: str) -> str:
if not text.strip():
return text
try:
corrected = self.corrector(
text,
max_length=512,
num_beams=5,
early_stopping=True
)
return corrected[0]['generated_text']
except Exception as e:
print(f"خطا در تصحیح متن: {e}")
return text
# --- 3. پردازش کامل (OCR + تصحیح خودکار) ---
def full_processing(image: np.ndarray) -> Tuple[str, str]:
try:
# استخراج متن از تصویر
ocr_text = OCRProcessor().extract_text(image)
# تصحیح متن با مدل زبانی
corrected_text = TextCorrector().correct(ocr_text)
return ocr_text, corrected_text
except Exception as e:
error_msg = f"خطا: {str(e)}"
return error_msg, error_msg
# --- 4. رابط کاربری Gradio ---
with gr.Blocks(title="پایپلاین OCR + تصحیح خودکار متن فارسی") as app:
gr.Markdown("""
# استخراج و تصحیح هوشمند متن فارسی از تصویر
""")
with gr.Row():
with gr.Column():
img_input = gr.Image(label="تصویر ورودی", type="numpy")
process_btn = gr.Button("پردازش تصویر", variant="primary")
with gr.Column():
raw_output = gr.Textbox(label="متن استخراج شده (خام)", lines=8, interactive=True)
corrected_output = gr.Textbox(label="متن تصحیح شده (هوشمند)", lines=10, interactive=True)
process_btn.click(
fn=full_processing,
inputs=img_input,
outputs=[raw_output, corrected_output]
)
if __name__ == "__main__":
app.launch()