jdalfonso commited on
Commit
02ad44e
·
unverified ·
2 Parent(s): 312c28a 3acee9d

Merge pull request #6 from jdalfons/develop

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
.github/workflows/check_file_size.yml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Check file size
2
+ on: # or directly `on: [push]` to run the action on every push on any branch
3
+ pull_request:
4
+ branches: [main]
5
+
6
+ # to run this workflow manually from the Actions tab
7
+ workflow_dispatch:
8
+
9
+ jobs:
10
+ sync-to-hub:
11
+ runs-on: ubuntu-latest
12
+ steps:
13
+ - name: Check large files
14
+ uses: ActionsDesk/[email protected]
15
+ with:
16
+ filesizelimit: 10485760 # this is 10MB so we can sync to HF Spaces
.github/workflows/sync_hf.yml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Sync to Hugging Face hub
2
+ on:
3
+ push:
4
+ branches: [main]
5
+
6
+ # to run this workflow manually from the Actions tab
7
+ workflow_dispatch:
8
+
9
+ jobs:
10
+ sync-to-hub:
11
+ runs-on: ubuntu-latest
12
+ steps:
13
+ - uses: actions/checkout@v3
14
+ with:
15
+ fetch-depth: 0
16
+ lfs: true
17
+ - name: Push to hub
18
+ env:
19
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
20
+ run: git push https://HF_USERNAME:[email protected]/spaces/jdalfonso/SISE-ULTIMATE-CHALLENGE main
.gitignore CHANGED
@@ -2,6 +2,7 @@
2
  __pycache__/
3
  *.py[cod]
4
  *$py.class
 
5
 
6
  # C extensions
7
  *.so
@@ -178,6 +179,10 @@ dataset/
178
  old/
179
  *.wav
180
  data/*
181
-
 
182
  # Mac
183
  .DS_Store
 
 
 
 
2
  __pycache__/
3
  *.py[cod]
4
  *$py.class
5
+ .idea/
6
 
7
  # C extensions
8
  *.so
 
179
  old/
180
  *.wav
181
  data/*
182
+ *.pth
183
+ old/
184
  # Mac
185
  .DS_Store
186
+ .idea
187
+ wav2vec2_emotion/
188
+ dataset/
.idea/.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
4
+ # Editor-based HTTP Client requests
5
+ /httpRequests/
6
+ # Datasource local storage ignored files
7
+ /dataSources/
8
+ /dataSources.local.xml
.streamlit/config.toml CHANGED
@@ -1,4 +1,4 @@
1
  [theme]
2
- base="dark"
3
  primaryColor="#7c99b4"
4
 
 
1
  [theme]
2
+ base="light"
3
  primaryColor="#7c99b4"
4
 
README.md CHANGED
@@ -1,5 +1,5 @@
1
  # SISE Ultimate Challenge
2
- ![Logo du Ultimate Challenge SISE](img/logo.png)
3
 
4
  Ceci est le Ultimate Challenge pour le Master SISE.
5
 
 
1
  # SISE Ultimate Challenge
2
+ ![Logo du Ultimate Challenge SISE](img/logo_01.png)
3
 
4
  Ceci est le Ultimate Challenge pour le Master SISE.
5
 
app.py CHANGED
@@ -1,34 +1,248 @@
1
  import streamlit as st
2
- from streamlit_option_menu import option_menu
3
- from views.application import application
4
- from views.about import about
5
-
6
- # Set the logo
7
- st.sidebar.image("img/logo.png", use_container_width=True)
8
-
9
- # Create a sidebar with navigation options
10
- # Sidebar navigation with streamlit-option-menu
11
- with st.sidebar:
12
- # st.image("img/logo.png", use_container_width=True)
13
- # st.markdown("<h1 style='text-align: center;'>SecureIA Dashboard</h1>", unsafe_allow_html=True)
14
- # Navigation menu with icons
15
- selected_tab = option_menu(
16
- menu_title=None, # Added menu_title parameter
17
- options=["Application", "About"],
18
- icons=["robot", "bar-chart", "robot"],
19
- menu_icon="cast",
20
- default_index=0,
21
- # styles={
22
- # "container": {"padding": "5px", "background-color": "#f0f2f6"},
23
- # "icon": {"color": "orange", "font-size": "18px"},
24
- # "nav-link": {"font-size": "16px", "text-align": "left", "margin": "0px", "color": "black"},
25
- # "nav-link-selected": {"background-color": "#4CAF50", "color": "white"},
26
- # }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  )
28
 
 
 
 
 
 
 
 
 
 
29
 
30
- if selected_tab == "Application":
31
- application()
32
- elif selected_tab == "About":
33
- about()
34
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ import os
5
+ import time
6
+ import matplotlib.pyplot as plt
7
+ from datetime import datetime
8
+ import tempfile
9
+ import io
10
+ import json
11
+ from model.transcriber import transcribe_audio
12
+ from predict import predict_emotion
13
+
14
+ # You'll need to install this package:
15
+ # pip install streamlit-audiorec
16
+ from st_audiorec import st_audiorec
17
+
18
+ # Page configuration
19
+ st.set_page_config(
20
+ page_title="Emotion Analyser",
21
+ page_icon="🎤",
22
+ layout="wide"
23
+ )
24
+
25
+ # Initialize session state variables if they don't exist
26
+ if 'audio_data' not in st.session_state:
27
+ st.session_state.audio_data = []
28
+ if 'current_audio_index' not in st.session_state:
29
+ st.session_state.current_audio_index = -1
30
+ if 'audio_history_csv' not in st.session_state:
31
+ # Define columns for our CSV storage
32
+ st.session_state.audio_history_csv = pd.DataFrame(
33
+ columns=['timestamp', 'file_path', 'transcription', 'emotion', 'probabilities']
34
+ )
35
+ if 'needs_rerun' not in st.session_state:
36
+ st.session_state.needs_rerun = False
37
+
38
+ # Function to ensure we keep only the last 10 entries
39
+ def update_audio_history(new_entry):
40
+ # Add the new entry
41
+ st.session_state.audio_history_csv = pd.concat([st.session_state.audio_history_csv, pd.DataFrame([new_entry])], ignore_index=True)
42
+
43
+ # Keep only the last 10 entries
44
+ if len(st.session_state.audio_history_csv) > 10:
45
+ st.session_state.audio_history_csv = st.session_state.audio_history_csv.iloc[-10:]
46
+
47
+ # Save to CSV
48
+ st.session_state.audio_history_csv.to_csv('audio_history.csv', index=False)
49
+
50
+ # Function to process audio and get results
51
+ def process_audio(audio_path):
52
+ try:
53
+ # Get transcription
54
+ transcription = transcribe_audio(audio_path)
55
+
56
+ # Get emotion prediction
57
+ predicted_emotion, probabilities = predict_emotion(audio_path)
58
+
59
+ # Update audio history
60
+ new_entry = {
61
+ 'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
62
+ 'file_path': audio_path,
63
+ 'transcription': transcription,
64
+ 'emotion': predicted_emotion,
65
+ 'probabilities': str(probabilities) # Convert dict to string for storage
66
+ }
67
+ update_audio_history(new_entry)
68
+
69
+ # Update current index
70
+ st.session_state.current_audio_index = len(st.session_state.audio_history_csv) - 1
71
+
72
+ return transcription, predicted_emotion, probabilities
73
+ except Exception as e:
74
+ st.error(f"Error processing audio: {str(e)}")
75
+ return None, None, None
76
+
77
+ # Function to split audio into 10-second segments
78
+ def split_audio(audio_file, segment_length=10):
79
+ # This is a placeholder - in a real implementation, you'd use a library like pydub
80
+ # to split the audio file into segments
81
+ st.warning("Audio splitting functionality is a placeholder. Implement with pydub or similar library.")
82
+ # For now, we'll just return the whole file as a single segment
83
+ return [audio_file]
84
+
85
+ # Function to display emotion visualization
86
+ def display_emotion_chart(probabilities):
87
+ emotions = list(probabilities.keys())
88
+ values = list(probabilities.values())
89
+
90
+ fig, ax = plt.subplots(figsize=(10, 5))
91
+ bars = ax.bar(emotions, values, color=['red', 'gray', 'green'])
92
+
93
+ # Add data labels on top of bars
94
+ for bar in bars:
95
+ height = bar.get_height()
96
+ ax.text(bar.get_x() + bar.get_width()/2., height + 0.02,
97
+ f'{height:.2f}', ha='center', va='bottom')
98
+
99
+ ax.set_ylim(0, 1.1)
100
+ ax.set_ylabel('Probability')
101
+ ax.set_title('Emotion Prediction Results')
102
+
103
+ st.pyplot(fig)
104
+
105
+ # Trigger rerun if needed (replaces experimental_rerun)
106
+ if st.session_state.needs_rerun:
107
+ st.session_state.needs_rerun = False
108
+ st.rerun() # Using st.rerun() instead of experimental_rerun
109
+
110
+ # Main App Layout
111
+ st.image("./img/logo_01.png", width=400)
112
+
113
+ # Create two columns for the main layout
114
+ col1, col2 = st.columns([1, 1])
115
+
116
+ with col1:
117
+ st.header("Audio Input")
118
+
119
+ # Method selection
120
+
121
+ tab1, tab2 = st.tabs(["Record Audio", "Upload Audio"])
122
+
123
+ with tab1:
124
+ st.write("Record your audio (max 10 seconds):")
125
+
126
+ # Using streamlit-audiorec for better recording functionality
127
+ wav_audio_data = st_audiorec()
128
+
129
+ if wav_audio_data is not None:
130
+ # Save the recorded audio to a temporary file
131
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
132
+ tmp_file.write(wav_audio_data)
133
+ tmp_file_path = tmp_file.name
134
+
135
+ st.success("Audio recorded successfully!")
136
+
137
+ # Process button
138
+ if st.button("Process Recorded Audio"):
139
+ # Process the audio
140
+ with st.spinner("Processing audio..."):
141
+ transcription, emotion, probs = process_audio(tmp_file_path)
142
+ # Set flag for rerun instead of calling experimental_rerun
143
+ if transcription is not None:
144
+ st.success("Audio processed successfully!")
145
+ st.session_state.needs_rerun = True
146
+
147
+ with tab2:
148
+ uploaded_file = st.file_uploader("Upload an audio file (WAV format)", type=['wav'])
149
+
150
+ if uploaded_file is not None:
151
+ # Save the uploaded file to a temporary location
152
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
153
+ tmp_file.write(uploaded_file.getbuffer())
154
+ tmp_file_path = tmp_file.name
155
+
156
+ st.audio(uploaded_file, format="audio/wav")
157
+
158
+ # Process button
159
+ if st.button("Process Uploaded Audio"):
160
+ # Split audio into 10-second segments
161
+ with st.spinner("Processing audio..."):
162
+ segments = split_audio(tmp_file_path)
163
+
164
+ # Process each segment
165
+ for i, segment_path in enumerate(segments):
166
+ st.write(f"Processing segment {i+1}...")
167
+ transcription, emotion, probs = process_audio(segment_path)
168
+
169
+ # Set flag for rerun instead of calling experimental_rerun
170
+ st.success("Audio processed successfully!")
171
+ st.session_state.needs_rerun = True
172
+
173
+ with col2:
174
+ st.header("Results")
175
+
176
+ # Display results if available
177
+ if st.session_state.current_audio_index >= 0 and len(st.session_state.audio_history_csv) > 0:
178
+ current_data = st.session_state.audio_history_csv.iloc[st.session_state.current_audio_index]
179
+
180
+ # Transcription
181
+ st.subheader("Transcription")
182
+ st.text_area("", value=current_data['transcription'], height=100, key="transcription_area")
183
+
184
+ # Emotion
185
+ st.subheader("Detected Emotion")
186
+ st.info(f"🎭 Predicted emotion: **{current_data['emotion']}**")
187
+
188
+ # Convert string representation of dict back to actual dict
189
+ try:
190
+ import ast
191
+ probs = ast.literal_eval(current_data['probabilities'])
192
+ display_emotion_chart(probs)
193
+ except Exception as e:
194
+ st.error(f"Error parsing probabilities: {str(e)}")
195
+ st.write(f"Raw probabilities: {current_data['probabilities']}")
196
+ else:
197
+ st.info("Record or upload audio to see results")
198
+
199
+ # Audio History and Analytics Section
200
+ st.header("Audio History and Analytics")
201
+
202
+ if len(st.session_state.audio_history_csv) > 0:
203
+ # Display a select box to choose from audio history
204
+ timestamps = st.session_state.audio_history_csv['timestamp'].tolist()
205
+ selected_timestamp = st.selectbox(
206
+ "Select audio from history:",
207
+ options=timestamps,
208
+ index=len(timestamps) - 1 # Default to most recent
209
  )
210
 
211
+ # Update current index when selection changes
212
+ selected_index = st.session_state.audio_history_csv[
213
+ st.session_state.audio_history_csv['timestamp'] == selected_timestamp
214
+ ].index[0]
215
+
216
+ # Only update if different
217
+ if st.session_state.current_audio_index != selected_index:
218
+ st.session_state.current_audio_index = selected_index
219
+ st.session_state.needs_rerun = True
220
 
221
+ # Analytics button
222
+ if st.button("Run Analytics on Selected Audio"):
223
+ st.subheader("Analytics Results")
224
+
225
+ # Get the selected audio data
226
+ selected_data = st.session_state.audio_history_csv.iloc[selected_index]
227
+
228
+ # Display analytics (this is where you would add more sophisticated analytics)
229
+ st.write(f"Selected Audio: {selected_data['timestamp']}")
230
+ st.write(f"Emotion: {selected_data['emotion']}")
231
+ st.write(f"File Path: {selected_data['file_path']}")
232
+
233
+ # Add any additional analytics you want here
234
+
235
+ # Try to play the selected audio
236
+ try:
237
+ if os.path.exists(selected_data['file_path']):
238
+ st.audio(selected_data['file_path'], format="audio/wav")
239
+ else:
240
+ st.warning("Audio file not found - it may have been deleted or moved.")
241
+ except Exception as e:
242
+ st.error(f"Error playing audio: {str(e)}")
243
+ else:
244
+ st.info("No audio history available. Record or upload audio to create history.")
245
+
246
+ # Footer
247
+ st.markdown("---")
248
+ st.caption("Audio Emotion Analyzer - Processes audio in 10-second segments and predicts emotions")
config.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from dotenv import load_dotenv
4
+
5
+ # Charger les variables d'environnement
6
+ load_dotenv()
7
+ HF_API_KEY = os.getenv("HF_API_KEY")
8
+
9
+ if not HF_API_KEY:
10
+ raise ValueError("Le token Hugging Face n'a pas été trouvé dans .env")
11
+
12
+ # Labels d'émotions
13
+ LABELS = {"colere": 0, "neutre": 1, "joie": 2}
14
+ #LABELS = ["colere", "neutre", "joie"]
15
+ NUM_LABELS = len(LABELS)
16
+
17
+ # Choisir le device
18
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
19
+
20
+ # Modèle Wav2Vec2
21
+ MODEL_NAME = "facebook/wav2vec2-large-xlsr-53-french"
22
+
23
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
24
+ BEST_MODEL_NAME = os.path.join(BASE_DIR, "model","fr-speech-emotion-model.pth") # Monte d'un niveau pour aller à la racine
25
+
img/logo.png DELETED
Binary file (179 kB)
 
img/logo_01.png ADDED
model/__init__.py ADDED
File without changes
model/emotion_classifier.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ # Prédit 33% environ partout (dans le cas 3 classes)
4
+
5
+ # class EmotionClassifier(nn.Module):
6
+ # def __init__(self, feature_dim, num_labels):
7
+ # super(EmotionClassifier, self).__init__()
8
+ # self.fc1 = nn.Linear(feature_dim, 256)
9
+ # self.relu = nn.ReLU()
10
+ # self.dropout = nn.Dropout(0.3)
11
+ # self.fc2 = nn.Linear(256, num_labels)
12
+
13
+ # def forward(self, x):
14
+ # x = self.fc1(x)
15
+ # x = self.relu(x)
16
+ # x = self.dropout(x)
17
+ # return self.fc2(x)
18
+
19
+
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+
25
+ class Attention(nn.Module):
26
+ """Mécanisme d’attention permettant de pondérer l’importance des caractéristiques audio"""
27
+ def __init__(self, hidden_dim):
28
+ super(Attention, self).__init__()
29
+ self.attention_weights = nn.Linear(hidden_dim, 1)
30
+
31
+ def forward(self, lstm_output):
32
+ # lstm_output: (batch_size, sequence_length, hidden_dim)
33
+ attention_scores = self.attention_weights(lstm_output) # (batch_size, sequence_length, 1)
34
+ attention_weights = torch.softmax(attention_scores, dim=1) # Normalisation softmax
35
+ weighted_output = lstm_output * attention_weights # Pondération des features
36
+ return weighted_output.sum(dim=1) # Somme pondérée sur la séquence
37
+
38
+ class EmotionClassifier(nn.Module):
39
+ """Modèle de classification des émotions basé sur BiLSTM et attention"""
40
+ def __init__(self, feature_dim, num_labels, hidden_dim=128):
41
+ super(EmotionClassifier, self).__init__()
42
+ self.lstm = nn.LSTM(feature_dim, hidden_dim, batch_first=True, bidirectional=True)
43
+ self.attention = Attention(hidden_dim * 2) # Bidirectionnel → hidden_dim * 2
44
+ self.fc = nn.Linear(hidden_dim * 2, num_labels) # Couche de classification finale
45
+
46
+ def forward(self, x):
47
+ lstm_out, _ = self.lstm(x) # (batch_size, sequence_length, hidden_dim*2)
48
+ attention_out = self.attention(lstm_out) # (batch_size, hidden_dim*2)
49
+ logits = self.fc(attention_out) # (batch_size, num_labels)
50
+ return logits
51
+
52
+
53
+
54
+
model/feature_extractor.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import Wav2Vec2Model, Wav2Vec2Processor
3
+ from config import MODEL_NAME, DEVICE
4
+
5
+ processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)
6
+ feature_extractor = Wav2Vec2Model.from_pretrained(MODEL_NAME).to(DEVICE)
model/transcriber.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import librosa
4
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
5
+
6
+ # Charger le modèle et le processeur
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+ MODEL_NAME = "facebook/wav2vec2-large-xlsr-53-french"
9
+
10
+ processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)
11
+ model = Wav2Vec2ForCTC.from_pretrained(MODEL_NAME).to(device)
12
+ model.eval()
13
+
14
+ def transcribe_audio(audio_path, sampling_rate=16000):
15
+ # Charger l'audio
16
+ audio, sr = librosa.load(audio_path, sr=sampling_rate)
17
+
18
+ # Transformer l'audio en entrée pour le modèle
19
+ input_values = processor(audio, sampling_rate=sampling_rate, return_tensors="pt").input_values.to(device)
20
+
21
+ # Obtenir les prédictions
22
+ with torch.no_grad():
23
+ logits = model(input_values).logits
24
+
25
+ # Décoder les prédictions en texte
26
+ predicted_ids = torch.argmax(logits, dim=-1)
27
+ transcription = processor.batch_decode(predicted_ids)[0]
28
+ return transcription
29
+
30
+ # Exemple d'utilisation
31
+ if __name__ == "__main__":
32
+ base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "data"))
33
+ audio_path = os.path.join(base_path, "colere", "c1af.wav")
34
+ texte = transcribe_audio(audio_path)
35
+ print(f"Transcription : {texte}")
predict.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import torch
4
+ import librosa
5
+ import numpy as np
6
+ from model.emotion_classifier import EmotionClassifier
7
+ from utils.preprocessing import collate_fn
8
+ from config import DEVICE, NUM_LABELS, BEST_MODEL_NAME
9
+
10
+ # Charger le modèle entraîné
11
+ feature_dim = 40 # Nombre de MFCCs utilisés
12
+ model = EmotionClassifier(feature_dim, NUM_LABELS).to(DEVICE)
13
+ model.load_state_dict(torch.load(BEST_MODEL_NAME, map_location=DEVICE))
14
+ model.eval() # Mode évaluation
15
+
16
+ # Labels des émotions
17
+ LABELS = {0: "colère", 1: "neutre", 2: "joie"}
18
+
19
+ # Fonction pour prédire l’émotion d’un fichier audio avec probabilités
20
+ def predict_emotion(audio_path, max_length=128):
21
+ # Charger l’audio
22
+ y, sr = librosa.load(audio_path, sr=16000)
23
+
24
+ # Extraire les MFCCs
25
+ mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=40)
26
+
27
+ # Ajuster la taille des MFCCs avec padding/troncature
28
+ if mfcc.shape[1] > max_length:
29
+ mfcc = mfcc[:, :max_length] # Tronquer si trop long
30
+ else:
31
+ pad_width = max_length - mfcc.shape[1]
32
+ mfcc = np.pad(mfcc, pad_width=((0, 0), (0, pad_width)), mode='constant')
33
+
34
+ # Convertir en tenseur PyTorch
35
+ input_tensor = torch.tensor(mfcc.T, dtype=torch.float32).unsqueeze(0).to(DEVICE) # (1, max_length, 40)
36
+
37
+ # Prédiction avec le modèle
38
+ with torch.no_grad():
39
+ logits = model(input_tensor)
40
+ probabilities = torch.nn.functional.softmax(logits, dim=-1).cpu().numpy().flatten() # Convertir en probabilités
41
+ predicted_class = torch.argmax(logits, dim=-1).item()
42
+
43
+ # Associer les probabilités aux labels
44
+ probabilities_dict = {LABELS[i]: float(probabilities[i]) for i in range(NUM_LABELS)}
45
+
46
+ return LABELS[predicted_class], probabilities_dict
47
+
48
+
49
+ # Exemple d'utilisation
50
+ if __name__ == "__main__":
51
+ base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "data"))
52
+ audio_file = os.path.join(base_path, "colere", "c1ac.wav")
53
+
54
+ predicted_emotion, probabilities = predict_emotion(audio_file)
55
+
56
+ print(f"🎤 L'émotion prédite est : {predicted_emotion}")
57
+ print(f"📊 Probabilités par classe : {probabilities}")
requirements.txt CHANGED
@@ -15,3 +15,5 @@ scikit-learn
15
  huggingface
16
  huggingface_hub
17
  pyaudio
 
 
 
15
  huggingface
16
  huggingface_hub
17
  pyaudio
18
+ streamlit_audiorec
19
+ dotenv
src/data/colere/c1ac.wav DELETED
Binary file (110 kB)
 
src/data/colere/c1af.wav DELETED
Binary file (157 kB)
 
src/data/colere/c1aj.wav DELETED
Binary file (210 kB)
 
src/data/colere/c1an.wav DELETED
Binary file (148 kB)
 
src/data/colere/c1bc.wav DELETED
Binary file (65.8 kB)
 
src/data/colere/c1bf.wav DELETED
Binary file (117 kB)
 
src/data/colere/c1bj.wav DELETED
Binary file (76.9 kB)
 
src/data/colere/c1bn.wav DELETED
Binary file (74.3 kB)
 
src/data/colere/c1cc.wav DELETED
Binary file (112 kB)
 
src/data/colere/c1cf.wav DELETED
Binary file (138 kB)
 
src/data/colere/c1cj.wav DELETED
Binary file (101 kB)
 
src/data/colere/c2ac.wav DELETED
Binary file (108 kB)
 
src/data/colere/c2af.wav DELETED
Binary file (138 kB)
 
src/data/colere/c2aj.wav DELETED
Binary file (115 kB)
 
src/data/colere/c2an.wav DELETED
Binary file (140 kB)
 
src/data/colere/c2bc.wav DELETED
Binary file (89.1 kB)
 
src/data/colere/c2bf.wav DELETED
Binary file (115 kB)
 
src/data/colere/c2bj.wav DELETED
Binary file (110 kB)
 
src/data/colere/c2bn.wav DELETED
Binary file (138 kB)
 
src/data/colere/c2cn.wav DELETED
Binary file (123 kB)
 
src/data/colere/c3ac.wav DELETED
Binary file (119 kB)
 
src/data/colere/c3af.wav DELETED
Binary file (127 kB)
 
src/data/colere/c3aj.wav DELETED
Binary file (119 kB)
 
src/data/colere/c3an.wav DELETED
Binary file (129 kB)
 
src/data/colere/c3bc.wav DELETED
Binary file (115 kB)
 
src/data/colere/c3bf.wav DELETED
Binary file (142 kB)
 
src/data/colere/c3bj.wav DELETED
Binary file (99.7 kB)
 
src/data/colere/c3bn.wav DELETED
Binary file (153 kB)
 
src/data/colere/c4aaf.wav DELETED
Binary file (142 kB)
 
src/data/colere/c4ac.wav DELETED
Binary file (108 kB)
 
src/data/colere/c4af.wav DELETED
Binary file (127 kB)
 
src/data/colere/c4aj.wav DELETED
Binary file (159 kB)
 
src/data/colere/c4an.wav DELETED
Binary file (121 kB)
 
src/data/colere/c4bc.wav DELETED
Binary file (112 kB)