LeCarnet-Demo / app.py
MaxLSB's picture
Update app.py
99f5fa0 verified
raw
history blame
3.94 kB
import os
import threading
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
# Hugging Face token
hf_token = os.environ["HUGGINGFACEHUB_API_TOKEN"]
# Global model & tokenizer
tokenizer = None
model = None
# Load selected model
def load_model(model_name):
global tokenizer, model
full_model_name = f"MaxLSB/{model_name}"
tokenizer = AutoTokenizer.from_pretrained(full_model_name, token=hf_token)
model = AutoModelForCausalLM.from_pretrained(full_model_name, token=hf_token)
model.eval()
# Initialize default model
load_model("LeCarnet-8M")
# Streaming generation function
def respond(message, max_tokens, temperature, top_p):
inputs = tokenizer(message, return_tensors="pt")
streamer = TextIteratorStreamer(tokenizer, skip_prompt=False, skip_special_tokens=True)
generate_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=max_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
eos_token_id=tokenizer.eos_token_id,
)
thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
response = ""
for new_text in streamer:
response += new_text
yield response
# User input handler
def user(message, chat_history):
chat_history.append([message, None])
return "", chat_history
# Bot response handler
def bot(chatbot, max_tokens, temperature, top_p):
message = chatbot[-1][0]
response_generator = respond(message, max_tokens, temperature, top_p)
for response in response_generator:
chatbot[-1][1] = response
yield chatbot
# Model selector handler
def update_model(model_name):
load_model(model_name)
return []
image_path = "static/le-carnet.png"
# Gradio UI
with gr.Blocks(title="LeCarnet - Chat Interface") as demo:
with gr.Row():
gr.HTML("""
<div style="text-align: center; width: 100%;">
<h1 style="margin: 0;">LeCarnet</h1>
</div>
""")
# Main layout
with gr.Row():
# Left column: Options
with gr.Column(scale=1, min_width=150):
model_selector = gr.Dropdown(
choices=["LeCarnet-3M", "LeCarnet-8M", "LeCarnet-21M"],
value="LeCarnet-8M",
label="Select Model"
)
max_tokens = gr.Slider(1, 512, value=512, step=1, label="Max New Tokens")
temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p Sampling")
clear_button = gr.Button("Clear Chat")
# Right column: Chat
with gr.Column(scale=4):
chatbot = gr.Chatbot(
bubble_full_width=False,
height=500
)
msg_input = gr.Textbox(
placeholder="Type your message and press Enter...",
label="User Input"
)
gr.Examples(
examples=[
["Il était une fois un petit garçon qui vivait dans un village paisible."],
["Il était une fois une grenouille qui rêvait de toucher les étoiles chaque nuit depuis son étang."],
["Il était une fois un petit lapin perdu"],
],
inputs=msg_input,
label="Example Prompts"
)
# Event handlers
model_selector.change(fn=update_model, inputs=[model_selector], outputs=[])
msg_input.submit(fn=user, inputs=[msg_input, chatbot], outputs=[msg_input, chatbot], queue=False).then(
fn=bot, inputs=[chatbot, max_tokens, temperature, top_p], outputs=[chatbot]
)
clear_button.click(fn=lambda: None, inputs=None, outputs=chatbot, queue=False)
if __name__ == "__main__":
demo.queue()
demo.launch(allowed_paths=["media/"])