File size: 7,155 Bytes
60b53a6
 
33f0de1
6696db2
33f0de1
60b53a6
3c28324
f18c3eb
 
9787d82
33f0de1
ea35578
2759f98
55f3d52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6cca076
55f3d52
 
 
 
 
 
2deee43
 
 
55f3d52
 
 
2deee43
 
 
cd6f17c
 
55f3d52
 
 
 
 
 
cd6f17c
 
 
55f3d52
 
 
cd6f17c
 
 
33f0de1
 
 
6cca076
60b53a6
55f3d52
 
 
 
 
 
 
33f0de1
55f3d52
 
bdd35f2
8869d77
55f3d52
f18c3eb
2deee43
6cca076
55f3d52
6cca076
55f3d52
6cca076
9255c5c
6cca076
 
 
 
55f3d52
6cca076
12f46b7
6cca076
2deee43
8869d77
 
2deee43
f18c3eb
55f3d52
cd6f17c
 
55f3d52
cd6f17c
55f3d52
cd6f17c
 
 
 
 
bdd35f2
cd6f17c
6cca076
55f3d52
 
19de71a
33f0de1
cd6f17c
33f0de1
 
55f3d52
 
 
33f0de1
 
6cca076
 
33f0de1
 
 
 
 
 
0c7cad3
3226776
33f0de1
3226776
33f0de1
f18c3eb
 
 
33f0de1
6cca076
3c28324
3226776
33f0de1
 
55f3d52
 
cd6f17c
55f3d52
cd6f17c
6cca076
3226776
 
f18c3eb
33f0de1
 
3c28324
33f0de1
6cca076
60b53a6
0c7cad3
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login
import os
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import time

# Authentification
login(token=os.environ["HF_TOKEN"])

# Structure hiérarchique des modèles
model_hierarchy = {
    "meta-llama": {
        "Llama-2": ["7B", "13B", "70B"],
        "Llama-3": ["8B", "3.2B", "3.1B"]
    },
    "mistralai": {
        "Mistral": ["7B-v0.1", "7B-v0.3"],
        "Mixtral": ["8x7B-v0.1"]
    },
    "google": {
        "Gemma": ["2B", "9B", "27B"]
    },
    "croissantllm": {
        "CroissantLLM": ["Base"]
    }
}

# Mise à jour de la liste des modèles et leurs langues supportées
models_and_languages = {
    "meta-llama/Llama-2-7B": ["en"],
    "meta-llama/Llama-2-13B": ["en"],
    "meta-llama/Llama-2-70B": ["en"],
    "meta-llama/Llama-3-8B": ["en"],
    "meta-llama/Llama-3-3.2B": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
    "meta-llama/Llama-3-3.1B": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
    "mistralai/Mistral-7B-v0.1": ["en"],
    "mistralai/Mixtral-8x7B-v0.1": ["en", "fr", "it", "de", "es"],
    "mistralai/Mistral-7B-v0.3": ["en"],
    "google/Gemma-2B": ["en"],
    "google/Gemma-9B": ["en"],
    "google/Gemma-27B": ["en"],
    "croissantllm/CroissantLLMBase": ["en", "fr"]
}

# Paramètres recommandés pour chaque modèle
model_parameters = {
    "meta-llama/Llama-2-7B": {"temperature": 0.8, "top_p": 0.9, "top_k": 40},
    "meta-llama/Llama-2-13B": {"temperature": 0.8, "top_p": 0.9, "top_k": 40},
    "meta-llama/Llama-2-70B": {"temperature": 0.8, "top_p": 0.9, "top_k": 40},
    "meta-llama/Llama-3-8B": {"temperature": 0.75, "top_p": 0.9, "top_k": 50},
    "meta-llama/Llama-3-3.2B": {"temperature": 0.75, "top_p": 0.9, "top_k": 50},
    "meta-llama/Llama-3-3.1B": {"temperature": 0.75, "top_p": 0.9, "top_k": 50},
    "mistralai/Mistral-7B-v0.1": {"temperature": 0.7, "top_p": 0.9, "top_k": 50},
    "mistralai/Mixtral-8x7B-v0.1": {"temperature": 0.8, "top_p": 0.95, "top_k": 50},
    "mistralai/Mistral-7B-v0.3": {"temperature": 0.7, "top_p": 0.9, "top_k": 50},
    "google/Gemma-2B": {"temperature": 0.7, "top_p": 0.95, "top_k": 40},
    "google/Gemma-9B": {"temperature": 0.7, "top_p": 0.95, "top_k": 40},
    "google/Gemma-27B": {"temperature": 0.7, "top_p": 0.95, "top_k": 40},
    "croissantllm/CroissantLLMBase": {"temperature": 0.8, "top_p": 0.92, "top_k": 50}
}

# Variables globales
model = None
tokenizer = None
selected_language = None

def update_model_choices(company):
    return gr.Dropdown(choices=list(model_hierarchy[company].keys()), value=None)

def update_variation_choices(company, model_name):
    return gr.Dropdown(choices=model_hierarchy[company][model_name], value=None)

def load_model(company, model_name, variation, progress=gr.Progress()):
    global model, tokenizer
    full_model_name = f"{company}/{model_name}-{variation}"
    
    try:
        progress(0, desc="Chargement du tokenizer")
        tokenizer = AutoTokenizer.from_pretrained(full_model_name)
        progress(0.5, desc="Chargement du modèle")
        
        # Configurations spécifiques par modèle
        if "mixtral" in full_model_name.lower():
            model = AutoModelForCausalLM.from_pretrained(
                full_model_name,
                torch_dtype=torch.float16,
                device_map="auto",
                load_in_8bit=True
            )
        else:
            model = AutoModelForCausalLM.from_pretrained(
                full_model_name,
                torch_dtype=torch.float16,
                device_map="auto"
            )
        
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        
        progress(1.0, desc="Modèle chargé")
        available_languages = models_and_languages[full_model_name]
        
        # Mise à jour des sliders avec les valeurs recommandées
        params = model_parameters[full_model_name]
        return (
            f"Modèle {full_model_name} chargé avec succès. Langues disponibles : {', '.join(available_languages)}",
            gr.Dropdown(choices=available_languages, value=available_languages[0], visible=True, interactive=True),
            params["temperature"],
            params["top_p"],
            params["top_k"]
        )
    except Exception as e:
        return f"Erreur lors du chargement du modèle : {str(e)}", gr.Dropdown(visible=False), None, None, None

# Le reste du code reste inchangé
# ...

with gr.Blocks() as demo:
    gr.Markdown("# LLM&BIAS")
    
    with gr.Accordion("Sélection du modèle"):
        company_dropdown = gr.Dropdown(choices=list(model_hierarchy.keys()), label="Choisissez une société")
        model_dropdown = gr.Dropdown(label="Choisissez un modèle", choices=[])
        variation_dropdown = gr.Dropdown(label="Choisissez une variation", choices=[])
        load_button = gr.Button("Charger le modèle")
        load_output = gr.Textbox(label="Statut du chargement")
        language_dropdown = gr.Dropdown(label="Choisissez une langue", visible=False)
        language_output = gr.Textbox(label="Langue sélectionnée")
    
    with gr.Row():
        temperature = gr.Slider(0.1, 2.0, value=1.0, label="Température")
        top_p = gr.Slider(0.1, 1.0, value=1.0, label="Top-p")
        top_k = gr.Slider(1, 100, value=50, step=1, label="Top-k")
    
    input_text = gr.Textbox(label="Texte d'entrée", lines=3)
    analyze_button = gr.Button("Analyser le prochain token")
    
    next_token_probs = gr.Textbox(label="Probabilités du prochain token")
    
    with gr.Row():
        attention_plot = gr.Plot(label="Visualisation de l'attention")
        prob_plot = gr.Plot(label="Probabilités des tokens suivants")
    
    generate_button = gr.Button("Générer le prochain mot")
    generated_text = gr.Textbox(label="Texte généré")
    
    reset_button = gr.Button("Réinitialiser")
    
    company_dropdown.change(update_model_choices, inputs=[company_dropdown], outputs=[model_dropdown])
    model_dropdown.change(update_variation_choices, inputs=[company_dropdown, model_dropdown], outputs=[variation_dropdown])
    load_button.click(load_model, 
                      inputs=[company_dropdown, model_dropdown, variation_dropdown], 
                      outputs=[load_output, language_dropdown, temperature, top_p, top_k])
    language_dropdown.change(set_language, inputs=[language_dropdown], outputs=[language_output])
    analyze_button.click(analyze_next_token, 
                         inputs=[input_text, temperature, top_p, top_k], 
                         outputs=[next_token_probs, attention_plot, prob_plot])
    generate_button.click(generate_text, 
                          inputs=[input_text, temperature, top_p, top_k], 
                          outputs=[generated_text])
    reset_button.click(reset, 
                       outputs=[input_text, temperature, top_p, top_k, next_token_probs, attention_plot, prob_plot, generated_text, language_dropdown, language_output])

if __name__ == "__main__":
    demo.launch()