DocQA_Agent / app.py
OmidSakaki's picture
Update app.py
6ecc4f4 verified
raw
history blame
4.31 kB
import gradio as gr
import easyocr
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import numpy as np
from typing import Tuple
## 1. تنظیمات اولیه و مدل‌ها
# ----------------------------------
class OCRProcessor:
def __init__(self):
self.reader = easyocr.Reader(['fa'])
def extract_text(self, image: np.ndarray) -> str:
"""استخراج متن از تصویر با EasyOCR"""
try:
results = self.reader.readtext(image, detail=0, paragraph=True)
return "\n".join(results) if results else ""
except Exception as e:
raise RuntimeError(f"خطا در پردازش OCR: {str(e)}")
class TextPostProcessor:
def __init__(self):
self.replacements = {
'ي': 'ی', 'ك': 'ک',
'۰': '0', '۱': '1', '۲': '2', '۳': '3', '۴': '4',
'۵': '5', '۶': '6', '۷': '7', '۸': '8', '۹': '9'
}
try:
self.llm = pipeline(
"text-generation",
model="HooshvareLab/gpt2-fa",
tokenizer="HooshvareLab/gpt2-fa"
)
except Exception as e:
print("خطا در بارگذاری مدل زبانی:", e)
self.llm = None
def preprocess(self, text: str) -> str:
"""نرمال‌سازی ساده متن"""
if not text:
return ""
for old, new in self.replacements.items():
text = text.replace(old, new)
return " ".join(text.split())
def enhance_with_llm(self, text: str) -> str:
"""بازنویسی یا بهبود متن با LLM فارسی"""
if not text or not self.llm:
return text
prompt = f"متن زیر را بازنویسی کن و به صورت روان و صحیح برگردان:\n{text}\nبازنویسی:"
try:
output = self.llm(
prompt,
max_length=len(prompt) + len(text) + 60,
num_return_sequences=1,
do_sample=True,
temperature=0.9,
pad_token_id=0,
eos_token_id=2
)
gen_text = output[0]['generated_text']
# فقط بخش بازنویسی شده را جدا کن
if "بازنویسی:" in gen_text:
gen_text = gen_text.split("بازنویسی:")[-1].strip()
# اگر بازنویسی مدل بی‌معنا یا کوتاه بود، همان متن را برگردان
if len(gen_text) < 8:
return text
return gen_text
except Exception as e:
print("خطا در بازنویسی با LLM:", e)
return text
## 2. پایپلاین اصلی
def full_processing(image: np.ndarray) -> Tuple[str, str]:
try:
ocr_text = OCRProcessor().extract_text(image)
post_processor = TextPostProcessor()
cleaned_text = post_processor.preprocess(ocr_text)
enhanced_text = post_processor.enhance_with_llm(cleaned_text)
return cleaned_text, enhanced_text
except Exception as e:
return f"خطا: {str(e)}", ""
## 3. رابط کاربری Gradio
with gr.Blocks(title="پایپلاین OCR و بازنویسی متن فارسی") as app:
gr.Markdown("""
# سیستم استخراج و بازنویسی متن فارسی از تصویر
تصویر را بارگذاری کنید، متن استخراج و سپس با مدل زبانی بازنویسی می‌شود.
""")
with gr.Row():
with gr.Column():
img_input = gr.Image(label="تصویر ورودی", type="numpy")
process_btn = gr.Button("پردازش تصویر", variant="primary")
with gr.Column():
with gr.Tab("متن استخراج شده"):
raw_output = gr.Textbox(label="متن استخراج شده")
with gr.Tab("متن بازنویسی شده"):
enhanced_output = gr.Textbox(label="متن بازنویسی شده")
img_input.change(fn=lambda x: x, inputs=img_input, outputs=img_preview)
process_btn.click(fn=full_processing, inputs=img_input, outputs=[raw_output, enhanced_output])
if __name__ == "__main__":
app.launch()