File size: 3,628 Bytes
958557f
 
 
 
 
 
 
9814153
3a3ea8c
919832d
9814153
 
919832d
 
e53bb43
6c67b51
 
 
919832d
9814153
919832d
3a3ea8c
 
 
919832d
b2631fe
 
 
3a3ea8c
 
b2631fe
 
 
3a3ea8c
b2631fe
3a3ea8c
 
 
 
b2631fe
3a3ea8c
 
9814153
919832d
958557f
 
 
 
 
e53bb43
958557f
 
 
 
 
 
 
 
 
 
 
 
46bd600
 
 
 
 
958557f
 
2338a53
 
e53bb43
958557f
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from huggingface_hub import InferenceClient
import gradio as gr

client = InferenceClient(
    "mistralai/Mixtral-8x7B-Instruct-v0.1"
)

# Formats the prompt to hold all of the past messages
"""def format_prompt(message, history):
    prompt = "<s>"

    # Iterates through every pass user input and response to be added to the prompt
    for user_prompt, bot_response in history:
        prompt += f"[INST] {user_prompt} [/INST]"
        prompt += f" {bot_response}</s> "

    prompt += "[INST] Instruction [/INST] The instruction is: 'Instruction'. The answer is: Model answer</s> "
    #prompt += f"[INST] Please correct the grammar in the following sentence: {message} [/INST]"
    prompt += f"[INST] {message} [/INST]"
    
    return prompt
"""
def format_prompt(message, history):
    prompt = "<s>"

    # String to add before every prompt
    prompt_prefix = "Please correct the grammar in the following sentence: "
    
    # Iterates through every past user input and response to be added to the prompt
    for user_prompt, bot_response in history:
        corrected_prompt = prompt_prefix + user_prompt
        
        prompt += f"[INST] {corrected_prompt} [/INST]"
        prompt += f" {bot_response}</s> "
        #print(f"HISTORIC PROMPT: \n\t[INST] {corrected_prompt} [/INST] {bot_response}</s> ")

    # Also prepend the prefix to the current message
    corrected_message = grammar_correction_prefix + message
    prompt += f"[INST] {corrected_message} [/INST]"
    print("\nPROMPT: \n\t" + prompt)

    return prompt

def generate(prompt, history, system_prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0):
    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)

    generate_kwargs = dict(temperature=temperature, max_new_tokens=max_new_tokens, top_p=top_p, repetition_penalty=repetition_penalty, do_sample=True, seed=42,)

    formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
    output = ""

    for response in stream:
        output += response.token.text
        yield output
    return output


additional_inputs=[
    gr.Textbox( label="System Prompt", max_lines=1, interactive=True, ),
    gr.Slider( label="Temperature", value=0.9, minimum=0.0, maximum=1.0, step=0.05, interactive=True, info="Higher values produce more diverse outputs", ),
    gr.Slider( label="Max new tokens", value=256, minimum=0, maximum=1048, step=64, interactive=True, info="The maximum numbers of new tokens", ),
    gr.Slider( label="Top-p (nucleus sampling)", value=0.90, minimum=0.0, maximum=1, step=0.05, interactive=True, info="Higher values sample more low-probability tokens", ),
    gr.Slider( label="Repetition penalty", value=1.2, minimum=1.0, maximum=2.0, step=0.05, interactive=True, info="Penalize repeated tokens", )
]

examples=[['Give me the grammatically correct version of the sentence: "We shood buy an car"', None, None, None, None, None, ],
          ["Give me an example exam question testing students on square roots on basic integers", None, None, None, None, None,],
          ["Would this block of HTML code run?\n```\n\n```", None, None, None, None, None,], ]

gr.ChatInterface(
    fn=generate,
    chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"),
    additional_inputs=additional_inputs,
    title="Mixtral 46.7B",
    examples=examples,
    concurrency_limit=20,
).launch(show_api=False)