DocQA_Agent / app.py
OmidSakaki's picture
Update app.py
c55b6a8 verified
raw
history blame
2.81 kB
import gradio as gr
import easyocr
import numpy as np
from typing import Tuple
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
class OCRProcessor:
def __init__(self):
self.reader = easyocr.Reader(['fa'])
def extract_text(self, image: np.ndarray) -> str:
try:
results = self.reader.readtext(image, detail=0, paragraph=True)
return "\n".join(results) if results else ""
except Exception as e:
raise RuntimeError(f"خطا در پردازش OCR: {str(e)}")
class TextCorrector:
def __init__(self):
model_name = "HooshvareLab/mt5-small-fa"
try:
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
except Exception as e:
raise RuntimeError(f"خطا در بارگذاری مدل زبانی: {str(e)}")
def correct(self, text: str) -> str:
if not text.strip():
return text
try:
prompt = "تصحیح نگارشی متن: " + text
inputs = self.tokenizer(
prompt,
return_tensors="pt",
max_length=512,
truncation=True
)
outputs = self.model.generate(
**inputs,
max_length=512,
num_beams=5,
early_stopping=True
)
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
except Exception as e:
print(f"خطا در تصحیح متن: {e}")
return text
def full_processing(image: np.ndarray) -> Tuple[str, str]:
try:
ocr_text = OCRProcessor().extract_text(image)
corrected_text = TextCorrector().correct(ocr_text)
return ocr_text, corrected_text
except Exception as e:
error_msg = f"خطا: {str(e)}"
return error_msg, error_msg
with gr.Blocks(title="پایپلاین OCR + تصحیح خودکار متن فارسی") as app:
gr.Markdown("""
# سیستم استخراج و تصحیح هوشمند متن فارسی
""")
with gr.Row():
with gr.Column():
img_input = gr.Image(label="تصویر ورودی", type="numpy")
process_btn = gr.Button("پردازش تصویر", variant="primary")
with gr.Column():
raw_output = gr.Textbox(label="متن استخراج شده (خام)", lines=8, max_lines=None)
corrected_output = gr.Textbox(label="متن تصحیح شده (هوشمند)", lines=15, max_lines=None)
process_btn.click(
fn=full_processing,
inputs=img_input,
outputs=[raw_output, corrected_output]
)
if __name__ == "__main__":
app.launch()