File size: 2,820 Bytes
738953f
 
 
55b11ff
738953f
55b11ff
 
 
f16951b
738953f
f16951b
55b11ff
f16951b
 
 
55b11ff
738953f
e15a09f
738953f
 
 
 
 
 
 
 
 
 
 
 
 
 
1091ed2
55b11ff
738953f
 
 
 
 
 
 
 
a114927
 
 
 
 
 
 
 
 
55b11ff
a114927
 
 
 
 
 
 
 
55b11ff
a114927
 
 
 
 
 
 
 
55b11ff
a114927
 
 
 
 
 
 
 
55b11ff
a114927
 
 
 
 
 
 
 
f16951b
 
edcd873
f16951b
 
 
 
 
 
 
edcd873
f16951b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from huggingface_hub import InferenceClient
import gradio as gr

client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")

def tokenize(text):
    return text
    # return tok.encode(text, add_special_tokens=False)
    
def format_prompt(message, history):
    prompt = ""
    for user_prompt, bot_response in history:
        prompt += "<s>" + tokenize("[INST]") + tokenize(user_prompt) + tokenize("[/INST]")
        prompt += tokenize(bot_response) + "</s> "
    prompt += tokenize("[INST]") + tokenize(message) + tokenize("[/INST]")
    return prompt

def generate(prompt, history, system_prompt, temperature=0.2, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0):
    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)

    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=42,
    )

    formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)

    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
    output = ""

    for response in stream:
        output += response.token.text
        yield output
    return output


additional_inputs=[
    gr.Textbox(
        label="System Prompt",
        max_lines=1,
        interactive=True,
    ),
    gr.Slider(
        label="Temperature",
        value=0.2,
        minimum=0.0,
        maximum=1.0,
        step=0.05,
        interactive=True,
        info="Higher values produce more diverse outputs",
    ),
    gr.Slider(
        label="Max new tokens",
        value=512,
        minimum=0,
        maximum=1048,
        step=64,
        interactive=True,
        info="The maximum numbers of new tokens",
    ),
    gr.Slider(
        label="Top-p (nucleus sampling)",
        value=0.95,
        minimum=0.0,
        maximum=1,
        step=0.05,
        interactive=True,
        info="Higher values sample more low-probability tokens",
    ),
    gr.Slider(
        label="Repetition penalty",
        value=1,
        minimum=1.0,
        maximum=2.0,
        step=0.05,
        interactive=True,
        info="Penalize repeated tokens",
    )
]

mychatbot = gr.Chatbot(
    avatar_images=["./user.png", "./botm.png"], bubble_full_width=False, show_label=False, show_copy_button=True, likeable=True,)

demo = gr.ChatInterface(fn=generate, 
                        chatbot=mychatbot,
                        additional_inputs=additional_inputs,
                        title="Kamran's Mixtral 8x7b Chat",
                        retry_btn=None,
                        undo_btn=None
                       )

demo.queue().launch(show_api=False)