Coco-18 commited on
Commit
53ae3c9
Β·
verified Β·
1 Parent(s): edd4f88

Update evaluate.py

Browse files
Files changed (1) hide show
  1. evaluate.py +174 -174
evaluate.py CHANGED
@@ -326,112 +326,70 @@ def handle_evaluation_request(request, reference_dir, output_dir, sample_rate):
326
  logger.warning(f"[{request_id}] ⚠️ Created missing reference directory: {reference_dir_path}")
327
  except Exception as e:
328
  logger.error(f"[{request_id}] ❌ Failed to create reference directory: {str(e)}")
329
- return jsonify({"error": f"Reference audio directory not found: {reference_locator}"}), 404
330
 
331
- # Check for reference files
332
- reference_files = glob.glob(os.path.join(reference_dir_path, "*.wav"))
333
- logger.info(f"[{request_id}] πŸ“ Found {len(reference_files)} reference files")
334
 
335
- # If no reference files exist, create a dummy reference file
336
- if not reference_files:
337
- logger.warning(f"[{request_id}] ⚠️ No reference audio files found in {reference_dir_path}")
338
 
339
- # Create a dummy reference file
340
- try:
341
- dummy_file_path = os.path.join(reference_dir_path, "dummy_reference.wav")
342
- logger.info(f"[{request_id}] πŸ”„ Creating dummy reference file: {dummy_file_path}")
343
-
344
- # Create a 1-second audio file with a slight sound
345
- silent_audio = AudioSegment.silent(duration=1000, frame_rate=sample_rate)
346
- # Add a tiny bit of noise to help ASR
347
- for i in range(50, 950, 300):
348
- silent_audio = silent_audio.overlay(AudioSegment.silent(duration=50, frame_rate=sample_rate) + 3, position=i)
349
- silent_audio.export(dummy_file_path, format="wav")
350
-
351
- # Add it to the list of reference files
352
- reference_files = [dummy_file_path]
353
- logger.info(f"[{request_id}] βœ… Created dummy reference file for testing")
354
- except Exception as e:
355
- logger.error(f"[{request_id}] ❌ Failed to create dummy reference: {str(e)}")
356
- return jsonify({"error": f"No reference audio found for {reference_locator}"}), 404
357
-
358
- lang_code = LANGUAGE_CODES.get(language, language)
359
- logger.info(f"[{request_id}] πŸ”„ Evaluating pronunciation for reference: {reference_locator} with language code: {lang_code}")
360
-
361
- # Create a request-specific temp directory to avoid conflicts
362
- temp_dir = os.path.join(output_dir, f"temp_{request_id}")
363
- os.makedirs(temp_dir, exist_ok=True)
364
-
365
- # Process user audio
366
- user_audio_path = os.path.join(temp_dir, "user_audio_input.wav")
367
- with open(user_audio_path, 'wb') as f:
368
- f.write(audio_file.read())
369
-
370
- try:
371
- logger.info(f"[{request_id}] πŸ”„ Processing user audio file")
372
- audio = AudioSegment.from_file(user_audio_path)
373
- audio = audio.set_frame_rate(sample_rate).set_channels(1)
374
 
375
- processed_path = os.path.join(temp_dir, "processed_user_audio.wav")
376
- audio.export(processed_path, format="wav")
377
 
378
- user_waveform, sr = torchaudio.load(processed_path)
379
- user_waveform = user_waveform.squeeze().numpy()
380
- logger.info(f"[{request_id}] βœ… User audio processed: {sr}Hz, length: {len(user_waveform)} samples")
381
 
382
- user_audio_path = processed_path
383
- except Exception as e:
384
- logger.error(f"[{request_id}] ❌ Audio processing failed: {str(e)}")
385
- return jsonify({"error": f"Audio processing failed: {str(e)}"}), 500
386
 
387
- # Transcribe user audio
388
- try:
389
- logger.info(f"[{request_id}] πŸ”„ Transcribing user audio")
390
- # Remove language parameter if causing warnings
391
- inputs = asr_processor(
392
- user_waveform,
393
- sampling_rate=sample_rate,
394
- return_tensors="pt"
395
- )
396
- inputs = {k: v.to(asr_model.device) for k, v in inputs.items()}
397
-
398
- with torch.no_grad():
399
- logits = asr_model(**inputs).logits
400
- ids = torch.argmax(logits, dim=-1)[0]
401
- user_transcription = asr_processor.decode(ids)
402
-
403
- logger.info(f"[{request_id}] βœ… User transcription: '{user_transcription}'")
404
- except Exception as e:
405
- logger.error(f"[{request_id}] ❌ ASR inference failed: {str(e)}")
406
- return jsonify({"error": f"ASR inference failed: {str(e)}"}), 500
407
 
408
- # Process reference files in batches
409
- batch_size = 2 # Process 2 files at a time - adjust based on your hardware
410
- results = []
411
- best_score = 0
412
- best_reference = None
413
- best_transcription = None
414
 
415
- # Use this if you want to limit the number of files to process
416
- max_files_to_check = min(5, len(reference_files)) # Check at most 5 files
417
- reference_files = reference_files[:max_files_to_check]
418
 
419
- logger.info(f"[{request_id}] πŸ”„ Processing {len(reference_files)} reference files in batches of {batch_size}")
 
 
 
420
 
421
- # Function to process a single reference file
422
- def process_reference_file(ref_file):
423
- ref_filename = os.path.basename(ref_file)
424
  try:
425
- # Load and resample reference audio
426
- ref_waveform, ref_sr = torchaudio.load(ref_file)
427
- if ref_sr != sample_rate:
428
- ref_waveform = torchaudio.transforms.Resample(ref_sr, sample_rate)(ref_waveform)
429
- ref_waveform = ref_waveform.squeeze().numpy()
430
-
431
- # Transcribe reference audio - use the local asr_model and asr_processor
432
  # Remove language parameter if causing warnings
433
  inputs = asr_processor(
434
- ref_waveform,
435
  sampling_rate=sample_rate,
436
  return_tensors="pt"
437
  )
@@ -440,93 +398,135 @@ def handle_evaluation_request(request, reference_dir, output_dir, sample_rate):
440
  with torch.no_grad():
441
  logits = asr_model(**inputs).logits
442
  ids = torch.argmax(logits, dim=-1)[0]
443
- ref_transcription = asr_processor.decode(ids)
444
 
445
- # Calculate similarity
446
- similarity = calculate_similarity(user_transcription, ref_transcription)
 
 
447
 
448
- logger.info(
449
- f"[{request_id}] πŸ“Š Similarity with {ref_filename}: {similarity:.2f}%, transcription: '{ref_transcription}'")
 
 
 
 
450
 
451
- return {
452
- "reference_file": ref_filename,
453
- "reference_text": ref_transcription,
454
- "similarity_score": similarity
455
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
456
  except Exception as e:
457
- logger.error(f"[{request_id}] ❌ Error processing {ref_filename}: {str(e)}")
458
- return {
459
- "reference_file": ref_filename,
460
- "reference_text": "Error",
461
- "similarity_score": 0,
462
- "error": str(e)
463
- }
464
 
465
- # Process files in batches using ThreadPoolExecutor
466
- with ThreadPoolExecutor(max_workers=batch_size) as executor:
467
- batch_results = list(executor.map(process_reference_file, reference_files))
468
- results.extend(batch_results)
469
-
470
- # Find the best result
471
- for result in batch_results:
472
- if result["similarity_score"] > best_score:
473
- best_score = result["similarity_score"]
474
- best_reference = result["reference_file"]
475
- best_transcription = result["reference_text"]
476
-
477
- # Exit early if we found a very good match (optional)
478
- if best_score > 80.0:
479
- logger.info(f"[{request_id}] 🏁 Found excellent match: {best_score:.2f}%")
480
- break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
481
 
482
- # Clean up temp files
483
- try:
484
- if temp_dir and os.path.exists(temp_dir):
485
- shutil.rmtree(temp_dir)
486
- logger.debug(f"[{request_id}] 🧹 Cleaned up temporary directory")
487
  except Exception as e:
488
- logger.warning(f"[{request_id}] ⚠️ Failed to clean up temp files: {str(e)}")
489
-
490
- # Determine feedback based on score
491
- is_correct = best_score >= 70.0
492
-
493
- if best_score >= 90.0:
494
- feedback = "Perfect pronunciation! Excellent job!"
495
- elif best_score >= 80.0:
496
- feedback = "Great pronunciation! Your accent is very good."
497
- elif best_score >= 70.0:
498
- feedback = "Good pronunciation. Keep practicing!"
499
- elif best_score >= 50.0:
500
- feedback = "Fair attempt. Try focusing on the syllables that differ from the sample."
501
- else:
502
- feedback = "Try again. Listen carefully to the sample pronunciation."
503
-
504
- logger.info(f"[{request_id}] πŸ“Š Final evaluation results: score={best_score:.2f}%, is_correct={is_correct}")
505
- logger.info(f"[{request_id}] πŸ“ Feedback: '{feedback}'")
506
- logger.info(f"[{request_id}] βœ… Evaluation complete")
507
-
508
- # Sort results by score descending
509
- results.sort(key=lambda x: x["similarity_score"], reverse=True)
510
-
511
- return jsonify({
512
- "is_correct": is_correct,
513
- "score": best_score,
514
- "feedback": feedback,
515
- "user_transcription": user_transcription,
516
- "best_reference_transcription": best_transcription,
517
- "reference_locator": reference_locator,
518
- "details": results
519
- })
520
-
521
- except Exception as e:
522
- logger.error(f"[{request_id}] ❌ Unhandled exception in evaluation endpoint: {str(e)}")
523
- logger.debug(f"[{request_id}] Stack trace: {traceback.format_exc()}")
524
-
525
- # Clean up on error
526
- try:
527
- if temp_dir and os.path.exists(temp_dir):
528
- shutil.rmtree(temp_dir)
529
- except:
530
- pass
531
 
532
- return jsonify({"error": f"Internal server error: {str(e)}"}), 500
 
326
  logger.warning(f"[{request_id}] ⚠️ Created missing reference directory: {reference_dir_path}")
327
  except Exception as e:
328
  logger.error(f"[{request_id}] ❌ Failed to create reference directory: {str(e)}")
329
+ return jsonify({"error": f"Reference audio directory not found: {reference_locator}"}), 404
330
 
331
+ # Check for reference files
332
+ reference_files = glob.glob(os.path.join(reference_dir_path, "*.wav"))
333
+ logger.info(f"[{request_id}] πŸ“ Found {len(reference_files)} reference files")
334
 
335
+ # If no reference files exist, create a dummy reference file
336
+ if not reference_files:
337
+ logger.warning(f"[{request_id}] ⚠️ No reference audio files found in {reference_dir_path}")
338
 
339
+ # Create a dummy reference file
340
+ try:
341
+ dummy_file_path = os.path.join(reference_dir_path, "dummy_reference.wav")
342
+ logger.info(f"[{request_id}] πŸ”„ Creating dummy reference file: {dummy_file_path}")
343
+
344
+ # Create a 1-second audio file with a slight sound
345
+ silent_audio = AudioSegment.silent(duration=1000, frame_rate=sample_rate)
346
+ # Add a tiny bit of noise to help ASR
347
+ for i in range(50, 950, 300):
348
+ silent_audio = silent_audio.overlay(AudioSegment.silent(duration=50, frame_rate=sample_rate) + 3, position=i)
349
+ silent_audio.export(dummy_file_path, format="wav")
350
+
351
+ # Add it to the list of reference files
352
+ reference_files = [dummy_file_path]
353
+ logger.info(f"[{request_id}] βœ… Created dummy reference file for testing")
354
+ except Exception as e:
355
+ logger.error(f"[{request_id}] ❌ Failed to create dummy reference: {str(e)}")
356
+ return jsonify({"error": f"No reference audio found for {reference_locator}"}), 404
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
 
358
+ lang_code = LANGUAGE_CODES.get(language, language)
359
+ logger.info(f"[{request_id}] πŸ”„ Evaluating pronunciation for reference: {reference_locator} with language code: {lang_code}")
360
 
361
+ # Create a request-specific temp directory to avoid conflicts
362
+ temp_dir = os.path.join(output_dir, f"temp_{request_id}")
363
+ os.makedirs(temp_dir, exist_ok=True)
364
 
365
+ # Process user audio
366
+ user_audio_path = os.path.join(temp_dir, "user_audio_input.wav")
367
+ with open(user_audio_path, 'wb') as f:
368
+ f.write(audio_file.read())
369
 
370
+ try:
371
+ logger.info(f"[{request_id}] πŸ”„ Processing user audio file")
372
+ audio = AudioSegment.from_file(user_audio_path)
373
+ audio = audio.set_frame_rate(sample_rate).set_channels(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
374
 
375
+ processed_path = os.path.join(temp_dir, "processed_user_audio.wav")
376
+ audio.export(processed_path, format="wav")
 
 
 
 
377
 
378
+ user_waveform, sr = torchaudio.load(processed_path)
379
+ user_waveform = user_waveform.squeeze().numpy()
380
+ logger.info(f"[{request_id}] βœ… User audio processed: {sr}Hz, length: {len(user_waveform)} samples")
381
 
382
+ user_audio_path = processed_path
383
+ except Exception as e:
384
+ logger.error(f"[{request_id}] ❌ Audio processing failed: {str(e)}")
385
+ return jsonify({"error": f"Audio processing failed: {str(e)}"}), 500
386
 
387
+ # Transcribe user audio
 
 
388
  try:
389
+ logger.info(f"[{request_id}] πŸ”„ Transcribing user audio")
 
 
 
 
 
 
390
  # Remove language parameter if causing warnings
391
  inputs = asr_processor(
392
+ user_waveform,
393
  sampling_rate=sample_rate,
394
  return_tensors="pt"
395
  )
 
398
  with torch.no_grad():
399
  logits = asr_model(**inputs).logits
400
  ids = torch.argmax(logits, dim=-1)[0]
401
+ user_transcription = asr_processor.decode(ids)
402
 
403
+ logger.info(f"[{request_id}] βœ… User transcription: '{user_transcription}'")
404
+ except Exception as e:
405
+ logger.error(f"[{request_id}] ❌ ASR inference failed: {str(e)}")
406
+ return jsonify({"error": f"ASR inference failed: {str(e)}"}), 500
407
 
408
+ # Process reference files in batches
409
+ batch_size = 2 # Process 2 files at a time - adjust based on your hardware
410
+ results = []
411
+ best_score = 0
412
+ best_reference = None
413
+ best_transcription = None
414
 
415
+ # Use this if you want to limit the number of files to process
416
+ max_files_to_check = min(5, len(reference_files)) # Check at most 5 files
417
+ reference_files = reference_files[:max_files_to_check]
418
+
419
+ logger.info(f"[{request_id}] πŸ”„ Processing {len(reference_files)} reference files in batches of {batch_size}")
420
+
421
+ # Function to process a single reference file
422
+ def process_reference_file(ref_file):
423
+ ref_filename = os.path.basename(ref_file)
424
+ try:
425
+ # Load and resample reference audio
426
+ ref_waveform, ref_sr = torchaudio.load(ref_file)
427
+ if ref_sr != sample_rate:
428
+ ref_waveform = torchaudio.transforms.Resample(ref_sr, sample_rate)(ref_waveform)
429
+ ref_waveform = ref_waveform.squeeze().numpy()
430
+
431
+ # Transcribe reference audio - use the local asr_model and asr_processor
432
+ # Remove language parameter if causing warnings
433
+ inputs = asr_processor(
434
+ ref_waveform,
435
+ sampling_rate=sample_rate,
436
+ return_tensors="pt"
437
+ )
438
+ inputs = {k: v.to(asr_model.device) for k, v in inputs.items()}
439
+
440
+ with torch.no_grad():
441
+ logits = asr_model(**inputs).logits
442
+ ids = torch.argmax(logits, dim=-1)[0]
443
+ ref_transcription = asr_processor.decode(ids)
444
+
445
+ # Calculate similarity
446
+ similarity = calculate_similarity(user_transcription, ref_transcription)
447
+
448
+ logger.info(
449
+ f"[{request_id}] πŸ“Š Similarity with {ref_filename}: {similarity:.2f}%, transcription: '{ref_transcription}'")
450
+
451
+ return {
452
+ "reference_file": ref_filename,
453
+ "reference_text": ref_transcription,
454
+ "similarity_score": similarity
455
+ }
456
+ except Exception as e:
457
+ logger.error(f"[{request_id}] ❌ Error processing {ref_filename}: {str(e)}")
458
+ return {
459
+ "reference_file": ref_filename,
460
+ "reference_text": "Error",
461
+ "similarity_score": 0,
462
+ "error": str(e)
463
+ }
464
+
465
+ # Process files in batches using ThreadPoolExecutor
466
+ with ThreadPoolExecutor(max_workers=batch_size) as executor:
467
+ batch_results = list(executor.map(process_reference_file, reference_files))
468
+ results.extend(batch_results)
469
+
470
+ # Find the best result
471
+ for result in batch_results:
472
+ if result["similarity_score"] > best_score:
473
+ best_score = result["similarity_score"]
474
+ best_reference = result["reference_file"]
475
+ best_transcription = result["reference_text"]
476
+
477
+ # Exit early if we found a very good match (optional)
478
+ if best_score > 80.0:
479
+ logger.info(f"[{request_id}] 🏁 Found excellent match: {best_score:.2f}%")
480
+ break
481
+
482
+ # Clean up temp files
483
+ try:
484
+ if temp_dir and os.path.exists(temp_dir):
485
+ shutil.rmtree(temp_dir)
486
+ logger.debug(f"[{request_id}] 🧹 Cleaned up temporary directory")
487
  except Exception as e:
488
+ logger.warning(f"[{request_id}] ⚠️ Failed to clean up temp files: {str(e)}")
 
 
 
 
 
 
489
 
490
+ # Determine feedback based on score
491
+ is_correct = best_score >= 70.0
492
+
493
+ if best_score >= 90.0:
494
+ feedback = "Perfect pronunciation! Excellent job!"
495
+ elif best_score >= 80.0:
496
+ feedback = "Great pronunciation! Your accent is very good."
497
+ elif best_score >= 70.0:
498
+ feedback = "Good pronunciation. Keep practicing!"
499
+ elif best_score >= 50.0:
500
+ feedback = "Fair attempt. Try focusing on the syllables that differ from the sample."
501
+ else:
502
+ feedback = "Try again. Listen carefully to the sample pronunciation."
503
+
504
+ logger.info(f"[{request_id}] πŸ“Š Final evaluation results: score={best_score:.2f}%, is_correct={is_correct}")
505
+ logger.info(f"[{request_id}] πŸ“ Feedback: '{feedback}'")
506
+ logger.info(f"[{request_id}] βœ… Evaluation complete")
507
+
508
+ # Sort results by score descending
509
+ results.sort(key=lambda x: x["similarity_score"], reverse=True)
510
+
511
+ return jsonify({
512
+ "is_correct": is_correct,
513
+ "score": best_score,
514
+ "feedback": feedback,
515
+ "user_transcription": user_transcription,
516
+ "best_reference_transcription": best_transcription,
517
+ "reference_locator": reference_locator,
518
+ "details": results
519
+ })
520
 
 
 
 
 
 
521
  except Exception as e:
522
+ logger.error(f"[{request_id}] ❌ Unhandled exception in evaluation endpoint: {str(e)}")
523
+ logger.debug(f"[{request_id}] Stack trace: {traceback.format_exc()}")
524
+
525
+ # Clean up on error
526
+ try:
527
+ if temp_dir and os.path.exists(temp_dir):
528
+ shutil.rmtree(temp_dir)
529
+ except:
530
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
531
 
532
+ return jsonify({"error": f"Internal server error: {str(e)}"}), 500