DocQA_Agent / app.py
OmidSakaki's picture
Update app.py
2711484 verified
raw
history blame
4.45 kB
import gradio as gr
import easyocr
import numpy as np
from transformers import pipeline
from sentence_transformers import SentenceTransformer
import faiss
import torch
# 1. OCR Processor
class OCRProcessor:
def __init__(self):
self.reader = easyocr.Reader(['fa'])
def extract_text(self, image: np.ndarray) -> str:
try:
results = self.reader.readtext(image, detail=0, paragraph=True)
return "\n".join(results) if results else ""
except Exception as e:
return f"خطا در پردازش OCR: {str(e)}"
# 2. Text Chunker
def text_chunker(text, chunk_size=250, overlap=50):
words = text.split()
chunks = []
i = 0
while i < len(words):
chunk = " ".join(words[i:i+chunk_size])
chunks.append(chunk)
i += chunk_size - overlap
return chunks
# 3. Embedding Agent
class EmbeddingAgent:
def __init__(self):
self.model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')
def embed(self, texts):
return self.model.encode(texts)
# 4. Retriever Agent (with FAISS)
class RetrieverAgent:
def __init__(self, embeddings, texts):
self.texts = texts
d = embeddings.shape[1]
self.index = faiss.IndexFlatL2(d)
self.index.add(embeddings)
def retrieve(self, query_embedding, top_k=1):
D, I = self.index.search(query_embedding, top_k)
return [self.texts[idx] for idx in I[0]]
# 5. QA Agent (using multilingual QA model)
class MultilingualQAModel:
def __init__(self):
self.qa_pipeline = pipeline(
"question-answering",
model="deepset/roberta-base-squad2",
tokenizer="deepset/roberta-base-squad2"
)
def answer_question(self, context: str, question: str) -> str:
if not context.strip() or not question.strip():
return "متن یا سوال وارد نشده است."
try:
result = self.qa_pipeline({"context": context, "question": question})
answer = result.get('answer', '').strip()
if not answer or answer in ['[CLS]', '[SEP]', '[PAD]']:
return "جوابی یافت نشد."
return answer
except Exception as e:
return f"خطا در مدل پرسش و پاسخ: {str(e)}"
# Full DocQA Pipeline
ocr_processor = OCRProcessor()
embedder_agent = EmbeddingAgent()
qa_agent = MultilingualQAModel()
def docqa_pipeline(image, question):
# 1. OCR
context = ocr_processor.extract_text(image)
if context.startswith("خطا"):
return context, "پاسخی وجود ندارد"
# 2. Chunking
chunks = text_chunker(context)
# 3. Embedding (chunks + question)
chunk_embeddings = embedder_agent.embed(chunks)
question_embedding = embedder_agent.embed([question])
# 4. Retriever: پیدا کردن مرتبط‌ترین بخش
retriever = RetrieverAgent(chunk_embeddings, chunks)
relevant_chunk = retriever.retrieve(question_embedding, top_k=1)[0]
# 5. QA: پاسخ به سوال بر اساس بخش بازیابی‌شده
answer = qa_agent.answer_question(relevant_chunk, question)
return context, f"متن مرتبط:\n{relevant_chunk}\n\nپاسخ مدل:\n{answer}"
with gr.Blocks(title="DocQA Agent: پرسش و پاسخ هوشمند از سند فارسی استخراج‌شده از تصویر") as app:
gr.Markdown("""
# DocQA Agent
<br>
یک سامانه چندعاملی برای پرسش و پاسخ از اسناد فارسی (OCR + جستجو + پاسخ هوشمند)
""")
with gr.Row():
with gr.Column():
img_input = gr.Image(label="تصویر ورودی", type="numpy")
question_input = gr.Textbox(label="سوال شما به فارسی", placeholder="مثلاً: نویسنده این متن کیست؟", lines=1)
process_btn = gr.Button("پاسخ")
with gr.Column():
context_output = gr.Textbox(label="متن استخراج شده", lines=10, max_lines=None, interactive=False)
answer_output = gr.Textbox(label="جواب مدل (بخش مرتبط و پاسخ)", lines=10, max_lines=None, interactive=False)
process_btn.click(
fn=docqa_pipeline,
inputs=[img_input, question_input],
outputs=[context_output, answer_output]
)
if __name__ == "__main__":
app.launch()