DocQA_Agent / app.py
OmidSakaki's picture
Update app.py
87a1c7f verified
raw
history blame
4.13 kB
import gradio as gr
import easyocr
import numpy as np
from langchain_community.llms import HuggingFacePipeline
from langchain.chains import RetrievalQA
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.docstore.document import Document
from transformers import pipeline as hf_pipeline
# 1. OCR Processor (English)
class OCRProcessor:
def __init__(self):
self.reader = easyocr.Reader(['en'])
def extract_text(self, image: np.ndarray) -> str:
try:
results = self.reader.readtext(image, detail=0, paragraph=True)
return "\n".join(results) if results else ""
except Exception as e:
return f"OCR error: {str(e)}"
# 2. LangChain-based DocQA Agent
class LangChainDocQAAgent:
def __init__(self):
# Embedding model
self.embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
# Text splitter (chunk size and overlap for better retrieval)
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
# HuggingFace QA pipeline as an LLM
self.qa_llm = HuggingFacePipeline(
pipeline=hf_pipeline(
"question-answering",
model="deepset/roberta-base-squad2",
tokenizer="deepset/roberta-base-squad2"
),
model_kwargs={"return_full_text": False}
)
def prepare_retriever(self, text):
# Split text into LangChain Document objects
docs = [Document(page_content=chunk) for chunk in self.text_splitter.split_text(text)]
# Create FAISS vectorstore for retrieval
vectorstore = FAISS.from_documents(docs, self.embeddings)
return vectorstore.as_retriever(), docs
def qa(self, text, question):
if not text.strip() or not question.strip():
return "No text or question provided.", ""
# Build retriever from text
retriever, docs = self.prepare_retriever(text)
# RetrievalQA chain: retrieve relevant chunk and answer
qa_chain = RetrievalQA.from_chain_type(
llm=self.qa_llm,
chain_type="stuff",
retriever=retriever,
return_source_documents=True
)
result = qa_chain({"query": question})
answer = result["result"]
# Show the most relevant chunk as context
relevant_context = result["source_documents"][0].page_content if result["source_documents"] else ""
return relevant_context, answer
ocr_processor = OCRProcessor()
docqa_agent = LangChainDocQAAgent()
def docqa_pipeline(image, question):
# 1. OCR
context = ocr_processor.extract_text(image)
if context.startswith("OCR error"):
return context, "No answer."
# 2. LangChain RetrievalQA
relevant_chunk, answer = docqa_agent.qa(context, question)
return context, f"Relevant chunk:\n{relevant_chunk}\n\nModel answer:\n{answer}"
with gr.Blocks(title="DocQA Agent (LangChain): Intelligent Q&A from Extracted English Document") as app:
gr.Markdown("""
# DocQA Agent (LangChain)
<br>
A multi-agent system for question answering from English documents (OCR + retrieval + intelligent answer with LangChain)
""")
with gr.Row():
with gr.Column():
img_input = gr.Image(label="Input Image", type="numpy")
question_input = gr.Textbox(label="Your question (in English)", placeholder="e.g. Who is the author of this text?", lines=1)
process_btn = gr.Button("Get Answer")
with gr.Column():
context_output = gr.Textbox(label="Extracted Text", lines=10, max_lines=None, interactive=False)
answer_output = gr.Textbox(label="Model Output (Relevant Chunk & Answer)", lines=10, max_lines=None, interactive=False)
process_btn.click(
fn=docqa_pipeline,
inputs=[img_input, question_input],
outputs=[context_output, answer_output]
)
if __name__ == "__main__":
app.launch()