Merge pull request #6 from jdalfons/develop
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .github/workflows/check_file_size.yml +16 -0
- .github/workflows/sync_hf.yml +20 -0
- .gitignore +6 -1
- .idea/.gitignore +8 -0
- .streamlit/config.toml +1 -1
- README.md +1 -1
- app.py +244 -30
- config.py +25 -0
- img/logo.png +0 -0
- img/logo_01.png +0 -0
- model/__init__.py +0 -0
- model/emotion_classifier.py +54 -0
- model/feature_extractor.py +6 -0
- model/transcriber.py +35 -0
- predict.py +57 -0
- requirements.txt +2 -0
- src/data/colere/c1ac.wav +0 -0
- src/data/colere/c1af.wav +0 -0
- src/data/colere/c1aj.wav +0 -0
- src/data/colere/c1an.wav +0 -0
- src/data/colere/c1bc.wav +0 -0
- src/data/colere/c1bf.wav +0 -0
- src/data/colere/c1bj.wav +0 -0
- src/data/colere/c1bn.wav +0 -0
- src/data/colere/c1cc.wav +0 -0
- src/data/colere/c1cf.wav +0 -0
- src/data/colere/c1cj.wav +0 -0
- src/data/colere/c2ac.wav +0 -0
- src/data/colere/c2af.wav +0 -0
- src/data/colere/c2aj.wav +0 -0
- src/data/colere/c2an.wav +0 -0
- src/data/colere/c2bc.wav +0 -0
- src/data/colere/c2bf.wav +0 -0
- src/data/colere/c2bj.wav +0 -0
- src/data/colere/c2bn.wav +0 -0
- src/data/colere/c2cn.wav +0 -0
- src/data/colere/c3ac.wav +0 -0
- src/data/colere/c3af.wav +0 -0
- src/data/colere/c3aj.wav +0 -0
- src/data/colere/c3an.wav +0 -0
- src/data/colere/c3bc.wav +0 -0
- src/data/colere/c3bf.wav +0 -0
- src/data/colere/c3bj.wav +0 -0
- src/data/colere/c3bn.wav +0 -0
- src/data/colere/c4aaf.wav +0 -0
- src/data/colere/c4ac.wav +0 -0
- src/data/colere/c4af.wav +0 -0
- src/data/colere/c4aj.wav +0 -0
- src/data/colere/c4an.wav +0 -0
- src/data/colere/c4bc.wav +0 -0
.github/workflows/check_file_size.yml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Check file size
|
2 |
+
on: # or directly `on: [push]` to run the action on every push on any branch
|
3 |
+
pull_request:
|
4 |
+
branches: [main]
|
5 |
+
|
6 |
+
# to run this workflow manually from the Actions tab
|
7 |
+
workflow_dispatch:
|
8 |
+
|
9 |
+
jobs:
|
10 |
+
sync-to-hub:
|
11 |
+
runs-on: ubuntu-latest
|
12 |
+
steps:
|
13 |
+
- name: Check large files
|
14 |
+
uses: ActionsDesk/[email protected]
|
15 |
+
with:
|
16 |
+
filesizelimit: 10485760 # this is 10MB so we can sync to HF Spaces
|
.github/workflows/sync_hf.yml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Sync to Hugging Face hub
|
2 |
+
on:
|
3 |
+
push:
|
4 |
+
branches: [main]
|
5 |
+
|
6 |
+
# to run this workflow manually from the Actions tab
|
7 |
+
workflow_dispatch:
|
8 |
+
|
9 |
+
jobs:
|
10 |
+
sync-to-hub:
|
11 |
+
runs-on: ubuntu-latest
|
12 |
+
steps:
|
13 |
+
- uses: actions/checkout@v3
|
14 |
+
with:
|
15 |
+
fetch-depth: 0
|
16 |
+
lfs: true
|
17 |
+
- name: Push to hub
|
18 |
+
env:
|
19 |
+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
20 |
+
run: git push https://HF_USERNAME:[email protected]/spaces/jdalfonso/SISE-ULTIMATE-CHALLENGE main
|
.gitignore
CHANGED
@@ -2,6 +2,7 @@
|
|
2 |
__pycache__/
|
3 |
*.py[cod]
|
4 |
*$py.class
|
|
|
5 |
|
6 |
# C extensions
|
7 |
*.so
|
@@ -178,6 +179,10 @@ dataset/
|
|
178 |
old/
|
179 |
*.wav
|
180 |
data/*
|
181 |
-
|
|
|
182 |
# Mac
|
183 |
.DS_Store
|
|
|
|
|
|
|
|
2 |
__pycache__/
|
3 |
*.py[cod]
|
4 |
*$py.class
|
5 |
+
.idea/
|
6 |
|
7 |
# C extensions
|
8 |
*.so
|
|
|
179 |
old/
|
180 |
*.wav
|
181 |
data/*
|
182 |
+
*.pth
|
183 |
+
old/
|
184 |
# Mac
|
185 |
.DS_Store
|
186 |
+
.idea
|
187 |
+
wav2vec2_emotion/
|
188 |
+
dataset/
|
.idea/.gitignore
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Default ignored files
|
2 |
+
/shelf/
|
3 |
+
/workspace.xml
|
4 |
+
# Editor-based HTTP Client requests
|
5 |
+
/httpRequests/
|
6 |
+
# Datasource local storage ignored files
|
7 |
+
/dataSources/
|
8 |
+
/dataSources.local.xml
|
.streamlit/config.toml
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
[theme]
|
2 |
-
base="
|
3 |
primaryColor="#7c99b4"
|
4 |
|
|
|
1 |
[theme]
|
2 |
+
base="light"
|
3 |
primaryColor="#7c99b4"
|
4 |
|
README.md
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
# SISE Ultimate Challenge
|
2 |
-

|
3 |
|
4 |
Ceci est le Ultimate Challenge pour le Master SISE.
|
5 |
|
app.py
CHANGED
@@ -1,34 +1,248 @@
|
|
1 |
import streamlit as st
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
)
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
import time
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
from datetime import datetime
|
8 |
+
import tempfile
|
9 |
+
import io
|
10 |
+
import json
|
11 |
+
from model.transcriber import transcribe_audio
|
12 |
+
from predict import predict_emotion
|
13 |
+
|
14 |
+
# You'll need to install this package:
|
15 |
+
# pip install streamlit-audiorec
|
16 |
+
from st_audiorec import st_audiorec
|
17 |
+
|
18 |
+
# Page configuration
|
19 |
+
st.set_page_config(
|
20 |
+
page_title="Emotion Analyser",
|
21 |
+
page_icon="🎤",
|
22 |
+
layout="wide"
|
23 |
+
)
|
24 |
+
|
25 |
+
# Initialize session state variables if they don't exist
|
26 |
+
if 'audio_data' not in st.session_state:
|
27 |
+
st.session_state.audio_data = []
|
28 |
+
if 'current_audio_index' not in st.session_state:
|
29 |
+
st.session_state.current_audio_index = -1
|
30 |
+
if 'audio_history_csv' not in st.session_state:
|
31 |
+
# Define columns for our CSV storage
|
32 |
+
st.session_state.audio_history_csv = pd.DataFrame(
|
33 |
+
columns=['timestamp', 'file_path', 'transcription', 'emotion', 'probabilities']
|
34 |
+
)
|
35 |
+
if 'needs_rerun' not in st.session_state:
|
36 |
+
st.session_state.needs_rerun = False
|
37 |
+
|
38 |
+
# Function to ensure we keep only the last 10 entries
|
39 |
+
def update_audio_history(new_entry):
|
40 |
+
# Add the new entry
|
41 |
+
st.session_state.audio_history_csv = pd.concat([st.session_state.audio_history_csv, pd.DataFrame([new_entry])], ignore_index=True)
|
42 |
+
|
43 |
+
# Keep only the last 10 entries
|
44 |
+
if len(st.session_state.audio_history_csv) > 10:
|
45 |
+
st.session_state.audio_history_csv = st.session_state.audio_history_csv.iloc[-10:]
|
46 |
+
|
47 |
+
# Save to CSV
|
48 |
+
st.session_state.audio_history_csv.to_csv('audio_history.csv', index=False)
|
49 |
+
|
50 |
+
# Function to process audio and get results
|
51 |
+
def process_audio(audio_path):
|
52 |
+
try:
|
53 |
+
# Get transcription
|
54 |
+
transcription = transcribe_audio(audio_path)
|
55 |
+
|
56 |
+
# Get emotion prediction
|
57 |
+
predicted_emotion, probabilities = predict_emotion(audio_path)
|
58 |
+
|
59 |
+
# Update audio history
|
60 |
+
new_entry = {
|
61 |
+
'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
62 |
+
'file_path': audio_path,
|
63 |
+
'transcription': transcription,
|
64 |
+
'emotion': predicted_emotion,
|
65 |
+
'probabilities': str(probabilities) # Convert dict to string for storage
|
66 |
+
}
|
67 |
+
update_audio_history(new_entry)
|
68 |
+
|
69 |
+
# Update current index
|
70 |
+
st.session_state.current_audio_index = len(st.session_state.audio_history_csv) - 1
|
71 |
+
|
72 |
+
return transcription, predicted_emotion, probabilities
|
73 |
+
except Exception as e:
|
74 |
+
st.error(f"Error processing audio: {str(e)}")
|
75 |
+
return None, None, None
|
76 |
+
|
77 |
+
# Function to split audio into 10-second segments
|
78 |
+
def split_audio(audio_file, segment_length=10):
|
79 |
+
# This is a placeholder - in a real implementation, you'd use a library like pydub
|
80 |
+
# to split the audio file into segments
|
81 |
+
st.warning("Audio splitting functionality is a placeholder. Implement with pydub or similar library.")
|
82 |
+
# For now, we'll just return the whole file as a single segment
|
83 |
+
return [audio_file]
|
84 |
+
|
85 |
+
# Function to display emotion visualization
|
86 |
+
def display_emotion_chart(probabilities):
|
87 |
+
emotions = list(probabilities.keys())
|
88 |
+
values = list(probabilities.values())
|
89 |
+
|
90 |
+
fig, ax = plt.subplots(figsize=(10, 5))
|
91 |
+
bars = ax.bar(emotions, values, color=['red', 'gray', 'green'])
|
92 |
+
|
93 |
+
# Add data labels on top of bars
|
94 |
+
for bar in bars:
|
95 |
+
height = bar.get_height()
|
96 |
+
ax.text(bar.get_x() + bar.get_width()/2., height + 0.02,
|
97 |
+
f'{height:.2f}', ha='center', va='bottom')
|
98 |
+
|
99 |
+
ax.set_ylim(0, 1.1)
|
100 |
+
ax.set_ylabel('Probability')
|
101 |
+
ax.set_title('Emotion Prediction Results')
|
102 |
+
|
103 |
+
st.pyplot(fig)
|
104 |
+
|
105 |
+
# Trigger rerun if needed (replaces experimental_rerun)
|
106 |
+
if st.session_state.needs_rerun:
|
107 |
+
st.session_state.needs_rerun = False
|
108 |
+
st.rerun() # Using st.rerun() instead of experimental_rerun
|
109 |
+
|
110 |
+
# Main App Layout
|
111 |
+
st.image("./img/logo_01.png", width=400)
|
112 |
+
|
113 |
+
# Create two columns for the main layout
|
114 |
+
col1, col2 = st.columns([1, 1])
|
115 |
+
|
116 |
+
with col1:
|
117 |
+
st.header("Audio Input")
|
118 |
+
|
119 |
+
# Method selection
|
120 |
+
|
121 |
+
tab1, tab2 = st.tabs(["Record Audio", "Upload Audio"])
|
122 |
+
|
123 |
+
with tab1:
|
124 |
+
st.write("Record your audio (max 10 seconds):")
|
125 |
+
|
126 |
+
# Using streamlit-audiorec for better recording functionality
|
127 |
+
wav_audio_data = st_audiorec()
|
128 |
+
|
129 |
+
if wav_audio_data is not None:
|
130 |
+
# Save the recorded audio to a temporary file
|
131 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
|
132 |
+
tmp_file.write(wav_audio_data)
|
133 |
+
tmp_file_path = tmp_file.name
|
134 |
+
|
135 |
+
st.success("Audio recorded successfully!")
|
136 |
+
|
137 |
+
# Process button
|
138 |
+
if st.button("Process Recorded Audio"):
|
139 |
+
# Process the audio
|
140 |
+
with st.spinner("Processing audio..."):
|
141 |
+
transcription, emotion, probs = process_audio(tmp_file_path)
|
142 |
+
# Set flag for rerun instead of calling experimental_rerun
|
143 |
+
if transcription is not None:
|
144 |
+
st.success("Audio processed successfully!")
|
145 |
+
st.session_state.needs_rerun = True
|
146 |
+
|
147 |
+
with tab2:
|
148 |
+
uploaded_file = st.file_uploader("Upload an audio file (WAV format)", type=['wav'])
|
149 |
+
|
150 |
+
if uploaded_file is not None:
|
151 |
+
# Save the uploaded file to a temporary location
|
152 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
|
153 |
+
tmp_file.write(uploaded_file.getbuffer())
|
154 |
+
tmp_file_path = tmp_file.name
|
155 |
+
|
156 |
+
st.audio(uploaded_file, format="audio/wav")
|
157 |
+
|
158 |
+
# Process button
|
159 |
+
if st.button("Process Uploaded Audio"):
|
160 |
+
# Split audio into 10-second segments
|
161 |
+
with st.spinner("Processing audio..."):
|
162 |
+
segments = split_audio(tmp_file_path)
|
163 |
+
|
164 |
+
# Process each segment
|
165 |
+
for i, segment_path in enumerate(segments):
|
166 |
+
st.write(f"Processing segment {i+1}...")
|
167 |
+
transcription, emotion, probs = process_audio(segment_path)
|
168 |
+
|
169 |
+
# Set flag for rerun instead of calling experimental_rerun
|
170 |
+
st.success("Audio processed successfully!")
|
171 |
+
st.session_state.needs_rerun = True
|
172 |
+
|
173 |
+
with col2:
|
174 |
+
st.header("Results")
|
175 |
+
|
176 |
+
# Display results if available
|
177 |
+
if st.session_state.current_audio_index >= 0 and len(st.session_state.audio_history_csv) > 0:
|
178 |
+
current_data = st.session_state.audio_history_csv.iloc[st.session_state.current_audio_index]
|
179 |
+
|
180 |
+
# Transcription
|
181 |
+
st.subheader("Transcription")
|
182 |
+
st.text_area("", value=current_data['transcription'], height=100, key="transcription_area")
|
183 |
+
|
184 |
+
# Emotion
|
185 |
+
st.subheader("Detected Emotion")
|
186 |
+
st.info(f"🎭 Predicted emotion: **{current_data['emotion']}**")
|
187 |
+
|
188 |
+
# Convert string representation of dict back to actual dict
|
189 |
+
try:
|
190 |
+
import ast
|
191 |
+
probs = ast.literal_eval(current_data['probabilities'])
|
192 |
+
display_emotion_chart(probs)
|
193 |
+
except Exception as e:
|
194 |
+
st.error(f"Error parsing probabilities: {str(e)}")
|
195 |
+
st.write(f"Raw probabilities: {current_data['probabilities']}")
|
196 |
+
else:
|
197 |
+
st.info("Record or upload audio to see results")
|
198 |
+
|
199 |
+
# Audio History and Analytics Section
|
200 |
+
st.header("Audio History and Analytics")
|
201 |
+
|
202 |
+
if len(st.session_state.audio_history_csv) > 0:
|
203 |
+
# Display a select box to choose from audio history
|
204 |
+
timestamps = st.session_state.audio_history_csv['timestamp'].tolist()
|
205 |
+
selected_timestamp = st.selectbox(
|
206 |
+
"Select audio from history:",
|
207 |
+
options=timestamps,
|
208 |
+
index=len(timestamps) - 1 # Default to most recent
|
209 |
)
|
210 |
|
211 |
+
# Update current index when selection changes
|
212 |
+
selected_index = st.session_state.audio_history_csv[
|
213 |
+
st.session_state.audio_history_csv['timestamp'] == selected_timestamp
|
214 |
+
].index[0]
|
215 |
+
|
216 |
+
# Only update if different
|
217 |
+
if st.session_state.current_audio_index != selected_index:
|
218 |
+
st.session_state.current_audio_index = selected_index
|
219 |
+
st.session_state.needs_rerun = True
|
220 |
|
221 |
+
# Analytics button
|
222 |
+
if st.button("Run Analytics on Selected Audio"):
|
223 |
+
st.subheader("Analytics Results")
|
224 |
+
|
225 |
+
# Get the selected audio data
|
226 |
+
selected_data = st.session_state.audio_history_csv.iloc[selected_index]
|
227 |
+
|
228 |
+
# Display analytics (this is where you would add more sophisticated analytics)
|
229 |
+
st.write(f"Selected Audio: {selected_data['timestamp']}")
|
230 |
+
st.write(f"Emotion: {selected_data['emotion']}")
|
231 |
+
st.write(f"File Path: {selected_data['file_path']}")
|
232 |
+
|
233 |
+
# Add any additional analytics you want here
|
234 |
+
|
235 |
+
# Try to play the selected audio
|
236 |
+
try:
|
237 |
+
if os.path.exists(selected_data['file_path']):
|
238 |
+
st.audio(selected_data['file_path'], format="audio/wav")
|
239 |
+
else:
|
240 |
+
st.warning("Audio file not found - it may have been deleted or moved.")
|
241 |
+
except Exception as e:
|
242 |
+
st.error(f"Error playing audio: {str(e)}")
|
243 |
+
else:
|
244 |
+
st.info("No audio history available. Record or upload audio to create history.")
|
245 |
+
|
246 |
+
# Footer
|
247 |
+
st.markdown("---")
|
248 |
+
st.caption("Audio Emotion Analyzer - Processes audio in 10-second segments and predicts emotions")
|
config.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from dotenv import load_dotenv
|
4 |
+
|
5 |
+
# Charger les variables d'environnement
|
6 |
+
load_dotenv()
|
7 |
+
HF_API_KEY = os.getenv("HF_API_KEY")
|
8 |
+
|
9 |
+
if not HF_API_KEY:
|
10 |
+
raise ValueError("Le token Hugging Face n'a pas été trouvé dans .env")
|
11 |
+
|
12 |
+
# Labels d'émotions
|
13 |
+
LABELS = {"colere": 0, "neutre": 1, "joie": 2}
|
14 |
+
#LABELS = ["colere", "neutre", "joie"]
|
15 |
+
NUM_LABELS = len(LABELS)
|
16 |
+
|
17 |
+
# Choisir le device
|
18 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
19 |
+
|
20 |
+
# Modèle Wav2Vec2
|
21 |
+
MODEL_NAME = "facebook/wav2vec2-large-xlsr-53-french"
|
22 |
+
|
23 |
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
24 |
+
BEST_MODEL_NAME = os.path.join(BASE_DIR, "model","fr-speech-emotion-model.pth") # Monte d'un niveau pour aller à la racine
|
25 |
+
|
img/logo.png
DELETED
Binary file (179 kB)
|
|
img/logo_01.png
ADDED
![]() |
model/__init__.py
ADDED
File without changes
|
model/emotion_classifier.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
# Prédit 33% environ partout (dans le cas 3 classes)
|
4 |
+
|
5 |
+
# class EmotionClassifier(nn.Module):
|
6 |
+
# def __init__(self, feature_dim, num_labels):
|
7 |
+
# super(EmotionClassifier, self).__init__()
|
8 |
+
# self.fc1 = nn.Linear(feature_dim, 256)
|
9 |
+
# self.relu = nn.ReLU()
|
10 |
+
# self.dropout = nn.Dropout(0.3)
|
11 |
+
# self.fc2 = nn.Linear(256, num_labels)
|
12 |
+
|
13 |
+
# def forward(self, x):
|
14 |
+
# x = self.fc1(x)
|
15 |
+
# x = self.relu(x)
|
16 |
+
# x = self.dropout(x)
|
17 |
+
# return self.fc2(x)
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import torch.nn as nn
|
23 |
+
import torch.nn.functional as F
|
24 |
+
|
25 |
+
class Attention(nn.Module):
|
26 |
+
"""Mécanisme d’attention permettant de pondérer l’importance des caractéristiques audio"""
|
27 |
+
def __init__(self, hidden_dim):
|
28 |
+
super(Attention, self).__init__()
|
29 |
+
self.attention_weights = nn.Linear(hidden_dim, 1)
|
30 |
+
|
31 |
+
def forward(self, lstm_output):
|
32 |
+
# lstm_output: (batch_size, sequence_length, hidden_dim)
|
33 |
+
attention_scores = self.attention_weights(lstm_output) # (batch_size, sequence_length, 1)
|
34 |
+
attention_weights = torch.softmax(attention_scores, dim=1) # Normalisation softmax
|
35 |
+
weighted_output = lstm_output * attention_weights # Pondération des features
|
36 |
+
return weighted_output.sum(dim=1) # Somme pondérée sur la séquence
|
37 |
+
|
38 |
+
class EmotionClassifier(nn.Module):
|
39 |
+
"""Modèle de classification des émotions basé sur BiLSTM et attention"""
|
40 |
+
def __init__(self, feature_dim, num_labels, hidden_dim=128):
|
41 |
+
super(EmotionClassifier, self).__init__()
|
42 |
+
self.lstm = nn.LSTM(feature_dim, hidden_dim, batch_first=True, bidirectional=True)
|
43 |
+
self.attention = Attention(hidden_dim * 2) # Bidirectionnel → hidden_dim * 2
|
44 |
+
self.fc = nn.Linear(hidden_dim * 2, num_labels) # Couche de classification finale
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
lstm_out, _ = self.lstm(x) # (batch_size, sequence_length, hidden_dim*2)
|
48 |
+
attention_out = self.attention(lstm_out) # (batch_size, hidden_dim*2)
|
49 |
+
logits = self.fc(attention_out) # (batch_size, num_labels)
|
50 |
+
return logits
|
51 |
+
|
52 |
+
|
53 |
+
|
54 |
+
|
model/feature_extractor.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import Wav2Vec2Model, Wav2Vec2Processor
|
3 |
+
from config import MODEL_NAME, DEVICE
|
4 |
+
|
5 |
+
processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)
|
6 |
+
feature_extractor = Wav2Vec2Model.from_pretrained(MODEL_NAME).to(DEVICE)
|
model/transcriber.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import librosa
|
4 |
+
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
|
5 |
+
|
6 |
+
# Charger le modèle et le processeur
|
7 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
8 |
+
MODEL_NAME = "facebook/wav2vec2-large-xlsr-53-french"
|
9 |
+
|
10 |
+
processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)
|
11 |
+
model = Wav2Vec2ForCTC.from_pretrained(MODEL_NAME).to(device)
|
12 |
+
model.eval()
|
13 |
+
|
14 |
+
def transcribe_audio(audio_path, sampling_rate=16000):
|
15 |
+
# Charger l'audio
|
16 |
+
audio, sr = librosa.load(audio_path, sr=sampling_rate)
|
17 |
+
|
18 |
+
# Transformer l'audio en entrée pour le modèle
|
19 |
+
input_values = processor(audio, sampling_rate=sampling_rate, return_tensors="pt").input_values.to(device)
|
20 |
+
|
21 |
+
# Obtenir les prédictions
|
22 |
+
with torch.no_grad():
|
23 |
+
logits = model(input_values).logits
|
24 |
+
|
25 |
+
# Décoder les prédictions en texte
|
26 |
+
predicted_ids = torch.argmax(logits, dim=-1)
|
27 |
+
transcription = processor.batch_decode(predicted_ids)[0]
|
28 |
+
return transcription
|
29 |
+
|
30 |
+
# Exemple d'utilisation
|
31 |
+
if __name__ == "__main__":
|
32 |
+
base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "data"))
|
33 |
+
audio_path = os.path.join(base_path, "colere", "c1af.wav")
|
34 |
+
texte = transcribe_audio(audio_path)
|
35 |
+
print(f"Transcription : {texte}")
|
predict.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import librosa
|
5 |
+
import numpy as np
|
6 |
+
from model.emotion_classifier import EmotionClassifier
|
7 |
+
from utils.preprocessing import collate_fn
|
8 |
+
from config import DEVICE, NUM_LABELS, BEST_MODEL_NAME
|
9 |
+
|
10 |
+
# Charger le modèle entraîné
|
11 |
+
feature_dim = 40 # Nombre de MFCCs utilisés
|
12 |
+
model = EmotionClassifier(feature_dim, NUM_LABELS).to(DEVICE)
|
13 |
+
model.load_state_dict(torch.load(BEST_MODEL_NAME, map_location=DEVICE))
|
14 |
+
model.eval() # Mode évaluation
|
15 |
+
|
16 |
+
# Labels des émotions
|
17 |
+
LABELS = {0: "colère", 1: "neutre", 2: "joie"}
|
18 |
+
|
19 |
+
# Fonction pour prédire l’émotion d’un fichier audio avec probabilités
|
20 |
+
def predict_emotion(audio_path, max_length=128):
|
21 |
+
# Charger l’audio
|
22 |
+
y, sr = librosa.load(audio_path, sr=16000)
|
23 |
+
|
24 |
+
# Extraire les MFCCs
|
25 |
+
mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=40)
|
26 |
+
|
27 |
+
# Ajuster la taille des MFCCs avec padding/troncature
|
28 |
+
if mfcc.shape[1] > max_length:
|
29 |
+
mfcc = mfcc[:, :max_length] # Tronquer si trop long
|
30 |
+
else:
|
31 |
+
pad_width = max_length - mfcc.shape[1]
|
32 |
+
mfcc = np.pad(mfcc, pad_width=((0, 0), (0, pad_width)), mode='constant')
|
33 |
+
|
34 |
+
# Convertir en tenseur PyTorch
|
35 |
+
input_tensor = torch.tensor(mfcc.T, dtype=torch.float32).unsqueeze(0).to(DEVICE) # (1, max_length, 40)
|
36 |
+
|
37 |
+
# Prédiction avec le modèle
|
38 |
+
with torch.no_grad():
|
39 |
+
logits = model(input_tensor)
|
40 |
+
probabilities = torch.nn.functional.softmax(logits, dim=-1).cpu().numpy().flatten() # Convertir en probabilités
|
41 |
+
predicted_class = torch.argmax(logits, dim=-1).item()
|
42 |
+
|
43 |
+
# Associer les probabilités aux labels
|
44 |
+
probabilities_dict = {LABELS[i]: float(probabilities[i]) for i in range(NUM_LABELS)}
|
45 |
+
|
46 |
+
return LABELS[predicted_class], probabilities_dict
|
47 |
+
|
48 |
+
|
49 |
+
# Exemple d'utilisation
|
50 |
+
if __name__ == "__main__":
|
51 |
+
base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "data"))
|
52 |
+
audio_file = os.path.join(base_path, "colere", "c1ac.wav")
|
53 |
+
|
54 |
+
predicted_emotion, probabilities = predict_emotion(audio_file)
|
55 |
+
|
56 |
+
print(f"🎤 L'émotion prédite est : {predicted_emotion}")
|
57 |
+
print(f"📊 Probabilités par classe : {probabilities}")
|
requirements.txt
CHANGED
@@ -15,3 +15,5 @@ scikit-learn
|
|
15 |
huggingface
|
16 |
huggingface_hub
|
17 |
pyaudio
|
|
|
|
|
|
15 |
huggingface
|
16 |
huggingface_hub
|
17 |
pyaudio
|
18 |
+
streamlit_audiorec
|
19 |
+
dotenv
|
src/data/colere/c1ac.wav
DELETED
Binary file (110 kB)
|
|
src/data/colere/c1af.wav
DELETED
Binary file (157 kB)
|
|
src/data/colere/c1aj.wav
DELETED
Binary file (210 kB)
|
|
src/data/colere/c1an.wav
DELETED
Binary file (148 kB)
|
|
src/data/colere/c1bc.wav
DELETED
Binary file (65.8 kB)
|
|
src/data/colere/c1bf.wav
DELETED
Binary file (117 kB)
|
|
src/data/colere/c1bj.wav
DELETED
Binary file (76.9 kB)
|
|
src/data/colere/c1bn.wav
DELETED
Binary file (74.3 kB)
|
|
src/data/colere/c1cc.wav
DELETED
Binary file (112 kB)
|
|
src/data/colere/c1cf.wav
DELETED
Binary file (138 kB)
|
|
src/data/colere/c1cj.wav
DELETED
Binary file (101 kB)
|
|
src/data/colere/c2ac.wav
DELETED
Binary file (108 kB)
|
|
src/data/colere/c2af.wav
DELETED
Binary file (138 kB)
|
|
src/data/colere/c2aj.wav
DELETED
Binary file (115 kB)
|
|
src/data/colere/c2an.wav
DELETED
Binary file (140 kB)
|
|
src/data/colere/c2bc.wav
DELETED
Binary file (89.1 kB)
|
|
src/data/colere/c2bf.wav
DELETED
Binary file (115 kB)
|
|
src/data/colere/c2bj.wav
DELETED
Binary file (110 kB)
|
|
src/data/colere/c2bn.wav
DELETED
Binary file (138 kB)
|
|
src/data/colere/c2cn.wav
DELETED
Binary file (123 kB)
|
|
src/data/colere/c3ac.wav
DELETED
Binary file (119 kB)
|
|
src/data/colere/c3af.wav
DELETED
Binary file (127 kB)
|
|
src/data/colere/c3aj.wav
DELETED
Binary file (119 kB)
|
|
src/data/colere/c3an.wav
DELETED
Binary file (129 kB)
|
|
src/data/colere/c3bc.wav
DELETED
Binary file (115 kB)
|
|
src/data/colere/c3bf.wav
DELETED
Binary file (142 kB)
|
|
src/data/colere/c3bj.wav
DELETED
Binary file (99.7 kB)
|
|
src/data/colere/c3bn.wav
DELETED
Binary file (153 kB)
|
|
src/data/colere/c4aaf.wav
DELETED
Binary file (142 kB)
|
|
src/data/colere/c4ac.wav
DELETED
Binary file (108 kB)
|
|
src/data/colere/c4af.wav
DELETED
Binary file (127 kB)
|
|
src/data/colere/c4aj.wav
DELETED
Binary file (159 kB)
|
|
src/data/colere/c4an.wav
DELETED
Binary file (121 kB)
|
|
src/data/colere/c4bc.wav
DELETED
Binary file (112 kB)
|
|