bbektas25 commited on
Commit
fbd20b5
·
verified ·
1 Parent(s): 289eebc

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -0
app.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify, send_from_directory
2
+ from flask_cors import CORS
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ import torch
5
+ from datasets import load_dataset
6
+ from sklearn.metrics.pairwise import cosine_similarity
7
+ from sentence_transformers import SentenceTransformer
8
+
9
+ app = Flask(__name__)
10
+ CORS(app)
11
+
12
+ # Load the DialoGPT model and tokenizer
13
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
14
+ model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
15
+
16
+ # Load the Bitext Travel Dataset
17
+ try:
18
+ dataset = load_dataset("bitext/Bitext-travel-llm-chatbot-training-dataset")
19
+ print("Bitext dataset loaded successfully.")
20
+ except Exception as e:
21
+ print(f"Error loading Bitext dataset: {str(e)}")
22
+ dataset = None
23
+
24
+ # Load a pre-trained sentence transformer model for semantic similarity
25
+ sentence_model = SentenceTransformer('all-MiniLM-L6-v2')
26
+
27
+
28
+ def find_closest_response(user_input):
29
+ if dataset is None:
30
+ return "I'm sorry, but I couldn't load the travel dataset. Please try again later."
31
+
32
+ try:
33
+ # Precompute embeddings for all instructions in the dataset
34
+ instructions = [example['instruction'] for example in dataset['train']]
35
+ instruction_embeddings = sentence_model.encode(instructions)
36
+
37
+ # Encode the user input
38
+ user_embedding = sentence_model.encode([user_input])
39
+
40
+ # Compute cosine similarity between the user input and all instructions
41
+ similarities = cosine_similarity(
42
+ user_embedding, instruction_embeddings)
43
+ closest_index = similarities.argmax()
44
+
45
+ # Return the closest response
46
+ return dataset['train'][closest_index]['response']
47
+ except Exception as e:
48
+ print(f"Error finding closest response: {str(e)}")
49
+ return "I'm sorry, but I couldn't find a suitable response. Please try again."
50
+
51
+
52
+ def chat_with_bot(user_input, chat_history_ids=None):
53
+ try:
54
+ # Find the closest response from the Bitext dataset
55
+ closest_response = find_closest_response(user_input)
56
+ print(f"Closest response: {closest_response}") # Debugging statement
57
+
58
+ # Generate a response using DialoGPT
59
+ new_user_input_ids = tokenizer.encode(
60
+ user_input + tokenizer.eos_token, return_tensors='pt')
61
+ bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids],
62
+ dim=-1) if chat_history_ids is not None else new_user_input_ids
63
+ chat_history_ids = model.generate(
64
+ bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
65
+ bot_reply = tokenizer.decode(
66
+ chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
67
+ print(f"DialoGPT response: {bot_reply}") # Debugging statement
68
+
69
+ # Combine the Bitext response and DialoGPT response
70
+ combined_response = f"{closest_response}\n\n{bot_reply}"
71
+
72
+ return combined_response, chat_history_ids
73
+ except Exception as e:
74
+ print(f"Exception: {str(e)}") # Print the full exception
75
+ return "Sorry, an unexpected error occurred. Please try again.", None
76
+
77
+ # Serve the HTML file
78
+
79
+
80
+ @app.route("/")
81
+ def serve_html():
82
+ return send_from_directory(".", "index.html")
83
+
84
+ # Chat route
85
+
86
+
87
+ @app.route("/chat", methods=["POST"])
88
+ def chat():
89
+ user_input = request.json.get("message")
90
+ if not user_input:
91
+ return jsonify({"error": "No message provided"}), 400
92
+
93
+ # Get the chat history from the session (if any)
94
+ chat_history_ids = request.json.get("chat_history_ids")
95
+ if chat_history_ids:
96
+ chat_history_ids = torch.tensor(
97
+ chat_history_ids) # Convert back to a tensor
98
+
99
+ # Get the bot's response
100
+ bot_response, chat_history_ids = chat_with_bot(
101
+ user_input, chat_history_ids)
102
+
103
+ # Convert chat_history_ids to a list for JSON serialization
104
+ chat_history_ids_list = chat_history_ids.tolist(
105
+ ) if chat_history_ids is not None else None
106
+
107
+ return jsonify({
108
+ "response": bot_response,
109
+ "chat_history_ids": chat_history_ids_list
110
+ })
111
+
112
+
113
+ if __name__ == "__main__":
114
+ app.run(debug=True)