|
import os |
|
|
|
os.environ["KERAS_BACKEND"] = "jax" |
|
|
|
import gradio as gr |
|
from gradio import ChatMessage |
|
import keras_hub |
|
|
|
from chatstate import ChatState |
|
from models import ( |
|
model_presets, |
|
load_model, |
|
model_labels, |
|
preset_to_website_url, |
|
get_appropriate_chat_template, |
|
) |
|
|
|
model_labels_list = list(model_labels) |
|
|
|
|
|
models = [] |
|
for preset in model_presets: |
|
model = load_model(preset) |
|
chat_template = get_appropriate_chat_template(preset) |
|
chat_state = ChatState(model, "", chat_template) |
|
prompt, response = chat_state.send_message("Hello") |
|
print("model " + preset + "loaded and initialized.") |
|
print("The model responded: " + response) |
|
|
|
models = [load_model(preset) for preset in model_presets] |
|
|
|
|
|
|
|
|
|
|
|
|
|
def chat_turn_assistant_1( |
|
model, |
|
message, |
|
history, |
|
system_message, |
|
preset, |
|
|
|
|
|
|
|
): |
|
chat_template = get_appropriate_chat_template(preset) |
|
chat_state = ChatState(model, system_message, chat_template) |
|
|
|
for msg in history: |
|
msg = ChatMessage(**msg) |
|
if msg.role == "user": |
|
chat_state.add_to_history_as_user(msg.content) |
|
elif msg.role == "assistant": |
|
chat_state.add_to_history_as_model(msg.content) |
|
|
|
prompt, response = chat_state.send_message(message) |
|
history.append(ChatMessage(role="assistant", content=response)) |
|
return history |
|
|
|
|
|
def chat_turn_assistant( |
|
message, |
|
sel1, |
|
history1, |
|
sel2, |
|
history2, |
|
system_message, |
|
|
|
|
|
|
|
): |
|
history1 = chat_turn_assistant_1( |
|
models[sel1], message, history1, system_message, model_presets[sel1] |
|
) |
|
history2 = chat_turn_assistant_1( |
|
models[sel2], message, history2, system_message, model_presets[sel2] |
|
) |
|
return "", history1, history2 |
|
|
|
|
|
def chat_turn_user_1(message, history): |
|
history.append(ChatMessage(role="user", content=message)) |
|
return history |
|
|
|
|
|
def chat_turn_user(message, history1, history2): |
|
history1 = chat_turn_user_1(message, history1) |
|
history2 = chat_turn_user_1(message, history2) |
|
return "", history1, history2 |
|
|
|
|
|
def bot_icon_select(model_name): |
|
if "gemma" in model_name: |
|
return "img/gemma.png" |
|
elif "llama" in model_name: |
|
return "img/llama.png" |
|
elif "vicuna" in model_name: |
|
return "img/vicuna.png" |
|
elif "mistral" in model_name: |
|
return "img/mistral.png" |
|
|
|
return "img/bot.png" |
|
|
|
|
|
def instantiate_chatbots(sel1, sel2): |
|
model_name1 = model_presets[sel1] |
|
chatbot1 = gr.Chatbot( |
|
type="messages", |
|
show_label=False, |
|
avatar_images=("img/usr.png", bot_icon_select(model_name1)), |
|
) |
|
model_name2 = model_presets[sel2] |
|
chatbot2 = gr.Chatbot( |
|
type="messages", |
|
show_label=False, |
|
avatar_images=("img/usr.png", bot_icon_select(model_name2)), |
|
) |
|
return chatbot1, chatbot2 |
|
|
|
|
|
def instantiate_select_boxes(sel1, sel2, model_labels): |
|
sel1 = gr.Dropdown( |
|
choices=[(name, i) for i, name in enumerate(model_labels)], |
|
show_label=False, |
|
info="<span style='color:black'>Selected model 1:</span> " |
|
+ "<a href='" |
|
+ preset_to_website_url(model_presets[sel1]) |
|
+ "'>" |
|
+ preset_to_website_url(model_presets[sel1]) |
|
+ "</a>", |
|
value=sel1, |
|
) |
|
sel2 = gr.Dropdown( |
|
choices=[(name, i) for i, name in enumerate(model_labels)], |
|
show_label=False, |
|
info="<span style='color:black'>Selected model 2:</span> " |
|
+ "<a href='" |
|
+ preset_to_website_url(model_presets[sel2]) |
|
+ "'>" |
|
+ preset_to_website_url(model_presets[sel2]) |
|
+ "</a>", |
|
value=sel2, |
|
) |
|
return sel1, sel2 |
|
|
|
|
|
def instantiate_chatbots_and_select_boxes(sel1, sel2, model_labels): |
|
chatbot1, chatbot2 = instantiate_chatbots(sel1, sel2) |
|
sel1, sel2 = instantiate_select_boxes(sel1, sel2, model_labels) |
|
return sel1, chatbot1, sel2, chatbot2 |
|
|
|
|
|
with gr.Blocks(fill_width=True, title="Keras demo") as demo: |
|
|
|
with gr.Row(): |
|
gr.Image( |
|
"img/keras_logo_k.png", |
|
width=80, |
|
height=80, |
|
min_width=80, |
|
show_label=False, |
|
show_download_button=False, |
|
show_fullscreen_button=False, |
|
interactive=False, |
|
scale=0.01, |
|
container=False, |
|
) |
|
gr.HTML( |
|
"<H2> Battle of the Keras chatbots on TPU</H2>" |
|
+ "All the models are loaded into the TPU memory. " |
|
+ "You can call them at will and compare their answers. <br/>" |
|
+ "The entire chat history is fed to the models at every submission." |
|
+ "This demno is runnig on a Google TPU v5e 2x4 (8 cores).", |
|
) |
|
with gr.Row(): |
|
sel1, sel2 = instantiate_select_boxes(0, 1, model_labels_list) |
|
|
|
with gr.Row(): |
|
chatbot1, chatbot2 = instantiate_chatbots(sel1.value, sel2.value) |
|
|
|
msg = gr.Textbox( |
|
label="Your message:", |
|
) |
|
with gr.Row(): |
|
gr.ClearButton([msg, chatbot1, chatbot2]) |
|
with gr.Accordion("Additional settings", open=False): |
|
system_message = gr.Textbox( |
|
label="Sytem prompt", |
|
value="You are a helpful assistant and your name is Eliza.", |
|
) |
|
|
|
sel1.select( |
|
lambda sel1, sel2: instantiate_chatbots_and_select_boxes( |
|
sel1, sel2, model_labels_list |
|
), |
|
inputs=[sel1, sel2], |
|
outputs=[sel1, chatbot1, sel2, chatbot2], |
|
) |
|
|
|
sel2.select( |
|
lambda sel1, sel2: instantiate_chatbots_and_select_boxes( |
|
sel1, sel2, model_labels_list |
|
), |
|
inputs=[sel1, sel2], |
|
outputs=[sel1, chatbot1, sel2, chatbot2], |
|
) |
|
|
|
msg.submit( |
|
chat_turn_user, |
|
inputs=[msg, chatbot1, chatbot2], |
|
outputs=[msg, chatbot1, chatbot2], |
|
).then( |
|
chat_turn_assistant, |
|
[msg, sel1, chatbot1, sel2, chatbot2, system_message], |
|
outputs=[msg, chatbot1, chatbot2], |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|