|
import pandas as pd |
|
import gradio as gr |
|
from langchain.vectorstores import FAISS |
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
from langchain.schema import Document |
|
import os |
|
import pickle |
|
|
|
EMBEDDING_MODEL = "sentence-transformers/multi-qa-mpnet-base-dot-v1" |
|
DATASET_PATH = "qa_dataset.csv" |
|
FAISS_INDEX_PATH = "faiss_index" |
|
|
|
|
|
def initialize_system(): |
|
if os.path.exists(FAISS_INDEX_PATH): |
|
print("Loading FAISS index from cache...") |
|
with open(FAISS_INDEX_PATH, "rb") as f: |
|
return pickle.load(f) |
|
|
|
print("Initializing FAISS from scratch...") |
|
data = pd.read_csv(DATASET_PATH).dropna().head(500) |
|
|
|
|
|
documents = [ |
|
Document( |
|
page_content=f"Q: {row['Question']}\nA: {row['Answer']}", |
|
metadata={"question": row['Question'], "answer": row['Answer']} |
|
) for _, row in data.iterrows() |
|
] |
|
|
|
|
|
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL) |
|
|
|
|
|
vector_store = FAISS.from_documents(documents, embeddings) |
|
|
|
|
|
with open(FAISS_INDEX_PATH, "wb") as f: |
|
pickle.dump(vector_store, f) |
|
|
|
return vector_store |
|
|
|
vector_store = initialize_system() |
|
|
|
def classify_question(query: str, k: int = 3): |
|
results = vector_store.similarity_search(query, k=k) |
|
|
|
if not results: |
|
return {"Category": "Unknown", "Top Matches": "No matches found", "Confidence": "0%"} |
|
|
|
answers = " ".join([doc.metadata['answer'] for doc in results]) |
|
keywords = list(dict.fromkeys(answers.split()))[:5] |
|
category = " ".join(keywords) |
|
|
|
return { |
|
"Category": category, |
|
"Top Matches": "\n\n".join([f"Q: {doc.metadata['question']}\nA: {doc.metadata['answer']}" |
|
for doc in results]), |
|
"Confidence": f"{len(results)/k:.0%}" |
|
} |
|
|
|
|
|
interface = gr.Interface( |
|
fn=lambda q: classify_question(q, 3), |
|
inputs=gr.Textbox(label="Input Question", placeholder="Type your question here..."), |
|
outputs=[ |
|
gr.Textbox(label="Predicted Category"), |
|
gr.Textbox(label="Supporting Q&A"), |
|
gr.Textbox(label="Confidence") |
|
], |
|
title="Question Classification System", |
|
description="Classify questions based on existing Q&A pairs using FAISS" |
|
) |
|
|
|
if __name__ == "__main__": |
|
interface.launch(share=True) |
|
|