Update app.py
Browse files
app.py
CHANGED
@@ -1,25 +1,42 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
2 |
|
3 |
-
|
|
|
4 |
|
5 |
-
|
6 |
-
|
7 |
-
max_tokens = 200
|
8 |
-
if response_type == "Short":
|
9 |
-
max_tokens = 50
|
10 |
-
elif response_type == "Medium":
|
11 |
-
max_tokens = 100
|
12 |
-
return response[:max_tokens]
|
13 |
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
gr.Textbox(label="Prompt"),
|
18 |
-
gr.Radio(["Short", "Medium", "Long"], label="Response Type")
|
19 |
-
],
|
20 |
-
outputs=gr.Textbox(label="Response"),
|
21 |
-
title="Mixtral-8x7B-Instruct-v0.1 Chat",
|
22 |
-
description="Chat with the Mixtral-8x7B-Instruct-v0.1 model."
|
23 |
-
)
|
24 |
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
import time
|
3 |
+
import re
|
4 |
+
import os
|
5 |
|
6 |
+
# Available models
|
7 |
+
MODEL = "models/mistralai/Mixtral-8x7B-Instruct-v0.1"
|
8 |
|
9 |
+
# Sambanova API base URL
|
10 |
+
API_BASE = "https://api.sambanova.ai/v1"
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
+
def create_client():
|
13 |
+
"""Creates an client instance."""
|
14 |
+
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
+
def chat_with_ai(message, chat_history, system_prompt):
|
17 |
+
"""Formats the chat history for the API call."""
|
18 |
+
messages = [{"role": "system", "content": system_prompt}]
|
19 |
+
for tup in chat_history:
|
20 |
+
first_key = list(tup.keys())[0] # First key
|
21 |
+
last_key = list(tup.keys())[-1] # Last key
|
22 |
+
messages.append({"role": "user", "content": tup[first_key]})
|
23 |
+
messages.append({"role": "assistant", "content": tup[last_key]})
|
24 |
+
messages.append({"role": "user", "content": message})
|
25 |
+
return messages
|
26 |
+
|
27 |
+
def respond(message, chat_history, system_prompt, thinking_budget):
|
28 |
+
"""Sends the message to the API and gets the response."""
|
29 |
+
messages = chat_with_ai(message, chat_history, system_prompt.format(budget=thinking_budget))
|
30 |
+
start_time = time.time()
|
31 |
+
|
32 |
+
try:
|
33 |
+
response =
|
34 |
+
thinking_time = time.time() - start_time
|
35 |
+
return response, thinking_time
|
36 |
+
except Exception as e:
|
37 |
+
error_message = f"Error: {str(e)}"
|
38 |
+
return error_message, time.time() - start_time
|
39 |
+
|
40 |
+
def parse_response(response):
|
41 |
+
"""Parses the response from the API."""
|
42 |
+
answer_match = re.search(r'<answer>(.*?)
|