old_napoleon / app.py
baconnier's picture
Update app.py
4f9f77d verified
raw
history blame
5.74 kB
import json
import subprocess
from threading import Thread
import os
import torch
import spaces
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
# Update model configuration for Mistral-small-24B
MODEL_ID = "baconnier/Napoleon_24B_V0.0"
CHAT_TEMPLATE = "mistral" # Mistral uses its own chat template
MODEL_NAME = "Napoleon 24B"
CONTEXT_LENGTH = 32768 # Mistral supports longer context
COLOR = "black"
EMOJI = "🇫🇷" # Mistral-themed emoji
DESCRIPTION = f"This is {MODEL_NAME} model, a powerful 24B parameter language model from Mistral AI."
css = """
.message-row {
justify-content: space-evenly !important;
}
.message-bubble-border {
border-radius: 6px !important;
}
.dark.message-bubble-border {
border-color: #21293b !important;
}
.dark.user {
background: #0a1120 !important;
}
.dark.assistant {
background: transparent !important;
}
"""
PLACEHOLDER = """
<div class="message-bubble-border" style="display:flex; max-width: 600px; border-radius: 8px; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); backdrop-filter: blur(10px);">
<figure style="margin: 0;">
<img src="https://huggingface.co/spaces/baconnier/Napoleon/resolve/main/napoleon.jpg" style="width: 100%; height: 100%; border-radius: 8px;">
</figure>
<div style="padding: .5rem 1.5rem;">
<h2 style="text-align: left; font-size: 1.5rem; font-weight: 700; margin-bottom: 0.5rem;"> </h2>
<p style="text-align: left; font-size: 16px; line-height: 1.5; margin-bottom: 15px;">Notre sagesse vient de notre expérience, et notre expérience vient de nos sottises</p>
</div>
</div>
"""
def load_system_message():
try:
with open('system_message.txt', 'r', encoding='utf-8') as file:
return file.read().strip()
except FileNotFoundError:
print("Warning: system_message.txt not found. Using default message.")
return "Tu es Napoleon, reponds uniquement en francais."
except Exception as e:
print(f"Error loading system message: {e}")
return "Tu es Napoleon, reponds uniquement en francais."
SYSTEM_MESSAGE = load_system_message()
@spaces.GPU()
def predict(message, history, system_prompt, temperature, max_new_tokens, top_k, repetition_penalty, top_p):
# Format history using Mistral's chat template
messages = [{"role": "system", "content": SYSTEM_MESSAGE}]
for user, assistant in history:
messages.append({"role": "user", "content": user})
messages.append({"role": "assistant", "content": assistant})
messages.append({"role": "user", "content": message})
# Convert messages to Mistral format
prompt = tokenizer.apply_chat_template(messages, tokenize=False)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
enc = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
input_ids, attention_mask = enc.input_ids, enc.attention_mask
if input_ids.shape[1] > CONTEXT_LENGTH:
input_ids = input_ids[:, -CONTEXT_LENGTH:]
attention_mask = attention_mask[:, -CONTEXT_LENGTH:]
generate_kwargs = dict(
input_ids=input_ids.to(device),
attention_mask=attention_mask.to(device),
streamer=streamer,
do_sample=True,
temperature=temperature,
max_new_tokens=max_new_tokens,
top_k=top_k,
repetition_penalty=repetition_penalty,
top_p=top_p
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for new_token in streamer:
outputs.append(new_token)
yield "".join(outputs)
# Load model with optimized settings for Mistral-24B
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
use_double_quant=True, # Enable double quantization
bnb_4bit_quant_type="nf4" # Use normal float 4 for better precision
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
# Set the pad token to be the same as the end of sequence token
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16
)
# Create Gradio interface
gr.ChatInterface(
predict,
css=css,
title=EMOJI + " " + MODEL_NAME,
description=DESCRIPTION,
examples=[
['Pourquoi les Français portent-ils toujours une baguette sous le bras ?'],
['Comment un Français peut-il survivre sans fromage pendant plus de 24h ?'],
['Pourquoi les serveurs parisiens sont-ils si "charmants" avec les touristes ?'],
['Est-il vrai que les Français font la grève plus souvent qu ils ne travaillent ?'],
['Comment les Français trouvent-ils le temps de faire une pause déjeuner de 2 heures ?'],
['Pourquoi les Français disent-ils bof à tout ?']
],
placeholder=PLACEHOLDER,
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False),
additional_inputs=[
gr.Textbox(SYSTEM_MESSAGE, label="System prompt", visible=False), # Hidden system prompt
gr.Slider(0, 1, 0.7, label="Temperature"), # Adjusted default for Mistral
gr.Slider(0, 32768, 12000, label="Max new tokens"), # Increased for longer context
gr.Slider(1, 100, 50, label="Top K sampling"),
gr.Slider(0, 2, 1.1, label="Repetition penalty"),
gr.Slider(0, 1, 0.95, label="Top P sampling"),
],
).queue().launch()