File size: 3,701 Bytes
738953f
 
533b084
2f6a7e7
738953f
55b11ff
738953f
2f6a7e7
 
 
f78f29d
2f6a7e7
f78f29d
2f6a7e7
 
 
 
f78f29d
2f6a7e7
 
 
f78f29d
2f6a7e7
 
f78f29d
55b11ff
 
 
f16951b
738953f
f16951b
55b11ff
f16951b
 
 
55b11ff
738953f
e15a09f
738953f
 
 
 
 
 
 
 
 
 
 
 
 
 
1091ed2
55b11ff
738953f
 
 
 
31cf2be
baf9a7f
738953f
 
 
533b084
 
2f6a7e7
 
a114927
2f6a7e7
492d4b7
2f6a7e7
 
533b084
a114927
 
 
 
 
 
 
 
55b11ff
a114927
 
 
 
 
 
 
 
55b11ff
a114927
 
 
 
 
 
 
 
55b11ff
a114927
 
 
 
 
 
 
 
55b11ff
a114927
 
 
 
 
 
 
 
f16951b
e6c11b4
edcd873
31cf2be
f16951b
 
 
 
 
 
edcd873
f16951b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from huggingface_hub import InferenceClient
import gradio as gr
import requests
import json

client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")

def google_search(query, **kwargs):
    api_key = 'AIzaSyDseKKQCAUBmPidu_QapnpJCGLueDWYJbE'
    cse_id = '001ae9bf840514e61'
    
    service_url = 'https://www.googleapis.com/customsearch/v1'
    params = {
        'key': api_key,
        'cx': cse_id,
        'q': query,
        **kwargs
    }
    response = requests.get(service_url, params=params)
    if response.status_code == 200:
        return json.loads(response.text)['items']
    else:
        print(f'Error: {response.status_code}')
        return []
        
def tokenize(text):
    return text
    # return tok.encode(text, add_special_tokens=False)
    
def format_prompt(message, history):
    prompt = ""
    for user_prompt, bot_response in history:
        prompt += "<s>" + tokenize("[INST]") + tokenize(user_prompt) + tokenize("[/INST]")
        prompt += tokenize(bot_response) + "</s> "
    prompt += tokenize("[INST]") + tokenize(message) + tokenize("[/INST]")
    return prompt

def generate(prompt, history, system_prompt, temperature=0.2, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0):
    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)

    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=42,
    )

    formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)

    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
    output = ""

    for response in stream:
        print(response.token.text + "/n")
        output += response.token.text
        yield output
    return output

def generateS(prompt, history, system_prompt, temperature=0.2, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0):
    
    stream = google_search(prompt)
    output = ""

    for response in stream:
        output += json.dumps(response)
        yield output
    return output
    
additional_inputs=[
    gr.Textbox(
        label="System Prompt",
        max_lines=1,
        interactive=True,
    ),
    gr.Slider(
        label="Temperature",
        value=0.2,
        minimum=0.0,
        maximum=1.0,
        step=0.05,
        interactive=True,
        info="Higher values produce more diverse outputs",
    ),
    gr.Slider(
        label="Max new tokens",
        value=512,
        minimum=0,
        maximum=1048,
        step=64,
        interactive=True,
        info="The maximum numbers of new tokens",
    ),
    gr.Slider(
        label="Top-p (nucleus sampling)",
        value=0.95,
        minimum=0.0,
        maximum=1,
        step=0.05,
        interactive=True,
        info="Higher values sample more low-probability tokens",
    ),
    gr.Slider(
        label="Repetition penalty",
        value=1,
        minimum=1.0,
        maximum=2.0,
        step=0.05,
        interactive=True,
        info="Penalize repeated tokens",
    )
]

mychatbot = gr.Chatbot(
    avatar_images=["./user.png", "./botm.png"], bubble_full_width=False, show_label=False, show_copy_button=True, likeable=False)

demo = gr.ChatInterface(fn=generate, 
                        chatbot=mychatbot,
                        additional_inputs=additional_inputs,
                        title="Kamran's Mixtral 8x7b Chat",
                        retry_btn=None,
                        undo_btn=None
                       )

demo.queue().launch(show_api=False)