Spaces:
Sleeping
Sleeping
File size: 6,281 Bytes
36942d4 a7a20a5 39c555f f5f805b 39c555f 644b0a5 39c555f 644b0a5 a7a20a5 39c555f a7a20a5 39c555f a7a20a5 644b0a5 a7a20a5 644b0a5 a7a20a5 39c555f 644b0a5 a7a20a5 644b0a5 6ecb51d 644b0a5 6ecb51d 644b0a5 6ecb51d 52a9a97 6ecb51d 644b0a5 6ecb51d 644b0a5 6ecb51d 644b0a5 52a9a97 644b0a5 52a9a97 644b0a5 6ecb51d 644b0a5 6ecb51d 341bd22 f5f805b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
import os
import threading
from collections import defaultdict
import gradio as gr
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TextIteratorStreamer,
)
# Define model paths
model_name_to_path = {
"LeCarnet-3M": "MaxLSB/LeCarnet-3M",
"LeCarnet-8M": "MaxLSB/LeCarnet-8M",
"LeCarnet-21M": "MaxLSB/LeCarnet-21M",
}
# Load Hugging Face token
hf_token = os.environ["HUGGINGFACEHUB_API_TOKEN"]
# Preload models and tokenizers
loaded_models = defaultdict(dict)
for name, path in model_name_to_path.items():
loaded_models[name]["tokenizer"] = AutoTokenizer.from_pretrained(path, token=hf_token)
loaded_models[name]["model"] = AutoModelForCausalLM.from_pretrained(path, token=hf_token)
loaded_models[name]["model"].eval()
def respond(message, history, model_name, max_tokens, temperature, top_p):
"""
Generate a response from the selected model, streaming the output and updating chat history.
Args:
message (str): User's input message.
history (list): Current chat history as list of (user_msg, bot_msg) tuples.
model_name (str): Selected model name.
max_tokens (int): Maximum number of tokens to generate.
temperature (float): Sampling temperature.
top_p (float): Top-p sampling parameter.
Yields:
list: Updated chat history with the user's message and streaming bot response.
"""
# Append user's message to history with an empty bot response
history = history + [(message, "")]
yield history # Display user's message immediately
# Select tokenizer and model
tokenizer = loaded_models[model_name]["tokenizer"]
model = loaded_models[model_name]["model"]
# Tokenize input
inputs = tokenizer(message, return_tensors="pt")
# Set up streaming
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=False,
skip_special_tokens=True,
)
# Configure generation parameters
generate_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=max_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
eos_token_id=tokenizer.eos_token_id,
)
# Start generation in a background thread
thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
# Stream the response with model name prefix
accumulated = f"**{model_name}:** "
for new_text in streamer:
accumulated += new_text
history[-1] = (message, accumulated)
yield history
def submit(message, history, model_name, max_tokens, temperature, top_p):
"""
Handle form submission by calling respond and clearing the input box.
Args:
message (str): User's input message.
history (list): Current chat history.
model_name (str): Selected model name.
max_tokens (int): Max tokens parameter.
temperature (float): Temperature parameter.
top_p (float): Top-p parameter.
Yields:
tuple: (updated chat history, cleared user input)
"""
for updated_history in respond(message, history, model_name, max_tokens, temperature, top_p):
yield updated_history, ""
def select_model(model_name, current_model):
"""
Update the selected model name when a model button is clicked.
Args:
model_name (str): The model name to select.
current_model (str): The currently selected model.
Returns:
str: The newly selected model name.
"""
return model_name
# Create the Gradio interface with Blocks
with gr.Blocks(css=".gr-button {margin: 5px; width: 100%;} .gr-column {padding: 10px;}") as demo:
# Title and description
gr.Markdown("# LeCarnet")
gr.Markdown("Select a model on the right and type a message to chat.")
# Two-column layout with specific widths
with gr.Row():
# Left column: Chat interface (80% width)
with gr.Column(scale=4):
chatbot = gr.Chatbot(
avatar_images=(None, "media/le-carnet.png"), # User avatar: None, Bot avatar: Logo
label="Chat",
height=600, # Increase chat height for larger display
)
user_input = gr.Textbox(placeholder="Type your message here...", label="Message")
submit_btn = gr.Button("Send")
# Right column: Model selection and parameters (20% width)
with gr.Column(scale=1, min_width=200):
# State to track selected model
model_state = gr.State(value="LeCarnet-8M")
# Model selection buttons
gr.Markdown("**Select Model**")
btn_3m = gr.Button("LeCarnet-3M")
btn_8m = gr.Button("LeCarnet-8M")
btn_21m = gr.Button("LeCarnet-21M")
# Sliders for parameters
max_tokens = gr.Slider(1, 512, value=512, step=1, label="Max New Tokens")
temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
# Example prompts
examples = gr.Examples(
examples=[
["Il était une fois un petit garçon qui vivait dans un village paisible."],
["Il était une fois une grenouille qui rêvait de toucher les étoiles chaque nuit depuis son étang."],
["Il était une fois un petit lapin perdu"],
],
inputs=user_input,
)
# Event handling for submit button
submit_btn.click(
fn=submit,
inputs=[user_input, chatbot, model_state, max_tokens, temperature, top_p],
outputs=[chatbot, user_input],
)
# Event handling for model selection buttons
btn_3m.click(
fn=select_model,
inputs=[gr.State("LeCarnet-3M"), model_state],
outputs=model_state,
)
btn_8m.click(
fn=select_model,
inputs=[gr.State("LeCarnet-8M"), model_state],
outputs=model_state,
)
btn_21m.click(
fn=select_model,
inputs=[gr.State("LeCarnet-21M"), model_state],
outputs=model_state,
)
if __name__ == "__main__":
demo.queue(default_concurrency_limit=10, max_size=10).launch(ssr_mode=False, max_threads=10) |