SimpleFrog commited on
Commit
d7eb8e2
·
verified ·
1 Parent(s): e47bd9f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -18
app.py CHANGED
@@ -17,6 +17,7 @@ import numpy as np
17
  import evaluate
18
  import tempfile
19
  from huggingface_hub import snapshot_download
 
20
 
21
  st.title("📊 Évaluation WER d'un modèle Whisper")
22
  st.markdown("Ce Space permet d'évaluer la performance WER d'un modèle Whisper sur un dataset audio.")
@@ -26,14 +27,21 @@ st.subheader("1. Choix du modèle")
26
  model_option = st.radio("Quel modèle veux-tu utiliser ?", (
27
  "Whisper Large (baseline)",
28
  "Whisper Large + LoRA (SimpleFrog/whisper_finetuned)",
29
- "Whisper Large + LoRA + Post-processing (à venir)"
30
  ))
31
 
32
  # Section : Lien du dataset
33
  st.subheader("2. Chargement du dataset Hugging Face")
34
- dataset_link = st.text_input("Lien du dataset (format: user/dataset_name)", value="SimpleFrog/mon_dataset")
35
  hf_token = st.text_input("Token Hugging Face (si dataset privé)", type="password")
36
 
 
 
 
 
 
 
 
37
  # Section : Bouton pour lancer l'évaluation
38
  start_eval = st.button("🚀 Lancer l'évaluation WER")
39
 
@@ -43,8 +51,19 @@ if start_eval:
43
  # 🔹 Télécharger dataset
44
  with st.spinner("Chargement du dataset..."):
45
  try:
46
- #dataset = load_dataset(dataset_link, data_files="metadata.csv", data_dir=".", split="train", token=hf_token)
47
- dataset = load_dataset(dataset_link, split="train", token=hf_token)
 
 
 
 
 
 
 
 
 
 
 
48
  except Exception as e:
49
  st.error(f"Erreur lors du chargement du dataset : {e}")
50
  st.stop()
@@ -60,6 +79,21 @@ if start_eval:
60
  processor = WhisperProcessor.from_pretrained(base_model_name)
61
  model.eval()
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  # 🔹 Préparer WER metric
64
  wer_metric = evaluate.load("wer")
65
 
@@ -71,18 +105,11 @@ if start_eval:
71
  for example in dataset:
72
  st.write("Exemple brut :", example)
73
  try:
74
- #audio_path = example["file_name"] # full path or relative path in AudioFolder
75
- #audio_path = os.path.join(repo_local_path, example["file_name"])
76
  reference = example["text"]
77
 
78
  waveform = example["audio"]["array"]
79
  audio_path = example["audio"]["path"]
80
-
81
- #st.write(example)
82
- #st.write("Exemple brut :", dataset[0])
83
-
84
- # Load audio (we assume dataset is structured with 'file_name')
85
- #waveform, _ = librosa.load(audio_path, sr=16000)
86
 
87
  waveform = np.expand_dims(waveform, axis=0)
88
  inputs = processor(waveform, sampling_rate=16000, return_tensors="pt")
@@ -91,18 +118,27 @@ if start_eval:
91
  pred_ids = model.generate(input_features=inputs.input_features)
92
  prediction = processor.batch_decode(pred_ids, skip_special_tokens=True)[0]
93
 
 
 
 
 
 
 
 
 
94
  # 🔹 Nettoyage ponctuation pour WER "sans ponctuation"
95
  def clean(text):
96
  return ''.join([c for c in text.lower() if c.isalnum() or c.isspace()]).strip()
97
 
98
  ref_clean = clean(reference)
99
- pred_clean = clean(prediction)
100
  wer = wer_metric.compute(predictions=[pred_clean], references=[ref_clean])
101
 
102
  results.append({
103
  "Fichier": audio_path,
104
  "Référence": reference,
105
- "Transcription": prediction,
 
106
  "WER": round(wer, 4)
107
  })
108
 
@@ -110,7 +146,8 @@ if start_eval:
110
  results.append({
111
  "Fichier": example["audio"].get("path", "unknown"),
112
  "Référence": "Erreur",
113
- "Transcription": f"Erreur: {e}",
 
114
  "WER": "-"
115
  })
116
 
@@ -122,11 +159,10 @@ if start_eval:
122
  df.to_csv(tmp_csv.name, index=False)
123
 
124
  mean_wer = df[df["WER"] != "-"]["WER"].mean()
 
125
  st.markdown(f"### 🎯 WER moyen (sans ponctuation) : `{mean_wer:.3f}`")
126
 
127
- # Bloc placeholder pour post-processing à venir
128
- if "Post-processing" in model_option:
129
- st.info("🛠️ Le post-processing sera ajouté prochainement ici...")
130
 
131
 
132
  # 🔹 Bouton de téléchargement
 
17
  import evaluate
18
  import tempfile
19
  from huggingface_hub import snapshot_download
20
+ from transformers import pipeline
21
 
22
  st.title("📊 Évaluation WER d'un modèle Whisper")
23
  st.markdown("Ce Space permet d'évaluer la performance WER d'un modèle Whisper sur un dataset audio.")
 
27
  model_option = st.radio("Quel modèle veux-tu utiliser ?", (
28
  "Whisper Large (baseline)",
29
  "Whisper Large + LoRA (SimpleFrog/whisper_finetuned)",
30
+ "Whisper Large + LoRA + Post-processing"
31
  ))
32
 
33
  # Section : Lien du dataset
34
  st.subheader("2. Chargement du dataset Hugging Face")
35
+ dataset_link = st.text_input("Lien du dataset (format: user/dataset_name)", value="SimpleFrog/Dataset_Test")
36
  hf_token = st.text_input("Token Hugging Face (si dataset privé)", type="password")
37
 
38
+ # Section : Choix du split
39
+ split_option = st.selectbox(
40
+ "Choix du split à évaluer",
41
+ options=["Tous", "train", "validation", "test"],
42
+ index=0 # par défaut "Tous"
43
+ )
44
+
45
  # Section : Bouton pour lancer l'évaluation
46
  start_eval = st.button("🚀 Lancer l'évaluation WER")
47
 
 
51
  # 🔹 Télécharger dataset
52
  with st.spinner("Chargement du dataset..."):
53
  try:
54
+
55
+ dataset_full = load_dataset(dataset_link, split="train", token=hf_token)
56
+
57
+ # 🔹 Filtrage selon la colonne 'split'
58
+ if split_option != "Tous":
59
+ dataset = dataset_full.filter(lambda x: x.get("split", "unknown") == split_option)
60
+ else:
61
+ dataset = dataset_full
62
+
63
+ if len(dataset) == 0:
64
+ st.warning(f"Aucun exemple trouvé pour le split sélectionné : '{split_option}'.")
65
+ st.stop()
66
+
67
  except Exception as e:
68
  st.error(f"Erreur lors du chargement du dataset : {e}")
69
  st.stop()
 
79
  processor = WhisperProcessor.from_pretrained(base_model_name)
80
  model.eval()
81
 
82
+ # Charger le pipeline de Mistral si post-processing demandé
83
+ if "Post-processing" in model_option:
84
+ with st.spinner("Chargement du modèle de post-traitement Mistral..."):
85
+ postproc_pipe = pipeline(
86
+ "text2text-generation",
87
+ model="mistralai/Mistral-7B-Instruct-v0.2",
88
+ device_map="auto", # ou device=0 si tu veux forcer le GPU
89
+ torch_dtype=torch.float16 # optionnel mais plus léger
90
+ )
91
+
92
+ def postprocess_with_llm(text):
93
+ prompt = f"Ce texte est issue d'une translation vocal. L'enregistrement est tiré d'une inspection détaillé de pont et comprend du vocabulaire technique associé. Corriges les éventuelles erreurs de translation : {text}"
94
+ result = postproc_pipe(prompt, max_new_tokens=256)[0]["generated_text"]
95
+ return result.strip()
96
+
97
  # 🔹 Préparer WER metric
98
  wer_metric = evaluate.load("wer")
99
 
 
105
  for example in dataset:
106
  st.write("Exemple brut :", example)
107
  try:
108
+
 
109
  reference = example["text"]
110
 
111
  waveform = example["audio"]["array"]
112
  audio_path = example["audio"]["path"]
 
 
 
 
 
 
113
 
114
  waveform = np.expand_dims(waveform, axis=0)
115
  inputs = processor(waveform, sampling_rate=16000, return_tensors="pt")
 
118
  pred_ids = model.generate(input_features=inputs.input_features)
119
  prediction = processor.batch_decode(pred_ids, skip_special_tokens=True)[0]
120
 
121
+ # === Post-processing conditionnel ===
122
+ if "Post-processing" in model_option:
123
+ postprocessed_prediction = postprocess_with_llm(prediction)
124
+ final_prediction = postprocessed_prediction
125
+ else:
126
+ postprocessed_prediction = "-"
127
+ final_prediction = prediction
128
+
129
  # 🔹 Nettoyage ponctuation pour WER "sans ponctuation"
130
  def clean(text):
131
  return ''.join([c for c in text.lower() if c.isalnum() or c.isspace()]).strip()
132
 
133
  ref_clean = clean(reference)
134
+ pred_clean = clean(final_prediction)
135
  wer = wer_metric.compute(predictions=[pred_clean], references=[ref_clean])
136
 
137
  results.append({
138
  "Fichier": audio_path,
139
  "Référence": reference,
140
+ "Transcription brute": prediction,
141
+ "Transcription corrigée": postprocessed_prediction,
142
  "WER": round(wer, 4)
143
  })
144
 
 
146
  results.append({
147
  "Fichier": example["audio"].get("path", "unknown"),
148
  "Référence": "Erreur",
149
+ "Transcription brute": f"Erreur: {e}",
150
+ "Transcription corrigée": "-",
151
  "WER": "-"
152
  })
153
 
 
159
  df.to_csv(tmp_csv.name, index=False)
160
 
161
  mean_wer = df[df["WER"] != "-"]["WER"].mean()
162
+
163
  st.markdown(f"### 🎯 WER moyen (sans ponctuation) : `{mean_wer:.3f}`")
164
 
165
+
 
 
166
 
167
 
168
  # 🔹 Bouton de téléchargement