File size: 3,224 Bytes
36942d4
a7a20a5
852d26e
a7a20a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341bd22
a7a20a5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
import threading
import gradio as gr
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TextIteratorStreamer,
)

# Configuration
MODEL_NAMES = ["LeCarnet-3M", "LeCarnet-8M", "LeCarnet-21M"]
HF_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
MEDIA_PATH = "media/le-carnet.png"  # Relative path to logo

# Pre-load all tokenizers and models
models = {}
tokenizers = {}
for name in MODEL_NAMES:
    hub_id = f"MaxLSB/LeCarnet-{name.split('-')[-1]}M"
    tokenizers[name] = AutoTokenizer.from_pretrained(hub_id, token=HF_TOKEN)
    models[name] = AutoModelForCausalLM.from_pretrained(hub_id, token=HF_TOKEN)
    models[name].eval()


def respond(
    prompt: str,
    chat_history,
    selected_model: str,
    max_tokens: int,
    temperature: float,
    top_p: float,
):
    """
    Generate a streaming response from the chosen LeCarnet model,
    prepending the logo and model name in the chat bubble.
    """
    tokenizer = tokenizers[selected_model]
    model = models[selected_model]
    inputs = tokenizer(prompt, return_tensors="pt")

    streamer = TextIteratorStreamer(
        tokenizer,
        skip_prompt=False,
        skip_special_tokens=True,
    )

    generate_kwargs = dict(
        **inputs,
        streamer=streamer,
        max_new_tokens=max_tokens,
        do_sample=True,
        temperature=temperature,
        top_p=top_p,
        eos_token_id=tokenizer.eos_token_id,
    )

    # Start generation in background thread
    thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
    thread.start()

    prefix = f"<img src='{MEDIA_PATH}' alt='logo' width='20' style='vertical-align: middle;'/> <strong>{selected_model}</strong>: "
    accumulated = ""
    first = True
    for new_text in streamer:
        if first:
            # include prefix only once at start
            accumulated = prefix + new_text
            first = False
        else:
            accumulated += new_text
        yield accumulated


# Build Gradio ChatInterface
with gr.Blocks() as demo:
    gr.Markdown("# LeCarnet: Short French Stories")
    with gr.Row():
        with gr.Column():
            chat = gr.ChatInterface(
                fn=respond,
                additional_inputs=[
                    gr.Dropdown(MODEL_NAMES, value="LeCarnet-8M", label="Model"),
                    gr.Slider(1, 512, value=512, step=1, label="Max new tokens"),
                    gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature"),
                    gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top‑p"),
                ],
                title="LeCarnet Chat",
                description="Type the beginning of a sentence and watch the model finish it.",
                examples=[
                    ["Il était une fois un petit garçon qui vivait dans un village paisible."],
                    ["Il était une fois une grenouille qui rêvait de toucher les étoiles chaque nuit depuis son étang."],
                    ["Il était une fois un petit lapin perdu"],
                ],
                cache_examples=False,
            )
    
if __name__ == "__main__":
    demo.queue()
    demo.launch()