vericudebuget commited on
Commit
b0a938e
·
verified ·
1 Parent(s): b765dcc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -13
app.py CHANGED
@@ -8,16 +8,14 @@ client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
8
  def format_prompt(message, history):
9
  prompt = "<s>"
10
  for user_prompt, bot_response in history:
11
- prompt += f"[INST] {user_prompt} [/INST]"
12
  prompt += f" {bot_response}</s> "
13
- prompt += f"[INST] {message} [/INST]"
14
  return prompt
15
 
16
  def generate(prompt, history, system_prompt, temperature=0.9, max_new_tokens=9048, top_p=0.95, repetition_penalty=1.0):
17
-
18
  temperature = max(float(temperature), 1e-2)
19
  top_p = float(top_p)
20
-
21
  generate_kwargs = dict(
22
  temperature=temperature,
23
  max_new_tokens=max_new_tokens,
@@ -35,10 +33,9 @@ def generate(prompt, history, system_prompt, temperature=0.9, max_new_tokens=904
35
  formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
36
  stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
37
  output = ""
38
-
39
  for response in stream:
40
  output += response.token.text
41
- yield output
42
  return output
43
 
44
  additional_inputs = [
@@ -49,10 +46,11 @@ additional_inputs = [
49
  gr.Slider(label="Repetition penalty", value=1.2, minimum=1.0, maximum=2.0, step=0.05, interactive=True, info="Penalize repeated tokens")
50
  ]
51
 
52
- gr.ChatInterface(
53
- fn=generate,
54
- chatbot=gr.Chatbot(show_label=True, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"),
55
- additional_inputs=additional_inputs,
56
- title="ConvoLite",
57
- concurrency_limit=20,
58
- ).launch(show_api=False,)
 
 
8
  def format_prompt(message, history):
9
  prompt = "<s>"
10
  for user_prompt, bot_response in history:
11
+ prompt += f"\[INST\] {user_prompt} \[/INST\]"
12
  prompt += f" {bot_response}</s> "
13
+ prompt += f"\[INST\] {message} \[/INST\]"
14
  return prompt
15
 
16
  def generate(prompt, history, system_prompt, temperature=0.9, max_new_tokens=9048, top_p=0.95, repetition_penalty=1.0):
 
17
  temperature = max(float(temperature), 1e-2)
18
  top_p = float(top_p)
 
19
  generate_kwargs = dict(
20
  temperature=temperature,
21
  max_new_tokens=max_new_tokens,
 
33
  formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
34
  stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
35
  output = ""
 
36
  for response in stream:
37
  output += response.token.text
38
+ yield output
39
  return output
40
 
41
  additional_inputs = [
 
46
  gr.Slider(label="Repetition penalty", value=1.2, minimum=1.0, maximum=2.0, step=0.05, interactive=True, info="Penalize repeated tokens")
47
  ]
48
 
49
+ with gr.themes.Soft():
50
+ gr.ChatInterface(
51
+ fn=generate,
52
+ chatbot=gr.Chatbot(show_label=True, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"),
53
+ additional_inputs=additional_inputs,
54
+ title="ConvoLite",
55
+ concurrency_limit=20,
56
+ ).launch(show_api=False)