File size: 7,707 Bytes
bdecbe9
645334b
7f613e2
 
 
 
a991705
 
bdecbe9
 
 
 
 
 
 
 
 
cea0ade
458a087
d7eb8e2
bdecbe9
8537a08
bdecbe9
 
 
 
 
 
 
 
d7eb8e2
bdecbe9
 
 
 
d7eb8e2
bdecbe9
 
8537a08
 
 
 
d7eb8e2
 
 
 
 
 
 
1cc0f11
 
 
 
 
 
 
bdecbe9
 
 
 
 
 
 
 
 
d7eb8e2
 
 
 
 
 
 
 
 
 
 
 
 
bdecbe9
 
 
 
1cc0f11
 
 
 
 
bdecbe9
 
 
 
 
 
 
 
 
 
 
d7eb8e2
 
 
 
 
a8cd8c4
d7eb8e2
 
 
eaa333d
 
d7eb8e2
a8cd8c4
d7eb8e2
 
eaa333d
d7eb8e2
bdecbe9
 
 
 
 
6802a5d
458a087
6802a5d
bdecbe9
98ab6e8
bdecbe9
d7eb8e2
bdecbe9
e47bd9f
 
 
 
bdecbe9
 
 
 
 
 
 
d7eb8e2
 
eaa333d
d7eb8e2
eaa333d
d7eb8e2
 
 
 
 
bdecbe9
 
 
 
 
d7eb8e2
bdecbe9
 
 
 
 
d7eb8e2
 
bdecbe9
 
 
 
 
e47bd9f
bdecbe9
d7eb8e2
 
bdecbe9
 
 
b6ba9c1
bdecbe9
cea0ade
 
 
 
 
 
d7eb8e2
cea0ade
 
d7eb8e2
cea0ade
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import os
os.environ["XDG_CONFIG_HOME"] = "/tmp"
os.environ["XDG_CACHE_HOME"] = "/tmp"
os.environ["HF_HOME"] = "/tmp/huggingface"  # pour les modèles/datasets
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers"
os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hub"

import streamlit as st
import tempfile
import pandas as pd
from datasets import load_dataset
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from peft import PeftModel
import torch
import librosa
import numpy as np
import evaluate
import tempfile
from huggingface_hub import snapshot_download
from transformers import pipeline


st.title("📊 Évaluation WER d'un modèle Whisper")
st.markdown("Ce Space permet d'évaluer la performance WER d'un modèle Whisper sur un dataset audio.")

# Section : Choix du modèle
st.subheader("1. Choix du modèle")
model_option = st.radio("Quel modèle veux-tu utiliser ?", (
    "Whisper Large (baseline)",
    "Whisper Large + LoRA (SimpleFrog/whisper_finetuned)",
    "Whisper Large + LoRA + Post-processing"
))

# Section : Lien du dataset
st.subheader("2. Chargement du dataset Hugging Face")
dataset_link = st.text_input("Lien du dataset (format: user/dataset_name)", value="SimpleFrog/Dataset_Test")
hf_token = st.text_input("Token Hugging Face (si dataset privé)", type="password")

if hf_token:
    from huggingface_hub import login
    login(hf_token)

# Section : Choix du split
split_option = st.selectbox(
    "Choix du split à évaluer",
    options=["Tous", "train", "validation", "test"],
    index=0  # par défaut "Tous"
)

# Section : Choix du nombre maximal d'exemples à évaluer
max_examples_option = st.selectbox(
    "Nombre maximum d'audios à traiter",
    options=["1", "5", "10", "Tous"],
    index=3  # par défaut "Tous"
)

# Section : Bouton pour lancer l'évaluation
start_eval = st.button("🚀 Lancer l'évaluation WER")

if start_eval:
    st.subheader("🔍 Traitement en cours...")

    # 🔹 Télécharger dataset
    with st.spinner("Chargement du dataset..."):
        try:

            dataset_full = load_dataset(dataset_link, split="train", token=hf_token)

            # 🔹 Filtrage selon la colonne 'split'
            if split_option != "Tous":
                dataset = dataset_full.filter(lambda x: x.get("split", "unknown") == split_option)
            else:
                dataset = dataset_full
            
            if len(dataset) == 0:
                st.warning(f"Aucun exemple trouvé pour le split sélectionné : '{split_option}'.")
                st.stop()
            
        except Exception as e:
            st.error(f"Erreur lors du chargement du dataset : {e}")
            st.stop()

     # Limiter le nombre d'exemples selon la sélection
    if max_examples_option != "Tous":
        max_examples = int(max_examples_option)
        dataset = dataset.select(range(min(max_examples, len(dataset))))

    # 🔹 Charger le modèle choisi
    with st.spinner("Chargement du modèle..."):
        base_model_name = "openai/whisper-large"
        model = WhisperForConditionalGeneration.from_pretrained(base_model_name)

        if "LoRA" in model_option:
            model = PeftModel.from_pretrained(model, "SimpleFrog/whisper_finetuned", token=hf_token)

        processor = WhisperProcessor.from_pretrained(base_model_name)
        model.eval()

        # Charger le pipeline de Mistral si post-processing demandé
        if "Post-processing" in model_option:
            with st.spinner("Chargement du modèle de post-traitement Mistral..."):
                postproc_pipe = pipeline(
                    "text2text-generation",
                    model="mistralai/Mistral-7B-Instruct-v0.2",
                    device_map="auto",  # ou device=0 si tu veux forcer le GPU
                    torch_dtype=torch.float16  # optionnel mais plus léger
                )
                st.success("✅ Modèle Mistral chargé.")
                
                def postprocess_with_llm(text):
                    prompt = f"Tu es CorrecteurAI, une AI française qui permet de corriger les erreurs de saisie vocal. La translation d'un enregistrement audio tiré d'une inspection détaillé de pont t'es envoyé et tu renvoies le texte identique mais avec les éventuelles corrections si des erreurs sont détectés. Le texte peut comprendre du vocabulaire technique associé aux ouvrages d'art. Renvoies uniquement le texte corrigé en français et sans autre commentaire. Voici le texte : {text}"
                    result = postproc_pipe(prompt, max_new_tokens=256)[0]["generated_text"]
                    return result.strip()
    

    # 🔹 Préparer WER metric
    wer_metric = evaluate.load("wer")

    results = []

    # Téléchargement explicite du dossier audio (chemin local vers chaque fichier)
    repo_local_path = snapshot_download(repo_id=dataset_link, repo_type="dataset", token=hf_token)

    for example in dataset:
        st.write("Exemple brut :", example)
        try:
           
            reference = example["text"]
            
            waveform = example["audio"]["array"]
            audio_path = example["audio"]["path"]
            
            waveform = np.expand_dims(waveform, axis=0)
            inputs = processor(waveform, sampling_rate=16000, return_tensors="pt")

            with torch.no_grad():
                pred_ids = model.generate(input_features=inputs.input_features)
            prediction = processor.batch_decode(pred_ids, skip_special_tokens=True)[0]

            # === Post-processing conditionnel ===
            if "Post-processing" in model_option:
                st.write("⏳ Post-processing avec Mistral...")
                postprocessed_prediction = postprocess_with_llm(prediction)
                st.write("✅ Post-processing terminé.")
                final_prediction = postprocessed_prediction
            else:
                postprocessed_prediction = "-"
                final_prediction = prediction

            # 🔹 Nettoyage ponctuation pour WER "sans ponctuation"
            def clean(text):
                return ''.join([c for c in text.lower() if c.isalnum() or c.isspace()]).strip()

            ref_clean = clean(reference)
            pred_clean = clean(final_prediction)
            wer = wer_metric.compute(predictions=[pred_clean], references=[ref_clean])

            results.append({
                "Fichier": audio_path,
                "Référence": reference,
                "Transcription brute": prediction,
                "Transcription corrigée": postprocessed_prediction,
                "WER": round(wer, 4)
            })

        except Exception as e:
            results.append({
                "Fichier": example["audio"].get("path", "unknown"),
                "Référence": "Erreur",
                "Transcription brute": f"Erreur: {e}",
                "Transcription corrigée": "-",
                "WER": "-"
            })

    # 🔹 Générer le tableau de résultats
    df = pd.DataFrame(results)

    # 🔹 Créer un fichier temporaire pour le CSV
    with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".csv") as tmp_csv:
        df.to_csv(tmp_csv.name, index=False)

        mean_wer = df[df["WER"] != "-"]["WER"].mean()
        
        st.markdown(f"### 🎯 WER moyen (sans ponctuation) : `{mean_wer:.3f}`")




        # 🔹 Bouton de téléchargement
        with open(tmp_csv.name, "rb") as f:
            st.download_button(
                label="📥 Télécharger les résultats WER (.csv)",
                data=f,
                file_name="wer_results.csv",
                mime="text/csv"
            )