|
import os |
|
import re |
|
import json |
|
import copy |
|
import gradio as gr |
|
|
|
from palmapi import GradioPaLMChatPPManager |
|
from palmapi import gen_text |
|
|
|
from styles import MODEL_SELECTION_CSS |
|
from js import GET_LOCAL_STORAGE, UPDATE_LEFT_BTNS_STATE, UPDATE_PLACEHOLDERS |
|
from templates import templates |
|
from constants import DEFAULT_GLOBAL_CTX |
|
|
|
from pingpong import PingPong |
|
from pingpong.context import CtxLastWindowStrategy |
|
from pingpong.context import InternetSearchStrategy, SimilaritySearcher |
|
|
|
TOKEN = os.getenv('HF_TOKEN') |
|
MODEL_ID = 'meta-llama/Llama-2-70b-chat-hf' |
|
|
|
def build_prompts(ppmanager, global_context, win_size=3): |
|
dummy_ppm = copy.deepcopy(ppmanager) |
|
dummy_ppm.ctx = global_context |
|
lws = CtxLastWindowStrategy(win_size) |
|
return lws(dummy_ppm) |
|
|
|
ex_file = open("examples.txt", "r") |
|
examples = ex_file.read().split("\n") |
|
ex_btns = [] |
|
|
|
chl_file = open("channels.txt", "r") |
|
channels = chl_file.read().split("\n") |
|
channel_btns = [] |
|
|
|
def get_placeholders(text): |
|
"""Returns all substrings in between <placeholder> and </placeholder>.""" |
|
pattern = r"\[([^\]]*)\]" |
|
matches = re.findall(pattern, text) |
|
return matches |
|
|
|
def fill_up_placeholders(txt): |
|
placeholders = get_placeholders(txt) |
|
highlighted_txt = txt |
|
|
|
return ( |
|
gr.update( |
|
visible=True, |
|
value=highlighted_txt |
|
), |
|
gr.update( |
|
visible=True if len(placeholders) >= 1 else False, |
|
placeholder=placeholders[0] if len(placeholders) >= 1 else "" |
|
), |
|
gr.update( |
|
visible=True if len(placeholders) >= 2 else False, |
|
placeholder=placeholders[1] if len(placeholders) >= 2 else "" |
|
), |
|
gr.update( |
|
visible=True if len(placeholders) >= 3 else False, |
|
placeholder=placeholders[2] if len(placeholders) >= 3 else "" |
|
), |
|
"" if len(placeholders) >= 1 else txt |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def rollback_last( |
|
idx, local_data, chat_state, |
|
global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv, |
|
internet_option, serper_api_key, palm_if |
|
): |
|
internet_option = True if internet_option == "on" else False |
|
|
|
res = [ |
|
chat_state["ppmanager_type"].from_json(json.dumps(ppm)) |
|
for ppm in local_data |
|
] |
|
|
|
ppm = res[idx] |
|
last_user_message = res[idx].pingpongs[-1].ping |
|
res[idx].pingpongs = res[idx].pingpongs[:-1] |
|
|
|
ppm.add_pingpong( |
|
PingPong(last_user_message, "") |
|
) |
|
prompt = build_prompts(ppm, global_context, ctx_num_lconv) |
|
|
|
|
|
if internet_option: |
|
search_prompt = None |
|
for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv): |
|
search_prompt = tmp_prompt |
|
yield prompt, uis, str(res), gr.update(interactive=False), "off" |
|
|
|
async for result in gen_text( |
|
search_prompt if internet_option else prompt, |
|
hf_model=MODEL_ID, hf_token=TOKEN, |
|
parameters={ |
|
'max_new_tokens': res_mnts, |
|
'do_sample': res_sample, |
|
'return_full_text': False, |
|
'temperature': res_temp, |
|
'top_k': res_topk, |
|
'repetition_penalty': res_rpen |
|
} |
|
): |
|
ppm.append_pong(result) |
|
yield prompt, ppm.build_uis(), str(res), gr.update(interactive=False), "off" |
|
|
|
yield prompt, ppm.build_uis(), str(res), gr.update(interactive=True), "off" |
|
|
|
def reset_chat(idx, ld, state, palm_if): |
|
res = [state["ppmanager_type"].from_json(json.dumps(ppm_str)) for ppm_str in ld] |
|
res[idx].pingpongs = [] |
|
|
|
return ( |
|
"", |
|
[], |
|
str(res), |
|
gr.update(visible=True), |
|
gr.update(interactive=False), |
|
) |
|
|
|
async def chat_stream( |
|
idx, local_data, instruction_txtbox, chat_state, |
|
global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv, |
|
internet_option, serper_api_key, palm_if |
|
): |
|
internet_option = True if internet_option == "on" else False |
|
|
|
res = [ |
|
chat_state["ppmanager_type"].from_json(json.dumps(ppm)) |
|
for ppm in local_data |
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
palm_if, response_txt = gen_text(instruction_txtbox, palm_if) |
|
|
|
ppm = res[idx] |
|
ppm.add_pingpong( |
|
PingPong(instruction_txtbox, response_txt) |
|
) |
|
|
|
return "", "", ppm.build_uis(), str(res), gr.update(interactive=True), "off" |
|
|
|
def channel_num(btn_title): |
|
choice = 0 |
|
|
|
for idx, channel in enumerate(channels): |
|
if channel == btn_title: |
|
choice = idx |
|
|
|
return choice |
|
|
|
def set_chatbot(btn, ld, state): |
|
choice = channel_num(btn) |
|
|
|
res = [state["ppmanager_type"].from_json(json.dumps(ppm_str)) for ppm_str in ld] |
|
empty = len(res[choice].pingpongs) == 0 |
|
return (res[choice].build_uis(), choice, gr.update(visible=empty), gr.update(interactive=not empty)) |
|
|
|
def set_example(btn): |
|
return btn, gr.update(visible=False) |
|
|
|
def get_final_template( |
|
txt, placeholder_txt1, placeholder_txt2, placeholder_txt3 |
|
): |
|
placeholders = get_placeholders(txt) |
|
example_prompt = txt |
|
|
|
if len(placeholders) >= 1: |
|
if placeholder_txt1 != "": |
|
example_prompt = example_prompt.replace(f"[{placeholders[0]}]", placeholder_txt1) |
|
if len(placeholders) >= 2: |
|
if placeholder_txt2 != "": |
|
example_prompt = example_prompt.replace(f"[{placeholders[1]}]", placeholder_txt2) |
|
if len(placeholders) >= 3: |
|
if placeholder_txt3 != "": |
|
example_prompt = example_prompt.replace(f"[{placeholders[2]}]", placeholder_txt3) |
|
|
|
return ( |
|
example_prompt, |
|
"", |
|
"", |
|
"" |
|
) |
|
|
|
with gr.Blocks(css=MODEL_SELECTION_CSS, theme='gradio/soft') as demo: |
|
palm_if = gr.State() |
|
|
|
with gr.Column() as chat_view: |
|
idx = gr.State(0) |
|
chat_state = gr.State({ |
|
"ppmanager_type": GradioPaLMChatPPManager |
|
}) |
|
local_data = gr.JSON({}, visible=False) |
|
|
|
gr.Markdown("## LLaMA2 70B with Gradio Chat and Hugging Face Inference API", elem_classes=["center"]) |
|
gr.Markdown( |
|
"This space demonstrates how to build feature rich chatbot UI in [Gradio](https://www.gradio.app/). Supported features " |
|
"include • multiple chatting channels, • chat history save/restoration, • stop generating text response, • regenerate the " |
|
"last conversation, • clean the chat history, • dynamic kick-starting prompt templates, • adjusting text generation parameters, " |
|
"• inspecting the actual prompt that the model sees. The underlying Large Language Model is the [Meta AI](https://ai.meta.com/)'s " |
|
"[LLaMA2-70B](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf) which is hosted as [Hugging Face Inference API](https://huggingface.co/inference-api), " |
|
"and [Text Generation Inference](https://github.com/huggingface/text-generation-inference) is the underlying serving framework. " |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1, min_width=180): |
|
gr.Markdown("GradioChat", elem_id="left-top") |
|
|
|
with gr.Column(elem_id="left-pane"): |
|
with gr.Accordion("Histories", elem_id="chat-history-accordion", open=True): |
|
channel_btns.append(gr.Button(channels[0], elem_classes=["custom-btn-highlight"])) |
|
|
|
for channel in channels[1:]: |
|
channel_btns.append(gr.Button(channel, elem_classes=["custom-btn"])) |
|
|
|
internet_option = gr.Radio( |
|
choices=["on", "off"], value="off", |
|
label="internet mode", elem_id="internet_option_radio") |
|
serper_api_key = gr.Textbox( |
|
value= os.getenv("SERPER_API_KEY"), |
|
placeholder="Get one by visiting serper.dev", |
|
label="Serper api key", |
|
visible=False |
|
) |
|
|
|
with gr.Column(scale=8, elem_id="right-pane"): |
|
with gr.Column( |
|
elem_id="initial-popup", visible=False |
|
) as example_block: |
|
with gr.Row(scale=1): |
|
with gr.Column(elem_id="initial-popup-left-pane"): |
|
gr.Markdown("GradioChat", elem_id="initial-popup-title") |
|
gr.Markdown("Making the community's best AI chat models available to everyone.") |
|
with gr.Column(elem_id="initial-popup-right-pane"): |
|
gr.Markdown("Chat UI is now open sourced on Hugging Face Hub") |
|
gr.Markdown("check out the [↗ repository](https://huggingface.co/spaces/chansung/test-multi-conv)") |
|
|
|
with gr.Column(scale=1): |
|
gr.Markdown("Examples") |
|
with gr.Row(): |
|
for example in examples: |
|
ex_btns.append(gr.Button(example, elem_classes=["example-btn"])) |
|
|
|
with gr.Column(elem_id="aux-btns-popup", visible=True): |
|
with gr.Row(): |
|
|
|
regenerate = gr.Button("Regen", interactive=False, elem_classes=["aux-btn"]) |
|
clean = gr.Button("Clean", elem_classes=["aux-btn"]) |
|
|
|
with gr.Accordion("Context Inspector", elem_id="aux-viewer", open=False): |
|
context_inspector = gr.Textbox( |
|
"", |
|
elem_id="aux-viewer-inspector", |
|
label="", |
|
lines=30, |
|
max_lines=50, |
|
) |
|
|
|
chatbot = gr.Chatbot(elem_id='chatbot', label="PaLM API") |
|
instruction_txtbox = gr.Textbox(placeholder="Ask anything", label="", elem_id="prompt-txt") |
|
|
|
with gr.Accordion("Example Templates", open=False): |
|
template_txt = gr.Textbox(visible=False) |
|
template_md = gr.Markdown(label="Chosen Template", visible=False, elem_classes="template-txt") |
|
|
|
with gr.Row(): |
|
placeholder_txt1 = gr.Textbox(label="placeholder #1", visible=False, interactive=True) |
|
placeholder_txt2 = gr.Textbox(label="placeholder #2", visible=False, interactive=True) |
|
placeholder_txt3 = gr.Textbox(label="placeholder #3", visible=False, interactive=True) |
|
|
|
for template in templates: |
|
with gr.Tab(template['title']): |
|
gr.Examples( |
|
template['template'], |
|
inputs=[template_txt], |
|
outputs=[template_md, placeholder_txt1, placeholder_txt2, placeholder_txt3, instruction_txtbox], |
|
run_on_click=True, |
|
fn=fill_up_placeholders, |
|
) |
|
|
|
with gr.Accordion("Control Panel", open=False) as control_panel: |
|
with gr.Column(): |
|
with gr.Column(): |
|
gr.Markdown("#### Global context") |
|
with gr.Accordion("global context will persist during conversation, and it is placed at the top of the prompt", open=True): |
|
global_context = gr.Textbox( |
|
DEFAULT_GLOBAL_CTX, |
|
lines=5, |
|
max_lines=10, |
|
interactive=True, |
|
elem_id="global-context" |
|
) |
|
|
|
gr.Markdown("#### GenConfig for **response** text generation") |
|
with gr.Row(): |
|
res_temp = gr.Slider(0.0, 2.0, 1.0, step=0.1, label="temp", interactive=True) |
|
res_topk = gr.Slider(20, 1000, 50, step=1, label="top_k", interactive=True) |
|
res_rpen = gr.Slider(0.0, 2.0, 1.2, step=0.1, label="rep_penalty", interactive=True) |
|
res_mnts = gr.Slider(64, 8192, 512, step=1, label="new_tokens", interactive=True) |
|
res_sample = gr.Radio([True, False], value=True, label="sample", interactive=True) |
|
|
|
with gr.Column(): |
|
gr.Markdown("#### Context managements") |
|
with gr.Row(): |
|
ctx_num_lconv = gr.Slider(2, 10, 3, step=1, label="number of recent talks to keep", interactive=True) |
|
|
|
gr.Markdown( |
|
"***NOTE:*** If you are subscribing [PRO](https://huggingface.co/pricing#pro), you can simply duplicate this space and use your " |
|
"Hugging Face Access Token to run the same application. Just add `HF_TOKEN` secret with the Token value accorindg to [this guide]" |
|
"(https://huggingface.co/docs/hub/spaces-overview#managing-secrets-and-environment-variables). Also, if you want to enable internet search " |
|
"capability in your private space, please specify `SERPER_API_KEY` secret after getting one from [serper.dev](https://serper.dev/)." |
|
) |
|
|
|
gr.Markdown( |
|
"***NOTE:*** If you want to run more extended version of this application, check out [LLM As Chatbot](https://github.com/deep-diver/LLM-As-Chatbot) " |
|
"project. This project lets you choose a model among various Open Source LLMs including LLaMA2 variations, and others more than 50. Also, if you " |
|
"have any other further questions and considerations, please [contact me](https://twitter.com/algo_diver)" |
|
) |
|
|
|
send_event = instruction_txtbox.submit( |
|
lambda: [ |
|
gr.update(visible=False), |
|
gr.update(interactive=True) |
|
], |
|
None, |
|
[example_block, regenerate] |
|
).then( |
|
chat_stream, |
|
[idx, local_data, instruction_txtbox, chat_state, |
|
global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv, |
|
internet_option, serper_api_key, palm_if], |
|
[instruction_txtbox, context_inspector, chatbot, local_data, regenerate, internet_option] |
|
).then( |
|
None, local_data, None, |
|
_js="(v)=>{ setStorage('local_data',v) }" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
regen_event = regenerate.click( |
|
rollback_last, |
|
[idx, local_data, chat_state, |
|
global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv, |
|
internet_option, serper_api_key, palm_if], |
|
[context_inspector, chatbot, local_data, regenerate, internet_option] |
|
).then( |
|
None, local_data, None, |
|
_js="(v)=>{ setStorage('local_data',v) }" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
for btn in channel_btns: |
|
btn.click( |
|
set_chatbot, |
|
[btn, local_data, chat_state], |
|
[chatbot, idx, example_block, regenerate] |
|
).then( |
|
None, btn, None, |
|
_js=UPDATE_LEFT_BTNS_STATE |
|
) |
|
|
|
for btn in ex_btns: |
|
btn.click( |
|
set_example, |
|
[btn], |
|
[instruction_txtbox, example_block] |
|
) |
|
|
|
clean.click( |
|
reset_chat, |
|
[idx, local_data, chat_state, palm_if], |
|
[instruction_txtbox, chatbot, local_data, example_block, regenerate] |
|
).then( |
|
None, local_data, None, |
|
_js="(v)=>{ setStorage('local_data',v) }" |
|
) |
|
|
|
|
|
placeholder_txt1.change( |
|
inputs=[template_txt, placeholder_txt1, placeholder_txt2, placeholder_txt3], |
|
outputs=[template_md], |
|
show_progress=False, |
|
_js=UPDATE_PLACEHOLDERS, |
|
fn=None |
|
) |
|
|
|
placeholder_txt2.change( |
|
inputs=[template_txt, placeholder_txt1, placeholder_txt2, placeholder_txt3], |
|
outputs=[template_md], |
|
show_progress=False, |
|
_js=UPDATE_PLACEHOLDERS, |
|
fn=None |
|
) |
|
|
|
placeholder_txt3.change( |
|
inputs=[template_txt, placeholder_txt1, placeholder_txt2, placeholder_txt3], |
|
outputs=[template_md], |
|
show_progress=False, |
|
_js=UPDATE_PLACEHOLDERS, |
|
fn=None |
|
) |
|
|
|
placeholder_txt1.submit( |
|
inputs=[template_txt, placeholder_txt1, placeholder_txt2, placeholder_txt3], |
|
outputs=[instruction_txtbox, placeholder_txt1, placeholder_txt2, placeholder_txt3], |
|
fn=get_final_template |
|
) |
|
|
|
placeholder_txt2.submit( |
|
inputs=[template_txt, placeholder_txt1, placeholder_txt2, placeholder_txt3], |
|
outputs=[instruction_txtbox, placeholder_txt1, placeholder_txt2, placeholder_txt3], |
|
fn=get_final_template |
|
) |
|
|
|
placeholder_txt3.submit( |
|
inputs=[template_txt, placeholder_txt1, placeholder_txt2, placeholder_txt3], |
|
outputs=[instruction_txtbox, placeholder_txt1, placeholder_txt2, placeholder_txt3], |
|
fn=get_final_template |
|
) |
|
|
|
demo.load( |
|
None, |
|
inputs=None, |
|
outputs=[chatbot, local_data], |
|
_js=GET_LOCAL_STORAGE, |
|
) |
|
|
|
demo.queue(concurrency_count=5, max_size=256).launch() |