File size: 3,846 Bytes
738953f
 
533b084
 
738953f
55b11ff
738953f
f78f29d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55b11ff
 
 
f16951b
738953f
f16951b
55b11ff
f16951b
 
 
55b11ff
738953f
e15a09f
738953f
 
 
 
 
 
 
 
 
 
 
 
 
 
1091ed2
55b11ff
738953f
 
 
 
 
 
 
 
533b084
 
 
 
a114927
533b084
 
 
 
 
a114927
 
 
 
 
 
 
 
55b11ff
a114927
 
 
 
 
 
 
 
55b11ff
a114927
 
 
 
 
 
 
 
55b11ff
a114927
 
 
 
 
 
 
 
55b11ff
a114927
 
 
 
 
 
 
 
f16951b
e6c11b4
edcd873
533b084
f16951b
 
 
 
 
 
edcd873
f16951b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from huggingface_hub import InferenceClient
import gradio as gr
import streamlit as st
import requests

client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")

def search_url(search_query):
    API_KEY = st.secrets['AIzaSyDseKKQCAUBmPidu_QapnpJCGLueDWYJbE']
    SEARCH_ENGINE_ID = st.secrets['001ae9bf840514e61']
    
    url = 'https://customsearch.googleapis.com/customsearch/v1'

    params = {
        'q': search_query,
        'key': API_KEY,
        'cx': SEARCH_ENGINE_ID,
    }

    response = requests.get(url, params=params)

    results = response.json()

    # print(results)

    if 'items' in results:
        for i in range(min(5, len(results['items']))):
            print(f"Link {i + 1}: {results['items'][i]['link']}")
        return results['items'][:5]
    else:
        print("No search results found.")
        return None
        
def tokenize(text):
    return text
    # return tok.encode(text, add_special_tokens=False)
    
def format_prompt(message, history):
    prompt = ""
    for user_prompt, bot_response in history:
        prompt += "<s>" + tokenize("[INST]") + tokenize(user_prompt) + tokenize("[/INST]")
        prompt += tokenize(bot_response) + "</s> "
    prompt += tokenize("[INST]") + tokenize(message) + tokenize("[/INST]")
    return prompt

def generate(prompt, history, system_prompt, temperature=0.2, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0):
    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)

    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=42,
    )

    formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)

    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
    output = ""

    for response in stream:
        output += response.token.text
        yield output
    return output

def generateS(prompt, history, system_prompt, temperature=0.2, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0):
    
    stream = search_url(prompt)
    # output = ""

    # for response in stream:
    #     output += response.token.text
    #     yield output
    return stream
    
additional_inputs=[
    gr.Textbox(
        label="System Prompt",
        max_lines=1,
        interactive=True,
    ),
    gr.Slider(
        label="Temperature",
        value=0.2,
        minimum=0.0,
        maximum=1.0,
        step=0.05,
        interactive=True,
        info="Higher values produce more diverse outputs",
    ),
    gr.Slider(
        label="Max new tokens",
        value=512,
        minimum=0,
        maximum=1048,
        step=64,
        interactive=True,
        info="The maximum numbers of new tokens",
    ),
    gr.Slider(
        label="Top-p (nucleus sampling)",
        value=0.95,
        minimum=0.0,
        maximum=1,
        step=0.05,
        interactive=True,
        info="Higher values sample more low-probability tokens",
    ),
    gr.Slider(
        label="Repetition penalty",
        value=1,
        minimum=1.0,
        maximum=2.0,
        step=0.05,
        interactive=True,
        info="Penalize repeated tokens",
    )
]

mychatbot = gr.Chatbot(
    avatar_images=["./user.png", "./botm.png"], bubble_full_width=False, show_label=False, show_copy_button=True, likeable=False)

demo = gr.ChatInterface(fn=generateS, 
                        chatbot=mychatbot,
                        additional_inputs=additional_inputs,
                        title="Kamran's Mixtral 8x7b Chat",
                        retry_btn=None,
                        undo_btn=None
                       )

demo.queue().launch(show_api=False)