ambrosfitz commited on
Commit
9c3b53a
·
verified ·
1 Parent(s): 4958c2a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +225 -0
app.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import requests
4
+ import json
5
+ import logging
6
+ from dotenv import load_dotenv
7
+
8
+ # Load environment variables
9
+ load_dotenv()
10
+
11
+ # API Keys configuration
12
+ COHERE_API_KEY = os.getenv("COHERE_API_KEY")
13
+ MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY")
14
+ GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
15
+
16
+ if not all([COHERE_API_KEY, MISTRAL_API_KEY, GEMINI_API_KEY]):
17
+ raise ValueError("Missing required API keys in environment variables")
18
+
19
+ # API endpoints configuration
20
+ COHERE_API_URL = "https://api.cohere.ai/v1/chat"
21
+ MISTRAL_API_URL = "https://api.mistral.ai/v1/chat/completions"
22
+ GEMINI_API_URL = "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-pro-002:generateContent"
23
+ VECTOR_API_URL = "https://sendthat.cc"
24
+ HISTORY_INDEX = "onramps"
25
+
26
+ # Model configurations
27
+ MODELS = {
28
+ "Cohere": {
29
+ "name": "command-r-08-2024",
30
+ "api_url": COHERE_API_URL,
31
+ "api_key": COHERE_API_KEY
32
+ },
33
+ "Mistral": {
34
+ "name": "ft:open-mistral-nemo:ef730d29:20241022:2a0e7d46",
35
+ "api_url": MISTRAL_API_URL,
36
+ "api_key": MISTRAL_API_KEY
37
+ },
38
+ "Gemini": {
39
+ "name": "gemini-1.5-pro-002",
40
+ "api_url": GEMINI_API_URL,
41
+ "api_key": GEMINI_API_KEY
42
+ }
43
+ }
44
+
45
+ def search_document(query, k):
46
+ try:
47
+ url = f"{VECTOR_API_URL}/search/{HISTORY_INDEX}"
48
+ payload = {"text": query, "k": k}
49
+ headers = {"Content-Type": "application/json"}
50
+ response = requests.post(url, json=payload, headers=headers)
51
+ response.raise_for_status()
52
+ return response.json(), "", k
53
+ except requests.exceptions.RequestException as e:
54
+ logging.error(f"Error in search: {e}")
55
+ return {"error": str(e)}, query, k
56
+
57
+ def generate_answer_cohere(question, context, citations):
58
+ headers = {
59
+ "Authorization": f"Bearer {MODELS['Cohere']['api_key']}",
60
+ "Content-Type": "application/json"
61
+ }
62
+
63
+ prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer the question based on the given context. Include citations as [1], [2], etc.:"
64
+
65
+ payload = {
66
+ "message": prompt,
67
+ "model": MODELS['Cohere']['name'],
68
+ "preamble": "You are an AI-assistant chatbot. Provide thorough responses with citations.",
69
+ "chat_history": []
70
+ }
71
+
72
+ try:
73
+ response = requests.post(MODELS['Cohere']['api_url'], headers=headers, json=payload)
74
+ response.raise_for_status()
75
+ answer = response.json()['text']
76
+
77
+ answer += "\n\nSources:"
78
+ for i, citation in enumerate(citations, 1):
79
+ answer += f"\n[{i}] {citation}"
80
+
81
+ return answer
82
+ except requests.exceptions.RequestException as e:
83
+ logging.error(f"Error in generate_answer_cohere: {e}")
84
+ return f"An error occurred: {str(e)}"
85
+
86
+ def generate_answer_mistral(question, context, citations):
87
+ headers = {
88
+ "Authorization": f"Bearer {MODELS['Mistral']['api_key']}",
89
+ "Content-Type": "application/json",
90
+ "Accept": "application/json"
91
+ }
92
+
93
+ prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer the question based on the given context. Include citations as [1], [2], etc.:"
94
+
95
+ payload = {
96
+ "model": MODELS['Mistral']['name'],
97
+ "messages": [
98
+ {
99
+ "role": "user",
100
+ "content": prompt
101
+ }
102
+ ]
103
+ }
104
+
105
+ try:
106
+ response = requests.post(MODELS['Mistral']['api_url'], headers=headers, json=payload)
107
+ response.raise_for_status()
108
+ answer = response.json()['choices'][0]['message']['content']
109
+
110
+ answer += "\n\nSources:"
111
+ for i, citation in enumerate(citations, 1):
112
+ answer += f"\n[{i}] {citation}"
113
+
114
+ return answer
115
+ except requests.exceptions.RequestException as e:
116
+ logging.error(f"Error in generate_answer_mistral: {e}")
117
+ return f"An error occurred: {str(e)}"
118
+
119
+ def generate_answer_gemini(question, context, citations):
120
+ headers = {
121
+ "Content-Type": "application/json"
122
+ }
123
+
124
+ prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer the question based on the given context. Include citations as [1], [2], etc.:"
125
+
126
+ payload = {
127
+ "contents": [
128
+ {
129
+ "role": "user",
130
+ "parts": [
131
+ {
132
+ "text": prompt
133
+ }
134
+ ]
135
+ }
136
+ ],
137
+ "generationConfig": {
138
+ "temperature": 1,
139
+ "topK": 40,
140
+ "topP": 0.95,
141
+ "maxOutputTokens": 8192,
142
+ "responseMimeType": "text/plain"
143
+ }
144
+ }
145
+
146
+ try:
147
+ url = f"{MODELS['Gemini']['api_url']}?key={MODELS['Gemini']['api_key']}"
148
+ response = requests.post(url, headers=headers, json=payload)
149
+ response.raise_for_status()
150
+ answer = response.json()['candidates'][0]['content']['parts'][0]['text']
151
+
152
+ answer += "\n\nSources:"
153
+ for i, citation in enumerate(citations, 1):
154
+ answer += f"\n[{i}] {citation}"
155
+
156
+ return answer
157
+ except requests.exceptions.RequestException as e:
158
+ logging.error(f"Error in generate_answer_gemini: {e}")
159
+ return f"An error occurred: {str(e)}"
160
+
161
+ def answer_question(question, model_choice, k=3):
162
+ # Search the vector database
163
+ search_results, _, _ = search_document(question, k)
164
+
165
+ # Extract and combine the retrieved contexts
166
+ if "results" in search_results:
167
+ contexts = []
168
+ citations = []
169
+ for item in search_results['results']:
170
+ contexts.append(item['metadata']['content'])
171
+ citations.append(f"{item['metadata'].get('title', 'Unknown Source')} - {item['metadata'].get('source', 'No source provided')}")
172
+ combined_context = " ".join(contexts)
173
+ else:
174
+ logging.error(f"Error in database search or no results found: {search_results}")
175
+ combined_context = ""
176
+ citations = []
177
+
178
+ # Generate answer using the selected model
179
+ if model_choice == "Cohere":
180
+ return generate_answer_cohere(question, combined_context, citations)
181
+ elif model_choice == "Mistral":
182
+ return generate_answer_mistral(question, combined_context, citations)
183
+ else:
184
+ return generate_answer_gemini(question, combined_context, citations)
185
+
186
+ def chatbot(message, history, model_choice):
187
+ response = answer_question(message, model_choice)
188
+ return response
189
+
190
+ # Example questions with default model choice
191
+ EXAMPLE_QUESTIONS = [
192
+ ["Why was Anne Hutchinson banished from Massachusetts?", "Cohere"],
193
+ ["What were the major causes of World War I?", "Mistral"],
194
+ ["Who was the first President of the United States?", "Gemini"],
195
+ ["What was the significance of the Industrial Revolution?", "Cohere"]
196
+ ]
197
+
198
+ # Create Gradio interface
199
+ with gr.Blocks(theme="soft") as iface:
200
+ gr.Markdown("# History Chatbot")
201
+ gr.Markdown("Ask me anything about history, and I'll provide answers with citations!")
202
+
203
+ with gr.Row():
204
+ model_choice = gr.Radio(
205
+ choices=["Cohere", "Mistral", "Gemini"],
206
+ value="Cohere",
207
+ label="Choose LLM Model",
208
+ info="Select which AI model to use for generating responses"
209
+ )
210
+
211
+ chatbot_interface = gr.ChatInterface(
212
+ fn=lambda message, history, model: chatbot(message, history, model),
213
+ additional_inputs=[model_choice],
214
+ chatbot=gr.Chatbot(height=300),
215
+ textbox=gr.Textbox(placeholder="Ask a question about history...", container=False, scale=7),
216
+ examples=EXAMPLE_QUESTIONS,
217
+ cache_examples=False,
218
+ retry_btn=None,
219
+ undo_btn="Delete Previous",
220
+ clear_btn="Clear",
221
+ )
222
+
223
+ # Launch the app
224
+ if __name__ == "__main__":
225
+ iface.launch()