OmidSakaki commited on
Commit
87a1c7f
·
verified ·
1 Parent(s): fd5f89e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -72
app.py CHANGED
@@ -1,10 +1,15 @@
1
  import gradio as gr
2
  import easyocr
3
  import numpy as np
4
- from transformers import pipeline
5
- from sentence_transformers import SentenceTransformer
6
- import faiss
7
- import torch
 
 
 
 
 
8
 
9
  # 1. OCR Processor (English)
10
  class OCRProcessor:
@@ -18,89 +23,65 @@ class OCRProcessor:
18
  except Exception as e:
19
  return f"OCR error: {str(e)}"
20
 
21
- # 2. Text Chunker
22
- def text_chunker(text, chunk_size=250, overlap=50):
23
- words = text.split()
24
- chunks = []
25
- i = 0
26
- while i < len(words):
27
- chunk = " ".join(words[i:i+chunk_size])
28
- chunks.append(chunk)
29
- i += chunk_size - overlap
30
- return chunks
31
-
32
- # 3. Embedding Agent (English)
33
- class EmbeddingAgent:
34
- def __init__(self):
35
- self.model = SentenceTransformer('all-MiniLM-L6-v2')
36
-
37
- def embed(self, texts):
38
- return self.model.encode(texts)
39
-
40
- # 4. Retriever Agent (with FAISS)
41
- class RetrieverAgent:
42
- def __init__(self, embeddings, texts):
43
- self.texts = texts
44
- d = embeddings.shape[1]
45
- self.index = faiss.IndexFlatL2(d)
46
- self.index.add(embeddings)
47
-
48
- def retrieve(self, query_embedding, top_k=1):
49
- D, I = self.index.search(query_embedding, top_k)
50
- return [self.texts[idx] for idx in I[0]]
51
-
52
- # 5. QA Agent (English QA model)
53
- class EnglishQAModel:
54
  def __init__(self):
55
- self.qa_pipeline = pipeline(
56
- "question-answering",
57
- model="deepset/roberta-base-squad2",
58
- tokenizer="deepset/roberta-base-squad2"
 
 
 
 
 
 
 
 
59
  )
60
 
61
- def answer_question(self, context: str, question: str) -> str:
62
- if not context.strip() or not question.strip():
63
- return "No text or question provided."
64
- try:
65
- result = self.qa_pipeline({"context": context, "question": question})
66
- answer = result.get('answer', '').strip()
67
- if not answer or answer in ['[CLS]', '[SEP]', '[PAD]']:
68
- return "No answer found."
69
- return answer
70
- except Exception as e:
71
- return f"QA error: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- # Full DocQA Pipeline (English)
74
  ocr_processor = OCRProcessor()
75
- embedder_agent = EmbeddingAgent()
76
- qa_agent = EnglishQAModel()
77
 
78
  def docqa_pipeline(image, question):
79
  # 1. OCR
80
  context = ocr_processor.extract_text(image)
81
  if context.startswith("OCR error"):
82
  return context, "No answer."
83
-
84
- # 2. Chunking
85
- chunks = text_chunker(context)
86
-
87
- # 3. Embedding
88
- chunk_embeddings = embedder_agent.embed(chunks)
89
- question_embedding = embedder_agent.embed([question])
90
-
91
- # 4. Retrieval
92
- retriever = RetrieverAgent(chunk_embeddings, chunks)
93
- relevant_chunk = retriever.retrieve(question_embedding, top_k=1)[0]
94
-
95
- # 5. QA
96
- answer = qa_agent.answer_question(relevant_chunk, question)
97
  return context, f"Relevant chunk:\n{relevant_chunk}\n\nModel answer:\n{answer}"
98
 
99
- with gr.Blocks(title="DocQA Agent: Intelligent Q&A from Extracted English Document") as app:
100
  gr.Markdown("""
101
- # DocQA Agent
102
  <br>
103
- A multi-agent system for question answering from English documents (OCR + retrieval + intelligent answer)
104
  """)
105
  with gr.Row():
106
  with gr.Column():
 
1
  import gradio as gr
2
  import easyocr
3
  import numpy as np
4
+
5
+ from langchain_community.llms import HuggingFacePipeline
6
+ from langchain.chains import RetrievalQA
7
+ from langchain_community.vectorstores import FAISS
8
+ from langchain_community.embeddings import HuggingFaceEmbeddings
9
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
10
+ from langchain.docstore.document import Document
11
+
12
+ from transformers import pipeline as hf_pipeline
13
 
14
  # 1. OCR Processor (English)
15
  class OCRProcessor:
 
23
  except Exception as e:
24
  return f"OCR error: {str(e)}"
25
 
26
+ # 2. LangChain-based DocQA Agent
27
+ class LangChainDocQAAgent:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  def __init__(self):
29
+ # Embedding model
30
+ self.embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
31
+ # Text splitter (chunk size and overlap for better retrieval)
32
+ self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
33
+ # HuggingFace QA pipeline as an LLM
34
+ self.qa_llm = HuggingFacePipeline(
35
+ pipeline=hf_pipeline(
36
+ "question-answering",
37
+ model="deepset/roberta-base-squad2",
38
+ tokenizer="deepset/roberta-base-squad2"
39
+ ),
40
+ model_kwargs={"return_full_text": False}
41
  )
42
 
43
+ def prepare_retriever(self, text):
44
+ # Split text into LangChain Document objects
45
+ docs = [Document(page_content=chunk) for chunk in self.text_splitter.split_text(text)]
46
+ # Create FAISS vectorstore for retrieval
47
+ vectorstore = FAISS.from_documents(docs, self.embeddings)
48
+ return vectorstore.as_retriever(), docs
49
+
50
+ def qa(self, text, question):
51
+ if not text.strip() or not question.strip():
52
+ return "No text or question provided.", ""
53
+ # Build retriever from text
54
+ retriever, docs = self.prepare_retriever(text)
55
+ # RetrievalQA chain: retrieve relevant chunk and answer
56
+ qa_chain = RetrievalQA.from_chain_type(
57
+ llm=self.qa_llm,
58
+ chain_type="stuff",
59
+ retriever=retriever,
60
+ return_source_documents=True
61
+ )
62
+ result = qa_chain({"query": question})
63
+ answer = result["result"]
64
+ # Show the most relevant chunk as context
65
+ relevant_context = result["source_documents"][0].page_content if result["source_documents"] else ""
66
+ return relevant_context, answer
67
 
 
68
  ocr_processor = OCRProcessor()
69
+ docqa_agent = LangChainDocQAAgent()
 
70
 
71
  def docqa_pipeline(image, question):
72
  # 1. OCR
73
  context = ocr_processor.extract_text(image)
74
  if context.startswith("OCR error"):
75
  return context, "No answer."
76
+ # 2. LangChain RetrievalQA
77
+ relevant_chunk, answer = docqa_agent.qa(context, question)
 
 
 
 
 
 
 
 
 
 
 
 
78
  return context, f"Relevant chunk:\n{relevant_chunk}\n\nModel answer:\n{answer}"
79
 
80
+ with gr.Blocks(title="DocQA Agent (LangChain): Intelligent Q&A from Extracted English Document") as app:
81
  gr.Markdown("""
82
+ # DocQA Agent (LangChain)
83
  <br>
84
+ A multi-agent system for question answering from English documents (OCR + retrieval + intelligent answer with LangChain)
85
  """)
86
  with gr.Row():
87
  with gr.Column():