File size: 4,099 Bytes
36942d4
f6b834f
a7a20a5
39c555f
790cffd
954f37f
790cffd
 
954f37f
f6b834f
 
3a38c1f
790cffd
 
3a38c1f
790cffd
 
 
3a38c1f
790cffd
 
 
 
3a38c1f
790cffd
 
 
 
 
 
 
90d1b16
954f37f
a7a20a5
 
 
 
 
 
 
 
 
954f37f
f6b834f
 
 
 
 
a7a20a5
954f37f
790cffd
a7a20a5
790cffd
3a38c1f
 
790cffd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7811152
f19d748
99f5fa0
07012cb
f19d748
3a38c1f
b7b0fd1
52a9a97
790cffd
 
 
 
 
 
 
 
 
b7b0fd1
790cffd
6ecb51d
644b0a5
790cffd
 
 
 
 
 
644b0a5
790cffd
954f37f
2d0a01f
 
 
954f37f
790cffd
 
7b4f2fa
954f37f
790cffd
 
 
a167f72
790cffd
6ecb51d
341bd22
3a38c1f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
import torch
import threading
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

# Hugging Face token
hf_token = os.environ["HUGGINGFACEHUB_API_TOKEN"]

torch.set_num_threads(4)

# Globals
tokenizer = None
model = None
current_model_name = None

# Load selected model
def load_model(model_name):
    global tokenizer, model, current_model_name
    full_model_name = f"MaxLSB/{model_name}"
    tokenizer = AutoTokenizer.from_pretrained(full_model_name, token=hf_token)
    model = AutoModelForCausalLM.from_pretrained(full_model_name, token=hf_token)
    model.eval()
    current_model_name = model_name

# Initialize default model
load_model("LeCarnet-8M")

# Streaming generation function
def respond(message, max_tokens, temperature, top_p):
    inputs = tokenizer(message, return_tensors="pt")
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=False, skip_special_tokens=True)

    generate_kwargs = dict(
        **inputs,
        streamer=streamer,
        max_new_tokens=max_tokens,
        do_sample=True,
        temperature=temperature,
        top_p=top_p,
        eos_token_id=tokenizer.eos_token_id,
    )

    def run():
        with torch.no_grad():
            model.generate(**generate_kwargs)

    thread = threading.Thread(target=run)
    thread.start()

    response = ""
    for new_text in streamer:
        response += new_text
        # prepend model name on its own line
        yield f"**Model: {current_model_name}**\n\n{response}"

# User input handler
def user(message, chat_history):
    chat_history.append([message, None])
    return "", chat_history

# Bot response handler
def bot(chatbot, max_tokens, temperature, top_p):
    message = chatbot[-1][0]
    response_generator = respond(message, max_tokens, temperature, top_p)
    for response in response_generator:
        chatbot[-1][1] = response
        yield chatbot

# Model selector handler
def update_model(model_name):
    load_model(model_name)
    return []

# Gradio UI
with gr.Blocks(title="LeCarnet - Chat Interface") as demo:
    with gr.Row():
        gr.HTML("""
        <div style="text-align: center; width: 100%;">
            <h1 style="margin: 0;">LeCarnet Demo 📊</h1>
        </div>
        """ )

    with gr.Row():
        with gr.Column(scale=1, min_width=150):
            model_selector = gr.Dropdown(
                choices=["LeCarnet-3M", "LeCarnet-8M", "LeCarnet-21M"],
                value="LeCarnet-8M",
                label="Select Model"
            )
            max_tokens = gr.Slider(1, 512, value=512, step=1, label="Max New Tokens")
            temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
            top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p Sampling")
            clear_button = gr.Button("Clear Chat")

        with gr.Column(scale=4):
            chatbot = gr.Chatbot(
                bubble_full_width=False,
                height=500
            )
            msg_input = gr.Textbox(
                placeholder="Type your message and press Enter...",
                label="User Input"
            )
            gr.Examples(
                examples=[
                    ["Il était une fois un petit renard nommé Roux. Roux aimait jouer dans la forêt."],
                    ["Dans un petit village, il y avait un jardin magnifique."],
                    ["Il était une fois un petit garçon nommé Tom. Tom aimait beaucoup dessiner."],
                ],
                inputs=msg_input,
                label="Example Prompts"
            )

    model_selector.change(fn=update_model, inputs=[model_selector], outputs=[])
    msg_input.submit(fn=user, inputs=[msg_input, chatbot], outputs=[msg_input, chatbot], queue=False).then(
        fn=bot, inputs=[chatbot, max_tokens, temperature, top_p], outputs=[chatbot]
    )
    clear_button.click(fn=lambda: None, inputs=None, outputs=chatbot, queue=False)

if __name__ == "__main__":
    demo.queue(default_concurrency_limit=10, max_size=10).launch(ssr_mode=False, max_threads=10)