File size: 3,701 Bytes
738953f 533b084 2f6a7e7 738953f 55b11ff 738953f 2f6a7e7 f78f29d 2f6a7e7 f78f29d 2f6a7e7 f78f29d 2f6a7e7 f78f29d 2f6a7e7 f78f29d 55b11ff f16951b 738953f f16951b 55b11ff f16951b 55b11ff 738953f e15a09f 738953f 1091ed2 55b11ff 738953f 31cf2be baf9a7f 738953f 533b084 2f6a7e7 a114927 2f6a7e7 492d4b7 2f6a7e7 533b084 a114927 55b11ff a114927 55b11ff a114927 55b11ff a114927 55b11ff a114927 f16951b e6c11b4 edcd873 31cf2be f16951b edcd873 f16951b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
from huggingface_hub import InferenceClient
import gradio as gr
import requests
import json
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
def google_search(query, **kwargs):
api_key = 'AIzaSyDseKKQCAUBmPidu_QapnpJCGLueDWYJbE'
cse_id = '001ae9bf840514e61'
service_url = 'https://www.googleapis.com/customsearch/v1'
params = {
'key': api_key,
'cx': cse_id,
'q': query,
**kwargs
}
response = requests.get(service_url, params=params)
if response.status_code == 200:
return json.loads(response.text)['items']
else:
print(f'Error: {response.status_code}')
return []
def tokenize(text):
return text
# return tok.encode(text, add_special_tokens=False)
def format_prompt(message, history):
prompt = ""
for user_prompt, bot_response in history:
prompt += "<s>" + tokenize("[INST]") + tokenize(user_prompt) + tokenize("[/INST]")
prompt += tokenize(bot_response) + "</s> "
prompt += tokenize("[INST]") + tokenize(message) + tokenize("[/INST]")
return prompt
def generate(prompt, history, system_prompt, temperature=0.2, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0):
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
seed=42,
)
formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
output = ""
for response in stream:
print(response.token.text + "/n")
output += response.token.text
yield output
return output
def generateS(prompt, history, system_prompt, temperature=0.2, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0):
stream = google_search(prompt)
output = ""
for response in stream:
output += json.dumps(response)
yield output
return output
additional_inputs=[
gr.Textbox(
label="System Prompt",
max_lines=1,
interactive=True,
),
gr.Slider(
label="Temperature",
value=0.2,
minimum=0.0,
maximum=1.0,
step=0.05,
interactive=True,
info="Higher values produce more diverse outputs",
),
gr.Slider(
label="Max new tokens",
value=512,
minimum=0,
maximum=1048,
step=64,
interactive=True,
info="The maximum numbers of new tokens",
),
gr.Slider(
label="Top-p (nucleus sampling)",
value=0.95,
minimum=0.0,
maximum=1,
step=0.05,
interactive=True,
info="Higher values sample more low-probability tokens",
),
gr.Slider(
label="Repetition penalty",
value=1,
minimum=1.0,
maximum=2.0,
step=0.05,
interactive=True,
info="Penalize repeated tokens",
)
]
mychatbot = gr.Chatbot(
avatar_images=["./user.png", "./botm.png"], bubble_full_width=False, show_label=False, show_copy_button=True, likeable=False)
demo = gr.ChatInterface(fn=generate,
chatbot=mychatbot,
additional_inputs=additional_inputs,
title="Kamran's Mixtral 8x7b Chat",
retry_btn=None,
undo_btn=None
)
demo.queue().launch(show_api=False)
|