DocQA_Agent / app.py
OmidSakaki's picture
Update app.py
77d9d02 verified
raw
history blame
2.93 kB
import gradio as gr
import easyocr
import numpy as np
from typing import Tuple
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
class OCRProcessor:
def __init__(self):
self.reader = easyocr.Reader(['fa'])
def extract_text(self, image: np.ndarray) -> str:
try:
results = self.reader.readtext(image, detail=0, paragraph=True)
return "\n".join(results) if results else ""
except Exception as e:
raise RuntimeError(f"خطا در پردازش OCR: {str(e)}")
class TextCorrector:
def __init__(self):
model_name = "persiannlp/mt5-small-parsinlu-arc-comqa-question"
try:
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
except Exception as e:
raise RuntimeError(f"خطا در بارگذاری مدل زبانی: {str(e)}")
def correct(self, text: str) -> str:
if not text.strip():
return text
try:
inputs = self.tokenizer(
"اصلاح متن: " + text,
return_tensors="pt",
max_length=512,
truncation=True
)
outputs = self.model.generate(
**inputs,
max_length=512,
num_beams=5,
early_stopping=True
)
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
except Exception as e:
print(f"خطا در تصحیح متن: {e}")
return text
def full_processing(image: np.ndarray) -> Tuple[str, str]:
try:
# استخراج متن از تصویر
ocr_text = OCRProcessor().extract_text(image)
# تصحیح متن با مدل زبانی
corrected_text = TextCorrector().correct(ocr_text)
return ocr_text, corrected_text
except Exception as e:
error_msg = f"خطا: {str(e)}"
return error_msg, error_msg
with gr.Blocks(title="پایپلاین OCR + تصحیح خودکار متن فارسی") as app:
gr.Markdown("""
# سیستم استخراج و تصحیح هوشمند متن فارسی
""")
with gr.Row():
with gr.Column():
img_input = gr.Image(label="تصویر ورودی", type="numpy")
process_btn = gr.Button("پردازش تصویر", variant="primary")
with gr.Column():
raw_output = gr.Textbox(label="متن استخراج شده (خام)", lines=8, max_lines=None)
corrected_output = gr.Textbox(label="متن تصحیح شده (هوشمند)", lines=15, max_lines=None)
process_btn.click(
fn=full_processing,
inputs=img_input,
outputs=[raw_output, corrected_output]
)
if __name__ == "__main__":
app.launch()