File size: 2,909 Bytes
5f9760f
 
 
 
 
 
 
 
 
 
 
923ef72
5f9760f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88a7827
5f9760f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88a7827
 
 
 
 
5f9760f
 
 
 
 
 
 
88a7827
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# app.py

import pandas as pd
from fastapi import FastAPI
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
import faiss
from datasets import load_dataset

# Load the dataset
support_data = load_dataset("rjac/e-commerce-customer-support-qa")
faq_data = pd.read_csv("/content/Ecommerce_FAQs.csv")

# Data preparation
faq_data.rename(columns={'prompt': 'Question', 'response': 'Answer'}, inplace=True)
faq_data = faq_data[['Question', 'Answer']]
support_data_df = pd.DataFrame(support_data['train'])

def extract_conversation(data):
    try:
        parts = data.split("\n\n")
        question = parts[1].split(": ", 1)[1] if len(parts) > 1 else ""
        answer = parts[2].split(": ", 1)[1] if len(parts) > 2 else ""
        return pd.Series({"Question": question, "Answer": answer})
    except IndexError:
        return pd.Series({"Question": "", "Answer": ""})

support_data_df[['Question', 'Answer']] = support_data_df['conversation'].apply(extract_conversation)
combined_data = pd.concat([faq_data, support_data_df[['Question', 'Answer']]], ignore_index=True)

# Initialize SBERT Model
model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')

# Generate and Index Embeddings
questions = combined_data['Question'].tolist()
embeddings = model.encode(questions, convert_to_tensor=True)
index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(embeddings.cpu().numpy())

# Define FastAPI app and model
app = FastAPI()

class Query(BaseModel):
    question: str

@app.post("/ask")
def ask_bot(query: Query):
    question_embedding = model.encode([query.question], convert_to_tensor=True)
    question_embedding_np = question_embedding.cpu().numpy()
    _, closest_index = index.search(question_embedding_np, k=1)
    best_match_idx = closest_index[0][0]
    answer = combined_data.iloc[best_match_idx]['Answer']
    return {"answer": answer}

# Gradio Interface

import gradio as gr
import requests

# Define the URL of your FastAPI endpoint
API_URL = "http://localhost:8000/ask"  # Update to your deployed FastAPI URL if needed

def respond(message, history: list[tuple[str, str]]):
    payload = {"question": message}
    
    try:
        response = requests.post(API_URL, json=payload)
        response.raise_for_status()
        response_data = response.json()
        answer = response_data.get("answer", "Sorry, I didn't get that.")
    except requests.exceptions.RequestException as e:
        answer = f"Request Error: {str(e)}"
    
    # Update history
    history.append((message, answer))
    return answer, history

# Gradio Chat Interface
demo = gr.ChatInterface(
    respond,
)

if __name__ == "__main__":
    import threading
    import uvicorn

    # Run FastAPI in a separate thread
    threading.Thread(target=uvicorn.run, args=(app,), kwargs={"host": "0.0.0.0", "port": 8000}).start()
    
    # Launch Gradio interface
    demo.launch()