ColeGuion commited on
Commit
9814153
·
verified ·
1 Parent(s): 46bd600

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -5,17 +5,19 @@ client = InferenceClient(
5
  "mistralai/Mixtral-8x7B-Instruct-v0.1"
6
  )
7
 
8
-
9
  def format_prompt(message, history):
10
  prompt = "<s>"
 
 
11
  for user_prompt, bot_response in history:
12
- print("USER PROMPT: '{}' \nBOT RESPONSE: '{}'".format(user_prompt, bot_response))
13
  prompt += f"[INST] {user_prompt} [/INST]"
14
  prompt += f" {bot_response}</s> "
15
- print("MSG: '{}'".format(message))
16
  prompt += f"[INST] {message} [/INST]"
 
17
  return prompt
18
 
 
19
  def generate(prompt, history, system_prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0):
20
  temperature = float(temperature)
21
  if temperature < 1e-2:
 
5
  "mistralai/Mixtral-8x7B-Instruct-v0.1"
6
  )
7
 
8
+ # Formats the prompt to hold all of the past messages
9
  def format_prompt(message, history):
10
  prompt = "<s>"
11
+
12
+ # Iterates through every pass user input and response to be added to the prompt
13
  for user_prompt, bot_response in history:
 
14
  prompt += f"[INST] {user_prompt} [/INST]"
15
  prompt += f" {bot_response}</s> "
 
16
  prompt += f"[INST] {message} [/INST]"
17
+
18
  return prompt
19
 
20
+
21
  def generate(prompt, history, system_prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0):
22
  temperature = float(temperature)
23
  if temperature < 1e-2: