LLMnBiasV2 / app.py
Woziii's picture
Update app.py
63dd69c verified
raw
history blame
7.29 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login
import os
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import time
# Authentification
login(token=os.environ["HF_TOKEN"])
# Structure hiérarchique des modèles
models_hierarchy = {
"meta-llama": {
"Llama-2": ["7b", "13b", "70b"],
"Llama-3": ["8B", "3.2-3B", "3.1-8B"]
},
"mistralai": {
"Mistral": ["7B-v0.1", "7B-v0.3"],
"Mixtral": ["8x7B-v0.1"]
},
"google": {
"gemma": ["2b", "9b", "27b"]
},
"croissantllm": {
"CroissantLLM": ["Base"]
}
}
# Langues supportées par modèle
models_and_languages = {
"meta-llama/Llama-2-7b-hf": ["en"],
"meta-llama/Llama-2-13b-hf": ["en"],
"meta-llama/Llama-2-70b-hf": ["en"],
"meta-llama/Meta-Llama-3-8B": ["en"],
"meta-llama/Llama-3.2-3B": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
"meta-llama/Llama-3.1-8B": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
"mistralai/Mistral-7B-v0.1": ["en"],
"mistralai/Mixtral-8x7B-v0.1": ["en", "fr", "it", "de", "es"],
"mistralai/Mistral-7B-v0.3": ["en"],
"google/gemma-2-2b": ["en"],
"google/gemma-2-9b": ["en"],
"google/gemma-2-27b": ["en"],
"croissantllm/CroissantLLMBase": ["en", "fr"]
}
# Paramètres recommandés pour chaque modèle
model_parameters = {
"meta-llama/Llama-2-13b-hf": {"temperature": 0.8, "top_p": 0.9, "top_k": 40},
"meta-llama/Llama-2-7b-hf": {"temperature": 0.8, "top_p": 0.9, "top_k": 40},
"meta-llama/Llama-2-70b-hf": {"temperature": 0.8, "top_p": 0.9, "top_k": 40},
"meta-llama/Meta-Llama-3-8B": {"temperature": 0.75, "top_p": 0.9, "top_k": 50},
"meta-llama/Llama-3.2-3B": {"temperature": 0.75, "top_p": 0.9, "top_k": 50},
"meta-llama/Llama-3.1-8B": {"temperature": 0.75, "top_p": 0.9, "top_k": 50},
"mistralai/Mistral-7B-v0.1": {"temperature": 0.7, "top_p": 0.9, "top_k": 50},
"mistralai/Mixtral-8x7B-v0.1": {"temperature": 0.8, "top_p": 0.95, "top_k": 50},
"mistralai/Mistral-7B-v0.3": {"temperature": 0.7, "top_p": 0.9, "top_k": 50},
"google/gemma-2-2b": {"temperature": 0.7, "top_p": 0.95, "top_k": 40},
"google/gemma-2-9b": {"temperature": 0.7, "top_p": 0.95, "top_k": 40},
"google/gemma-2-27b": {"temperature": 0.7, "top_p": 0.95, "top_k": 40},
"croissantllm/CroissantLLMBase": {"temperature": 0.8, "top_p": 0.92, "top_k": 50}
}
# Variables globales
model = None
tokenizer = None
selected_language = None
def update_model_choices(company):
return gr.Dropdown(choices=list(models_hierarchy[company].keys()), value=None)
def update_variation_choices(company, model_name):
return gr.Dropdown(choices=models_hierarchy[company][model_name], value=None)
def load_model(company, model_name, variation, progress=gr.Progress()):
global model, tokenizer
full_model_name = f"{company}/{model_name}-{variation}"
if full_model_name not in models_and_languages:
full_model_name = f"{company}/{model_name}{variation}"
try:
progress(0, desc="Chargement du tokenizer")
tokenizer = AutoTokenizer.from_pretrained(full_model_name)
progress(0.5, desc="Chargement du modèle")
# Configurations spécifiques par modèle
if "mixtral" in full_model_name.lower():
model = AutoModelForCausalLM.from_pretrained(
full_model_name,
torch_dtype=torch.float16,
device_map="auto",
load_in_8bit=True
)
else:
model = AutoModelForCausalLM.from_pretrained(
full_model_name,
torch_dtype=torch.float16,
device_map="auto"
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
progress(1.0, desc="Modèle chargé")
available_languages = models_and_languages[full_model_name]
# Mise à jour des sliders avec les valeurs recommandées
params = model_parameters[full_model_name]
return (
f"Modèle {full_model_name} chargé avec succès. Langues disponibles : {', '.join(available_languages)}",
gr.Dropdown(choices=available_languages, value=available_languages[0], visible=True, interactive=True),
params["temperature"],
params["top_p"],
params["top_k"]
)
except Exception as e:
return f"Erreur lors du chargement du modèle : {str(e)}", gr.Dropdown(visible=False), None, None, None
# Le reste du code reste inchangé...
with gr.Blocks() as demo:
gr.Markdown("# LLM&BIAS")
with gr.Accordion("Sélection du modèle"):
company_dropdown = gr.Dropdown(choices=list(models_hierarchy.keys()), label="Choisissez une société")
model_dropdown = gr.Dropdown(label="Choisissez un modèle", choices=[])
variation_dropdown = gr.Dropdown(label="Choisissez une variation", choices=[])
load_button = gr.Button("Charger le modèle")
load_output = gr.Textbox(label="Statut du chargement")
language_dropdown = gr.Dropdown(label="Choisissez une langue", visible=False)
language_output = gr.Textbox(label="Langue sélectionnée")
with gr.Row():
temperature = gr.Slider(0.1, 2.0, value=1.0, label="Température")
top_p = gr.Slider(0.1, 1.0, value=1.0, label="Top-p")
top_k = gr.Slider(1, 100, value=50, step=1, label="Top-k")
input_text = gr.Textbox(label="Texte d'entrée", lines=3)
analyze_button = gr.Button("Analyser le prochain token")
next_token_probs = gr.Textbox(label="Probabilités du prochain token")
with gr.Row():
attention_plot = gr.Plot(label="Visualisation de l'attention")
prob_plot = gr.Plot(label="Probabilités des tokens suivants")
generate_button = gr.Button("Générer le prochain mot")
generated_text = gr.Textbox(label="Texte généré")
reset_button = gr.Button("Réinitialiser")
company_dropdown.change(update_model_choices, inputs=[company_dropdown], outputs=[model_dropdown])
model_dropdown.change(update_variation_choices, inputs=[company_dropdown, model_dropdown], outputs=[variation_dropdown])
load_button.click(load_model,
inputs=[company_dropdown, model_dropdown, variation_dropdown],
outputs=[load_output, language_dropdown, temperature, top_p, top_k])
language_dropdown.change(set_language, inputs=[language_dropdown], outputs=[language_output])
analyze_button.click(analyze_next_token,
inputs=[input_text, temperature, top_p, top_k],
outputs=[next_token_probs, attention_plot, prob_plot])
generate_button.click(generate_text,
inputs=[input_text, temperature, top_p, top_k],
outputs=[generated_text])
reset_button.click(reset,
outputs=[input_text, temperature, top_p, top_k, next_token_probs, attention_plot, prob_plot, generated_text, language_dropdown, language_output])
if __name__ == "__main__":
demo.launch()