Marina Kpamegan commited on
Commit
87e9667
·
1 Parent(s): 103eb2f

test model for front end

Browse files
src/config.py CHANGED
@@ -19,3 +19,7 @@ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
19
 
20
  # Modèle Wav2Vec2
21
  MODEL_NAME = "facebook/wav2vec2-large-xlsr-53-french"
 
 
 
 
 
19
 
20
  # Modèle Wav2Vec2
21
  MODEL_NAME = "facebook/wav2vec2-large-xlsr-53-french"
22
+
23
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
24
+ BEST_MODEL_NAME = os.path.join(BASE_DIR, "..", "best_model.pth") # Monte d'un niveau pour aller à la racine
25
+
src/predict.py CHANGED
@@ -5,14 +5,13 @@ import librosa
5
  import numpy as np
6
  from src.model.emotion_classifier import EmotionClassifier
7
  from src.utils.preprocessing import collate_fn
8
- from src.config import DEVICE, NUM_LABELS
9
  import os
10
 
11
  # Charger le modèle entraîné
12
- MODEL_PATH = "acc_model.pth"
13
  feature_dim = 40 # Nombre de MFCCs utilisés
14
  model = EmotionClassifier(feature_dim, NUM_LABELS).to(DEVICE)
15
- model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
16
  model.eval() # Mode évaluation
17
 
18
  # Fonction pour prédire l’émotion d’un fichier audio
 
5
  import numpy as np
6
  from src.model.emotion_classifier import EmotionClassifier
7
  from src.utils.preprocessing import collate_fn
8
+ from src.config import DEVICE, NUM_LABELS, BEST_MODEL_NAME
9
  import os
10
 
11
  # Charger le modèle entraîné
 
12
  feature_dim = 40 # Nombre de MFCCs utilisés
13
  model = EmotionClassifier(feature_dim, NUM_LABELS).to(DEVICE)
14
+ model.load_state_dict(torch.load(BEST_MODEL_NAME, map_location=DEVICE))
15
  model.eval() # Mode évaluation
16
 
17
  # Fonction pour prédire l’émotion d’un fichier audio
src/test_backend.ipynb ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 5,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stdout",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "Transcription : tu as encore oublié de faire le dossier c'était hurgent nom de chien\n"
13
+ ]
14
+ }
15
+ ],
16
+ "source": [
17
+ "# make a transcription from audio file\n",
18
+ "from model.transcriber import transcribe_audio\n",
19
+ "import os\n",
20
+ "\n",
21
+ "base_path = os.path.abspath(os.path.join(\"data\"))\n",
22
+ "audio_path = os.path.join(base_path, \"colere\", \"c1af.wav\") # path to audio file\n",
23
+ "texte = transcribe_audio(audio_path)\n",
24
+ "print(f\"Transcription : {texte}\")"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": null,
30
+ "metadata": {},
31
+ "outputs": [],
32
+ "source": [
33
+ "from predict import predict_emotion\n",
34
+ "\n",
35
+ "base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), \"data\"))\n",
36
+ "audio_file = os.path.join(base_path, \"colere\", \"c1ac.wav\")\n",
37
+ "emotion = predict_emotion(audio_file)\n",
38
+ "print(f\"🎤 L'émotion prédite est : {emotion}\")"
39
+ ]
40
+ }
41
+ ],
42
+ "metadata": {
43
+ "kernelspec": {
44
+ "display_name": ".venv",
45
+ "language": "python",
46
+ "name": "python3"
47
+ },
48
+ "language_info": {
49
+ "codemirror_mode": {
50
+ "name": "ipython",
51
+ "version": 3
52
+ },
53
+ "file_extension": ".py",
54
+ "mimetype": "text/x-python",
55
+ "name": "python",
56
+ "nbconvert_exporter": "python",
57
+ "pygments_lexer": "ipython3",
58
+ "version": "3.11.5"
59
+ }
60
+ },
61
+ "nbformat": 4,
62
+ "nbformat_minor": 2
63
+ }
src/train.py CHANGED
@@ -6,7 +6,7 @@ from sklearn.metrics import accuracy_score
6
  from utils.dataset import load_audio_data
7
  from utils.preprocessing import preprocess_audio, prepare_features, collate_fn
8
  from model.emotion_classifier import EmotionClassifier
9
- from src.config import DEVICE, NUM_LABELS
10
  import os
11
 
12
  # Charger les données et les séparer en train / test
@@ -51,7 +51,7 @@ def train_classifier(classifier, train_loader, test_loader, epochs=20):
51
 
52
  if train_acc > best_accuracy:
53
  best_accuracy = train_acc
54
- torch.save(classifier.state_dict(), "best_model.pth")
55
  print(f"✔️ Nouveau meilleur modèle sauvegardé ! Accuracy: {best_accuracy:.4f}")
56
 
57
  print(f"📢 Epoch {epoch+1}/{epochs} - Loss: {total_loss:.4f} - Accuracy: {train_acc:.4f}")
 
6
  from utils.dataset import load_audio_data
7
  from utils.preprocessing import preprocess_audio, prepare_features, collate_fn
8
  from model.emotion_classifier import EmotionClassifier
9
+ from config import DEVICE, NUM_LABELS, BEST_MODEL_NAME
10
  import os
11
 
12
  # Charger les données et les séparer en train / test
 
51
 
52
  if train_acc > best_accuracy:
53
  best_accuracy = train_acc
54
+ torch.save(classifier.state_dict(), BEST_MODEL_NAME)
55
  print(f"✔️ Nouveau meilleur modèle sauvegardé ! Accuracy: {best_accuracy:.4f}")
56
 
57
  print(f"📢 Epoch {epoch+1}/{epochs} - Loss: {total_loss:.4f} - Accuracy: {train_acc:.4f}")
src/utils/dataset.py CHANGED
@@ -6,7 +6,7 @@ import pandas as pd
6
  import os
7
  from datasets import Dataset, DatasetDict
8
  import pandas as pd
9
- from config import LABELS
10
 
11
  def load_audio_data(data_dir):
12
  data = []
 
6
  import os
7
  from datasets import Dataset, DatasetDict
8
  import pandas as pd
9
+ from src.config import LABELS
10
 
11
  def load_audio_data(data_dir):
12
  data = []
views/studio.py CHANGED
@@ -1,7 +1,7 @@
1
  import streamlit as st
2
  from st_audiorec import st_audiorec
3
 
4
- # from src.model.transcriber import transcribe_audio
5
 
6
 
7
  def studio():
@@ -23,7 +23,7 @@ def studio():
23
  with tab1:
24
  st.header("⬆️ Upload Audio Record")
25
  st.write("Here you can upload a pre-recorded audio.")
26
- audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "ogg"])
27
 
28
  if "audio_file" not in st.session_state:
29
  st.session_state.audio_file = None
@@ -52,6 +52,7 @@ def studio():
52
  st.success("Audio recorded successfully !")
53
  st.session_state.audio_file = audio_file
54
 
 
55
  # Boutons pour démarrer et arrêter l'enregistrement
56
  # start_button = st.button("Démarrer l'enregistrement")
57
  # stop_button = st.button("Arrêter l'enregistrement")
@@ -103,24 +104,7 @@ def studio():
103
  # emotion_display.write(f"Émotion détectée : {emotion_prediction}")
104
  # # time.sleep(0.1)
105
 
106
- # audio.terminate()
107
-
108
-
109
-
110
-
111
-
112
-
113
-
114
-
115
-
116
-
117
-
118
-
119
-
120
-
121
-
122
-
123
-
124
 
125
 
126
  # stream = audio.open(format=FORMAT, channels=CHANNELS, rate=RATE, input=True, frames_per_buffer=CHUNK)
@@ -159,6 +143,7 @@ def studio():
159
  # final_emotion_placeholder.write(f"Émotion finale prédite : {final_emotion}")
160
 
161
 
 
162
 
163
  with tab3:
164
  st.header("📝 Speech2Text Transcription")
@@ -168,24 +153,24 @@ def studio():
168
 
169
  ############################# A décommenté quand ce sera débogué
170
  if st.button("Transcribe", key="transcribe-button"):
171
- # # Fonction pour transcrire l'audio
172
- # transcription = transcribe_audio(st.audio)
173
 
174
- # # Charger et transcrire l'audio
175
- # # audio, rate = load_audio(audio_file_path) # (re)chargement de l'audio si nécessaire
176
- # transcription = transcribe_audio(audio_file, sampling_rate=16000)
177
 
178
- # # Afficher la transcription
179
- # st.write("Transcription :", transcription)
180
 
181
  st.success("Audio registered successfully.")
182
- # if save:
183
- # file_path = "transcript.txt"
184
 
185
- # # Write the text to the file
186
- # with open(file_path, "w") as file:
187
- # file.write(transcription)
188
 
189
- # st.success(f"Text saved to {file_path}")
190
 
191
 
 
1
  import streamlit as st
2
  from st_audiorec import st_audiorec
3
 
4
+ from src.model.transcriber import transcribe_audio
5
 
6
 
7
  def studio():
 
23
  with tab1:
24
  st.header("⬆️ Upload Audio Record")
25
  st.write("Here you can upload a pre-recorded audio.")
26
+ audio_file = st.file_uploader("Upload an audio file", type=["wav"])
27
 
28
  if "audio_file" not in st.session_state:
29
  st.session_state.audio_file = None
 
52
  st.success("Audio recorded successfully !")
53
  st.session_state.audio_file = audio_file
54
 
55
+ ##############################################"realtime audio record"##############################################
56
  # Boutons pour démarrer et arrêter l'enregistrement
57
  # start_button = st.button("Démarrer l'enregistrement")
58
  # stop_button = st.button("Arrêter l'enregistrement")
 
104
  # emotion_display.write(f"Émotion détectée : {emotion_prediction}")
105
  # # time.sleep(0.1)
106
 
107
+ # audio.terminate(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
 
110
  # stream = audio.open(format=FORMAT, channels=CHANNELS, rate=RATE, input=True, frames_per_buffer=CHUNK)
 
143
  # final_emotion_placeholder.write(f"Émotion finale prédite : {final_emotion}")
144
 
145
 
146
+ ##############################################"end realtime audio record"##############################################
147
 
148
  with tab3:
149
  st.header("📝 Speech2Text Transcription")
 
153
 
154
  ############################# A décommenté quand ce sera débogué
155
  if st.button("Transcribe", key="transcribe-button"):
156
+ # Fonction pour transcrire l'audio
157
+ transcription = transcribe_audio(st.audio)
158
 
159
+ # Charger et transcrire l'audio
160
+ # audio, rate = load_audio(audio_file_path) # (re)chargement de l'audio si nécessaire
161
+ transcription = transcribe_audio(audio_file, sampling_rate=16000)
162
 
163
+ # Afficher la transcription
164
+ st.write("Transcription :", transcription)
165
 
166
  st.success("Audio registered successfully.")
167
+ if save:
168
+ file_path = "transcript.txt"
169
 
170
+ # Write the text to the file
171
+ with open(file_path, "w") as file:
172
+ file.write(transcription)
173
 
174
+ st.success(f"Text saved to {file_path}")
175
 
176