File size: 3,604 Bytes
958557f
 
 
 
 
 
 
9814153
3a3ea8c
 
919832d
b2631fe
 
871ce9c
b2631fe
3a3ea8c
fb05cd8
3a3ea8c
b2631fe
 
871ce9c
 
3a3ea8c
b2631fe
3a3ea8c
 
871ce9c
 
 
b2631fe
3a3ea8c
 
9814153
919832d
5a86243
958557f
 
 
 
 
e53bb43
958557f
 
 
 
 
 
 
 
 
 
 
 
80d968d
46bd600
 
 
 
958557f
 
2338a53
 
e9c195f
 
bbc202f
 
958557f
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from huggingface_hub import InferenceClient
import gradio as gr

client = InferenceClient(
    "mistralai/Mixtral-8x7B-Instruct-v0.1"
)

# Formats the prompt to hold all of the past messages
def format_prompt(message, history):
    prompt = "<s>"

    # String to add before every prompt
    prompt_prefix = "Please correct the grammar in the following sentence: "
    prompt_template = "[INST] " + prompt_prefix + "{} [/INST]"
    
    # Iterates through every past user input and response to be added to the prompt
    print("History Type: {}".format(type(history)))
    for user_prompt, bot_response in history:
        corrected_prompt = prompt_prefix + user_prompt
        
        #prompt += f"[INST] {corrected_prompt} [/INST]"
        prompt += prompt_template.format(user_prompt)
        prompt += f" {bot_response}</s> "
        #print(f"HISTORIC PROMPT: \n\t[INST] {corrected_prompt} [/INST] {bot_response}</s> ")

    # Also prepend the prefix to the current message
    #corrected_message = prompt_prefix + message
    #prompt += f"[INST] {corrected_message} [/INST]"
    prompt += prompt_template.format(message)
    print("\nPROMPT: \n\t" + prompt)

    return prompt

def generate(prompt, history, system_prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0):
    print("System Prompt: '{}'".format(system_prompt))
    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)

    generate_kwargs = dict(temperature=temperature, max_new_tokens=max_new_tokens, top_p=top_p, repetition_penalty=repetition_penalty, do_sample=True, seed=42,)

    formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
    output = ""

    for response in stream:
        output += response.token.text
        yield output
    return output


additional_inputs=[
    gr.Textbox( label="System Prompt", default="You are an expert at English grammar" , max_lines=1, interactive=True, ),
    gr.Slider( label="Temperature", value=0.9, minimum=0.0, maximum=1.0, step=0.05, interactive=True, info="Higher values produce more diverse outputs", ),
    gr.Slider( label="Max new tokens", value=256, minimum=0, maximum=1048, step=64, interactive=True, info="The maximum numbers of new tokens", ),
    gr.Slider( label="Top-p (nucleus sampling)", value=0.90, minimum=0.0, maximum=1, step=0.05, interactive=True, info="Higher values sample more low-probability tokens", ),
    gr.Slider( label="Repetition penalty", value=1.2, minimum=1.0, maximum=2.0, step=0.05, interactive=True, info="Penalize repeated tokens", )
]

examples=[['Give me the grammatically correct version of the sentence: "We shood buy an car"', None, None, None, None, None, ],
          ["Give me an example exam question testing students on square roots on basic integers", None, None, None, None, None,],
          ["Would this block of HTML code run?\n```\n\n```", None, None, None, None, None,], 
          ["I have been to New York last summer.", None, None, None, None, None,],
          ["We shood buy an car.", None, None, None, None, None,],
          ["People is coming to my party.", None, None, None, None, None,],]

gr.ChatInterface(
    fn=generate,
    chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"),
    additional_inputs=additional_inputs,
    title="Mixtral 46.7B",
    examples=examples,
    concurrency_limit=20,
).launch(show_api=False)