Spaces:
Runtime error
Runtime error
# app.py | |
import pandas as pd | |
from fastapi import FastAPI | |
from pydantic import BaseModel | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
from datasets import load_dataset | |
# Load the dataset | |
support_data = load_dataset("rjac/e-commerce-customer-support-qa") | |
faq_data = pd.read_csv("/content/Ecommerce_FAQs.csv") | |
# Data preparation | |
faq_data.rename(columns={'prompt': 'Question', 'response': 'Answer'}, inplace=True) | |
faq_data = faq_data[['Question', 'Answer']] | |
support_data_df = pd.DataFrame(support_data['train']) | |
def extract_conversation(data): | |
try: | |
parts = data.split("\n\n") | |
question = parts[1].split(": ", 1)[1] if len(parts) > 1 else "" | |
answer = parts[2].split(": ", 1)[1] if len(parts) > 2 else "" | |
return pd.Series({"Question": question, "Answer": answer}) | |
except IndexError: | |
return pd.Series({"Question": "", "Answer": ""}) | |
support_data_df[['Question', 'Answer']] = support_data_df['conversation'].apply(extract_conversation) | |
combined_data = pd.concat([faq_data, support_data_df[['Question', 'Answer']]], ignore_index=True) | |
# Initialize SBERT Model | |
model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2') | |
# Generate and Index Embeddings | |
questions = combined_data['Question'].tolist() | |
embeddings = model.encode(questions, convert_to_tensor=True) | |
index = faiss.IndexFlatL2(embeddings.shape[1]) | |
index.add(embeddings.cpu().numpy()) | |
# Define FastAPI app and model | |
app = FastAPI() | |
class Query(BaseModel): | |
question: str | |
def ask_bot(query: Query): | |
question_embedding = model.encode([query.question], convert_to_tensor=True) | |
question_embedding_np = question_embedding.cpu().numpy() | |
_, closest_index = index.search(question_embedding_np, k=1) | |
best_match_idx = closest_index[0][0] | |
answer = combined_data.iloc[best_match_idx]['Answer'] | |
return {"answer": answer} | |
# Gradio Interface | |
import gradio as gr | |
import requests | |
# Define the URL of your FastAPI endpoint | |
API_URL = "http://localhost:8000/ask" # Update to your deployed FastAPI URL if needed | |
def respond(message, history: list[tuple[str, str]]): | |
payload = {"question": message} | |
try: | |
response = requests.post(API_URL, json=payload) | |
response.raise_for_status() | |
response_data = response.json() | |
answer = response_data.get("answer", "Sorry, I didn't get that.") | |
except requests.exceptions.RequestException as e: | |
answer = f"Request Error: {str(e)}" | |
# Update history | |
history.append((message, answer)) | |
return answer, history | |
# Gradio Chat Interface | |
demo = gr.ChatInterface( | |
respond, | |
) | |
if __name__ == "__main__": | |
import threading | |
import uvicorn | |
# Run FastAPI in a separate thread | |
threading.Thread(target=uvicorn.run, args=(app,), kwargs={"host": "0.0.0.0", "port": 8000}).start() | |
# Launch Gradio interface | |
demo.launch() | |