File size: 6,226 Bytes
3cbe237
 
 
 
 
 
 
 
 
 
9a4f7ce
3cbe237
ad84a41
3609e1c
3cbe237
ad84a41
3cbe237
 
4f9f77d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a42e77
4f9f77d
 
 
 
 
 
 
 
 
3cbe237
 
 
 
 
 
0d92992
3cbe237
 
0d92992
3cbe237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df0e5d2
 
 
 
71734ec
 
 
 
 
 
3cbe237
 
 
 
 
3a42e77
 
 
 
 
 
 
 
 
 
 
4f9f77d
ad84a41
3cbe237
 
 
 
 
3609e1c
 
3cbe237
 
 
 
3a42e77
 
 
 
 
 
3cbe237
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import json
import subprocess
from threading import Thread
import os
import torch
import spaces
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer

# Update model configuration for Mistral-small-24B
MODEL_ID = "baconnier/Napoleon_24B_V0.0"
CHAT_TEMPLATE = "mistral"  # Mistral uses its own chat template
MODEL_NAME = "NAPOLEON"
CONTEXT_LENGTH = 2048  # Mistral supports longer context
COLOR = "black"
EMOJI = "🦅"  # Mistral-themed emoji
DESCRIPTION = f"This is {MODEL_NAME} model, a powerful 24B parameter language model from Mistral AI."

css = """
.message-row {
    justify-content: space-evenly !important;
}
.message-bubble-border {
    border-radius: 6px !important;
}
.dark.message-bubble-border {
    border-color: #21293b !important;
}
.dark.user {
    background: #0a1120 !important;
}
.dark.assistant {
    background: transparent !important;
}
"""

PLACEHOLDER = """
<div class="message-bubble-border" style="display:flex; max-width: 600px; border-radius: 8px; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); backdrop-filter: blur(10px);">
    <figure style="margin: 0;">
        <img src="https://huggingface.co/spaces/baconnier/NAPOLEON2/resolve/main/napoleon.jpg" style="width: 100%; height: 100%; border-radius: 8px;">
    </figure>
    <div style="padding: .5rem 1.5rem;">
        <h2 style="text-align: left; font-size: 1.5rem; font-weight: 700; margin-bottom: 0.5rem;"> </h2>
        <p style="text-align: left; font-size: 16px; line-height: 1.5; margin-bottom: 15px;">Notre sagesse vient de notre expérience, et notre expérience vient de nos sottises</p>
    </div>    
</div>
"""


def load_system_message():
    try:
        with open('system_message.txt', 'r', encoding='utf-8') as file:
            return file.read().strip()
    except FileNotFoundError:
        print("Warning: system_message.txt not found. Using default message.")
        return "Tu es Napoleon, reponds uniquement en francais."
    except Exception as e:
        print(f"Error loading system message: {e}")
        return "Tu es Napoleon, reponds uniquement en francais."

SYSTEM_MESSAGE = load_system_message()

@spaces.GPU()
def predict(message, history, system_prompt, temperature, max_new_tokens, top_k, repetition_penalty, top_p):
    # Format history using Mistral's chat template
    messages = [{"role": "system", "content": SYSTEM_MESSAGE}]
    
    for user, assistant in history:
        messages.append({"role": "user", "content": user})
        messages.append({"role": "assistant", "content": assistant})
    
    messages.append({"role": "user", "content": message})
    
    # Convert messages to Mistral format
    prompt = tokenizer.apply_chat_template(messages, tokenize=False)
    
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    enc = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
    input_ids, attention_mask = enc.input_ids, enc.attention_mask

    if input_ids.shape[1] > CONTEXT_LENGTH:
        input_ids = input_ids[:, -CONTEXT_LENGTH:]
        attention_mask = attention_mask[:, -CONTEXT_LENGTH:]

    generate_kwargs = dict(
        input_ids=input_ids.to(device),
        attention_mask=attention_mask.to(device),
        streamer=streamer,
        do_sample=True,
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_k=top_k,
        repetition_penalty=repetition_penalty,
        top_p=top_p
    )

    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for new_token in streamer:
        outputs.append(new_token)
        yield "".join(outputs)

# Load model with optimized settings for Mistral-24B
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    use_double_quant=True,  # Enable double quantization
    bnb_4bit_quant_type="nf4"  # Use normal float 4 for better precision
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
# Set the pad token to be the same as the end of sequence token
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id
    
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16
)

# Create Gradio interface
gr.ChatInterface(
    predict,
    theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="blue", neutral_hue="gray",font=[gr.themes.GoogleFont("Exo"), "ui-sans-serif", "system-ui", "sans-serif"]).set(
        body_background_fill_dark="#0f172a",
        block_background_fill_dark="#0f172a",
        block_border_width="1px",
        block_title_background_fill_dark="#070d1b",
        #input_background_fill_dark="#0c1425",
        button_secondary_background_fill_dark="#070d1b",
        border_color_primary_dark="#21293b",
        background_fill_secondary_dark="#0f172a",
        color_accent_soft_dark="transparent"
    ),
    css=css,
    title=EMOJI + " " + MODEL_NAME+"  🇫🇷",
    description=DESCRIPTION,
    additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False),
    additional_inputs=[
        gr.Textbox(SYSTEM_MESSAGE, label="System prompt", visible=False),  # Hidden system prompt
        gr.Slider(0, 1, 0.7, label="Temperature"),  # Adjusted default for Mistral
        gr.Slider(0, 2048, 1024, label="Max new tokens"),  # Increased for longer context
        #gr.Slider(0, 32768, 12000, label="Max new tokens"),  # Increased for longer context
        gr.Slider(1, 100, 50, label="Top K sampling"),
        gr.Slider(0, 2, 1.1, label="Repetition penalty"),
        gr.Slider(0, 1, 0.95, label="Top P sampling"),
    ],
    chatbot=gr.Chatbot(scale=1, placeholder=PLACEHOLDER),
    examples=[
        ['Comment un Français peut-il survivre sans fromage pendant plus de 24h ?'],
        ['Pourquoi les serveurs parisiens sont-ils si "charmants" avec les touristes ?'],
        ['Est-il vrai que les Français font la grève plus souvent qu ils ne travaillent ?'],
        ],    
).queue().launch()