File size: 5,272 Bytes
958557f
 
 
 
 
 
 
25b1a79
 
 
 
 
7e14f2f
 
25b1a79
 
 
 
 
 
 
 
 
 
 
 
 
9814153
c0cb522
2915268
 
 
 
25b1a79
2915268
e5b44fa
3edcd37
f1da4c4
e5b44fa
2915268
 
e5b44fa
2915268
3edcd37
2915268
 
 
 
 
c0cb522
 
 
 
 
b102992
c0cb522
 
b102992
 
c0cb522
 
 
 
 
 
 
 
 
716b126
bdc235c
 
919832d
5335780
958557f
 
 
 
 
e53bb43
111aa32
7e14f2f
c0cb522
b102992
 
7e14f2f
958557f
 
 
 
 
 
 
 
111aa32
958557f
3edcd37
46bd600
 
 
 
958557f
 
42f0387
 
 
 
958557f
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from huggingface_hub import InferenceClient
import gradio as gr

client = InferenceClient(
    "mistralai/Mixtral-8x7B-Instruct-v0.1"
)

# Formats the prompt to hold all of the past messages
def format_prompt_basic(message, history):
    prompt = "<s>"

    # String to add before every prompt
    prompt_prefix = ""   
    #prompt_prefix = "Please correct the grammar in the following sentence:"
    prompt_template = "[INST] " + prompt_prefix + " {} [/INST]"
    

    # Iterates through every past user input and response to be added to the prompt
    for user_prompt, bot_response in history:
        prompt += prompt_template.format(user_prompt)
        prompt += f" {bot_response}</s> "
        
    prompt += prompt_template.format(message)
    print("\nPROMPT: \n\t" + prompt)

    return prompt

# Formats the prompt to hold all of the past messages
def format_prompt_grammar1(message, history):
    prompt = "<s>"

    # String to add before every prompt
    prompt_prefix = "Correct any grammatical errors in the following sentence and provide the corrected version:\n\nSentence: "
    prompt_template = "[INST] " + prompt_prefix + ' {} [/INST]'
    
    
    myList = [["It is my friends house in England.", "Corrected Sentence: It is my friend's house in England."], ["Every girl must bring their books to school.", "Corrected Sentence: Every girl must bring her books to school."]] 
    myList = myList + history


    # Iterates through every past user input and response to be added to the prompt
    for user_prompt, bot_response in myList:
        prompt += prompt_template.format(user_prompt)
        prompt += f" {bot_response}</s> \n"

    prompt += prompt_template.format(message)
    print("PROMPT: \n\t{}\n".format(prompt))
    return prompt

# Don't track actual history
def format_prompt_grammar(message):
    prompt = "<s>"

    # String to add before every prompt
    prompt_prefix = "Correct any grammatical errors in the following sentence and provide the corrected version:\n\nSentence:"
    prompt_template = "[INST] " + prompt_prefix + ' {} [/INST]'

    history = [["It is my friends house in England.", "It is my friend's house in England."], 
               ["Every girl must bring their books to school.", "Every girl must bring her books to school."]]# ["I have been to New York last summer.", "I went to New York last summer."]] 

    # Iterates through every past user input and response to be added to the prompt
    for user_prompt, bot_response in history:
        prompt += prompt_template.format(user_prompt)
        prompt += f" {bot_response}</s> \n"

    prompt += prompt_template.format(message)
    print("PROMPT: \n\t{}\n".format(prompt))
    return prompt



def generate(prompt, history, system_prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0):
    print("\n\nSystem Prompt: '{}'".format(system_prompt))
    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)

    generate_kwargs = dict(temperature=temperature, max_new_tokens=max_new_tokens, top_p=top_p, repetition_penalty=repetition_penalty, do_sample=True, seed=42,)
    
    #formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
    #formatted_prompt = format_prompt_grammar(f"{system_prompt} {prompt}", history)
    #formatted_prompt = format_prompt_grammar(prompt)
    formatted_prompt = format_prompt_grammar(f"{system_prompt} {prompt}")
    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
    output = ""

    for response in stream:
        output += response.token.text
        yield output
    return output



additional_inputs=[
    gr.Textbox( label="System Prompt", value="Corrected Sentence:" , max_lines=1, interactive=True, ),
    gr.Slider( label="Temperature", value=0.9, minimum=0.0, maximum=1.0, step=0.05, interactive=True, info="Higher values produce more diverse outputs", ),
    gr.Slider( label="Max new tokens", value=256, minimum=0, maximum=1048, step=64, interactive=True, info="The maximum numbers of new tokens", ),
    gr.Slider( label="Top-p (nucleus sampling)", value=0.90, minimum=0.0, maximum=1, step=0.05, interactive=True, info="Higher values sample more low-probability tokens", ),
    gr.Slider( label="Repetition penalty", value=1.2, minimum=1.0, maximum=2.0, step=0.05, interactive=True, info="Penalize repeated tokens", )
]


examples=['Give me the grammatically correct version of the sentence: "We shood buy an car"', "Give me an example exam question testing students on square roots on basic integers", "Would this block of HTML code run?\n```\n\n```", "I have been to New York last summer.", "We shood buy an car.", "People is coming to my party.", "She is more taller.", "Their were lot of sheeps.", "I want to speak English good.", "I must to buy a new cartoon book."]
examples = [[x, None, None, None, None, None] for x in examples]


gr.ChatInterface(
    fn=generate,
    chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"),
    additional_inputs=additional_inputs,
    title="Mixtral 46.7B",
    examples=examples,
    concurrency_limit=20,
).launch(show_api=False)