File size: 5,106 Bytes
958557f
 
 
 
 
 
 
9814153
2915268
3a3ea8c
919832d
b2631fe
f79dbe4
 
b2631fe
f7657a2
 
f79dbe4
 
3a3ea8c
871ce9c
3a3ea8c
f7657a2
3a3ea8c
871ce9c
b2631fe
3a3ea8c
 
9814153
2915268
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bdc235c
 
 
 
 
 
919832d
5335780
958557f
 
 
 
 
e53bb43
111aa32
2915268
 
958557f
 
 
 
 
 
 
 
 
111aa32
958557f
f79dbe4
46bd600
 
 
 
958557f
 
2338a53
 
e9c195f
 
bbc202f
 
958557f
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from huggingface_hub import InferenceClient
import gradio as gr

client = InferenceClient(
    "mistralai/Mixtral-8x7B-Instruct-v0.1"
)

# Formats the prompt to hold all of the past messages
def format_prompt1(message, history):
    prompt = "<s>"

    # String to add before every prompt
    prompt_prefix = "Please correct the grammar in the following sentence:"
    prompt_template = "[INST] " + prompt_prefix + " {} [/INST]"
    
    #history.append("It is my friends house in England.", "It is my friend's house in England.")
    #history.append("Every girl must bring their books to school.", "Every girl must bring her books to school.")

    # Iterates through every past user input and response to be added to the prompt
    for user_prompt, bot_response in history:
        prompt += prompt_template.format(user_prompt)
        prompt += f" {bot_response}</s> "
        

    prompt += prompt_template.format(message)
    print("\nPROMPT: \n\t" + prompt)

    return prompt

def format_prompt(message, history):
    prompt = "<s>"

    # String to add before every prompt
    #prompt_prefix = "Please correct the grammar in the following sentence:"
    #prompt_template = "[INST] " + prompt_prefix + " {} [/INST]"
    prompt_prefix = "Correct any grammatical errors in the following sentence and provide the corrected version:\n\nSentence: "
    prompt_template = "[INST] " + prompt_prefix + ' "{}" [/INST] Corrected Sentence:'
    
    print("History Type: {}".format(type(history)))
    if type(history) != type(list()):
        print("\nOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOO\nOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOO\nOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOO\n")
    #history.append("It is my friends house in England.", "It is my friend's house in England.")
    #history.append("Every girl must bring their books to school.", "Every girl must bring her books to school.")

    # Iterates through every past user input and response to be added to the prompt
    for user_prompt, bot_response in history:
        prompt += prompt_template.format(user_prompt)
        prompt += f" {bot_response}</s> "

    prompt += prompt_template.format(message)
    print("PROMPT: \n\t{}\n".format(prompt))
    return prompt

def format_my_prompt(user_input):
    # Formatting the prompt as per the new template
    prompt = f"<s> [INST] Please correct the grammatical errors in the following sentence: {user_input} [/INST] Model answer</s> [INST] Return only the grammatically corrected sentence. [/INST]"
    return prompt


def generate(prompt, history, system_prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0):
    print("\n\nSystem Prompt: '{}'".format(system_prompt))
    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)

    generate_kwargs = dict(temperature=temperature, max_new_tokens=max_new_tokens, top_p=top_p, repetition_penalty=repetition_penalty, do_sample=True, seed=42,)
    
    formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
    #formatted_prompt = format_my_prompt(prompt)
    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
    output = ""

    for response in stream:
        output += response.token.text
        yield output
    return output



additional_inputs=[
    gr.Textbox( label="System Prompt", value="Correct the following sentence to make it grammatically accurate while maintaining the original meaning. Output only the corrected sentence." , max_lines=1, interactive=True, ),
    gr.Slider( label="Temperature", value=0.9, minimum=0.0, maximum=1.0, step=0.05, interactive=True, info="Higher values produce more diverse outputs", ),
    gr.Slider( label="Max new tokens", value=256, minimum=0, maximum=1048, step=64, interactive=True, info="The maximum numbers of new tokens", ),
    gr.Slider( label="Top-p (nucleus sampling)", value=0.90, minimum=0.0, maximum=1, step=0.05, interactive=True, info="Higher values sample more low-probability tokens", ),
    gr.Slider( label="Repetition penalty", value=1.2, minimum=1.0, maximum=2.0, step=0.05, interactive=True, info="Penalize repeated tokens", )
]

examples=[['Give me the grammatically correct version of the sentence: "We shood buy an car"', None, None, None, None, None, ],
          ["Give me an example exam question testing students on square roots on basic integers", None, None, None, None, None,],
          ["Would this block of HTML code run?\n```\n\n```", None, None, None, None, None,], 
          ["I have been to New York last summer.", None, None, None, None, None,],
          ["We shood buy an car.", None, None, None, None, None,],
          ["People is coming to my party.", None, None, None, None, None,],]

gr.ChatInterface(
    fn=generate,
    chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"),
    additional_inputs=additional_inputs,
    title="Mixtral 46.7B",
    examples=examples,
    concurrency_limit=20,
).launch(show_api=False)