Spaces:
Runtime error
Runtime error
File size: 6,226 Bytes
3cbe237 9a4f7ce 3cbe237 ad84a41 3609e1c 3cbe237 ad84a41 3cbe237 4f9f77d 3a42e77 4f9f77d 3cbe237 0d92992 3cbe237 0d92992 3cbe237 df0e5d2 71734ec 3cbe237 3a42e77 4f9f77d ad84a41 3cbe237 3609e1c 3cbe237 3a42e77 3cbe237 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import json
import subprocess
from threading import Thread
import os
import torch
import spaces
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
# Update model configuration for Mistral-small-24B
MODEL_ID = "baconnier/Napoleon_24B_V0.0"
CHAT_TEMPLATE = "mistral" # Mistral uses its own chat template
MODEL_NAME = "NAPOLEON"
CONTEXT_LENGTH = 2048 # Mistral supports longer context
COLOR = "black"
EMOJI = "🦅" # Mistral-themed emoji
DESCRIPTION = f"This is {MODEL_NAME} model, a powerful 24B parameter language model from Mistral AI."
css = """
.message-row {
justify-content: space-evenly !important;
}
.message-bubble-border {
border-radius: 6px !important;
}
.dark.message-bubble-border {
border-color: #21293b !important;
}
.dark.user {
background: #0a1120 !important;
}
.dark.assistant {
background: transparent !important;
}
"""
PLACEHOLDER = """
<div class="message-bubble-border" style="display:flex; max-width: 600px; border-radius: 8px; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); backdrop-filter: blur(10px);">
<figure style="margin: 0;">
<img src="https://huggingface.co/spaces/baconnier/NAPOLEON2/resolve/main/napoleon.jpg" style="width: 100%; height: 100%; border-radius: 8px;">
</figure>
<div style="padding: .5rem 1.5rem;">
<h2 style="text-align: left; font-size: 1.5rem; font-weight: 700; margin-bottom: 0.5rem;"> </h2>
<p style="text-align: left; font-size: 16px; line-height: 1.5; margin-bottom: 15px;">Notre sagesse vient de notre expérience, et notre expérience vient de nos sottises</p>
</div>
</div>
"""
def load_system_message():
try:
with open('system_message.txt', 'r', encoding='utf-8') as file:
return file.read().strip()
except FileNotFoundError:
print("Warning: system_message.txt not found. Using default message.")
return "Tu es Napoleon, reponds uniquement en francais."
except Exception as e:
print(f"Error loading system message: {e}")
return "Tu es Napoleon, reponds uniquement en francais."
SYSTEM_MESSAGE = load_system_message()
@spaces.GPU()
def predict(message, history, system_prompt, temperature, max_new_tokens, top_k, repetition_penalty, top_p):
# Format history using Mistral's chat template
messages = [{"role": "system", "content": SYSTEM_MESSAGE}]
for user, assistant in history:
messages.append({"role": "user", "content": user})
messages.append({"role": "assistant", "content": assistant})
messages.append({"role": "user", "content": message})
# Convert messages to Mistral format
prompt = tokenizer.apply_chat_template(messages, tokenize=False)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
enc = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
input_ids, attention_mask = enc.input_ids, enc.attention_mask
if input_ids.shape[1] > CONTEXT_LENGTH:
input_ids = input_ids[:, -CONTEXT_LENGTH:]
attention_mask = attention_mask[:, -CONTEXT_LENGTH:]
generate_kwargs = dict(
input_ids=input_ids.to(device),
attention_mask=attention_mask.to(device),
streamer=streamer,
do_sample=True,
temperature=temperature,
max_new_tokens=max_new_tokens,
top_k=top_k,
repetition_penalty=repetition_penalty,
top_p=top_p
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for new_token in streamer:
outputs.append(new_token)
yield "".join(outputs)
# Load model with optimized settings for Mistral-24B
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
use_double_quant=True, # Enable double quantization
bnb_4bit_quant_type="nf4" # Use normal float 4 for better precision
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
# Set the pad token to be the same as the end of sequence token
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16
)
# Create Gradio interface
gr.ChatInterface(
predict,
theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="blue", neutral_hue="gray",font=[gr.themes.GoogleFont("Exo"), "ui-sans-serif", "system-ui", "sans-serif"]).set(
body_background_fill_dark="#0f172a",
block_background_fill_dark="#0f172a",
block_border_width="1px",
block_title_background_fill_dark="#070d1b",
#input_background_fill_dark="#0c1425",
button_secondary_background_fill_dark="#070d1b",
border_color_primary_dark="#21293b",
background_fill_secondary_dark="#0f172a",
color_accent_soft_dark="transparent"
),
css=css,
title=EMOJI + " " + MODEL_NAME+" 🇫🇷",
description=DESCRIPTION,
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False),
additional_inputs=[
gr.Textbox(SYSTEM_MESSAGE, label="System prompt", visible=False), # Hidden system prompt
gr.Slider(0, 1, 0.7, label="Temperature"), # Adjusted default for Mistral
gr.Slider(0, 2048, 1024, label="Max new tokens"), # Increased for longer context
#gr.Slider(0, 32768, 12000, label="Max new tokens"), # Increased for longer context
gr.Slider(1, 100, 50, label="Top K sampling"),
gr.Slider(0, 2, 1.1, label="Repetition penalty"),
gr.Slider(0, 1, 0.95, label="Top P sampling"),
],
chatbot=gr.Chatbot(scale=1, placeholder=PLACEHOLDER),
examples=[
['Comment un Français peut-il survivre sans fromage pendant plus de 24h ?'],
['Pourquoi les serveurs parisiens sont-ils si "charmants" avec les touristes ?'],
['Est-il vrai que les Français font la grève plus souvent qu ils ne travaillent ?'],
],
).queue().launch() |