Spaces:
Running
Running
File size: 3,628 Bytes
958557f 9814153 3a3ea8c 919832d 9814153 919832d e53bb43 6c67b51 919832d 9814153 919832d 3a3ea8c 919832d b2631fe 3a3ea8c b2631fe 3a3ea8c b2631fe 3a3ea8c b2631fe 3a3ea8c 9814153 919832d 958557f e53bb43 958557f 46bd600 958557f 2338a53 e53bb43 958557f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
from huggingface_hub import InferenceClient
import gradio as gr
client = InferenceClient(
"mistralai/Mixtral-8x7B-Instruct-v0.1"
)
# Formats the prompt to hold all of the past messages
"""def format_prompt(message, history):
prompt = "<s>"
# Iterates through every pass user input and response to be added to the prompt
for user_prompt, bot_response in history:
prompt += f"[INST] {user_prompt} [/INST]"
prompt += f" {bot_response}</s> "
prompt += "[INST] Instruction [/INST] The instruction is: 'Instruction'. The answer is: Model answer</s> "
#prompt += f"[INST] Please correct the grammar in the following sentence: {message} [/INST]"
prompt += f"[INST] {message} [/INST]"
return prompt
"""
def format_prompt(message, history):
prompt = "<s>"
# String to add before every prompt
prompt_prefix = "Please correct the grammar in the following sentence: "
# Iterates through every past user input and response to be added to the prompt
for user_prompt, bot_response in history:
corrected_prompt = prompt_prefix + user_prompt
prompt += f"[INST] {corrected_prompt} [/INST]"
prompt += f" {bot_response}</s> "
#print(f"HISTORIC PROMPT: \n\t[INST] {corrected_prompt} [/INST] {bot_response}</s> ")
# Also prepend the prefix to the current message
corrected_message = grammar_correction_prefix + message
prompt += f"[INST] {corrected_message} [/INST]"
print("\nPROMPT: \n\t" + prompt)
return prompt
def generate(prompt, history, system_prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0):
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
generate_kwargs = dict(temperature=temperature, max_new_tokens=max_new_tokens, top_p=top_p, repetition_penalty=repetition_penalty, do_sample=True, seed=42,)
formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
output = ""
for response in stream:
output += response.token.text
yield output
return output
additional_inputs=[
gr.Textbox( label="System Prompt", max_lines=1, interactive=True, ),
gr.Slider( label="Temperature", value=0.9, minimum=0.0, maximum=1.0, step=0.05, interactive=True, info="Higher values produce more diverse outputs", ),
gr.Slider( label="Max new tokens", value=256, minimum=0, maximum=1048, step=64, interactive=True, info="The maximum numbers of new tokens", ),
gr.Slider( label="Top-p (nucleus sampling)", value=0.90, minimum=0.0, maximum=1, step=0.05, interactive=True, info="Higher values sample more low-probability tokens", ),
gr.Slider( label="Repetition penalty", value=1.2, minimum=1.0, maximum=2.0, step=0.05, interactive=True, info="Penalize repeated tokens", )
]
examples=[['Give me the grammatically correct version of the sentence: "We shood buy an car"', None, None, None, None, None, ],
["Give me an example exam question testing students on square roots on basic integers", None, None, None, None, None,],
["Would this block of HTML code run?\n```\n\n```", None, None, None, None, None,], ]
gr.ChatInterface(
fn=generate,
chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"),
additional_inputs=additional_inputs,
title="Mixtral 46.7B",
examples=examples,
concurrency_limit=20,
).launch(show_api=False) |