Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
import docx | |
import numpy as np | |
from sentence_transformers import SentenceTransformer | |
from sklearn.metrics.pairwise import cosine_similarity | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.vectorstores import FAISS | |
from langchain.chains import ConversationalRetrievalChain | |
from langchain.memory import ConversationBufferMemory | |
from langchain_community.llms import HuggingFaceEndpoint | |
from langchain_huggingface import HuggingFaceEmbeddings | |
# Initialize semantic model | |
semantic_model = SentenceTransformer("all-MiniLM-L6-v2") | |
def extract_text_from_docx(file_path): | |
doc = docx.Document(file_path) | |
extracted_text = [] | |
for para in doc.paragraphs: | |
if para.text.strip(): | |
extracted_text.append(para.text.strip()) | |
for table in doc.tables: | |
extracted_text.append("π Table Detected:") | |
for row in table.rows: | |
row_text = [cell.text.strip() for cell in row.cells] | |
if any(row_text): | |
extracted_text.append(" | ".join(row_text)) | |
return "\n".join(extracted_text) | |
def load_documents(): | |
file_paths = { | |
"Fastener_Types_Manual": "Fastener_Types_Manual.docx", | |
"Manufacturing_Expert_Manual": "Manufacturing Expert Manual.docx" | |
} | |
all_splits = [] | |
for doc_name, file_path in file_paths.items(): | |
if not os.path.exists(file_path): | |
raise FileNotFoundError(f"Document not found: {file_path}") | |
print(f"Extracting text from {file_path}...") | |
full_text = extract_text_from_docx(file_path) | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=200) | |
doc_splits = text_splitter.create_documents([full_text]) | |
for chunk in doc_splits: | |
chunk.metadata = {"source": doc_name} | |
all_splits.extend(doc_splits) | |
return all_splits | |
def create_db(splits): | |
embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-base-en-v1.5") | |
vectordb = FAISS.from_documents(splits, embeddings) | |
return vectordb, embeddings | |
def retrieve_documents(query, retriever, embeddings): | |
query_embedding = np.array(embeddings.embed_query(query)).reshape(1, -1) | |
results = retriever.invoke(query) | |
if not results: | |
return [] | |
doc_embeddings = np.array([embeddings.embed_query(doc.page_content) for doc in results]) | |
similarity_scores = cosine_similarity(query_embedding, doc_embeddings)[0] | |
MIN_SIMILARITY = 0.3 | |
filtered_results = [(doc, sim) for doc, sim in zip(results, similarity_scores) if sim >= MIN_SIMILARITY] | |
print(f"π Query: {query}") | |
print(f"π Retrieved Docs: {[(doc.metadata.get('source', 'Unknown'), sim) for doc, sim in filtered_results]}") | |
return [doc for doc, _ in filtered_results] if filtered_results else [] | |
def validate_query_semantically(query, retrieved_docs): | |
if not retrieved_docs: | |
return False | |
combined_text = " ".join([doc.page_content for doc in retrieved_docs]) | |
query_embedding = semantic_model.encode(query, normalize_embeddings=True) | |
doc_embedding = semantic_model.encode(combined_text, normalize_embeddings=True) | |
similarity_score = np.dot(query_embedding, doc_embedding) | |
print(f"π Semantic Similarity Score: {similarity_score}") | |
return similarity_score >= 0.3 | |
def initialize_chatbot(vector_db, embeddings): | |
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True, output_key='answer') | |
retriever = vector_db.as_retriever(search_kwargs={"k": 5}) | |
system_prompt = """You are an AI assistant that answers questions ONLY based on the provided documents. | |
- If no relevant documents are retrieved, respond with: "I couldn't find any relevant information." | |
- If the meaning of the query does not match the retrieved documents, say "I couldn't find any relevant information." | |
- Do NOT attempt to answer from general knowledge.""" | |
llm = HuggingFaceEndpoint( | |
repo_id="tiiuae/falcon-40b-instruct", | |
huggingfacehub_api_token=os.environ.get("HUGGINGFACE_API_TOKEN"), | |
temperature=0.1, | |
max_new_tokens=400, | |
task="text-generation", | |
system_prompt=system_prompt | |
) | |
qa_chain = ConversationalRetrievalChain.from_llm( | |
llm=llm, | |
retriever=retriever, | |
memory=memory, | |
return_source_documents=True, | |
verbose=False | |
) | |
return retriever, qa_chain | |
def handle_query(query, history, retriever, qa_chain, embeddings): | |
retrieved_docs = retrieve_documents(query, retriever, embeddings) | |
if not retrieved_docs or not validate_query_semantically(query, retrieved_docs): | |
return history + [(query, "I couldn't find any relevant information.")], "" | |
response = qa_chain.invoke({"question": query, "chat_history": history}) | |
assistant_response = response['answer'].strip() | |
if not validate_query_semantically(query, retrieved_docs): | |
assistant_response = "I couldn't find any relevant information." | |
assistant_response += f"\n\nπ Source: {', '.join(set(doc.metadata.get('source', 'Unknown') for doc in retrieved_docs))}" | |
history.append((query, assistant_response)) | |
return history, "" | |
def demo(): | |
documents = load_documents() | |
vector_db, embeddings = create_db(documents) | |
retriever, qa_chain = initialize_chatbot(vector_db, embeddings) | |
with gr.Blocks() as app: | |
gr.Markdown("### π€ Document Question Answering System") | |
chatbot = gr.Chatbot() | |
query_input = gr.Textbox(label="Ask a question about the documents") | |
query_btn = gr.Button("Submit") | |
def user_query_handler(query, history): | |
return handle_query(query, history, retriever, qa_chain, embeddings) | |
query_btn.click( | |
user_query_handler, | |
inputs=[query_input, chatbot], | |
outputs=[chatbot, query_input] | |
) | |
query_input.submit( | |
user_query_handler, | |
inputs=[query_input, chatbot], | |
outputs=[chatbot, query_input] | |
) | |
app.launch() | |
if __name__ == "__main__": | |
demo() | |