deveix commited on
Commit
a1b9bc0
·
1 Parent(s): e7838b2
Files changed (2) hide show
  1. app/main.py +122 -1
  2. requirements.txt +2 -1
app/main.py CHANGED
@@ -22,6 +22,9 @@ import opensmile
22
 
23
  import ffmpeg
24
  import noisereduce as nr
 
 
 
25
 
26
  default_sample_rate=22050
27
 
@@ -201,6 +204,124 @@ async def get_answer(item: Item, token: str = Depends(verify_token)):
201
  # If there's an error, return a 500 error with the error's details
202
  raise HTTPException(status_code=500, detail=str(e))
203
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  # random forest
205
  model = joblib.load('app/1713661391.0946255_trained_model.joblib')
206
  pca = joblib.load('app/pca.pkl')
@@ -320,7 +441,7 @@ def repair_mp3_with_ffmpeg_python(input_path, output_path):
320
  print(f"Failed to repair file {input_path}: {str(e.stderr)}")
321
 
322
 
323
- @app.post("/mlp")
324
  async def handle_audio(file: UploadFile = File(...)):
325
  try:
326
  # Ensure that we are handling an MP3 file
 
22
 
23
  import ffmpeg
24
  import noisereduce as nr
25
+ from tensorflow.keras.models import load_model
26
+ from tensorflow.keras.utils import to_categorical
27
+ from tensorflow.keras.models import Sequential
28
 
29
  default_sample_rate=22050
30
 
 
204
  # If there's an error, return a 500 error with the error's details
205
  raise HTTPException(status_code=500, detail=str(e))
206
 
207
+ # ------- CNN
208
+
209
+ # Constants
210
+ TARGET_DURATION = 3 # seconds for each audio clip
211
+ SAMPLE_RATE = 44100 # sample rate to use
212
+ N_MELS = 128 # number of Mel bands to generate
213
+ HOP_LENGTH = 512 # number of samples between successive frames
214
+
215
+ def preprocess_audio(file_path):
216
+ try:
217
+ # Load the audio file
218
+ audio, sr = librosa.load(file_path, sr=SAMPLE_RATE)
219
+ audio_length = len(audio)/SAMPLE_RATE
220
+ except FileNotFoundError:
221
+ print(f"Error: File '{file_path}' not found.")
222
+ return None
223
+ except Exception as e:
224
+ print(f"Error loading audio file: {e}")
225
+ return None
226
+
227
+ # Check if audio signal is None
228
+ if audio is None:
229
+ print(f"Error: Audio signal is None for file '{file_path}'.")
230
+ return None
231
+
232
+ audio, _ = librosa.effects.trim(audio, top_db = 25)
233
+
234
+ audio = nr.reduce_noise(y = audio, sr=SAMPLE_RATE, thresh_n_mult_nonstationary=1,stationary=False)
235
+
236
+ # Determine how many 20-second clips can be made from the audio
237
+ if audio_length < TARGET_DURATION:
238
+ # If audio is shorter than 20 seconds, pad it
239
+ pad_length = int((TARGET_DURATION - audio_length) * sr)
240
+ padded_audio = np.pad(audio, (0, pad_length), mode='constant')
241
+ return [padded_audio] # Return as a list for consistent output format
242
+ else:
243
+ # If audio is longer than or equal to 20 seconds, split it into 20-second clips
244
+ clip_length = TARGET_DURATION * sr
245
+ clips = []
246
+ for start in range(0, len(audio), clip_length):
247
+ end = start + clip_length
248
+ # Ensure the last clip has enough samples
249
+ if end > len(audio):
250
+ # Here you can choose to pad the last clip or simply not use it if it's too short
251
+ last_clip = np.pad(audio[start:], (0, end - len(audio)), mode='constant')
252
+ clips.append(last_clip)
253
+ else:
254
+ clips.append(audio[start:end])
255
+ return clips
256
+
257
+ def generate_spectrogram(audio):
258
+ # Generate a Mel-scaled spectrogram
259
+ S = librosa.feature.melspectrogram(y=audio, sr=SAMPLE_RATE, n_mels=N_MELS, hop_length=HOP_LENGTH)
260
+ S_dB = librosa.power_to_db(S, ref=np.max)
261
+
262
+ # Normalize the spectrogram to be between 0 and 1
263
+ S_dB_norm = librosa.util.normalize(S_dB)
264
+
265
+ return S_dB_norm
266
+
267
+ cnn_model = load_model('app/cnn.h5')
268
+ cnn_label_encoder = joblib.load('app/cnn_label_encoder.pkl')
269
+
270
+ @app.post("/cnn")
271
+ async def handle_cnn(file: UploadFile = File(...)):
272
+ try:
273
+ # Ensure that we are handling an MP3 file
274
+ if file.content_type == "audio/mpeg" or file.content_type == "audio/mp3":
275
+ file_extension = ".mp3"
276
+ elif file.content_type == "audio/wav":
277
+ file_extension = ".wav"
278
+ else:
279
+ raise HTTPException(status_code=400, detail="Invalid file type. Supported types: MP3, WAV.")
280
+
281
+ # Read the file's content
282
+ contents = await file.read()
283
+ temp_filename = f"app/{uuid4().hex}{file_extension}"
284
+
285
+
286
+ # Save file to a temporary file if needed or process directly from memory
287
+ with open(temp_filename, "wb") as f:
288
+ f.write(contents)
289
+
290
+ spectrograms = []
291
+
292
+ clips = preprocess_audio(temp_filename)
293
+ for clip in clips:
294
+ spectrogram = generate_spectrogram(clip)
295
+ if np.isnan(spectrogram).any() or np.isinf(spectrogram).any():
296
+ print("Invalid spectrogram detected")
297
+ continue
298
+ spectrograms.append(spectrogram)
299
+ X = np.array(spectrograms)
300
+
301
+ X = X[..., np.newaxis]
302
+
303
+ # Make predictions
304
+ predictions = cnn_model.predict(X)
305
+
306
+ # Convert predictions to label indexes
307
+ predicted_label_indexes = np.argmax(predictions, axis=1)
308
+
309
+ # Convert label indexes to actual label names
310
+ predicted_labels = cnn_label_encoder.inverse_transform(predicted_label_indexes)
311
+
312
+ print('decoded', predicted_labels)
313
+ # .tolist()
314
+ # Clean up the temporary file
315
+ os.remove(temp_filename)
316
+ # Return a successful response with decoded predictions
317
+ return {"message": "File processed successfully", "sheikh": predicted_labels}
318
+ except Exception as e:
319
+ print(e)
320
+ # Handle possible exceptions
321
+ raise HTTPException(status_code=500, detail=str(e))
322
+
323
+
324
+
325
  # random forest
326
  model = joblib.load('app/1713661391.0946255_trained_model.joblib')
327
  pca = joblib.load('app/pca.pkl')
 
441
  print(f"Failed to repair file {input_path}: {str(e.stderr)}")
442
 
443
 
444
+ @app.post("/rf")
445
  async def handle_audio(file: UploadFile = File(...)):
446
  try:
447
  # Ensure that we are handling an MP3 file
requirements.txt CHANGED
@@ -19,4 +19,5 @@ matplotlib
19
  python-multipart
20
  ffmpeg-python
21
  noisereduce
22
- scikit-learn==1.2.2
 
 
19
  python-multipart
20
  ffmpeg-python
21
  noisereduce
22
+ scikit-learn==1.2.2
23
+ tensorflow