Mishal23's picture
Update app.py
923ef72 verified
raw
history blame
2.91 kB
# app.py
import pandas as pd
from fastapi import FastAPI
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
import faiss
from datasets import load_dataset
# Load the dataset
support_data = load_dataset("rjac/e-commerce-customer-support-qa")
faq_data = pd.read_csv("/content/Ecommerce_FAQs.csv")
# Data preparation
faq_data.rename(columns={'prompt': 'Question', 'response': 'Answer'}, inplace=True)
faq_data = faq_data[['Question', 'Answer']]
support_data_df = pd.DataFrame(support_data['train'])
def extract_conversation(data):
try:
parts = data.split("\n\n")
question = parts[1].split(": ", 1)[1] if len(parts) > 1 else ""
answer = parts[2].split(": ", 1)[1] if len(parts) > 2 else ""
return pd.Series({"Question": question, "Answer": answer})
except IndexError:
return pd.Series({"Question": "", "Answer": ""})
support_data_df[['Question', 'Answer']] = support_data_df['conversation'].apply(extract_conversation)
combined_data = pd.concat([faq_data, support_data_df[['Question', 'Answer']]], ignore_index=True)
# Initialize SBERT Model
model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
# Generate and Index Embeddings
questions = combined_data['Question'].tolist()
embeddings = model.encode(questions, convert_to_tensor=True)
index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(embeddings.cpu().numpy())
# Define FastAPI app and model
app = FastAPI()
class Query(BaseModel):
question: str
@app.post("/ask")
def ask_bot(query: Query):
question_embedding = model.encode([query.question], convert_to_tensor=True)
question_embedding_np = question_embedding.cpu().numpy()
_, closest_index = index.search(question_embedding_np, k=1)
best_match_idx = closest_index[0][0]
answer = combined_data.iloc[best_match_idx]['Answer']
return {"answer": answer}
# Gradio Interface
import gradio as gr
import requests
# Define the URL of your FastAPI endpoint
API_URL = "http://localhost:8000/ask" # Update to your deployed FastAPI URL if needed
def respond(message, history: list[tuple[str, str]]):
payload = {"question": message}
try:
response = requests.post(API_URL, json=payload)
response.raise_for_status()
response_data = response.json()
answer = response_data.get("answer", "Sorry, I didn't get that.")
except requests.exceptions.RequestException as e:
answer = f"Request Error: {str(e)}"
# Update history
history.append((message, answer))
return answer, history
# Gradio Chat Interface
demo = gr.ChatInterface(
respond,
)
if __name__ == "__main__":
import threading
import uvicorn
# Run FastAPI in a separate thread
threading.Thread(target=uvicorn.run, args=(app,), kwargs={"host": "0.0.0.0", "port": 8000}).start()
# Launch Gradio interface
demo.launch()