Spaces:
Sleeping
Sleeping
import gradio as gr | |
import easyocr | |
import numpy as np | |
from transformers import pipeline | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
import torch | |
# 1. OCR Processor | |
class OCRProcessor: | |
def __init__(self): | |
self.reader = easyocr.Reader(['fa']) | |
def extract_text(self, image: np.ndarray) -> str: | |
try: | |
results = self.reader.readtext(image, detail=0, paragraph=True) | |
return "\n".join(results) if results else "" | |
except Exception as e: | |
return f"خطا در پردازش OCR: {str(e)}" | |
# 2. Text Chunker | |
def text_chunker(text, chunk_size=250, overlap=50): | |
words = text.split() | |
chunks = [] | |
i = 0 | |
while i < len(words): | |
chunk = " ".join(words[i:i+chunk_size]) | |
chunks.append(chunk) | |
i += chunk_size - overlap | |
return chunks | |
# 3. Embedding Agent | |
class EmbeddingAgent: | |
def __init__(self): | |
self.model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2') | |
def embed(self, texts): | |
return self.model.encode(texts) | |
# 4. Retriever Agent (with FAISS) | |
class RetrieverAgent: | |
def __init__(self, embeddings, texts): | |
self.texts = texts | |
d = embeddings.shape[1] | |
self.index = faiss.IndexFlatL2(d) | |
self.index.add(embeddings) | |
def retrieve(self, query_embedding, top_k=1): | |
D, I = self.index.search(query_embedding, top_k) | |
return [self.texts[idx] for idx in I[0]] | |
# 5. QA Agent (using multilingual QA model) | |
class MultilingualQAModel: | |
def __init__(self): | |
self.qa_pipeline = pipeline( | |
"question-answering", | |
model="deepset/roberta-base-squad2", | |
tokenizer="deepset/roberta-base-squad2" | |
) | |
def answer_question(self, context: str, question: str) -> str: | |
if not context.strip() or not question.strip(): | |
return "متن یا سوال وارد نشده است." | |
try: | |
result = self.qa_pipeline({"context": context, "question": question}) | |
answer = result.get('answer', '').strip() | |
if not answer or answer in ['[CLS]', '[SEP]', '[PAD]']: | |
return "جوابی یافت نشد." | |
return answer | |
except Exception as e: | |
return f"خطا در مدل پرسش و پاسخ: {str(e)}" | |
# Full DocQA Pipeline | |
ocr_processor = OCRProcessor() | |
embedder_agent = EmbeddingAgent() | |
qa_agent = MultilingualQAModel() | |
def docqa_pipeline(image, question): | |
# 1. OCR | |
context = ocr_processor.extract_text(image) | |
if context.startswith("خطا"): | |
return context, "پاسخی وجود ندارد" | |
# 2. Chunking | |
chunks = text_chunker(context) | |
# 3. Embedding (chunks + question) | |
chunk_embeddings = embedder_agent.embed(chunks) | |
question_embedding = embedder_agent.embed([question]) | |
# 4. Retriever: پیدا کردن مرتبطترین بخش | |
retriever = RetrieverAgent(chunk_embeddings, chunks) | |
relevant_chunk = retriever.retrieve(question_embedding, top_k=1)[0] | |
# 5. QA: پاسخ به سوال بر اساس بخش بازیابیشده | |
answer = qa_agent.answer_question(relevant_chunk, question) | |
return context, f"متن مرتبط:\n{relevant_chunk}\n\nپاسخ مدل:\n{answer}" | |
with gr.Blocks(title="DocQA Agent: پرسش و پاسخ هوشمند از سند فارسی استخراجشده از تصویر") as app: | |
gr.Markdown(""" | |
# DocQA Agent | |
<br> | |
یک سامانه چندعاملی برای پرسش و پاسخ از اسناد فارسی (OCR + جستجو + پاسخ هوشمند) | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
img_input = gr.Image(label="تصویر ورودی", type="numpy") | |
question_input = gr.Textbox(label="سوال شما به فارسی", placeholder="مثلاً: نویسنده این متن کیست؟", lines=1) | |
process_btn = gr.Button("پاسخ") | |
with gr.Column(): | |
context_output = gr.Textbox(label="متن استخراج شده", lines=10, max_lines=None, interactive=False) | |
answer_output = gr.Textbox(label="جواب مدل (بخش مرتبط و پاسخ)", lines=10, max_lines=None, interactive=False) | |
process_btn.click( | |
fn=docqa_pipeline, | |
inputs=[img_input, question_input], | |
outputs=[context_output, answer_output] | |
) | |
if __name__ == "__main__": | |
app.launch() |