Spaces:
Runtime error
Runtime error
File size: 6,422 Bytes
04c25c5 c745c39 04c25c5 dd04363 9b9128d dd04363 58c547c c745c39 dd04363 58c547c f69f8fb dd04363 1d73b44 c745c39 dd04363 c745c39 9cff5d9 21544ca 0651449 dd04363 0651449 73ff9c9 e9b47ff 0651449 73ff9c9 dd04363 0651449 dd04363 e9b47ff 1d73b44 21544ca dd04363 04c25c5 5d5363d dd04363 5d5363d dd04363 5d5363d dd04363 5d5363d 0651449 21544ca 0651449 21544ca dd04363 a68d7a5 dd04363 a68d7a5 dd04363 5d5363d 21544ca 9cff5d9 dd04363 04c25c5 d7942b7 dd04363 5d5363d dd04363 61da65e dd04363 0878eec 04c25c5 dd04363 0d86f5d 04c25c5 61da65e dd04363 04c25c5 21544ca 0d86f5d dd04363 18fb827 dd04363 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import gradio as gr
from gradio_client import Client
from huggingface_hub import InferenceClient
from deep_translator import GoogleTranslator
import random
ss_client = Client("https://omnibus-html-image-current-tab.hf.space/")
models = [
"google/gemma-7b",
"google/gemma-7b-it",
"google/gemma-2b",
"google/gemma-2b-it"
]
clients = [
InferenceClient(models[0]),
InferenceClient(models[1]),
InferenceClient(models[2]),
InferenceClient(models[3]),
]
VERBOSE = False
def translate_to_english(prompt):
translated_prompt = GoogleTranslator(source='auto', target='en').translate(prompt)
return translated_prompt
def translate_to_persian_text(response):
translated_response = GoogleTranslator(source='auto', target='fa').translate(response)
return translated_response
def load_models(inp):
if VERBOSE == True:
print(type(inp))
print(inp)
print(models[inp])
return gr.update(label=models[inp])
def format_prompt(message, history, cust_p):
prompt = "<s>"
if history:
for user_prompt, bot_response in history:
prompt += f"<start_of_turn>user{user_prompt}<end_of_turn>"
prompt += f"<start_of_turn>model{bot_response}<end_of_turn></s>"
if VERBOSE == True:
print(prompt)
prompt += cust_p.replace("USER_INPUT", message)
return prompt
def chat_inf(system_prompt, prompt, history, memory, client_choice, seed, temp, tokens, top_p, rep_p, chat_mem, custom_prompt, translate_to_persian):
hist_len = 0
client = clients[int(client_choice) - 1]
if not history:
history = []
if not memory:
memory = []
if memory:
for ea in memory[0 - chat_mem:]:
hist_len += len(str(ea))
in_len = len(system_prompt + prompt) + hist_len
if (in_len + tokens) > 8000:
history.append((prompt, "Wait, that's too many tokens, please reduce the 'Chat Memory' value, or reduce the 'Max new tokens' value"))
yield history, memory
else:
generate_kwargs = dict(
max_new_tokens=tokens,
)
if system_prompt:
formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", memory[0 - chat_mem:], custom_prompt)
else:
formatted_prompt = format_prompt(prompt, memory[0 - chat_mem:], custom_prompt)
translated_prompt = translate_to_english(formatted_prompt)
chat = [
{"role": "user", "content": f"{translated_prompt}"},
]
stream = client.text_generation(translated_prompt, **generate_kwargs, stream=True, details=True, return_full_text=True)
output = ""
for response in stream:
output += response.token.text
if translate_to_persian:
output = translate_to_persian_text(output)
yield [(prompt, output)], memory
history.append((prompt, output))
memory.append((prompt, output))
yield history, memory
def clear_fn():
return None, None, None, None
rand_val = random.randint(1, 1111111111111111)
def check_rand(inp, val):
if inp == True:
return gr.Slider(label="Seed", minimum=1, maximum=1111111111111111, value=random.randint(1, 1111111111111111))
else:
return gr.Slider(label="Seed", minimum=1, maximum=1111111111111111, value=int(val))
def chat_wrapper(sys_inp, inp, chat_b, memory, client_choice, seed, temp, tokens, top_p, rep_p, chat_mem, custom_prompt, translate_to_persian_checkbox):
return chat_inf(sys_inp, inp, chat_b, memory, client_choice, seed, temp, tokens, top_p, rep_p, chat_mem, custom_prompt, translate_to_persian_checkbox)
with gr.Blocks() as app:
memory = gr.State()
chat_b = gr.Chatbot(height=500)
with gr.Group():
with gr.Row():
with gr.Column(scale=3):
inp = gr.Textbox(label="Prompt")
sys_inp = gr.Textbox(label="System Prompt (optional)")
with gr.Row():
with gr.Column(scale=2):
btn = gr.Button("Chat")
with gr.Column(scale=1):
with gr.Group():
stop_btn = gr.Button("Stop")
clear_btn = gr.Button("Clear")
client_choice = gr.Dropdown(label="Models", type='index', choices=[c for c in models], value=models[0], interactive=True)
with gr.Accordion("Prompt Format", open=False):
custom_prompt = gr.Textbox(label="Modify Prompt Format", info="For testing purposes. 'USER_INPUT' is where 'SYSTEM_PROMPT, PROMPT' will be placed", lines=5, value="<start_of_turn>userUSER_INPUT<end_of_turn><start_of_turn>model")
with gr.Column(scale=1):
with gr.Group():
translate_to_persian_checkbox = gr.Checkbox(label="Translate to Persian", value=True)
rand = gr.Checkbox(label="Random Seed", value=True)
seed = gr.Slider(label="Seed", minimum=1, maximum=1111111111111111, step=1, value=rand_val)
tokens = gr.Slider(label="Max new tokens", value=1600, minimum=0, maximum=8000, step=64, interactive=True, visible=True, info="The maximum number of tokens")
temp = gr.Slider(label="Temperature", step=0.01, minimum=0.01, maximum=1.0, value=0.9)
top_p = gr.Slider(label="Top-P", step=0.01, minimum=0.01, maximum=1.0, value=0.9)
rep_p = gr.Slider(label="Repetition Penalty", step=0.1, minimum=0.1, maximum=2.0, value=1.0)
chat_mem = gr.Number(label="Chat Memory", info="Number of previous chats to retain", value=4)
client_choice.change(load_models, client_choice, [chat_b])
app.load(load_models, client_choice, [chat_b])
chat_sub = inp.submit(check_rand, [rand, seed], seed).then(chat_wrapper, [sys_inp, inp, chat_b, memory, client_choice, seed, temp, tokens, top_p, rep_p, chat_mem, custom_prompt, translate_to_persian_checkbox]).then(chat_b.display, memory)
go = btn.click(check_rand, [rand, seed], seed).then(chat_wrapper, [sys_inp, inp, chat_b, memory, client_choice, seed, temp, tokens, top_p, rep_p, chat_mem, custom_prompt, translate_to_persian_checkbox]).then(chat_b.display, memory)
clear_btn.click(clear_fn, None, [inp, sys_inp, chat_b, memory])
app.queue(default_concurrency_limit=10).launch()
|