Spaces:
Sleeping
Sleeping
import os | |
os.environ["XDG_CONFIG_HOME"] = "/tmp" | |
os.environ["XDG_CACHE_HOME"] = "/tmp" | |
os.environ["HF_HOME"] = "/tmp/huggingface" # pour les modèles/datasets | |
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers" | |
os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hub" | |
import streamlit as st | |
import tempfile | |
import pandas as pd | |
from datasets import load_dataset | |
from transformers import WhisperProcessor, WhisperForConditionalGeneration | |
from peft import PeftModel | |
import torch | |
import librosa | |
import numpy as np | |
import evaluate | |
import tempfile | |
from huggingface_hub import snapshot_download | |
from transformers import pipeline | |
st.title("📊 Évaluation WER d'un modèle Whisper") | |
st.markdown("Ce Space permet d'évaluer la performance WER d'un modèle Whisper sur un dataset audio.") | |
# Section : Choix du modèle | |
st.subheader("1. Choix du modèle") | |
model_option = st.radio("Quel modèle veux-tu utiliser ?", ( | |
"Whisper Large (baseline)", | |
"Whisper Large + LoRA (SimpleFrog/whisper_finetuned)", | |
"Whisper Large + LoRA + Post-processing" | |
)) | |
# Section : Lien du dataset | |
st.subheader("2. Chargement du dataset Hugging Face") | |
dataset_link = st.text_input("Lien du dataset (format: user/dataset_name)", value="SimpleFrog/Dataset_Test") | |
hf_token = st.text_input("Token Hugging Face (si dataset privé)", type="password") | |
if hf_token: | |
from huggingface_hub import login | |
login(hf_token) | |
# Section : Choix du split | |
split_option = st.selectbox( | |
"Choix du split à évaluer", | |
options=["Tous", "train", "validation", "test"], | |
index=0 # par défaut "Tous" | |
) | |
# Section : Choix du nombre maximal d'exemples à évaluer | |
max_examples_option = st.selectbox( | |
"Nombre maximum d'audios à traiter", | |
options=["1", "5", "10", "Tous"], | |
index=3 # par défaut "Tous" | |
) | |
# Section : Bouton pour lancer l'évaluation | |
start_eval = st.button("🚀 Lancer l'évaluation WER") | |
if start_eval: | |
st.subheader("🔍 Traitement en cours...") | |
# 🔹 Télécharger dataset | |
with st.spinner("Chargement du dataset..."): | |
try: | |
dataset_full = load_dataset(dataset_link, split="train", token=hf_token) | |
# 🔹 Filtrage selon la colonne 'split' | |
if split_option != "Tous": | |
dataset = dataset_full.filter(lambda x: x.get("split", "unknown") == split_option) | |
else: | |
dataset = dataset_full | |
if len(dataset) == 0: | |
st.warning(f"Aucun exemple trouvé pour le split sélectionné : '{split_option}'.") | |
st.stop() | |
except Exception as e: | |
st.error(f"Erreur lors du chargement du dataset : {e}") | |
st.stop() | |
# Limiter le nombre d'exemples selon la sélection | |
if max_examples_option != "Tous": | |
max_examples = int(max_examples_option) | |
dataset = dataset.select(range(min(max_examples, len(dataset)))) | |
# 🔹 Charger le modèle choisi | |
with st.spinner("Chargement du modèle..."): | |
base_model_name = "openai/whisper-large" | |
model = WhisperForConditionalGeneration.from_pretrained(base_model_name) | |
if "LoRA" in model_option: | |
model = PeftModel.from_pretrained(model, "SimpleFrog/whisper_finetuned", token=hf_token) | |
processor = WhisperProcessor.from_pretrained(base_model_name) | |
model.eval() | |
# Charger le pipeline de Mistral si post-processing demandé | |
if "Post-processing" in model_option: | |
with st.spinner("Chargement du modèle de post-traitement Mistral..."): | |
postproc_pipe = pipeline( | |
"text2text-generation", | |
model="NousResearch/Nous-Hermes-2-Mistral-7B-DPO", | |
device_map="auto", # ou device=0 si tu veux forcer le GPU | |
torch_dtype=torch.float16 # optionnel mais plus léger | |
) | |
st.success("✅ Modèle Mistral chargé.") | |
def postprocess_with_llm(text): | |
prompt = f"Ce texte est issue d'une translation vocal. L'enregistrement est tiré d'une inspection détaillé de pont et comprend du vocabulaire technique associé. Corriges les éventuelles erreurs de translation : {text}" | |
result = postproc_pipe(prompt, max_new_tokens=256)[0]["generated_text"] | |
return result.strip() | |
# 🔹 Préparer WER metric | |
wer_metric = evaluate.load("wer") | |
results = [] | |
# Téléchargement explicite du dossier audio (chemin local vers chaque fichier) | |
repo_local_path = snapshot_download(repo_id=dataset_link, repo_type="dataset", token=hf_token) | |
for example in dataset: | |
st.write("Exemple brut :", example) | |
try: | |
reference = example["text"] | |
waveform = example["audio"]["array"] | |
audio_path = example["audio"]["path"] | |
waveform = np.expand_dims(waveform, axis=0) | |
inputs = processor(waveform, sampling_rate=16000, return_tensors="pt") | |
with torch.no_grad(): | |
pred_ids = model.generate(input_features=inputs.input_features) | |
prediction = processor.batch_decode(pred_ids, skip_special_tokens=True)[0] | |
# === Post-processing conditionnel === | |
if "Post-processing" in model_option: | |
st.write("⏳ Post-processing avec Mistral...") | |
postprocessed_prediction = postprocess_with_llm(prediction) | |
st.write("✅ Post-processing terminé.") | |
final_prediction = postprocessed_prediction | |
else: | |
postprocessed_prediction = "-" | |
final_prediction = prediction | |
# 🔹 Nettoyage ponctuation pour WER "sans ponctuation" | |
def clean(text): | |
return ''.join([c for c in text.lower() if c.isalnum() or c.isspace()]).strip() | |
ref_clean = clean(reference) | |
pred_clean = clean(final_prediction) | |
wer = wer_metric.compute(predictions=[pred_clean], references=[ref_clean]) | |
results.append({ | |
"Fichier": audio_path, | |
"Référence": reference, | |
"Transcription brute": prediction, | |
"Transcription corrigée": postprocessed_prediction, | |
"WER": round(wer, 4) | |
}) | |
except Exception as e: | |
results.append({ | |
"Fichier": example["audio"].get("path", "unknown"), | |
"Référence": "Erreur", | |
"Transcription brute": f"Erreur: {e}", | |
"Transcription corrigée": "-", | |
"WER": "-" | |
}) | |
# 🔹 Générer le tableau de résultats | |
df = pd.DataFrame(results) | |
# 🔹 Créer un fichier temporaire pour le CSV | |
with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".csv") as tmp_csv: | |
df.to_csv(tmp_csv.name, index=False) | |
mean_wer = df[df["WER"] != "-"]["WER"].mean() | |
st.markdown(f"### 🎯 WER moyen (sans ponctuation) : `{mean_wer:.3f}`") | |
# 🔹 Bouton de téléchargement | |
with open(tmp_csv.name, "rb") as f: | |
st.download_button( | |
label="📥 Télécharger les résultats WER (.csv)", | |
data=f, | |
file_name="wer_results.csv", | |
mime="text/csv" | |
) | |