File size: 6,281 Bytes
36942d4
a7a20a5
39c555f
f5f805b
39c555f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
644b0a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39c555f
 
 
 
644b0a5
a7a20a5
39c555f
a7a20a5
 
 
 
 
 
39c555f
a7a20a5
 
 
 
 
 
 
 
 
 
644b0a5
a7a20a5
 
 
644b0a5
 
a7a20a5
39c555f
644b0a5
 
a7a20a5
644b0a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ecb51d
 
 
 
 
 
 
 
 
 
 
 
 
644b0a5
6ecb51d
644b0a5
 
 
 
6ecb51d
52a9a97
6ecb51d
 
644b0a5
 
6ecb51d
 
644b0a5
 
 
 
6ecb51d
 
 
 
 
 
 
 
 
 
 
 
644b0a5
 
 
 
 
 
52a9a97
 
 
 
 
644b0a5
52a9a97
644b0a5
 
 
 
6ecb51d
644b0a5
 
 
6ecb51d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341bd22
f5f805b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import os
import threading
from collections import defaultdict

import gradio as gr
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TextIteratorStreamer,
)

# Define model paths
model_name_to_path = {
    "LeCarnet-3M": "MaxLSB/LeCarnet-3M",
    "LeCarnet-8M": "MaxLSB/LeCarnet-8M",
    "LeCarnet-21M": "MaxLSB/LeCarnet-21M",
}

# Load Hugging Face token
hf_token = os.environ["HUGGINGFACEHUB_API_TOKEN"]

# Preload models and tokenizers
loaded_models = defaultdict(dict)

for name, path in model_name_to_path.items():
    loaded_models[name]["tokenizer"] = AutoTokenizer.from_pretrained(path, token=hf_token)
    loaded_models[name]["model"] = AutoModelForCausalLM.from_pretrained(path, token=hf_token)
    loaded_models[name]["model"].eval()

def respond(message, history, model_name, max_tokens, temperature, top_p):
    """
    Generate a response from the selected model, streaming the output and updating chat history.
    
    Args:
        message (str): User's input message.
        history (list): Current chat history as list of (user_msg, bot_msg) tuples.
        model_name (str): Selected model name.
        max_tokens (int): Maximum number of tokens to generate.
        temperature (float): Sampling temperature.
        top_p (float): Top-p sampling parameter.
    
    Yields:
        list: Updated chat history with the user's message and streaming bot response.
    """
    # Append user's message to history with an empty bot response
    history = history + [(message, "")]
    yield history  # Display user's message immediately

    # Select tokenizer and model
    tokenizer = loaded_models[model_name]["tokenizer"]
    model = loaded_models[model_name]["model"]

    # Tokenize input
    inputs = tokenizer(message, return_tensors="pt")

    # Set up streaming
    streamer = TextIteratorStreamer(
        tokenizer,
        skip_prompt=False,
        skip_special_tokens=True,
    )

    # Configure generation parameters
    generate_kwargs = dict(
        **inputs,
        streamer=streamer,
        max_new_tokens=max_tokens,
        do_sample=True,
        temperature=temperature,
        top_p=top_p,
        eos_token_id=tokenizer.eos_token_id,
    )

    # Start generation in a background thread
    thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
    thread.start()

    # Stream the response with model name prefix
    accumulated = f"**{model_name}:** "
    for new_text in streamer:
        accumulated += new_text
        history[-1] = (message, accumulated)
        yield history

def submit(message, history, model_name, max_tokens, temperature, top_p):
    """
    Handle form submission by calling respond and clearing the input box.
    
    Args:
        message (str): User's input message.
        history (list): Current chat history.
        model_name (str): Selected model name.
        max_tokens (int): Max tokens parameter.
        temperature (float): Temperature parameter.
        top_p (float): Top-p parameter.
    
    Yields:
        tuple: (updated chat history, cleared user input)
    """
    for updated_history in respond(message, history, model_name, max_tokens, temperature, top_p):
        yield updated_history, ""

def select_model(model_name, current_model):
    """
    Update the selected model name when a model button is clicked.
    
    Args:
        model_name (str): The model name to select.
        current_model (str): The currently selected model.
    
    Returns:
        str: The newly selected model name.
    """
    return model_name

# Create the Gradio interface with Blocks
with gr.Blocks(css=".gr-button {margin: 5px; width: 100%;} .gr-column {padding: 10px;}") as demo:
    # Title and description
    gr.Markdown("# LeCarnet")
    gr.Markdown("Select a model on the right and type a message to chat.")

    # Two-column layout with specific widths
    with gr.Row():
        # Left column: Chat interface (80% width)
        with gr.Column(scale=4):
            chatbot = gr.Chatbot(
                avatar_images=(None, "media/le-carnet.png"),  # User avatar: None, Bot avatar: Logo
                label="Chat",
                height=600,  # Increase chat height for larger display
            )
            user_input = gr.Textbox(placeholder="Type your message here...", label="Message")
            submit_btn = gr.Button("Send")

        # Right column: Model selection and parameters (20% width)
        with gr.Column(scale=1, min_width=200):
            # State to track selected model
            model_state = gr.State(value="LeCarnet-8M")

            # Model selection buttons
            gr.Markdown("**Select Model**")
            btn_3m = gr.Button("LeCarnet-3M")
            btn_8m = gr.Button("LeCarnet-8M")
            btn_21m = gr.Button("LeCarnet-21M")

            # Sliders for parameters
            max_tokens = gr.Slider(1, 512, value=512, step=1, label="Max New Tokens")
            temperature = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
            top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")

    # Example prompts
    examples = gr.Examples(
        examples=[
            ["Il était une fois un petit garçon qui vivait dans un village paisible."],
            ["Il était une fois une grenouille qui rêvait de toucher les étoiles chaque nuit depuis son étang."],
            ["Il était une fois un petit lapin perdu"],
        ],
        inputs=user_input,
    )

    # Event handling for submit button
    submit_btn.click(
        fn=submit,
        inputs=[user_input, chatbot, model_state, max_tokens, temperature, top_p],
        outputs=[chatbot, user_input],
    )

    # Event handling for model selection buttons
    btn_3m.click(
        fn=select_model,
        inputs=[gr.State("LeCarnet-3M"), model_state],
        outputs=model_state,
    )
    btn_8m.click(
        fn=select_model,
        inputs=[gr.State("LeCarnet-8M"), model_state],
        outputs=model_state,
    )
    btn_21m.click(
        fn=select_model,
        inputs=[gr.State("LeCarnet-21M"), model_state],
        outputs=model_state,
    )

if __name__ == "__main__":
    demo.queue(default_concurrency_limit=10, max_size=10).launch(ssr_mode=False, max_threads=10)