|  | import os | 
					
						
						|  |  | 
					
						
						|  | os.environ["KERAS_BACKEND"] = "jax" | 
					
						
						|  |  | 
					
						
						|  | import gradio as gr | 
					
						
						|  | from gradio import ChatMessage | 
					
						
						|  | import keras_hub | 
					
						
						|  |  | 
					
						
						|  | from chatstate import ChatState | 
					
						
						|  | from models import ( | 
					
						
						|  | model_presets, | 
					
						
						|  | load_model, | 
					
						
						|  | model_labels, | 
					
						
						|  | preset_to_website_url, | 
					
						
						|  | get_appropriate_chat_template, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | model_labels_list = list(model_labels) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | models = [] | 
					
						
						|  | for preset in model_presets: | 
					
						
						|  | model = load_model(preset) | 
					
						
						|  | chat_template = get_appropriate_chat_template(preset) | 
					
						
						|  | chat_state = ChatState(model, "", chat_template) | 
					
						
						|  | prompt, response = chat_state.send_message("Hello") | 
					
						
						|  | print("model " + preset + "loaded and initialized.") | 
					
						
						|  | print("The model responded: " + response) | 
					
						
						|  |  | 
					
						
						|  | models = [load_model(preset) for preset in model_presets] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def chat_turn_assistant_1( | 
					
						
						|  | model, | 
					
						
						|  | message, | 
					
						
						|  | history, | 
					
						
						|  | system_message, | 
					
						
						|  | preset, | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | ): | 
					
						
						|  | chat_template = get_appropriate_chat_template(preset) | 
					
						
						|  | chat_state = ChatState(model, system_message, chat_template) | 
					
						
						|  |  | 
					
						
						|  | for msg in history: | 
					
						
						|  | msg = ChatMessage(**msg) | 
					
						
						|  | if msg.role == "user": | 
					
						
						|  | chat_state.add_to_history_as_user(msg.content) | 
					
						
						|  | elif msg.role == "assistant": | 
					
						
						|  | chat_state.add_to_history_as_model(msg.content) | 
					
						
						|  |  | 
					
						
						|  | prompt, response = chat_state.send_message(message) | 
					
						
						|  | history.append(ChatMessage(role="assistant", content=response)) | 
					
						
						|  | return history | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def chat_turn_assistant( | 
					
						
						|  | message, | 
					
						
						|  | sel1, | 
					
						
						|  | history1, | 
					
						
						|  | sel2, | 
					
						
						|  | history2, | 
					
						
						|  | system_message, | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | ): | 
					
						
						|  | history1 = chat_turn_assistant_1( | 
					
						
						|  | models[sel1], message, history1, system_message, model_presets[sel1] | 
					
						
						|  | ) | 
					
						
						|  | history2 = chat_turn_assistant_1( | 
					
						
						|  | models[sel2], message, history2, system_message, model_presets[sel2] | 
					
						
						|  | ) | 
					
						
						|  | return "", history1, history2 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def chat_turn_user_1(message, history): | 
					
						
						|  | history.append(ChatMessage(role="user", content=message)) | 
					
						
						|  | return history | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def chat_turn_user(message, history1, history2): | 
					
						
						|  | history1 = chat_turn_user_1(message, history1) | 
					
						
						|  | history2 = chat_turn_user_1(message, history2) | 
					
						
						|  | return "", history1, history2 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def bot_icon_select(model_name): | 
					
						
						|  | if "gemma" in model_name: | 
					
						
						|  | return "img/gemma.png" | 
					
						
						|  | elif "llama" in model_name: | 
					
						
						|  | return "img/llama.png" | 
					
						
						|  | elif "vicuna" in model_name: | 
					
						
						|  | return "img/vicuna.png" | 
					
						
						|  | elif "mistral" in model_name: | 
					
						
						|  | return "img/mistral.png" | 
					
						
						|  |  | 
					
						
						|  | return "img/bot.png" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def instantiate_chatbots(sel1, sel2): | 
					
						
						|  | model_name1 = model_presets[sel1] | 
					
						
						|  | chatbot1 = gr.Chatbot( | 
					
						
						|  | type="messages", | 
					
						
						|  | show_label=False, | 
					
						
						|  | avatar_images=("img/usr.png", bot_icon_select(model_name1)), | 
					
						
						|  | ) | 
					
						
						|  | model_name2 = model_presets[sel2] | 
					
						
						|  | chatbot2 = gr.Chatbot( | 
					
						
						|  | type="messages", | 
					
						
						|  | show_label=False, | 
					
						
						|  | avatar_images=("img/usr.png", bot_icon_select(model_name2)), | 
					
						
						|  | ) | 
					
						
						|  | return chatbot1, chatbot2 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def instantiate_select_boxes(sel1, sel2, model_labels): | 
					
						
						|  | sel1 = gr.Dropdown( | 
					
						
						|  | choices=[(name, i) for i, name in enumerate(model_labels)], | 
					
						
						|  | show_label=False, | 
					
						
						|  | info="<span style='color:black'>Selected model 1:</span> " | 
					
						
						|  | + "<a href='" | 
					
						
						|  | + preset_to_website_url(model_presets[sel1]) | 
					
						
						|  | + "'>" | 
					
						
						|  | + preset_to_website_url(model_presets[sel1]) | 
					
						
						|  | + "</a>", | 
					
						
						|  | value=sel1, | 
					
						
						|  | ) | 
					
						
						|  | sel2 = gr.Dropdown( | 
					
						
						|  | choices=[(name, i) for i, name in enumerate(model_labels)], | 
					
						
						|  | show_label=False, | 
					
						
						|  | info="<span style='color:black'>Selected model 2:</span> " | 
					
						
						|  | + "<a href='" | 
					
						
						|  | + preset_to_website_url(model_presets[sel2]) | 
					
						
						|  | + "'>" | 
					
						
						|  | + preset_to_website_url(model_presets[sel2]) | 
					
						
						|  | + "</a>", | 
					
						
						|  | value=sel2, | 
					
						
						|  | ) | 
					
						
						|  | return sel1, sel2 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def instantiate_chatbots_and_select_boxes(sel1, sel2, model_labels): | 
					
						
						|  | chatbot1, chatbot2 = instantiate_chatbots(sel1, sel2) | 
					
						
						|  | sel1, sel2 = instantiate_select_boxes(sel1, sel2, model_labels) | 
					
						
						|  | return sel1, chatbot1, sel2, chatbot2 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | with gr.Blocks(fill_width=True, title="Keras demo") as demo: | 
					
						
						|  |  | 
					
						
						|  | with gr.Row(): | 
					
						
						|  | gr.Image( | 
					
						
						|  | "img/keras_logo_k.png", | 
					
						
						|  | width=80, | 
					
						
						|  | height=80, | 
					
						
						|  | min_width=80, | 
					
						
						|  | show_label=False, | 
					
						
						|  | show_download_button=False, | 
					
						
						|  | show_fullscreen_button=False, | 
					
						
						|  | interactive=False, | 
					
						
						|  | scale=0.01, | 
					
						
						|  | container=False, | 
					
						
						|  | ) | 
					
						
						|  | gr.HTML( | 
					
						
						|  | "<H2> Battle of the Keras chatbots on TPU</H2>" | 
					
						
						|  | + "All the models are loaded into the TPU memory. " | 
					
						
						|  | + "You can call them at will and compare their answers. <br/>" | 
					
						
						|  | + "The entire chat history is fed to the models at every submission." | 
					
						
						|  | + "This demno is runnig on a Google TPU v5e 2x4 (8 cores).", | 
					
						
						|  | ) | 
					
						
						|  | with gr.Row(): | 
					
						
						|  | sel1, sel2 = instantiate_select_boxes(0, 1, model_labels_list) | 
					
						
						|  |  | 
					
						
						|  | with gr.Row(): | 
					
						
						|  | chatbot1, chatbot2 = instantiate_chatbots(sel1.value, sel2.value) | 
					
						
						|  |  | 
					
						
						|  | msg = gr.Textbox( | 
					
						
						|  | label="Your message:", | 
					
						
						|  | ) | 
					
						
						|  | with gr.Row(): | 
					
						
						|  | gr.ClearButton([msg, chatbot1, chatbot2]) | 
					
						
						|  | with gr.Accordion("Additional settings", open=False): | 
					
						
						|  | system_message = gr.Textbox( | 
					
						
						|  | label="Sytem prompt", | 
					
						
						|  | value="You are a helpful assistant and your name is Eliza.", | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | sel1.select( | 
					
						
						|  | lambda sel1, sel2: instantiate_chatbots_and_select_boxes( | 
					
						
						|  | sel1, sel2, model_labels_list | 
					
						
						|  | ), | 
					
						
						|  | inputs=[sel1, sel2], | 
					
						
						|  | outputs=[sel1, chatbot1, sel2, chatbot2], | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | sel2.select( | 
					
						
						|  | lambda sel1, sel2: instantiate_chatbots_and_select_boxes( | 
					
						
						|  | sel1, sel2, model_labels_list | 
					
						
						|  | ), | 
					
						
						|  | inputs=[sel1, sel2], | 
					
						
						|  | outputs=[sel1, chatbot1, sel2, chatbot2], | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | msg.submit( | 
					
						
						|  | chat_turn_user, | 
					
						
						|  | inputs=[msg, chatbot1, chatbot2], | 
					
						
						|  | outputs=[msg, chatbot1, chatbot2], | 
					
						
						|  | ).then( | 
					
						
						|  | chat_turn_assistant, | 
					
						
						|  | [msg, sel1, chatbot1, sel2, chatbot2, system_message], | 
					
						
						|  | outputs=[msg, chatbot1, chatbot2], | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if __name__ == "__main__": | 
					
						
						|  | demo.launch() | 
					
						
						|  |  |