OscarFAI commited on
Commit
9577ec2
·
1 Parent(s): 69d0c7f
Files changed (1) hide show
  1. app.py +31 -16
app.py CHANGED
@@ -9,7 +9,7 @@ HF_TOKEN = os.environ.get("HF_TOKEN", None)
9
 
10
  DESCRIPTION = '''
11
  <div>
12
- <h1 style="text-align: center;">Mistral Chat</h1>
13
  </div>
14
  '''
15
 
@@ -20,7 +20,7 @@ LICENSE = """
20
 
21
  PLACEHOLDER = """
22
  <div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
23
- <h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">DeepSeek-R1-Distill-Llama-8B</h1>
24
  <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Ask me anything...</p>
25
  </div>
26
  """
@@ -49,17 +49,19 @@ terminators = [
49
  ]
50
 
51
  @spaces.GPU(duration=120)
52
- def chat_llama3_8b(message: str,
53
- history: list,
54
- temperature: float,
55
- max_new_tokens: int,
56
- system_prompt: str) -> str:
 
57
  """
58
  Generate a streaming response using the Mistral-8B model.
59
  Args:
60
  message (str): The input message.
61
  history (list): The conversation history used by ChatInterface.
62
  temperature (float): The temperature for generating the response.
 
63
  max_new_tokens (int): The maximum number of new tokens to generate.
64
  system_prompt (str): The system prompt to guide the assistant's behavior.
65
  Returns:
@@ -67,14 +69,24 @@ def chat_llama3_8b(message: str,
67
  """
68
  conversation = []
69
 
70
- # Include system prompt at the beginning if provided
71
  if system_prompt:
72
- conversation.append({"role": "system", "content": system_prompt})
73
-
74
- for user, assistant in history:
75
- conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
76
-
77
- conversation.append({"role": "user", "content": message})
 
 
 
 
 
 
 
 
 
 
78
 
79
  input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
80
 
@@ -86,6 +98,7 @@ def chat_llama3_8b(message: str,
86
  max_new_tokens=max_new_tokens,
87
  do_sample=True,
88
  temperature=temperature,
 
89
  eos_token_id=terminators,
90
  )
91
 
@@ -115,19 +128,21 @@ with gr.Blocks(fill_height=True, css=css) as demo:
115
  )
116
 
117
  gr.ChatInterface(
118
- fn=chat_llama3_8b,
119
  chatbot=chatbot,
120
  fill_height=True,
121
  additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
122
  additional_inputs=[
123
  system_prompt_input,
124
  gr.Slider(minimum=0, maximum=1, step=0.1, value=0.8, label="Temperature", render=False),
 
125
  gr.Slider(minimum=128, maximum=4096, step=1, value=4096, label="Max new tokens", render=False),
126
  ],
127
  examples=[
128
  ['Are you a sentient being?']
129
  ],
130
- cache_examples=False
 
131
  )
132
 
133
  if __name__ == "__main__":
 
9
 
10
  DESCRIPTION = '''
11
  <div>
12
+ <h1 style="text-align: center;">Mistral 8B Instruct</h1>
13
  </div>
14
  '''
15
 
 
20
 
21
  PLACEHOLDER = """
22
  <div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
23
+ <h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">Mistral-8B</h1>
24
  <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Ask me anything...</p>
25
  </div>
26
  """
 
49
  ]
50
 
51
  @spaces.GPU(duration=120)
52
+ def chat_mistral(message: str,
53
+ history: list,
54
+ temperature: float,
55
+ top_p: float,
56
+ max_new_tokens: int,
57
+ system_prompt: str) -> str:
58
  """
59
  Generate a streaming response using the Mistral-8B model.
60
  Args:
61
  message (str): The input message.
62
  history (list): The conversation history used by ChatInterface.
63
  temperature (float): The temperature for generating the response.
64
+ top_p (float): The top-p (nucleus) sampling parameter.
65
  max_new_tokens (int): The maximum number of new tokens to generate.
66
  system_prompt (str): The system prompt to guide the assistant's behavior.
67
  Returns:
 
69
  """
70
  conversation = []
71
 
72
+ # Format system prompt correctly using [INST]
73
  if system_prompt:
74
+ formatted_prompt = f"[INST] {system_prompt} [/INST]\n\n"
75
+ else:
76
+ formatted_prompt = ""
77
+
78
+ # Modify first user message to include system prompt
79
+ if history:
80
+ first_user_msg = f"{formatted_prompt}{history[0][0]}" if formatted_prompt else history[0][0]
81
+ conversation.append({"role": "user", "content": first_user_msg})
82
+ conversation.append({"role": "assistant", "content": history[0][1]})
83
+
84
+ for user, assistant in history[1:]:
85
+ conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
86
+ else:
87
+ # First message in a new conversation
88
+ first_message = f"{formatted_prompt}{message}" if formatted_prompt else message
89
+ conversation.append({"role": "user", "content": first_message})
90
 
91
  input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
92
 
 
98
  max_new_tokens=max_new_tokens,
99
  do_sample=True,
100
  temperature=temperature,
101
+ top_p=top_p,
102
  eos_token_id=terminators,
103
  )
104
 
 
128
  )
129
 
130
  gr.ChatInterface(
131
+ fn=chat_mistral,
132
  chatbot=chatbot,
133
  fill_height=True,
134
  additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
135
  additional_inputs=[
136
  system_prompt_input,
137
  gr.Slider(minimum=0, maximum=1, step=0.1, value=0.8, label="Temperature", render=False),
138
+ gr.Slider(minimum=0, maximum=1, step=0.1, value=0.9, label="Top-p", render=False),
139
  gr.Slider(minimum=128, maximum=4096, step=1, value=4096, label="Max new tokens", render=False),
140
  ],
141
  examples=[
142
  ['Are you a sentient being?']
143
  ],
144
+ cache_examples=False,
145
+ type='messages',
146
  )
147
 
148
  if __name__ == "__main__":