Spaces:
Sleeping
Sleeping
File size: 4,099 Bytes
36942d4 f6b834f a7a20a5 39c555f 790cffd 954f37f 790cffd 954f37f f6b834f 3a38c1f 790cffd 3a38c1f 790cffd 3a38c1f 790cffd 3a38c1f 790cffd 90d1b16 954f37f a7a20a5 954f37f f6b834f a7a20a5 954f37f 790cffd a7a20a5 790cffd 3a38c1f 790cffd 7811152 f19d748 99f5fa0 07012cb f19d748 3a38c1f b7b0fd1 52a9a97 790cffd b7b0fd1 790cffd 6ecb51d 644b0a5 790cffd 644b0a5 790cffd 954f37f 2d0a01f 954f37f 790cffd 7b4f2fa 954f37f 790cffd a167f72 790cffd 6ecb51d 341bd22 3a38c1f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import os
import torch
import threading
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
# Hugging Face token
hf_token = os.environ["HUGGINGFACEHUB_API_TOKEN"]
torch.set_num_threads(4)
# Globals
tokenizer = None
model = None
current_model_name = None
# Load selected model
def load_model(model_name):
global tokenizer, model, current_model_name
full_model_name = f"MaxLSB/{model_name}"
tokenizer = AutoTokenizer.from_pretrained(full_model_name, token=hf_token)
model = AutoModelForCausalLM.from_pretrained(full_model_name, token=hf_token)
model.eval()
current_model_name = model_name
# Initialize default model
load_model("LeCarnet-8M")
# Streaming generation function
def respond(message, max_tokens, temperature, top_p):
inputs = tokenizer(message, return_tensors="pt")
streamer = TextIteratorStreamer(tokenizer, skip_prompt=False, skip_special_tokens=True)
generate_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=max_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
eos_token_id=tokenizer.eos_token_id,
)
def run():
with torch.no_grad():
model.generate(**generate_kwargs)
thread = threading.Thread(target=run)
thread.start()
response = ""
for new_text in streamer:
response += new_text
# prepend model name on its own line
yield f"**Model: {current_model_name}**\n\n{response}"
# User input handler
def user(message, chat_history):
chat_history.append([message, None])
return "", chat_history
# Bot response handler
def bot(chatbot, max_tokens, temperature, top_p):
message = chatbot[-1][0]
response_generator = respond(message, max_tokens, temperature, top_p)
for response in response_generator:
chatbot[-1][1] = response
yield chatbot
# Model selector handler
def update_model(model_name):
load_model(model_name)
return []
# Gradio UI
with gr.Blocks(title="LeCarnet - Chat Interface") as demo:
with gr.Row():
gr.HTML("""
<div style="text-align: center; width: 100%;">
<h1 style="margin: 0;">LeCarnet Demo 📊</h1>
</div>
""" )
with gr.Row():
with gr.Column(scale=1, min_width=150):
model_selector = gr.Dropdown(
choices=["LeCarnet-3M", "LeCarnet-8M", "LeCarnet-21M"],
value="LeCarnet-8M",
label="Select Model"
)
max_tokens = gr.Slider(1, 512, value=512, step=1, label="Max New Tokens")
temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p Sampling")
clear_button = gr.Button("Clear Chat")
with gr.Column(scale=4):
chatbot = gr.Chatbot(
bubble_full_width=False,
height=500
)
msg_input = gr.Textbox(
placeholder="Type your message and press Enter...",
label="User Input"
)
gr.Examples(
examples=[
["Il était une fois un petit renard nommé Roux. Roux aimait jouer dans la forêt."],
["Dans un petit village, il y avait un jardin magnifique."],
["Il était une fois un petit garçon nommé Tom. Tom aimait beaucoup dessiner."],
],
inputs=msg_input,
label="Example Prompts"
)
model_selector.change(fn=update_model, inputs=[model_selector], outputs=[])
msg_input.submit(fn=user, inputs=[msg_input, chatbot], outputs=[msg_input, chatbot], queue=False).then(
fn=bot, inputs=[chatbot, max_tokens, temperature, top_p], outputs=[chatbot]
)
clear_button.click(fn=lambda: None, inputs=None, outputs=chatbot, queue=False)
if __name__ == "__main__":
demo.queue(default_concurrency_limit=10, max_size=10).launch(ssr_mode=False, max_threads=10)
|