Spaces:
Runtime error
Runtime error
| # app.py | |
| import pandas as pd | |
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from sentence_transformers import SentenceTransformer | |
| import faiss | |
| from datasets import load_dataset | |
| # Load the dataset | |
| support_data = load_dataset("rjac/e-commerce-customer-support-qa") | |
| faq_data = pd.read_csv("/content/Ecommerce_FAQs.csv") | |
| # Data preparation | |
| faq_data.rename(columns={'prompt': 'Question', 'response': 'Answer'}, inplace=True) | |
| faq_data = faq_data[['Question', 'Answer']] | |
| support_data_df = pd.DataFrame(support_data['train']) | |
| def extract_conversation(data): | |
| try: | |
| parts = data.split("\n\n") | |
| question = parts[1].split(": ", 1)[1] if len(parts) > 1 else "" | |
| answer = parts[2].split(": ", 1)[1] if len(parts) > 2 else "" | |
| return pd.Series({"Question": question, "Answer": answer}) | |
| except IndexError: | |
| return pd.Series({"Question": "", "Answer": ""}) | |
| support_data_df[['Question', 'Answer']] = support_data_df['conversation'].apply(extract_conversation) | |
| combined_data = pd.concat([faq_data, support_data_df[['Question', 'Answer']]], ignore_index=True) | |
| # Initialize SBERT Model | |
| model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2') | |
| # Generate and Index Embeddings | |
| questions = combined_data['Question'].tolist() | |
| embeddings = model.encode(questions, convert_to_tensor=True) | |
| index = faiss.IndexFlatL2(embeddings.shape[1]) | |
| index.add(embeddings.cpu().numpy()) | |
| # Define FastAPI app and model | |
| app = FastAPI() | |
| class Query(BaseModel): | |
| question: str | |
| def ask_bot(query: Query): | |
| question_embedding = model.encode([query.question], convert_to_tensor=True) | |
| question_embedding_np = question_embedding.cpu().numpy() | |
| _, closest_index = index.search(question_embedding_np, k=1) | |
| best_match_idx = closest_index[0][0] | |
| answer = combined_data.iloc[best_match_idx]['Answer'] | |
| return {"answer": answer} | |
| # Gradio Interface | |
| import gradio as gr | |
| import requests | |
| # Define the URL of your FastAPI endpoint | |
| API_URL = "http://localhost:8000/ask" # Update to your deployed FastAPI URL if needed | |
| def respond(message, history: list[tuple[str, str]]): | |
| payload = {"question": message} | |
| try: | |
| response = requests.post(API_URL, json=payload) | |
| response.raise_for_status() | |
| response_data = response.json() | |
| answer = response_data.get("answer", "Sorry, I didn't get that.") | |
| except requests.exceptions.RequestException as e: | |
| answer = f"Request Error: {str(e)}" | |
| # Update history | |
| history.append((message, answer)) | |
| return answer, history | |
| # Gradio Chat Interface | |
| demo = gr.ChatInterface( | |
| respond, | |
| ) | |
| if __name__ == "__main__": | |
| import threading | |
| import uvicorn | |
| # Run FastAPI in a separate thread | |
| threading.Thread(target=uvicorn.run, args=(app,), kwargs={"host": "0.0.0.0", "port": 8000}).start() | |
| # Launch Gradio interface | |
| demo.launch() | |