import os os.environ["XDG_CONFIG_HOME"] = "/tmp" os.environ["XDG_CACHE_HOME"] = "/tmp" os.environ["HF_HOME"] = "/tmp/huggingface" # pour les modèles/datasets os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers" os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hub" import streamlit as st import tempfile import pandas as pd from datasets import load_dataset from transformers import WhisperProcessor, WhisperForConditionalGeneration from peft import PeftModel import torch import librosa import numpy as np import evaluate import tempfile from huggingface_hub import snapshot_download from transformers import pipeline st.title("📊 Évaluation WER d'un modèle Whisper") st.markdown("Ce Space permet d'évaluer la performance WER d'un modèle Whisper sur un dataset audio.") # Section : Choix du modèle st.subheader("1. Choix du modèle") model_option = st.radio("Quel modèle veux-tu utiliser ?", ( "Whisper Large (baseline)", "Whisper Large + LoRA (SimpleFrog/whisper_finetuned)", "Whisper Large + LoRA + Post-processing" )) # Section : Lien du dataset st.subheader("2. Chargement du dataset Hugging Face") dataset_link = st.text_input("Lien du dataset (format: user/dataset_name)", value="SimpleFrog/Dataset_Test") hf_token = st.text_input("Token Hugging Face (si dataset privé)", type="password") if hf_token: from huggingface_hub import login login(hf_token) # Section : Choix du split split_option = st.selectbox( "Choix du split à évaluer", options=["Tous", "train", "validation", "test"], index=0 # par défaut "Tous" ) # Section : Choix du nombre maximal d'exemples à évaluer max_examples_option = st.selectbox( "Nombre maximum d'audios à traiter", options=["1", "5", "10", "Tous"], index=3 # par défaut "Tous" ) # Section : Bouton pour lancer l'évaluation start_eval = st.button("🚀 Lancer l'évaluation WER") if start_eval: st.subheader("🔍 Traitement en cours...") # 🔹 Télécharger dataset with st.spinner("Chargement du dataset..."): try: dataset_full = load_dataset(dataset_link, split="train", token=hf_token) # 🔹 Filtrage selon la colonne 'split' if split_option != "Tous": dataset = dataset_full.filter(lambda x: x.get("split", "unknown") == split_option) else: dataset = dataset_full if len(dataset) == 0: st.warning(f"Aucun exemple trouvé pour le split sélectionné : '{split_option}'.") st.stop() except Exception as e: st.error(f"Erreur lors du chargement du dataset : {e}") st.stop() # Limiter le nombre d'exemples selon la sélection if max_examples_option != "Tous": max_examples = int(max_examples_option) dataset = dataset.select(range(min(max_examples, len(dataset)))) # 🔹 Charger le modèle choisi with st.spinner("Chargement du modèle..."): base_model_name = "openai/whisper-large" model = WhisperForConditionalGeneration.from_pretrained(base_model_name) if "LoRA" in model_option: model = PeftModel.from_pretrained(model, "SimpleFrog/whisper_finetuned", token=hf_token) processor = WhisperProcessor.from_pretrained(base_model_name) model.eval() # Charger le pipeline de Mistral si post-processing demandé if "Post-processing" in model_option: with st.spinner("Chargement du modèle de post-traitement Mistral..."): postproc_pipe = pipeline( "text2text-generation", model="NousResearch/Nous-Hermes-2-Mistral-7B-DPO", device_map="auto", # ou device=0 si tu veux forcer le GPU torch_dtype=torch.float16 # optionnel mais plus léger ) st.success("✅ Modèle Mistral chargé.") def postprocess_with_llm(text): prompt = f"Ce texte est issue d'une translation vocal. L'enregistrement est tiré d'une inspection détaillé de pont et comprend du vocabulaire technique associé. Corriges les éventuelles erreurs de translation : {text}" result = postproc_pipe(prompt, max_new_tokens=256)[0]["generated_text"] return result.strip() # 🔹 Préparer WER metric wer_metric = evaluate.load("wer") results = [] # Téléchargement explicite du dossier audio (chemin local vers chaque fichier) repo_local_path = snapshot_download(repo_id=dataset_link, repo_type="dataset", token=hf_token) for example in dataset: st.write("Exemple brut :", example) try: reference = example["text"] waveform = example["audio"]["array"] audio_path = example["audio"]["path"] waveform = np.expand_dims(waveform, axis=0) inputs = processor(waveform, sampling_rate=16000, return_tensors="pt") with torch.no_grad(): pred_ids = model.generate(input_features=inputs.input_features) prediction = processor.batch_decode(pred_ids, skip_special_tokens=True)[0] # === Post-processing conditionnel === if "Post-processing" in model_option: st.write("⏳ Post-processing avec Mistral...") postprocessed_prediction = postprocess_with_llm(prediction) st.write("✅ Post-processing terminé.") final_prediction = postprocessed_prediction else: postprocessed_prediction = "-" final_prediction = prediction # 🔹 Nettoyage ponctuation pour WER "sans ponctuation" def clean(text): return ''.join([c for c in text.lower() if c.isalnum() or c.isspace()]).strip() ref_clean = clean(reference) pred_clean = clean(final_prediction) wer = wer_metric.compute(predictions=[pred_clean], references=[ref_clean]) results.append({ "Fichier": audio_path, "Référence": reference, "Transcription brute": prediction, "Transcription corrigée": postprocessed_prediction, "WER": round(wer, 4) }) except Exception as e: results.append({ "Fichier": example["audio"].get("path", "unknown"), "Référence": "Erreur", "Transcription brute": f"Erreur: {e}", "Transcription corrigée": "-", "WER": "-" }) # 🔹 Générer le tableau de résultats df = pd.DataFrame(results) # 🔹 Créer un fichier temporaire pour le CSV with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".csv") as tmp_csv: df.to_csv(tmp_csv.name, index=False) mean_wer = df[df["WER"] != "-"]["WER"].mean() st.markdown(f"### 🎯 WER moyen (sans ponctuation) : `{mean_wer:.3f}`") # 🔹 Bouton de téléchargement with open(tmp_csv.name, "rb") as f: st.download_button( label="📥 Télécharger les résultats WER (.csv)", data=f, file_name="wer_results.csv", mime="text/csv" )