|
import os |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
os.environ["KERAS_BACKEND"] = "jax" |
|
|
|
import gradio as gr |
|
from gradio import ChatMessage |
|
import keras_hub |
|
|
|
from chatstate import ChatState |
|
from enum import Enum |
|
from models import ( |
|
model_presets, |
|
load_model, |
|
model_labels, |
|
preset_to_website_url, |
|
get_appropriate_chat_template, |
|
) |
|
|
|
|
|
class TextRoute(Enum): |
|
LEFT = 0 |
|
RIGHT = 1 |
|
BOTH = 2 |
|
|
|
|
|
model_labels_list = list(model_labels) |
|
|
|
|
|
models = [] |
|
for preset in model_presets: |
|
model = load_model(preset) |
|
chat_template = get_appropriate_chat_template(preset) |
|
chat_state = ChatState(model, "", chat_template) |
|
prompt, response = chat_state.send_message("Hello") |
|
print("model " + preset + " loaded and initialized.") |
|
print("The model responded: " + response) |
|
models.append(model) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def chat_turn_assistant( |
|
message, |
|
sel, |
|
history, |
|
system_message, |
|
|
|
|
|
|
|
): |
|
model = models[sel] |
|
preset = model_presets[sel] |
|
chat_template = get_appropriate_chat_template(preset) |
|
chat_state = ChatState(model, system_message, chat_template) |
|
|
|
for msg in history: |
|
msg = ChatMessage(**msg) |
|
if msg.role == "user": |
|
chat_state.add_to_history_as_user(msg.content) |
|
elif msg.role == "assistant": |
|
chat_state.add_to_history_as_model(msg.content) |
|
|
|
prompt, response = chat_state.send_message(message) |
|
history.append(ChatMessage(role="assistant", content=response)) |
|
return history |
|
|
|
|
|
def chat_turn_both_assistant( |
|
message, sel1, sel2, history1, history2, system_message |
|
): |
|
return ( |
|
chat_turn_assistant(message, sel1, history1, system_message), |
|
chat_turn_assistant(message, sel2, history2, system_message), |
|
) |
|
|
|
|
|
def chat_turn_user(message, history): |
|
history.append(ChatMessage(role="user", content=message)) |
|
return history |
|
|
|
|
|
def chat_turn_both_user(message, history1, history2): |
|
return ( |
|
chat_turn_user(message, history1), |
|
chat_turn_user(message, history2), |
|
) |
|
|
|
|
|
def bot_icon_select(model_name): |
|
if "gemma" in model_name: |
|
return "img/gemma.png" |
|
elif "llama" in model_name: |
|
return "img/meta.png" |
|
elif "vicuna" in model_name: |
|
return "img/vicuna.png" |
|
elif "mistral" in model_name: |
|
return "img/mistral.png" |
|
|
|
return "img/bot.png" |
|
|
|
|
|
def instantiate_select_box(sel, model_labels): |
|
return gr.Dropdown( |
|
choices=[(name, i) for i, name in enumerate(model_labels)], |
|
show_label=False, |
|
value=sel, |
|
info="<span style='color:black'>Selected model:</span> <a href='" |
|
+ preset_to_website_url(model_presets[sel]) |
|
+ "'>" |
|
+ preset_to_website_url(model_presets[sel]) |
|
+ "</a>", |
|
) |
|
|
|
|
|
def instantiate_chatbot(sel, key): |
|
model_name = model_presets[sel] |
|
return gr.Chatbot( |
|
type="messages", |
|
key=key, |
|
show_label=False, |
|
show_share_button=False, |
|
show_copy_all_button=True, |
|
avatar_images=("img/usr.png", bot_icon_select(model_name)), |
|
) |
|
|
|
|
|
def instantiate_arrow_button(route, text_route): |
|
icons = { |
|
TextRoute.LEFT: "img/arrowL.png", |
|
TextRoute.RIGHT: "img/arrowR.png", |
|
TextRoute.BOTH: "img/arrowRL.png", |
|
} |
|
button = gr.Button( |
|
"", |
|
size="sm", |
|
scale=0, |
|
min_width=40, |
|
icon=icons[route], |
|
) |
|
button.click(lambda: route, outputs=[text_route]) |
|
return button |
|
|
|
|
|
def instantiate_retry_button(route): |
|
return gr.Button( |
|
"", |
|
size="sm", |
|
scale=0, |
|
min_width=40, |
|
icon="img/retry.png", |
|
) |
|
|
|
|
|
def instantiate_trash_button(): |
|
return gr.Button( |
|
"", |
|
size="sm", |
|
scale=0, |
|
min_width=40, |
|
icon="img/trash.png", |
|
) |
|
|
|
|
|
def instantiate_text_box(): |
|
return gr.Textbox(label="Your message:", submit_btn=True, key="msg") |
|
|
|
|
|
def instantiate_additional_settings(): |
|
with gr.Accordion("Additional settings", open=False): |
|
system_message = gr.Textbox( |
|
label="Sytem prompt", |
|
key="system_prompt", |
|
value="You are a helpful assistant and your name is Eliza.", |
|
) |
|
return system_message |
|
|
|
|
|
def retry_fn(history): |
|
if len(history) >= 2: |
|
msg = history.pop(-1) |
|
msg = history.pop(-1) |
|
return msg["content"], history |
|
else: |
|
return gr.skip(), gr.skip() |
|
|
|
|
|
def retry_fn_both(history1, history2): |
|
msg1, history1 = retry_fn(history1) |
|
msg2, history2 = retry_fn(history2) |
|
if isinstance(msg1, str) and isinstance(msg2, str): |
|
if msg1 == msg2: |
|
msg = msg1 |
|
else: |
|
msg = msg1 + " / " + msg2 |
|
elif isinstance(msg1, str): |
|
msg = msg1 |
|
elif isinstance(msg2, str): |
|
msg = msg2 |
|
else: |
|
msg = msg1 |
|
return msg, history1, history2 |
|
|
|
|
|
sel1 = instantiate_select_box(0, model_labels_list) |
|
sel2 = instantiate_select_box(1, model_labels_list) |
|
chatbot1 = instantiate_chatbot(sel1.value, "chat1") |
|
chatbot2 = instantiate_chatbot(sel2.value, "chat2") |
|
|
|
|
|
CSS = ".stick-to-the-right {align-items: end; justify-content: end}" |
|
|
|
with gr.Blocks(fill_width=True, title="Keras demo", css=CSS) as demo: |
|
|
|
|
|
text_route = gr.State(TextRoute.BOTH) |
|
|
|
with gr.Row(): |
|
gr.Image( |
|
"img/keras_logo_k.png", |
|
width=80, |
|
height=80, |
|
min_width=80, |
|
show_label=False, |
|
show_download_button=False, |
|
show_fullscreen_button=False, |
|
show_share_button=False, |
|
interactive=False, |
|
scale=0, |
|
container=False, |
|
) |
|
gr.HTML( |
|
"<H2>Keras chatbot arena - running with JAX on TPU</H2>" |
|
+ "All the models are loaded into the TPU memory. " |
|
+ "You can call any of them and compare their answers. " |
|
+ "The entire chat<br/>history is fed to the models at every submission. " |
|
+ "This demo is runnig on a Google TPU v5e 2x4 (8 cores) in bfloat16 precision." |
|
) |
|
with gr.Row(): |
|
sel1.render(), |
|
sel2.render(), |
|
|
|
with gr.Row(): |
|
chatbot1.render() |
|
chatbot2.render() |
|
|
|
@gr.render(inputs=text_route) |
|
def render_text_area(route): |
|
|
|
if route == TextRoute.BOTH: |
|
with gr.Row(): |
|
msg = instantiate_text_box() |
|
with gr.Column(scale=0, min_width=100): |
|
with gr.Row(): |
|
instantiate_arrow_button(TextRoute.LEFT, text_route) |
|
retry = instantiate_retry_button(route) |
|
with gr.Row(): |
|
instantiate_arrow_button(TextRoute.RIGHT, text_route) |
|
trash = instantiate_trash_button() |
|
retry.click( |
|
retry_fn_both, |
|
inputs=[chatbot1, chatbot2], |
|
outputs=[msg, chatbot1, chatbot2], |
|
) |
|
trash.click(lambda: ("", [], []), outputs=[msg, chatbot1, chatbot2]) |
|
|
|
elif route == TextRoute.LEFT: |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
msg = instantiate_text_box() |
|
with gr.Column(scale=1): |
|
with gr.Row(): |
|
instantiate_arrow_button(TextRoute.RIGHT, text_route) |
|
retry = instantiate_retry_button(route) |
|
with gr.Row(): |
|
instantiate_arrow_button(TextRoute.BOTH, text_route) |
|
trash = instantiate_trash_button() |
|
retry.click(retry_fn, inputs=[chatbot1], outputs=[msg, chatbot1]) |
|
trash.click(lambda: ("", []), outputs=[msg, chatbot1]) |
|
|
|
elif route == TextRoute.RIGHT: |
|
with gr.Row(): |
|
with gr.Column(scale=1, elem_classes="stick-to-the-right"): |
|
with gr.Row(elem_classes="stick-to-the-right"): |
|
retry = instantiate_retry_button(route) |
|
instantiate_arrow_button(TextRoute.LEFT, text_route) |
|
with gr.Row(elem_classes="stick-to-the-right"): |
|
trash = instantiate_trash_button() |
|
instantiate_arrow_button(TextRoute.BOTH, text_route) |
|
with gr.Column(scale=1): |
|
msg = instantiate_text_box() |
|
retry.click(retry_fn, inputs=[chatbot2], outputs=[msg, chatbot2]) |
|
trash.click(lambda: ("", []), outputs=[msg, chatbot2]) |
|
|
|
system_message = instantiate_additional_settings() |
|
|
|
|
|
if route == TextRoute.LEFT: |
|
submission = msg.submit( |
|
chat_turn_user, inputs=[msg, chatbot1], outputs=[chatbot1] |
|
).then( |
|
chat_turn_assistant, |
|
[msg, sel1, chatbot1, system_message], |
|
outputs=[chatbot1], |
|
) |
|
elif route == TextRoute.RIGHT: |
|
submission = msg.submit( |
|
chat_turn_user, inputs=[msg, chatbot2], outputs=[chatbot2] |
|
).then( |
|
chat_turn_assistant, |
|
[msg, sel2, chatbot2, system_message], |
|
outputs=[chatbot2], |
|
) |
|
elif route == TextRoute.BOTH: |
|
submission = msg.submit( |
|
chat_turn_both_user, |
|
inputs=[msg, chatbot1, chatbot2], |
|
outputs=[chatbot1, chatbot2], |
|
).then( |
|
chat_turn_both_assistant, |
|
[msg, sel1, sel2, chatbot1, chatbot2, system_message], |
|
outputs=[chatbot1, chatbot2], |
|
) |
|
|
|
submission.then(lambda: "", outputs=msg) |
|
|
|
sel1.select( |
|
lambda sel: instantiate_chatbot(sel, "chat1"), |
|
inputs=[sel1], |
|
outputs=[chatbot1], |
|
).then( |
|
lambda sel: instantiate_select_box(sel, model_labels_list), |
|
inputs=[sel1], |
|
outputs=[sel1], |
|
) |
|
|
|
sel2.select( |
|
lambda sel: instantiate_chatbot(sel, "chat2"), |
|
inputs=[sel2], |
|
outputs=[chatbot2], |
|
).then( |
|
lambda sel: instantiate_select_box(sel, model_labels_list), |
|
inputs=[sel2], |
|
outputs=[sel2], |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|