Spaces:
Sleeping
Sleeping
File size: 3,937 Bytes
36942d4 a7a20a5 39c555f 790cffd 954f37f 790cffd 954f37f 790cffd 90d1b16 954f37f a7a20a5 954f37f a7a20a5 954f37f 790cffd a7a20a5 790cffd c32f52a 790cffd 7811152 f19d748 99f5fa0 f19d748 b7b0fd1 52a9a97 790cffd b7b0fd1 790cffd b7b0fd1 6ecb51d 644b0a5 790cffd 644b0a5 790cffd 954f37f 790cffd 7b4f2fa 954f37f 790cffd a167f72 790cffd 6ecb51d 341bd22 790cffd f19d748 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import os
import threading
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
# Hugging Face token
hf_token = os.environ["HUGGINGFACEHUB_API_TOKEN"]
# Global model & tokenizer
tokenizer = None
model = None
# Load selected model
def load_model(model_name):
global tokenizer, model
full_model_name = f"MaxLSB/{model_name}"
tokenizer = AutoTokenizer.from_pretrained(full_model_name, token=hf_token)
model = AutoModelForCausalLM.from_pretrained(full_model_name, token=hf_token)
model.eval()
# Initialize default model
load_model("LeCarnet-8M")
# Streaming generation function
def respond(message, max_tokens, temperature, top_p):
inputs = tokenizer(message, return_tensors="pt")
streamer = TextIteratorStreamer(tokenizer, skip_prompt=False, skip_special_tokens=True)
generate_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=max_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
eos_token_id=tokenizer.eos_token_id,
)
thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
response = ""
for new_text in streamer:
response += new_text
yield response
# User input handler
def user(message, chat_history):
chat_history.append([message, None])
return "", chat_history
# Bot response handler
def bot(chatbot, max_tokens, temperature, top_p):
message = chatbot[-1][0]
response_generator = respond(message, max_tokens, temperature, top_p)
for response in response_generator:
chatbot[-1][1] = response
yield chatbot
# Model selector handler
def update_model(model_name):
load_model(model_name)
return []
image_path = "static/le-carnet.png"
# Gradio UI
with gr.Blocks(title="LeCarnet - Chat Interface") as demo:
with gr.Row():
gr.HTML("""
<div style="text-align: center; width: 100%;">
<h1 style="margin: 0;">LeCarnet</h1>
</div>
""")
# Main layout
with gr.Row():
# Left column: Options
with gr.Column(scale=1, min_width=150):
model_selector = gr.Dropdown(
choices=["LeCarnet-3M", "LeCarnet-8M", "LeCarnet-21M"],
value="LeCarnet-8M",
label="Select Model"
)
max_tokens = gr.Slider(1, 512, value=512, step=1, label="Max New Tokens")
temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p Sampling")
clear_button = gr.Button("Clear Chat")
# Right column: Chat
with gr.Column(scale=4):
chatbot = gr.Chatbot(
bubble_full_width=False,
height=500
)
msg_input = gr.Textbox(
placeholder="Type your message and press Enter...",
label="User Input"
)
gr.Examples(
examples=[
["Il était une fois un petit garçon qui vivait dans un village paisible."],
["Il était une fois une grenouille qui rêvait de toucher les étoiles chaque nuit depuis son étang."],
["Il était une fois un petit lapin perdu"],
],
inputs=msg_input,
label="Example Prompts"
)
# Event handlers
model_selector.change(fn=update_model, inputs=[model_selector], outputs=[])
msg_input.submit(fn=user, inputs=[msg_input, chatbot], outputs=[msg_input, chatbot], queue=False).then(
fn=bot, inputs=[chatbot, max_tokens, temperature, top_p], outputs=[chatbot]
)
clear_button.click(fn=lambda: None, inputs=None, outputs=chatbot, queue=False)
if __name__ == "__main__":
demo.queue()
demo.launch(allowed_paths=["media/"]) |