Spaces:
Sleeping
Sleeping
import os | |
os.environ["XDG_CONFIG_HOME"] = "/tmp" | |
os.environ["XDG_CACHE_HOME"] = "/tmp" | |
os.environ["HF_HOME"] = "/tmp/huggingface" # pour les modèles/datasets | |
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers" | |
os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hub" | |
import streamlit as st | |
import tempfile | |
import pandas as pd | |
from datasets import load_dataset | |
from transformers import WhisperProcessor, WhisperForConditionalGeneration | |
from peft import PeftModel | |
import torch | |
import librosa | |
import numpy as np | |
import evaluate | |
import tempfile | |
from huggingface_hub import snapshot_download | |
from transformers import pipeline | |
st.title("📊 Évaluation WER d'un modèle Whisper") | |
st.markdown("Ce Space permet d'évaluer la performance WER d'un modèle Whisper sur un dataset audio.") | |
# Section : Choix du modèle | |
st.subheader("1. Choix du modèle") | |
model_option = st.radio("Quel modèle veux-tu utiliser ?", ( | |
"Whisper Large (baseline)", | |
"Whisper Large + LoRA (SimpleFrog/whisper_finetuned)", | |
"Whisper Large + LoRA + Post-processing" | |
)) | |
# Section : Lien du dataset | |
st.subheader("2. Chargement du dataset Hugging Face") | |
dataset_link = st.text_input("Lien du dataset (format: user/dataset_name)", value="SimpleFrog/Dataset_Test") | |
hf_token = st.text_input("Token Hugging Face (si dataset privé)", type="password") | |
if hf_token: | |
from huggingface_hub import login | |
login(hf_token) | |
# Section : Choix du split | |
split_option = st.selectbox( | |
"Choix du split à évaluer", | |
options=["Tous", "train", "validation", "test"], | |
index=0 # par défaut "Tous" | |
) | |
# Section : Choix du nombre maximal d'exemples à évaluer | |
max_examples_option = st.selectbox( | |
"Nombre maximum d'audios à traiter", | |
options=["1", "5", "10", "Tous"], | |
index=3 # par défaut "Tous" | |
) | |
# Section : Bouton pour lancer l'évaluation | |
start_eval = st.button("🚀 Lancer l'évaluation WER") | |
if start_eval: | |
st.subheader("🔍 Traitement en cours...") | |
# 🔹 Télécharger dataset | |
with st.spinner("Chargement du dataset..."): | |
try: | |
dataset_full = load_dataset(dataset_link, split="train", token=hf_token) | |
# 🔹 Filtrage selon la colonne 'split' | |
if split_option != "Tous": | |
dataset = dataset_full.filter(lambda x: x.get("split", "unknown") == split_option) | |
else: | |
dataset = dataset_full | |
if len(dataset) == 0: | |
st.warning(f"Aucun exemple trouvé pour le split sélectionné : '{split_option}'.") | |
st.stop() | |
except Exception as e: | |
st.error(f"Erreur lors du chargement du dataset : {e}") | |
st.stop() | |
# Limiter le nombre d'exemples selon la sélection | |
if max_examples_option != "Tous": | |
max_examples = int(max_examples_option) | |
dataset = dataset.select(range(min(max_examples, len(dataset)))) | |
# 🔹 Charger le modèle choisi | |
with st.spinner("Chargement du modèle..."): | |
base_model_name = "openai/whisper-large" | |
model = WhisperForConditionalGeneration.from_pretrained(base_model_name) | |
if "LoRA" in model_option: | |
model = PeftModel.from_pretrained(model, "SimpleFrog/whisper_finetuned", token=hf_token) | |
processor = WhisperProcessor.from_pretrained(base_model_name) | |
model.eval() | |
# Charger le pipeline de Mistral si post-processing demandé | |
if "Post-processing" in model_option: | |
with st.spinner("Chargement du modèle de post-traitement Mistral..."): | |
postproc_pipe = pipeline( | |
"text2text-generation", | |
model="mistralai/Mistral-7B-Instruct-v0.2", | |
device_map="auto", # ou device=0 si tu veux forcer le GPU | |
torch_dtype=torch.float16 # optionnel mais plus léger | |
) | |
st.success("✅ Modèle Mistral chargé.") | |
def postprocess_with_llm(text): | |
prompt = f"Tu es CorrecteurAI, une AI française qui permet de corriger les erreurs de saisie vocal. La translation d'un enregistrement audio tiré d'une inspection détaillé de pont t'es envoyé et tu renvoies le texte identique mais avec les éventuelles corrections si des erreurs sont détectés. Le texte peut comprendre du vocabulaire technique associé aux ouvrages d'art. Renvoies uniquement le texte corrigé en français et sans autre commentaire. Voici le texte : {text}" | |
result = postproc_pipe(prompt, max_new_tokens=256)[0]["generated_text"] | |
return result.strip() | |
# 🔹 Préparer WER metric | |
wer_metric = evaluate.load("wer") | |
results = [] | |
# Téléchargement explicite du dossier audio (chemin local vers chaque fichier) | |
repo_local_path = snapshot_download(repo_id=dataset_link, repo_type="dataset", token=hf_token) | |
for example in dataset: | |
st.write("Exemple brut :", example) | |
try: | |
reference = example["text"] | |
waveform = example["audio"]["array"] | |
audio_path = example["audio"]["path"] | |
waveform = np.expand_dims(waveform, axis=0) | |
inputs = processor(waveform, sampling_rate=16000, return_tensors="pt") | |
with torch.no_grad(): | |
pred_ids = model.generate(input_features=inputs.input_features) | |
prediction = processor.batch_decode(pred_ids, skip_special_tokens=True)[0] | |
# === Post-processing conditionnel === | |
if "Post-processing" in model_option: | |
st.write("⏳ Post-processing avec Mistral...") | |
postprocessed_prediction = postprocess_with_llm(prediction) | |
st.write("✅ Post-processing terminé.") | |
final_prediction = postprocessed_prediction | |
else: | |
postprocessed_prediction = "-" | |
final_prediction = prediction | |
# 🔹 Nettoyage ponctuation pour WER "sans ponctuation" | |
def clean(text): | |
return ''.join([c for c in text.lower() if c.isalnum() or c.isspace()]).strip() | |
ref_clean = clean(reference) | |
pred_clean = clean(final_prediction) | |
wer = wer_metric.compute(predictions=[pred_clean], references=[ref_clean]) | |
results.append({ | |
"Fichier": audio_path, | |
"Référence": reference, | |
"Transcription brute": prediction, | |
"Transcription corrigée": postprocessed_prediction, | |
"WER": round(wer, 4) | |
}) | |
except Exception as e: | |
results.append({ | |
"Fichier": example["audio"].get("path", "unknown"), | |
"Référence": "Erreur", | |
"Transcription brute": f"Erreur: {e}", | |
"Transcription corrigée": "-", | |
"WER": "-" | |
}) | |
# 🔹 Générer le tableau de résultats | |
df = pd.DataFrame(results) | |
# 🔹 Créer un fichier temporaire pour le CSV | |
with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".csv") as tmp_csv: | |
df.to_csv(tmp_csv.name, index=False) | |
mean_wer = df[df["WER"] != "-"]["WER"].mean() | |
st.markdown(f"### 🎯 WER moyen (sans ponctuation) : `{mean_wer:.3f}`") | |
# 🔹 Bouton de téléchargement | |
with open(tmp_csv.name, "rb") as f: | |
st.download_button( | |
label="📥 Télécharger les résultats WER (.csv)", | |
data=f, | |
file_name="wer_results.csv", | |
mime="text/csv" | |
) | |