Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -17,6 +17,7 @@ import numpy as np
|
|
17 |
import evaluate
|
18 |
import tempfile
|
19 |
from huggingface_hub import snapshot_download
|
|
|
20 |
|
21 |
st.title("📊 Évaluation WER d'un modèle Whisper")
|
22 |
st.markdown("Ce Space permet d'évaluer la performance WER d'un modèle Whisper sur un dataset audio.")
|
@@ -26,14 +27,21 @@ st.subheader("1. Choix du modèle")
|
|
26 |
model_option = st.radio("Quel modèle veux-tu utiliser ?", (
|
27 |
"Whisper Large (baseline)",
|
28 |
"Whisper Large + LoRA (SimpleFrog/whisper_finetuned)",
|
29 |
-
"Whisper Large + LoRA + Post-processing
|
30 |
))
|
31 |
|
32 |
# Section : Lien du dataset
|
33 |
st.subheader("2. Chargement du dataset Hugging Face")
|
34 |
-
dataset_link = st.text_input("Lien du dataset (format: user/dataset_name)", value="SimpleFrog/
|
35 |
hf_token = st.text_input("Token Hugging Face (si dataset privé)", type="password")
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
# Section : Bouton pour lancer l'évaluation
|
38 |
start_eval = st.button("🚀 Lancer l'évaluation WER")
|
39 |
|
@@ -43,8 +51,19 @@ if start_eval:
|
|
43 |
# 🔹 Télécharger dataset
|
44 |
with st.spinner("Chargement du dataset..."):
|
45 |
try:
|
46 |
-
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
except Exception as e:
|
49 |
st.error(f"Erreur lors du chargement du dataset : {e}")
|
50 |
st.stop()
|
@@ -60,6 +79,21 @@ if start_eval:
|
|
60 |
processor = WhisperProcessor.from_pretrained(base_model_name)
|
61 |
model.eval()
|
62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
# 🔹 Préparer WER metric
|
64 |
wer_metric = evaluate.load("wer")
|
65 |
|
@@ -71,18 +105,11 @@ if start_eval:
|
|
71 |
for example in dataset:
|
72 |
st.write("Exemple brut :", example)
|
73 |
try:
|
74 |
-
|
75 |
-
#audio_path = os.path.join(repo_local_path, example["file_name"])
|
76 |
reference = example["text"]
|
77 |
|
78 |
waveform = example["audio"]["array"]
|
79 |
audio_path = example["audio"]["path"]
|
80 |
-
|
81 |
-
#st.write(example)
|
82 |
-
#st.write("Exemple brut :", dataset[0])
|
83 |
-
|
84 |
-
# Load audio (we assume dataset is structured with 'file_name')
|
85 |
-
#waveform, _ = librosa.load(audio_path, sr=16000)
|
86 |
|
87 |
waveform = np.expand_dims(waveform, axis=0)
|
88 |
inputs = processor(waveform, sampling_rate=16000, return_tensors="pt")
|
@@ -91,18 +118,27 @@ if start_eval:
|
|
91 |
pred_ids = model.generate(input_features=inputs.input_features)
|
92 |
prediction = processor.batch_decode(pred_ids, skip_special_tokens=True)[0]
|
93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
# 🔹 Nettoyage ponctuation pour WER "sans ponctuation"
|
95 |
def clean(text):
|
96 |
return ''.join([c for c in text.lower() if c.isalnum() or c.isspace()]).strip()
|
97 |
|
98 |
ref_clean = clean(reference)
|
99 |
-
pred_clean = clean(
|
100 |
wer = wer_metric.compute(predictions=[pred_clean], references=[ref_clean])
|
101 |
|
102 |
results.append({
|
103 |
"Fichier": audio_path,
|
104 |
"Référence": reference,
|
105 |
-
"Transcription": prediction,
|
|
|
106 |
"WER": round(wer, 4)
|
107 |
})
|
108 |
|
@@ -110,7 +146,8 @@ if start_eval:
|
|
110 |
results.append({
|
111 |
"Fichier": example["audio"].get("path", "unknown"),
|
112 |
"Référence": "Erreur",
|
113 |
-
"Transcription": f"Erreur: {e}",
|
|
|
114 |
"WER": "-"
|
115 |
})
|
116 |
|
@@ -122,11 +159,10 @@ if start_eval:
|
|
122 |
df.to_csv(tmp_csv.name, index=False)
|
123 |
|
124 |
mean_wer = df[df["WER"] != "-"]["WER"].mean()
|
|
|
125 |
st.markdown(f"### 🎯 WER moyen (sans ponctuation) : `{mean_wer:.3f}`")
|
126 |
|
127 |
-
|
128 |
-
if "Post-processing" in model_option:
|
129 |
-
st.info("🛠️ Le post-processing sera ajouté prochainement ici...")
|
130 |
|
131 |
|
132 |
# 🔹 Bouton de téléchargement
|
|
|
17 |
import evaluate
|
18 |
import tempfile
|
19 |
from huggingface_hub import snapshot_download
|
20 |
+
from transformers import pipeline
|
21 |
|
22 |
st.title("📊 Évaluation WER d'un modèle Whisper")
|
23 |
st.markdown("Ce Space permet d'évaluer la performance WER d'un modèle Whisper sur un dataset audio.")
|
|
|
27 |
model_option = st.radio("Quel modèle veux-tu utiliser ?", (
|
28 |
"Whisper Large (baseline)",
|
29 |
"Whisper Large + LoRA (SimpleFrog/whisper_finetuned)",
|
30 |
+
"Whisper Large + LoRA + Post-processing"
|
31 |
))
|
32 |
|
33 |
# Section : Lien du dataset
|
34 |
st.subheader("2. Chargement du dataset Hugging Face")
|
35 |
+
dataset_link = st.text_input("Lien du dataset (format: user/dataset_name)", value="SimpleFrog/Dataset_Test")
|
36 |
hf_token = st.text_input("Token Hugging Face (si dataset privé)", type="password")
|
37 |
|
38 |
+
# Section : Choix du split
|
39 |
+
split_option = st.selectbox(
|
40 |
+
"Choix du split à évaluer",
|
41 |
+
options=["Tous", "train", "validation", "test"],
|
42 |
+
index=0 # par défaut "Tous"
|
43 |
+
)
|
44 |
+
|
45 |
# Section : Bouton pour lancer l'évaluation
|
46 |
start_eval = st.button("🚀 Lancer l'évaluation WER")
|
47 |
|
|
|
51 |
# 🔹 Télécharger dataset
|
52 |
with st.spinner("Chargement du dataset..."):
|
53 |
try:
|
54 |
+
|
55 |
+
dataset_full = load_dataset(dataset_link, split="train", token=hf_token)
|
56 |
+
|
57 |
+
# 🔹 Filtrage selon la colonne 'split'
|
58 |
+
if split_option != "Tous":
|
59 |
+
dataset = dataset_full.filter(lambda x: x.get("split", "unknown") == split_option)
|
60 |
+
else:
|
61 |
+
dataset = dataset_full
|
62 |
+
|
63 |
+
if len(dataset) == 0:
|
64 |
+
st.warning(f"Aucun exemple trouvé pour le split sélectionné : '{split_option}'.")
|
65 |
+
st.stop()
|
66 |
+
|
67 |
except Exception as e:
|
68 |
st.error(f"Erreur lors du chargement du dataset : {e}")
|
69 |
st.stop()
|
|
|
79 |
processor = WhisperProcessor.from_pretrained(base_model_name)
|
80 |
model.eval()
|
81 |
|
82 |
+
# Charger le pipeline de Mistral si post-processing demandé
|
83 |
+
if "Post-processing" in model_option:
|
84 |
+
with st.spinner("Chargement du modèle de post-traitement Mistral..."):
|
85 |
+
postproc_pipe = pipeline(
|
86 |
+
"text2text-generation",
|
87 |
+
model="mistralai/Mistral-7B-Instruct-v0.2",
|
88 |
+
device_map="auto", # ou device=0 si tu veux forcer le GPU
|
89 |
+
torch_dtype=torch.float16 # optionnel mais plus léger
|
90 |
+
)
|
91 |
+
|
92 |
+
def postprocess_with_llm(text):
|
93 |
+
prompt = f"Ce texte est issue d'une translation vocal. L'enregistrement est tiré d'une inspection détaillé de pont et comprend du vocabulaire technique associé. Corriges les éventuelles erreurs de translation : {text}"
|
94 |
+
result = postproc_pipe(prompt, max_new_tokens=256)[0]["generated_text"]
|
95 |
+
return result.strip()
|
96 |
+
|
97 |
# 🔹 Préparer WER metric
|
98 |
wer_metric = evaluate.load("wer")
|
99 |
|
|
|
105 |
for example in dataset:
|
106 |
st.write("Exemple brut :", example)
|
107 |
try:
|
108 |
+
|
|
|
109 |
reference = example["text"]
|
110 |
|
111 |
waveform = example["audio"]["array"]
|
112 |
audio_path = example["audio"]["path"]
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
|
114 |
waveform = np.expand_dims(waveform, axis=0)
|
115 |
inputs = processor(waveform, sampling_rate=16000, return_tensors="pt")
|
|
|
118 |
pred_ids = model.generate(input_features=inputs.input_features)
|
119 |
prediction = processor.batch_decode(pred_ids, skip_special_tokens=True)[0]
|
120 |
|
121 |
+
# === Post-processing conditionnel ===
|
122 |
+
if "Post-processing" in model_option:
|
123 |
+
postprocessed_prediction = postprocess_with_llm(prediction)
|
124 |
+
final_prediction = postprocessed_prediction
|
125 |
+
else:
|
126 |
+
postprocessed_prediction = "-"
|
127 |
+
final_prediction = prediction
|
128 |
+
|
129 |
# 🔹 Nettoyage ponctuation pour WER "sans ponctuation"
|
130 |
def clean(text):
|
131 |
return ''.join([c for c in text.lower() if c.isalnum() or c.isspace()]).strip()
|
132 |
|
133 |
ref_clean = clean(reference)
|
134 |
+
pred_clean = clean(final_prediction)
|
135 |
wer = wer_metric.compute(predictions=[pred_clean], references=[ref_clean])
|
136 |
|
137 |
results.append({
|
138 |
"Fichier": audio_path,
|
139 |
"Référence": reference,
|
140 |
+
"Transcription brute": prediction,
|
141 |
+
"Transcription corrigée": postprocessed_prediction,
|
142 |
"WER": round(wer, 4)
|
143 |
})
|
144 |
|
|
|
146 |
results.append({
|
147 |
"Fichier": example["audio"].get("path", "unknown"),
|
148 |
"Référence": "Erreur",
|
149 |
+
"Transcription brute": f"Erreur: {e}",
|
150 |
+
"Transcription corrigée": "-",
|
151 |
"WER": "-"
|
152 |
})
|
153 |
|
|
|
159 |
df.to_csv(tmp_csv.name, index=False)
|
160 |
|
161 |
mean_wer = df[df["WER"] != "-"]["WER"].mean()
|
162 |
+
|
163 |
st.markdown(f"### 🎯 WER moyen (sans ponctuation) : `{mean_wer:.3f}`")
|
164 |
|
165 |
+
|
|
|
|
|
166 |
|
167 |
|
168 |
# 🔹 Bouton de téléchargement
|