Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -18,6 +18,7 @@ import evaluate
|
|
18 |
import tempfile
|
19 |
from huggingface_hub import snapshot_download
|
20 |
from transformers import pipeline
|
|
|
21 |
|
22 |
|
23 |
st.title("📊 Évaluation WER d'un modèle Whisper")
|
@@ -28,7 +29,8 @@ st.subheader("1. Choix du modèle")
|
|
28 |
model_option = st.radio("Quel modèle veux-tu utiliser ?", (
|
29 |
"Whisper Large (baseline)",
|
30 |
"Whisper Large + LoRA (SimpleFrog/whisper_finetuned)",
|
31 |
-
"Whisper Large + LoRA + Post-processing"
|
|
|
32 |
))
|
33 |
|
34 |
# Section : Lien du dataset
|
@@ -36,6 +38,8 @@ st.subheader("2. Chargement du dataset Hugging Face")
|
|
36 |
dataset_link = st.text_input("Lien du dataset (format: user/dataset_name)", value="SimpleFrog/Dataset_Test")
|
37 |
hf_token = st.text_input("Token Hugging Face (si dataset privé)", type="password")
|
38 |
|
|
|
|
|
39 |
if hf_token:
|
40 |
from huggingface_hub import login
|
41 |
login(hf_token)
|
@@ -97,7 +101,7 @@ if start_eval:
|
|
97 |
model.eval()
|
98 |
|
99 |
# Charger le pipeline de Mistral si post-processing demandé
|
100 |
-
if "Post-processing" in model_option:
|
101 |
with st.spinner("Chargement du modèle de post-traitement Mistral..."):
|
102 |
postproc_pipe = pipeline(
|
103 |
"text2text-generation",
|
@@ -111,6 +115,21 @@ if start_eval:
|
|
111 |
prompt = f"Tu es CorrecteurAI, une AI française qui permet de corriger les erreurs de saisie vocal. La translation d'un enregistrement audio tiré d'une inspection détaillé de pont t'es envoyé et tu renvoies le texte identique mais avec les éventuelles corrections si des erreurs sont détectés. Le texte peut comprendre du vocabulaire technique associé aux ouvrages d'art. Renvoies uniquement le texte corrigé en français et sans autre commentaire. Voici le texte : {text}"
|
112 |
result = postproc_pipe(prompt, max_new_tokens=256)[0]["generated_text"]
|
113 |
return result.strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
|
115 |
|
116 |
# 🔹 Préparer WER metric
|
@@ -138,11 +157,23 @@ if start_eval:
|
|
138 |
prediction = processor.batch_decode(pred_ids, skip_special_tokens=True)[0]
|
139 |
|
140 |
# === Post-processing conditionnel ===
|
141 |
-
if "Post-processing" in model_option:
|
142 |
st.write("⏳ Post-processing avec Mistral...")
|
143 |
postprocessed_prediction = postprocess_with_llm(prediction)
|
144 |
-
st.write("✅
|
145 |
final_prediction = postprocessed_prediction
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
else:
|
147 |
postprocessed_prediction = "-"
|
148 |
final_prediction = prediction
|
|
|
18 |
import tempfile
|
19 |
from huggingface_hub import snapshot_download
|
20 |
from transformers import pipeline
|
21 |
+
import openai
|
22 |
|
23 |
|
24 |
st.title("📊 Évaluation WER d'un modèle Whisper")
|
|
|
29 |
model_option = st.radio("Quel modèle veux-tu utiliser ?", (
|
30 |
"Whisper Large (baseline)",
|
31 |
"Whisper Large + LoRA (SimpleFrog/whisper_finetuned)",
|
32 |
+
"Whisper Large + LoRA + Post-processing Mistral 7B",
|
33 |
+
"Whisper Large + LoRA + Post-processing GPT-4o"
|
34 |
))
|
35 |
|
36 |
# Section : Lien du dataset
|
|
|
38 |
dataset_link = st.text_input("Lien du dataset (format: user/dataset_name)", value="SimpleFrog/Dataset_Test")
|
39 |
hf_token = st.text_input("Token Hugging Face (si dataset privé)", type="password")
|
40 |
|
41 |
+
openai_api_key = st.text_input("Clé API OpenAI (pour GPT-4o)", type="password")
|
42 |
+
|
43 |
if hf_token:
|
44 |
from huggingface_hub import login
|
45 |
login(hf_token)
|
|
|
101 |
model.eval()
|
102 |
|
103 |
# Charger le pipeline de Mistral si post-processing demandé
|
104 |
+
if "Post-processing Mistral" in model_option:
|
105 |
with st.spinner("Chargement du modèle de post-traitement Mistral..."):
|
106 |
postproc_pipe = pipeline(
|
107 |
"text2text-generation",
|
|
|
115 |
prompt = f"Tu es CorrecteurAI, une AI française qui permet de corriger les erreurs de saisie vocal. La translation d'un enregistrement audio tiré d'une inspection détaillé de pont t'es envoyé et tu renvoies le texte identique mais avec les éventuelles corrections si des erreurs sont détectés. Le texte peut comprendre du vocabulaire technique associé aux ouvrages d'art. Renvoies uniquement le texte corrigé en français et sans autre commentaire. Voici le texte : {text}"
|
116 |
result = postproc_pipe(prompt, max_new_tokens=256)[0]["generated_text"]
|
117 |
return result.strip()
|
118 |
+
|
119 |
+
|
120 |
+
#fonction process GPT4o
|
121 |
+
def postprocess_with_gpt4o(text, api_key):
|
122 |
+
openai.api_key = api_key
|
123 |
+
response = openai.ChatCompletion.create(
|
124 |
+
model="gpt-4o",
|
125 |
+
messages=[
|
126 |
+
{"role": "system", "content": "Tu es un assistant qui corrige et ponctue des transcriptions vocales françaises sans changer le sens du texte. Répond uniquement avec le texte corrigé."},
|
127 |
+
{"role": "user", "content": f"Corrige ce texte : {text}"}
|
128 |
+
],
|
129 |
+
temperature=0.3,
|
130 |
+
max_tokens=512
|
131 |
+
)
|
132 |
+
return response.choices[0].message["content"].strip()
|
133 |
|
134 |
|
135 |
# 🔹 Préparer WER metric
|
|
|
157 |
prediction = processor.batch_decode(pred_ids, skip_special_tokens=True)[0]
|
158 |
|
159 |
# === Post-processing conditionnel ===
|
160 |
+
if "Post-processing Mistral" in model_option:
|
161 |
st.write("⏳ Post-processing avec Mistral...")
|
162 |
postprocessed_prediction = postprocess_with_llm(prediction)
|
163 |
+
st.write("✅ Terminé.")
|
164 |
final_prediction = postprocessed_prediction
|
165 |
+
|
166 |
+
elif "Post-processing GPT-4o" in model_option:
|
167 |
+
if not openai_api_key:
|
168 |
+
st.error("Clé API OpenAI requise pour GPT-4o.")
|
169 |
+
st.stop()
|
170 |
+
st.write("🤖 Post-processing avec GPT-4o...")
|
171 |
+
try:
|
172 |
+
postprocessed_prediction = postprocess_with_gpt4o(prediction, openai_api_key)
|
173 |
+
except Exception as e:
|
174 |
+
postprocessed_prediction = f"[Erreur GPT-4o: {e}]"
|
175 |
+
final_prediction = postprocessed_prediction
|
176 |
+
|
177 |
else:
|
178 |
postprocessed_prediction = "-"
|
179 |
final_prediction = prediction
|