|
from flask import Flask, request, jsonify, send_from_directory |
|
from flask_cors import CORS |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import torch |
|
from datasets import load_dataset |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
from sentence_transformers import SentenceTransformer |
|
|
|
app = Flask(__name__) |
|
CORS(app) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium") |
|
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium") |
|
|
|
|
|
try: |
|
dataset = load_dataset("bitext/Bitext-travel-llm-chatbot-training-dataset") |
|
print("Bitext dataset loaded successfully.") |
|
except Exception as e: |
|
print(f"Error loading Bitext dataset: {str(e)}") |
|
dataset = None |
|
|
|
|
|
sentence_model = SentenceTransformer('all-MiniLM-L6-v2') |
|
|
|
|
|
def find_closest_response(user_input): |
|
if dataset is None: |
|
return "I'm sorry, but I couldn't load the travel dataset. Please try again later." |
|
|
|
try: |
|
|
|
instructions = [example['instruction'] for example in dataset['train']] |
|
instruction_embeddings = sentence_model.encode(instructions) |
|
|
|
|
|
user_embedding = sentence_model.encode([user_input]) |
|
|
|
|
|
similarities = cosine_similarity( |
|
user_embedding, instruction_embeddings) |
|
closest_index = similarities.argmax() |
|
|
|
|
|
return dataset['train'][closest_index]['response'] |
|
except Exception as e: |
|
print(f"Error finding closest response: {str(e)}") |
|
return "I'm sorry, but I couldn't find a suitable response. Please try again." |
|
|
|
|
|
def chat_with_bot(user_input, chat_history_ids=None): |
|
try: |
|
|
|
closest_response = find_closest_response(user_input) |
|
print(f"Closest response: {closest_response}") |
|
|
|
|
|
new_user_input_ids = tokenizer.encode( |
|
user_input + tokenizer.eos_token, return_tensors='pt') |
|
bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], |
|
dim=-1) if chat_history_ids is not None else new_user_input_ids |
|
chat_history_ids = model.generate( |
|
bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id) |
|
bot_reply = tokenizer.decode( |
|
chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True) |
|
print(f"DialoGPT response: {bot_reply}") |
|
|
|
|
|
combined_response = f"{closest_response}\n\n{bot_reply}" |
|
|
|
return combined_response, chat_history_ids |
|
except Exception as e: |
|
print(f"Exception: {str(e)}") |
|
return "Sorry, an unexpected error occurred. Please try again.", None |
|
|
|
|
|
|
|
|
|
@app.route("/") |
|
def serve_html(): |
|
return send_from_directory(".", "index.html") |
|
|
|
|
|
|
|
|
|
@app.route("/chat", methods=["POST"]) |
|
def chat(): |
|
user_input = request.json.get("message") |
|
if not user_input: |
|
return jsonify({"error": "No message provided"}), 400 |
|
|
|
|
|
chat_history_ids = request.json.get("chat_history_ids") |
|
if chat_history_ids: |
|
chat_history_ids = torch.tensor( |
|
chat_history_ids) |
|
|
|
|
|
bot_response, chat_history_ids = chat_with_bot( |
|
user_input, chat_history_ids) |
|
|
|
|
|
chat_history_ids_list = chat_history_ids.tolist( |
|
) if chat_history_ids is not None else None |
|
|
|
return jsonify({ |
|
"response": bot_response, |
|
"chat_history_ids": chat_history_ids_list |
|
}) |
|
|
|
|
|
if __name__ == "__main__": |
|
app.run(debug=True) |
|
|