WER_Evaluation / app.py
SimpleFrog's picture
Update app.py
a8cd8c4 verified
raw
history blame
7.71 kB
import os
os.environ["XDG_CONFIG_HOME"] = "/tmp"
os.environ["XDG_CACHE_HOME"] = "/tmp"
os.environ["HF_HOME"] = "/tmp/huggingface" # pour les modèles/datasets
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers"
os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hub"
import streamlit as st
import tempfile
import pandas as pd
from datasets import load_dataset
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from peft import PeftModel
import torch
import librosa
import numpy as np
import evaluate
import tempfile
from huggingface_hub import snapshot_download
from transformers import pipeline
st.title("📊 Évaluation WER d'un modèle Whisper")
st.markdown("Ce Space permet d'évaluer la performance WER d'un modèle Whisper sur un dataset audio.")
# Section : Choix du modèle
st.subheader("1. Choix du modèle")
model_option = st.radio("Quel modèle veux-tu utiliser ?", (
"Whisper Large (baseline)",
"Whisper Large + LoRA (SimpleFrog/whisper_finetuned)",
"Whisper Large + LoRA + Post-processing"
))
# Section : Lien du dataset
st.subheader("2. Chargement du dataset Hugging Face")
dataset_link = st.text_input("Lien du dataset (format: user/dataset_name)", value="SimpleFrog/Dataset_Test")
hf_token = st.text_input("Token Hugging Face (si dataset privé)", type="password")
if hf_token:
from huggingface_hub import login
login(hf_token)
# Section : Choix du split
split_option = st.selectbox(
"Choix du split à évaluer",
options=["Tous", "train", "validation", "test"],
index=0 # par défaut "Tous"
)
# Section : Choix du nombre maximal d'exemples à évaluer
max_examples_option = st.selectbox(
"Nombre maximum d'audios à traiter",
options=["1", "5", "10", "Tous"],
index=3 # par défaut "Tous"
)
# Section : Bouton pour lancer l'évaluation
start_eval = st.button("🚀 Lancer l'évaluation WER")
if start_eval:
st.subheader("🔍 Traitement en cours...")
# 🔹 Télécharger dataset
with st.spinner("Chargement du dataset..."):
try:
dataset_full = load_dataset(dataset_link, split="train", token=hf_token)
# 🔹 Filtrage selon la colonne 'split'
if split_option != "Tous":
dataset = dataset_full.filter(lambda x: x.get("split", "unknown") == split_option)
else:
dataset = dataset_full
if len(dataset) == 0:
st.warning(f"Aucun exemple trouvé pour le split sélectionné : '{split_option}'.")
st.stop()
except Exception as e:
st.error(f"Erreur lors du chargement du dataset : {e}")
st.stop()
# Limiter le nombre d'exemples selon la sélection
if max_examples_option != "Tous":
max_examples = int(max_examples_option)
dataset = dataset.select(range(min(max_examples, len(dataset))))
# 🔹 Charger le modèle choisi
with st.spinner("Chargement du modèle..."):
base_model_name = "openai/whisper-large"
model = WhisperForConditionalGeneration.from_pretrained(base_model_name)
if "LoRA" in model_option:
model = PeftModel.from_pretrained(model, "SimpleFrog/whisper_finetuned", token=hf_token)
processor = WhisperProcessor.from_pretrained(base_model_name)
model.eval()
# Charger le pipeline de Mistral si post-processing demandé
if "Post-processing" in model_option:
with st.spinner("Chargement du modèle de post-traitement Mistral..."):
postproc_pipe = pipeline(
"text2text-generation",
model="mistralai/Mistral-7B-Instruct-v0.2",
device_map="auto", # ou device=0 si tu veux forcer le GPU
torch_dtype=torch.float16 # optionnel mais plus léger
)
st.success("✅ Modèle Mistral chargé.")
def postprocess_with_llm(text):
prompt = f"Tu es CorrecteurAI, une AI française qui permet de corriger les erreurs de saisie vocal. La translation d'un enregistrement audio tiré d'une inspection détaillé de pont t'es envoyé et tu renvoies le texte identique mais avec les éventuelles corrections si des erreurs sont détectés. Le texte peut comprendre du vocabulaire technique associé aux ouvrages d'art. Renvoies uniquement le texte corrigé en français et sans autre commentaire. Voici le texte : {text}"
result = postproc_pipe(prompt, max_new_tokens=256)[0]["generated_text"]
return result.strip()
# 🔹 Préparer WER metric
wer_metric = evaluate.load("wer")
results = []
# Téléchargement explicite du dossier audio (chemin local vers chaque fichier)
repo_local_path = snapshot_download(repo_id=dataset_link, repo_type="dataset", token=hf_token)
for example in dataset:
st.write("Exemple brut :", example)
try:
reference = example["text"]
waveform = example["audio"]["array"]
audio_path = example["audio"]["path"]
waveform = np.expand_dims(waveform, axis=0)
inputs = processor(waveform, sampling_rate=16000, return_tensors="pt")
with torch.no_grad():
pred_ids = model.generate(input_features=inputs.input_features)
prediction = processor.batch_decode(pred_ids, skip_special_tokens=True)[0]
# === Post-processing conditionnel ===
if "Post-processing" in model_option:
st.write("⏳ Post-processing avec Mistral...")
postprocessed_prediction = postprocess_with_llm(prediction)
st.write("✅ Post-processing terminé.")
final_prediction = postprocessed_prediction
else:
postprocessed_prediction = "-"
final_prediction = prediction
# 🔹 Nettoyage ponctuation pour WER "sans ponctuation"
def clean(text):
return ''.join([c for c in text.lower() if c.isalnum() or c.isspace()]).strip()
ref_clean = clean(reference)
pred_clean = clean(final_prediction)
wer = wer_metric.compute(predictions=[pred_clean], references=[ref_clean])
results.append({
"Fichier": audio_path,
"Référence": reference,
"Transcription brute": prediction,
"Transcription corrigée": postprocessed_prediction,
"WER": round(wer, 4)
})
except Exception as e:
results.append({
"Fichier": example["audio"].get("path", "unknown"),
"Référence": "Erreur",
"Transcription brute": f"Erreur: {e}",
"Transcription corrigée": "-",
"WER": "-"
})
# 🔹 Générer le tableau de résultats
df = pd.DataFrame(results)
# 🔹 Créer un fichier temporaire pour le CSV
with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".csv") as tmp_csv:
df.to_csv(tmp_csv.name, index=False)
mean_wer = df[df["WER"] != "-"]["WER"].mean()
st.markdown(f"### 🎯 WER moyen (sans ponctuation) : `{mean_wer:.3f}`")
# 🔹 Bouton de téléchargement
with open(tmp_csv.name, "rb") as f:
st.download_button(
label="📥 Télécharger les résultats WER (.csv)",
data=f,
file_name="wer_results.csv",
mime="text/csv"
)