Cyr-CK commited on
Commit
aaa3b8b
·
1 Parent(s): f9bbbb3

Added real-time emotion detection over an uploaded audio file

Browse files
.gitignore CHANGED
@@ -178,6 +178,7 @@ dataset/
178
  old/
179
  *.wav
180
  data/*
 
181
 
182
  # Mac
183
  .DS_Store
 
178
  old/
179
  *.wav
180
  data/*
181
+ *.pth
182
 
183
  # Mac
184
  .DS_Store
app.py CHANGED
@@ -3,6 +3,9 @@ from streamlit_option_menu import option_menu
3
  from views.application import application
4
  from views.about import about
5
 
 
 
 
6
  # Set the logo
7
  st.sidebar.image("img/logo.png", use_container_width=True)
8
 
 
3
  from views.application import application
4
  from views.about import about
5
 
6
+ if "model_loaded" not in st.session_state:
7
+ st.session_state.model_loaded = None
8
+
9
  # Set the logo
10
  st.sidebar.image("img/logo.png", use_container_width=True)
11
 
src/model/predict.py CHANGED
@@ -1,35 +1,48 @@
 
1
  import torch
2
  from transformers import Wav2Vec2Processor
3
- from model import Wav2Vec2EmotionClassifier
4
  import librosa
 
 
 
 
5
 
6
- # Charger le modèle et le processeur
7
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
- processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-xlsr-53")
9
- model = Wav2Vec2EmotionClassifier()
10
- model.load_state_dict(torch.load("wav2vec2_emotion.pth"))
11
- model.to(device)
12
- model.eval()
 
 
 
 
 
 
 
13
 
14
  emotion_labels = ["joie", "colère", "neutre"]
15
 
16
  def predict_emotion(audio_path, output_probs=False, sampling_rate=16000):
17
- waveform, _ = librosa.load(audio_path, sr=sampling_rate)
18
- input_values = processor(waveform, return_tensors="pt", sampling_rate=sampling_rate).input_values
19
  input_values = input_values.to(device)
20
-
21
  with torch.no_grad():
22
  outputs = model(input_values)
23
-
24
- if output_probs:
25
  # Appliquer softmax pour obtenir des probabilités
26
- probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
27
 
28
  # Convertir en numpy array et prendre le premier (et seul) élément
29
  probabilities = probabilities[0].detach().cpu().numpy()
30
 
31
  # Créer un dictionnaire associant chaque émotion à sa probabilité
32
  emotion_probabilities = {emotion: prob for emotion, prob in zip(emotion_labels, probabilities)}
 
 
33
  return emotion_probabilities
34
  else:
35
  # Obtenir l'émotion la plus probable (i.e. la prédiction)
@@ -38,6 +51,6 @@ def predict_emotion(audio_path, output_probs=False, sampling_rate=16000):
38
 
39
 
40
  # Exemple d'utilisation
41
- audio_test = "data/n1ac.wav"
42
- emotion = predict_emotion(audio_test)
43
- print(f"Émotion détectée : {emotion}")
 
1
+ import os
2
  import torch
3
  from transformers import Wav2Vec2Processor
4
+ from src.model.emotion_classifier import Wav2Vec2EmotionClassifier
5
  import librosa
6
+ import streamlit as st
7
+
8
+ if "model_loaded" not in st.session_state:
9
+ st.session_state.model_loaded = None
10
 
 
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ # Charger le modèle et le processeur
13
+ if st.session_state.model_loaded is None:
14
+ st.session_state.processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-xlsr-53-french")
15
+ st.session_state.model = Wav2Vec2EmotionClassifier()
16
+ st.session_state.model.load_state_dict(torch.load(os.path.join("src","model","wav2vec2_emotion.pth"), map_location=torch.device('cpu')), strict=False)
17
+ st.session_state.model_loaded = True
18
+
19
+ if st.session_state.model_loaded:
20
+ processor = st.session_state.processor
21
+ model = st.session_state.model
22
+ model.to(device)
23
+ model.eval()
24
 
25
  emotion_labels = ["joie", "colère", "neutre"]
26
 
27
  def predict_emotion(audio_path, output_probs=False, sampling_rate=16000):
28
+ # waveform, _ = librosa.load(audio_path, sr=sampling_rate)
29
+ input_values = processor(audio_path, return_tensors="pt", sampling_rate=sampling_rate).input_values
30
  input_values = input_values.to(device)
31
+
32
  with torch.no_grad():
33
  outputs = model(input_values)
34
+
35
+ if output_probs:
36
  # Appliquer softmax pour obtenir des probabilités
37
+ probabilities = torch.nn.functional.softmax(outputs, dim=-1)
38
 
39
  # Convertir en numpy array et prendre le premier (et seul) élément
40
  probabilities = probabilities[0].detach().cpu().numpy()
41
 
42
  # Créer un dictionnaire associant chaque émotion à sa probabilité
43
  emotion_probabilities = {emotion: prob for emotion, prob in zip(emotion_labels, probabilities)}
44
+ # emotion_probabilities = {"emotions": [emotion for emotion in emotion_labels],
45
+ # "probabilities": [prob for prob in probabilities]}
46
  return emotion_probabilities
47
  else:
48
  # Obtenir l'émotion la plus probable (i.e. la prédiction)
 
51
 
52
 
53
  # Exemple d'utilisation
54
+ # audio_test = "data/n1ac.wav"
55
+ # emotion = predict_emotion(audio_test)
56
+ # print(f"Émotion détectée : {emotion}")
src/model/transcriber.py CHANGED
@@ -1,13 +1,16 @@
 
1
  import torch
2
  from transformers import Wav2Vec2Processor
3
  from src.model.emotion_classifier import Wav2Vec2EmotionClassifier
4
  import librosa
5
 
6
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
- # processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-xlsr-53")
8
- # model = Wav2Vec2EmotionClassifier()
9
- # model.load_state_dict(torch.load("wav2vec2_emotion.pth"))
10
- # model.to(device)
 
 
11
 
12
 
13
  def transcribe_audio(audio, sampling_rate=16000):
 
1
+ import os
2
  import torch
3
  from transformers import Wav2Vec2Processor
4
  from src.model.emotion_classifier import Wav2Vec2EmotionClassifier
5
  import librosa
6
 
7
+ # Charger le modèle et le processeur
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+ # if st.
10
+ processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-xlsr-53-french")
11
+ model = Wav2Vec2EmotionClassifier()
12
+ model.load_state_dict(torch.load(os.path.join("src","model","wav2vec2_emotion.pth"), map_location=torch.device('cpu')), strict=False)
13
+ model.to(device)
14
 
15
 
16
  def transcribe_audio(audio, sampling_rate=16000):
views/application.py CHANGED
@@ -1,11 +1,20 @@
1
  import streamlit as st
 
2
  from st_audiorec import st_audiorec
3
  import datetime
4
  import os
 
 
5
  from src.model.transcriber import transcribe_audio
 
 
6
 
7
  DIRECTORY = "audios"
8
  FILE_NAME = "audio.wav"
 
 
 
 
9
 
10
  def application():
11
  st.title("SISE ultimate challenge")
@@ -25,12 +34,141 @@ def application():
25
  st.header("⬆️ Upload Audio Record")
26
  st.write("Here you can upload a pre-recorded audio.")
27
  audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "ogg"])
 
28
  if audio_file is not None:
29
 
30
- with open(f"{DIRECTORY}/{FILE_NAME}", "wb") as f:
31
  f.write(audio_file.getbuffer())
32
  st.success(f"Saved file: {FILE_NAME}")
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  with tab2:
35
  st.header("🔈 Realtime Audio Record")
36
  st.write("Here you can record an audio.")
@@ -52,17 +190,16 @@ def application():
52
  ############################# A décommenté quand ce sera débogué
53
  if st.button("Transcribe", key="transcribe-button"):
54
  # # Fonction pour transcrire l'audio
55
- # transcription = transcribe_audio(st.audio)
56
 
57
  # # Charger et transcrire l'audio
58
  # # audio, rate = load_audio(audio_file_path) # (re)chargement de l'audio si nécessaire
59
- # transcription = transcribe_audio(audio_file, sampling_rate=16000)
60
 
61
  # # Afficher la transcription
62
- # st.write("Transcription :", transcription)
63
 
64
-
65
- st.success("Audio registered successfully.")
66
  # if save:
67
  # file_path = "transcript.txt"
68
 
 
1
  import streamlit as st
2
+ import pandas as pd
3
  from st_audiorec import st_audiorec
4
  import datetime
5
  import os
6
+ import matplotlib.pyplot as plt
7
+ import librosa
8
  from src.model.transcriber import transcribe_audio
9
+ from src.model.predict import predict_emotion
10
+
11
 
12
  DIRECTORY = "audios"
13
  FILE_NAME = "audio.wav"
14
+ CHUNK = 1024
15
+ # FORMAT = pyaudio.paInt16
16
+ CHANNELS = 1
17
+ RATE = 16000
18
 
19
  def application():
20
  st.title("SISE ultimate challenge")
 
34
  st.header("⬆️ Upload Audio Record")
35
  st.write("Here you can upload a pre-recorded audio.")
36
  audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "ogg"])
37
+
38
  if audio_file is not None:
39
 
40
+ with open(os.path.join(DIRECTORY,FILE_NAME), "wb") as f:
41
  f.write(audio_file.getbuffer())
42
  st.success(f"Saved file: {FILE_NAME}")
43
 
44
+
45
+ start_inference = st.button("Start emotion recogniton","inf_on_upl_btn")
46
+ emotion_labels = ["joie", "colère", "neutre"]
47
+ colors = ['#f6d60a', '#f71c1c', '#cac8c8']
48
+
49
+ if start_inference:
50
+ # Configuration Streamlit
51
+ with st.spinner("Real-time emotion analysis..."):
52
+ # uploaded_file = st.file_uploader("Choisissez un fichier audio", type=["wav", "mp3"])
53
+
54
+ if audio_file is not None:
55
+ # Charger et rééchantillonner l'audio
56
+ audio, sr = librosa.load(audio_file, sr=RATE)
57
+ # chunk = audio_file
58
+
59
+ # Paramètres de la fenêtre glissante
60
+ window_size = 1 # en secondes
61
+ hop_length = 0.5 # en secondes
62
+
63
+ # Créer un graphique en temps réel
64
+ fig, ax = plt.subplots()
65
+ lines = [ax.plot([], [], label=emotion)[0] for emotion in emotion_labels]
66
+ ax.set_ylim(0, 1)
67
+ ax.set_xlim(0, len(audio) / sr)
68
+ ax.set_xlabel("Temps (s)")
69
+ ax.set_ylabel("Probabilité")
70
+ ax.legend()
71
+
72
+ chart = st.pyplot(fig)
73
+
74
+ scores = [[],[],[]] # 3 émotions pour l'instant
75
+
76
+ # Traitement par fenêtre glissante
77
+ for i in range(0, len(audio), int(hop_length * sr)):
78
+ chunk = audio[i:i + int(window_size * sr)]
79
+ if len(chunk) < int(window_size * sr):
80
+ break
81
+
82
+ emotion_scores = predict_emotion(chunk, output_probs=True, sampling_rate=RATE)
83
+
84
+ # Mettre à jour le graphique
85
+ for emotion, line in zip(emotion_labels, lines):
86
+ xdata = list(line.get_xdata())
87
+ ydata = list(line.get_ydata())
88
+ xdata.append(i / sr)
89
+ ydata.append(emotion_scores[emotion])
90
+ scores[list(emotion_scores).index(emotion)].append(emotion_scores[emotion])
91
+ line.set_data(xdata, ydata)
92
+
93
+ ax.relim()
94
+ ax.autoscale_view()
95
+ chart.pyplot(fig, use_container_width=True)
96
+
97
+ # Prepare the styling
98
+ st.markdown("""
99
+ <style>
100
+ .colored-box {
101
+ padding: 10px;
102
+ border-radius: 5px;
103
+ color: white;
104
+ font-weight: bold;
105
+ text-align: center;
106
+ }
107
+ </style>
108
+ """
109
+ , unsafe_allow_html=True)
110
+
111
+ # Dynamically create the specified number of columns
112
+ columns = st.columns(len(emotion_scores))
113
+
114
+ # emotion_scores_mean = [sum(sublist) / len(sublist) for sublist in scores]
115
+ emotion_scores_mean = {emotion:sum(sublist) / len(sublist) for emotion, sublist in zip(emotion_labels, scores)}
116
+ max_emo = max(emotion_scores_mean)
117
+ emotion_scores_sorted = dict(sorted(emotion_scores_mean.items(), key=lambda x: x[1], reverse=True))
118
+ colors_sorted = [colors[list(emotion_scores_mean.keys()).index(key)] for key in list(emotion_scores_sorted.keys())]
119
+
120
+ # Add content to each column
121
+ for i, (col, emotion) in enumerate(zip(columns, emotion_scores_sorted)):
122
+ color = colors_sorted[i % len(colors_sorted)] # Cycle through colors if more columns than colors
123
+ col.markdown(f"""
124
+ <div class="colored-box" style="background-color: {color};">
125
+ {emotion} : {100*emotion_scores_sorted[emotion]:.2f} %
126
+ </div>
127
+ """
128
+ , unsafe_allow_html=True)
129
+
130
+
131
+
132
+ st.success("Analyse terminée !")
133
+ else:
134
+ st.warning("You need to load an audio file !")
135
+
136
+ st.subheader("Feedback")
137
+
138
+ # Initialisation du fichier CSV
139
+ csv_file = os.path.join("src","predictions","feedback.csv")
140
+
141
+ # Vérifier si le fichier CSV existe, sinon le créer avec des colonnes appropriées
142
+ if not os.path.exists(csv_file):
143
+ df = pd.DataFrame(columns=["filepath", "prediction", "feedback"])
144
+ df.to_csv(csv_file, index=False)
145
+
146
+ # Charger les données existantes du CSV
147
+ df = pd.read_csv(csv_file)
148
+
149
+ with st.form("feedback_form"):
150
+ st.write("What should have been the correct prediction ? (*Choose the same emotion if the prediction was correct*).")
151
+ feedback = st.selectbox("Your answer :", ['Sadness','Anger', 'Disgust', 'Fear', 'Surprise', 'Joy', 'Neutral'])
152
+ submit_button = st.form_submit_button("Submit")
153
+ st.write("En cliquant sur ce bouton, vous acceptez que votre audio soit sauvegardé dans notre base de données.")
154
+
155
+ if submit_button:
156
+ # Ajouter le feedback au DataFrame
157
+ new_entry = {"filepath": audio_file.name, "prediction": max_emo, "feedback": feedback}
158
+ df = df.append(new_entry, ignore_index=True)
159
+
160
+ # Sauvegarder les données mises à jour dans le fichier CSV
161
+ df.to_csv(csv_file, index=False)
162
+
163
+ # Sauvegarder le fichier audio
164
+ with open(os.path.join("src","predictions","data"), "wb") as f:
165
+ f.write(audio_file.getbuffer())
166
+
167
+ # Confirmation pour l'utilisateur
168
+ st.success("Merci pour votre retour ! Vos données ont été sauvegardées.")
169
+
170
+
171
+
172
  with tab2:
173
  st.header("🔈 Realtime Audio Record")
174
  st.write("Here you can record an audio.")
 
190
  ############################# A décommenté quand ce sera débogué
191
  if st.button("Transcribe", key="transcribe-button"):
192
  # # Fonction pour transcrire l'audio
193
+ # transcription = transcribe_audio(st.audio)
194
 
195
  # # Charger et transcrire l'audio
196
  # # audio, rate = load_audio(audio_file_path) # (re)chargement de l'audio si nécessaire
197
+ # transcription = transcribe_audio(audio_file, sampling_rate=16000)
198
 
199
  # # Afficher la transcription
200
+ # st.write("Transcription :", transcription)
201
 
202
+ st.success("Audio registered successfully.")
 
203
  # if save:
204
  # file_path = "transcript.txt"
205
 
views/real_time.py CHANGED
@@ -86,64 +86,51 @@ if start_button:
86
  ### Real time prediction for uploaded audio file
87
  ###############################
88
  # Charger le modèle wav2vec et le processeur
89
- model = Wav2Vec2ForSequenceClassification.from_pretrained("your_emotion_model_path")
90
- processor = Wav2Vec2Processor.from_pretrained("your_emotion_model_path")
91
 
92
- # Définir les émotions
93
- emotions = ["neutre", "joie", "colère", "tristesse"] # Ajustez selon votre modèle
 
94
 
95
- # Fonction pour prédire l'émotion
96
- # def predict_emotion(audio_chunk):
97
- # inputs = processor(audio_chunk, sampling_rate=16000, return_tensors="pt", padding=True)
98
- # with torch.no_grad():
99
- # logits = model(**inputs).logits
100
- # scores = torch.softmax(logits, dim=1).squeeze().tolist()
101
- # return dict(zip(emotions, scores))
102
-
103
- # Configuration Streamlit
104
- st.title("Analyse des émotions en temps réel")
105
- uploaded_file = st.file_uploader("Choisissez un fichier audio", type=["wav", "mp3"])
106
-
107
- if uploaded_file is not None:
108
- # Charger et rééchantillonner l'audio
109
- audio, sr = librosa.load(uploaded_file, sr=16000)
110
 
111
- # Paramètres de la fenêtre glissante
112
- window_size = 1 # en secondes
113
- hop_length = 0.5 # en secondes
114
 
115
- # Créer un graphique en temps réel
116
- fig, ax = plt.subplots()
117
- lines = [ax.plot([], [], label=emotion)[0] for emotion in emotions]
118
- ax.set_ylim(0, 1)
119
- ax.set_xlim(0, len(audio) / sr)
120
- ax.set_xlabel("Temps (s)")
121
- ax.set_ylabel("Probabilité")
122
- ax.legend()
123
 
124
- chart = st.pyplot(fig)
125
 
126
- # Traitement par fenêtre glissante
127
- for i in range(0, len(audio), int(hop_length * sr)):
128
- chunk = audio[i:i + int(window_size * sr)]
129
- if len(chunk) < int(window_size * sr):
130
- break
131
 
132
- emotion_scores = predict_emotion(chunk, output_probs=False, sampling_rate=RATE)
133
 
134
- # Mettre à jour le graphique
135
- for emotion, line in zip(emotions, lines):
136
- xdata = line.get_xdata().tolist()
137
- ydata = line.get_ydata().tolist()
138
- xdata.append(i / sr)
139
- ydata.append(emotion_scores[emotion])
140
- line.set_data(xdata, ydata)
141
 
142
- ax.relim()
143
- ax.autoscale_view()
144
- chart.pyplot(fig)
145
 
146
- st.success("Analyse terminée !")
147
 
148
 
149
 
 
86
  ### Real time prediction for uploaded audio file
87
  ###############################
88
  # Charger le modèle wav2vec et le processeur
 
 
89
 
90
+ # # Configuration Streamlit
91
+ # st.title("Analyse des émotions en temps réel")
92
+ # uploaded_file = st.file_uploader("Choisissez un fichier audio", type=["wav", "mp3"])
93
 
94
+ # if uploaded_file is not None:
95
+ # # Charger et rééchantillonner l'audio
96
+ # audio, sr = librosa.load(uploaded_file, sr=16000)
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
+ # # Paramètres de la fenêtre glissante
99
+ # window_size = 1 # en secondes
100
+ # hop_length = 0.5 # en secondes
101
 
102
+ # # Créer un graphique en temps réel
103
+ # fig, ax = plt.subplots()
104
+ # lines = [ax.plot([], [], label=emotion)[0] for emotion in emotions]
105
+ # ax.set_ylim(0, 1)
106
+ # ax.set_xlim(0, len(audio) / sr)
107
+ # ax.set_xlabel("Temps (s)")
108
+ # ax.set_ylabel("Probabilité")
109
+ # ax.legend()
110
 
111
+ # chart = st.pyplot(fig)
112
 
113
+ # # Traitement par fenêtre glissante
114
+ # for i in range(0, len(audio), int(hop_length * sr)):
115
+ # chunk = audio[i:i + int(window_size * sr)]
116
+ # if len(chunk) < int(window_size * sr):
117
+ # break
118
 
119
+ # emotion_scores = predict_emotion(chunk, output_probs=False, sampling_rate=RATE)
120
 
121
+ # # Mettre à jour le graphique
122
+ # for emotion, line in zip(emotions, lines):
123
+ # xdata = line.get_xdata().tolist()
124
+ # ydata = line.get_ydata().tolist()
125
+ # xdata.append(i / sr)
126
+ # ydata.append(emotion_scores[emotion])
127
+ # line.set_data(xdata, ydata)
128
 
129
+ # ax.relim()
130
+ # ax.autoscale_view()
131
+ # chart.pyplot(fig)
132
 
133
+ # st.success("Analyse terminée !")
134
 
135
 
136