Spaces:
Sleeping
Sleeping
import os | |
import threading | |
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
# Hugging Face token | |
hf_token = os.environ["HUGGINGFACEHUB_API_TOKEN"] | |
# Global model & tokenizer | |
tokenizer = None | |
model = None | |
# Load selected model | |
def load_model(model_name): | |
global tokenizer, model | |
full_model_name = f"MaxLSB/{model_name}" | |
tokenizer = AutoTokenizer.from_pretrained(full_model_name, token=hf_token) | |
model = AutoModelForCausalLM.from_pretrained(full_model_name, token=hf_token) | |
model.eval() | |
# Initialize default model | |
load_model("LeCarnet-8M") | |
# Streaming generation function | |
def respond(message, max_tokens, temperature, top_p): | |
inputs = tokenizer(message, return_tensors="pt") | |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=False, skip_special_tokens=True) | |
generate_kwargs = dict( | |
**inputs, | |
streamer=streamer, | |
max_new_tokens=max_tokens, | |
do_sample=True, | |
temperature=temperature, | |
top_p=top_p, | |
eos_token_id=tokenizer.eos_token_id, | |
) | |
thread = threading.Thread(target=model.generate, kwargs=generate_kwargs) | |
thread.start() | |
response = "" | |
for new_text in streamer: | |
response += new_text | |
yield response | |
# User input handler | |
def user(message, chat_history): | |
chat_history.append([message, None]) | |
return "", chat_history | |
# Bot response handler | |
def bot(chatbot, max_tokens, temperature, top_p): | |
message = chatbot[-1][0] | |
response_generator = respond(message, max_tokens, temperature, top_p) | |
for response in response_generator: | |
chatbot[-1][1] = response | |
yield chatbot | |
# Model selector handler | |
def update_model(model_name): | |
load_model(model_name) | |
return [] | |
image_path = "static/le-carnet.png" | |
# Gradio UI | |
with gr.Blocks(title="LeCarnet - Chat Interface") as demo: | |
with gr.Row(): | |
gr.HTML(""" | |
<div style="text-align: center; width: 100%;"> | |
<h1 style="margin: 0;">LeCarnet</h1> | |
</div> | |
""") | |
# Main layout | |
with gr.Row(): | |
# Left column: Options | |
with gr.Column(scale=1, min_width=150): | |
model_selector = gr.Dropdown( | |
choices=["LeCarnet-3M", "LeCarnet-8M", "LeCarnet-21M"], | |
value="LeCarnet-8M", | |
label="Select Model" | |
) | |
max_tokens = gr.Slider(1, 512, value=512, step=1, label="Max New Tokens") | |
temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature") | |
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p Sampling") | |
clear_button = gr.Button("Clear Chat") | |
# Right column: Chat | |
with gr.Column(scale=4): | |
chatbot = gr.Chatbot( | |
bubble_full_width=False, | |
height=500 | |
) | |
msg_input = gr.Textbox( | |
placeholder="Type your message and press Enter...", | |
label="User Input" | |
) | |
gr.Examples( | |
examples=[ | |
["Il était une fois un petit garçon qui vivait dans un village paisible."], | |
["Il était une fois une grenouille qui rêvait de toucher les étoiles chaque nuit depuis son étang."], | |
["Il était une fois un petit lapin perdu"], | |
], | |
inputs=msg_input, | |
label="Example Prompts" | |
) | |
# Event handlers | |
model_selector.change(fn=update_model, inputs=[model_selector], outputs=[]) | |
msg_input.submit(fn=user, inputs=[msg_input, chatbot], outputs=[msg_input, chatbot], queue=False).then( | |
fn=bot, inputs=[chatbot, max_tokens, temperature, top_p], outputs=[chatbot] | |
) | |
clear_button.click(fn=lambda: None, inputs=None, outputs=chatbot, queue=False) | |
if __name__ == "__main__": | |
demo.queue() | |
demo.launch(allowed_paths=["media/"]) |