File size: 3,937 Bytes
36942d4
a7a20a5
39c555f
790cffd
954f37f
790cffd
 
954f37f
790cffd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90d1b16
954f37f
a7a20a5
 
 
 
 
 
 
 
 
954f37f
a7a20a5
 
954f37f
790cffd
a7a20a5
790cffd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c32f52a
 
790cffd
 
7811152
f19d748
99f5fa0
 
f19d748
 
b7b0fd1
 
52a9a97
790cffd
 
 
 
 
 
 
 
 
 
b7b0fd1
790cffd
b7b0fd1
6ecb51d
644b0a5
790cffd
 
 
 
 
 
644b0a5
790cffd
954f37f
 
 
 
 
790cffd
 
7b4f2fa
954f37f
790cffd
 
 
 
a167f72
790cffd
6ecb51d
341bd22
790cffd
f19d748
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
import threading
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

# Hugging Face token
hf_token = os.environ["HUGGINGFACEHUB_API_TOKEN"]

# Global model & tokenizer
tokenizer = None
model = None

# Load selected model
def load_model(model_name):
    global tokenizer, model
    full_model_name = f"MaxLSB/{model_name}"
    tokenizer = AutoTokenizer.from_pretrained(full_model_name, token=hf_token)
    model = AutoModelForCausalLM.from_pretrained(full_model_name, token=hf_token)
    model.eval()

# Initialize default model
load_model("LeCarnet-8M")

# Streaming generation function
def respond(message, max_tokens, temperature, top_p):
    inputs = tokenizer(message, return_tensors="pt")
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=False, skip_special_tokens=True)

    generate_kwargs = dict(
        **inputs,
        streamer=streamer,
        max_new_tokens=max_tokens,
        do_sample=True,
        temperature=temperature,
        top_p=top_p,
        eos_token_id=tokenizer.eos_token_id,
    )

    thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
    thread.start()

    response = ""
    for new_text in streamer:
        response += new_text
        yield response

# User input handler
def user(message, chat_history):
    chat_history.append([message, None])
    return "", chat_history

# Bot response handler
def bot(chatbot, max_tokens, temperature, top_p):
    message = chatbot[-1][0]
    response_generator = respond(message, max_tokens, temperature, top_p)
    for response in response_generator:
        chatbot[-1][1] = response
        yield chatbot

# Model selector handler
def update_model(model_name):
    load_model(model_name)
    return []

image_path = "static/le-carnet.png"

# Gradio UI
with gr.Blocks(title="LeCarnet - Chat Interface") as demo:
    with gr.Row():
        gr.HTML("""
        <div style="text-align: center; width: 100%;">
            <h1 style="margin: 0;">LeCarnet</h1>
        </div>
        """)

    # Main layout
    with gr.Row():
        # Left column: Options
        with gr.Column(scale=1, min_width=150):
            model_selector = gr.Dropdown(
                choices=["LeCarnet-3M", "LeCarnet-8M", "LeCarnet-21M"],
                value="LeCarnet-8M",
                label="Select Model"
            )
            max_tokens = gr.Slider(1, 512, value=512, step=1, label="Max New Tokens")
            temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
            top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p Sampling")
            clear_button = gr.Button("Clear Chat")

        # Right column: Chat
        with gr.Column(scale=4):
            chatbot = gr.Chatbot(
                bubble_full_width=False,
                height=500
            )
            msg_input = gr.Textbox(
                placeholder="Type your message and press Enter...",
                label="User Input"
            )
            gr.Examples(
                examples=[
                    ["Il était une fois un petit garçon qui vivait dans un village paisible."],
                    ["Il était une fois une grenouille qui rêvait de toucher les étoiles chaque nuit depuis son étang."],
                    ["Il était une fois un petit lapin perdu"],
                ],
                inputs=msg_input,
                label="Example Prompts"
            )

    # Event handlers
    model_selector.change(fn=update_model, inputs=[model_selector], outputs=[])
    msg_input.submit(fn=user, inputs=[msg_input, chatbot], outputs=[msg_input, chatbot], queue=False).then(
        fn=bot, inputs=[chatbot, max_tokens, temperature, top_p], outputs=[chatbot]
    )
    clear_button.click(fn=lambda: None, inputs=None, outputs=chatbot, queue=False)

if __name__ == "__main__":
    demo.queue()
    demo.launch(allowed_paths=["media/"])