Spaces:
Paused
Paused
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from huggingface_hub import login | |
import os | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import numpy as np | |
import time | |
# Authentification | |
login(token=os.environ["HF_TOKEN"]) | |
# Structure hiérarchique des modèles | |
models_hierarchy = { | |
"meta-llama": { | |
"Llama-2": ["7b", "13b", "70b"], | |
"Llama-3": ["8B", "3.2-3B", "3.1-8B"] | |
}, | |
"mistralai": { | |
"Mistral": ["7B-v0.1", "7B-v0.3"], | |
"Mixtral": ["8x7B-v0.1"] | |
}, | |
"google": { | |
"gemma": ["2b", "9b", "27b"] | |
}, | |
"croissantllm": { | |
"CroissantLLM": ["Base"] | |
} | |
} | |
# Langues supportées par modèle | |
models_and_languages = { | |
"meta-llama/Llama-2-7b-hf": ["en"], | |
"meta-llama/Llama-2-13b-hf": ["en"], | |
"meta-llama/Llama-2-70b-hf": ["en"], | |
"meta-llama/Meta-Llama-3-8B": ["en"], | |
"meta-llama/Llama-3.2-3B": ["en", "de", "fr", "it", "pt", "hi", "es", "th"], | |
"meta-llama/Llama-3.1-8B": ["en", "de", "fr", "it", "pt", "hi", "es", "th"], | |
"mistralai/Mistral-7B-v0.1": ["en"], | |
"mistralai/Mixtral-8x7B-v0.1": ["en", "fr", "it", "de", "es"], | |
"mistralai/Mistral-7B-v0.3": ["en"], | |
"google/gemma-2-2b": ["en"], | |
"google/gemma-2-9b": ["en"], | |
"google/gemma-2-27b": ["en"], | |
"croissantllm/CroissantLLMBase": ["en", "fr"] | |
} | |
# Paramètres recommandés pour chaque modèle | |
model_parameters = { | |
"meta-llama/Llama-2-13b-hf": {"temperature": 0.8, "top_p": 0.9, "top_k": 40}, | |
"meta-llama/Llama-2-7b-hf": {"temperature": 0.8, "top_p": 0.9, "top_k": 40}, | |
"meta-llama/Llama-2-70b-hf": {"temperature": 0.8, "top_p": 0.9, "top_k": 40}, | |
"meta-llama/Meta-Llama-3-8B": {"temperature": 0.75, "top_p": 0.9, "top_k": 50}, | |
"meta-llama/Llama-3.2-3B": {"temperature": 0.75, "top_p": 0.9, "top_k": 50}, | |
"meta-llama/Llama-3.1-8B": {"temperature": 0.75, "top_p": 0.9, "top_k": 50}, | |
"mistralai/Mistral-7B-v0.1": {"temperature": 0.7, "top_p": 0.9, "top_k": 50}, | |
"mistralai/Mixtral-8x7B-v0.1": {"temperature": 0.8, "top_p": 0.95, "top_k": 50}, | |
"mistralai/Mistral-7B-v0.3": {"temperature": 0.7, "top_p": 0.9, "top_k": 50}, | |
"google/gemma-2-2b": {"temperature": 0.7, "top_p": 0.95, "top_k": 40}, | |
"google/gemma-2-9b": {"temperature": 0.7, "top_p": 0.95, "top_k": 40}, | |
"google/gemma-2-27b": {"temperature": 0.7, "top_p": 0.95, "top_k": 40}, | |
"croissantllm/CroissantLLMBase": {"temperature": 0.8, "top_p": 0.92, "top_k": 50} | |
} | |
# Variables globales | |
model = None | |
tokenizer = None | |
selected_language = None | |
def update_model_choices(company): | |
return gr.Dropdown(choices=list(models_hierarchy[company].keys()), value=None) | |
def update_variation_choices(company, model_name): | |
return gr.Dropdown(choices=models_hierarchy[company][model_name], value=None) | |
def load_model(company, model_name, variation, progress=gr.Progress()): | |
global model, tokenizer | |
full_model_name = f"{company}/{model_name}-{variation}" | |
if full_model_name not in models_and_languages: | |
full_model_name = f"{company}/{model_name}{variation}" | |
try: | |
progress(0, desc="Chargement du tokenizer") | |
tokenizer = AutoTokenizer.from_pretrained(full_model_name) | |
progress(0.5, desc="Chargement du modèle") | |
# Configurations spécifiques par modèle | |
if "mixtral" in full_model_name.lower(): | |
model = AutoModelForCausalLM.from_pretrained( | |
full_model_name, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
load_in_8bit=True | |
) | |
else: | |
model = AutoModelForCausalLM.from_pretrained( | |
full_model_name, | |
torch_dtype=torch.float16, | |
device_map="auto" | |
) | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
progress(1.0, desc="Modèle chargé") | |
available_languages = models_and_languages[full_model_name] | |
# Mise à jour des sliders avec les valeurs recommandées | |
params = model_parameters[full_model_name] | |
return ( | |
f"Modèle {full_model_name} chargé avec succès. Langues disponibles : {', '.join(available_languages)}", | |
gr.Dropdown(choices=available_languages, value=available_languages[0], visible=True, interactive=True), | |
params["temperature"], | |
params["top_p"], | |
params["top_k"] | |
) | |
except Exception as e: | |
return f"Erreur lors du chargement du modèle : {str(e)}", gr.Dropdown(visible=False), None, None, None | |
# Le reste du code reste inchangé... | |
with gr.Blocks() as demo: | |
gr.Markdown("# LLM&BIAS") | |
with gr.Accordion("Sélection du modèle"): | |
company_dropdown = gr.Dropdown(choices=list(models_hierarchy.keys()), label="Choisissez une société") | |
model_dropdown = gr.Dropdown(label="Choisissez un modèle", choices=[]) | |
variation_dropdown = gr.Dropdown(label="Choisissez une variation", choices=[]) | |
load_button = gr.Button("Charger le modèle") | |
load_output = gr.Textbox(label="Statut du chargement") | |
language_dropdown = gr.Dropdown(label="Choisissez une langue", visible=False) | |
language_output = gr.Textbox(label="Langue sélectionnée") | |
with gr.Row(): | |
temperature = gr.Slider(0.1, 2.0, value=1.0, label="Température") | |
top_p = gr.Slider(0.1, 1.0, value=1.0, label="Top-p") | |
top_k = gr.Slider(1, 100, value=50, step=1, label="Top-k") | |
input_text = gr.Textbox(label="Texte d'entrée", lines=3) | |
analyze_button = gr.Button("Analyser le prochain token") | |
next_token_probs = gr.Textbox(label="Probabilités du prochain token") | |
with gr.Row(): | |
attention_plot = gr.Plot(label="Visualisation de l'attention") | |
prob_plot = gr.Plot(label="Probabilités des tokens suivants") | |
generate_button = gr.Button("Générer le prochain mot") | |
generated_text = gr.Textbox(label="Texte généré") | |
reset_button = gr.Button("Réinitialiser") | |
company_dropdown.change(update_model_choices, inputs=[company_dropdown], outputs=[model_dropdown]) | |
model_dropdown.change(update_variation_choices, inputs=[company_dropdown, model_dropdown], outputs=[variation_dropdown]) | |
load_button.click(load_model, | |
inputs=[company_dropdown, model_dropdown, variation_dropdown], | |
outputs=[load_output, language_dropdown, temperature, top_p, top_k]) | |
language_dropdown.change(set_language, inputs=[language_dropdown], outputs=[language_output]) | |
analyze_button.click(analyze_next_token, | |
inputs=[input_text, temperature, top_p, top_k], | |
outputs=[next_token_probs, attention_plot, prob_plot]) | |
generate_button.click(generate_text, | |
inputs=[input_text, temperature, top_p, top_k], | |
outputs=[generated_text]) | |
reset_button.click(reset, | |
outputs=[input_text, temperature, top_p, top_k, next_token_probs, attention_plot, prob_plot, generated_text, language_dropdown, language_output]) | |
if __name__ == "__main__": | |
demo.launch() | |