Coco-18 commited on
Commit
c5def88
Β·
verified Β·
1 Parent(s): c0eb848

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -66
app.py CHANGED
@@ -602,46 +602,31 @@ def evaluate_pronunciation():
602
  temp_dir = os.path.join(OUTPUT_DIR, f"temp_{request_id}")
603
  os.makedirs(temp_dir, exist_ok=True)
604
 
605
- # Save the uploaded file temporarily
606
  user_audio_path = os.path.join(temp_dir, "user_audio_input.wav")
607
  with open(user_audio_path, 'wb') as f:
608
  f.write(audio_file.read())
609
- logger.debug(f"[{request_id}] πŸ“ User audio saved to {user_audio_path}")
610
-
611
- # Convert to WAV if necessary and ensure correct format
612
  try:
613
  logger.info(f"[{request_id}] πŸ”„ Processing user audio file")
614
- # First try using pydub for consistent processing
615
  audio = AudioSegment.from_file(user_audio_path)
616
  audio = audio.set_frame_rate(SAMPLE_RATE).set_channels(1)
617
-
618
- # Save processed audio
619
  processed_path = os.path.join(temp_dir, "processed_user_audio.wav")
620
  audio.export(processed_path, format="wav")
621
- logger.debug(f"[{request_id}] πŸ“ Processed user audio saved to {processed_path}")
622
-
623
- # Load the processed audio for ASR
624
  user_waveform, sr = torchaudio.load(processed_path)
625
  user_waveform = user_waveform.squeeze().numpy()
626
- logger.info(f"[{request_id}] βœ… User audio processed successfully: {sr}Hz, length: {len(user_waveform)} samples")
627
-
628
- # Update user_audio_path to processed file
629
  user_audio_path = processed_path
630
  except Exception as e:
631
  logger.error(f"[{request_id}] ❌ Audio processing failed: {str(e)}")
632
- logger.debug(f"[{request_id}] Stack trace: {traceback.format_exc()}")
633
- # Clean up temp directory
634
- try:
635
- import shutil
636
- shutil.rmtree(temp_dir)
637
- except:
638
- pass
639
  return jsonify({"error": f"Audio processing failed: {str(e)}"}), 500
640
 
641
  # Transcribe user audio
642
  try:
643
  logger.info(f"[{request_id}] πŸ”„ Transcribing user audio")
644
- # Process audio for ASR
645
  inputs = asr_processor(
646
  user_waveform,
647
  sampling_rate=SAMPLE_RATE,
@@ -650,7 +635,6 @@ def evaluate_pronunciation():
650
  )
651
  inputs = {k: v.to(device) for k, v in inputs.items()}
652
 
653
- # Perform ASR
654
  with torch.no_grad():
655
  logits = asr_model(**inputs).logits
656
  ids = torch.argmax(logits, dim=-1)[0]
@@ -659,37 +643,32 @@ def evaluate_pronunciation():
659
  logger.info(f"[{request_id}] βœ… User transcription: '{user_transcription}'")
660
  except Exception as e:
661
  logger.error(f"[{request_id}] ❌ ASR inference failed: {str(e)}")
662
- # Clean up temp directory
663
- try:
664
- import shutil
665
- shutil.rmtree(temp_dir)
666
- except:
667
- pass
668
  return jsonify({"error": f"ASR inference failed: {str(e)}"}), 500
669
 
670
- # Compare with reference audios
 
671
  results = []
672
  best_score = 0
673
  best_reference = None
674
  best_transcription = None
675
-
676
- logger.info(f"[{request_id}] πŸ”„ Beginning comparison with {len(reference_files)} reference files")
677
-
678
- for ref_idx, ref_file in enumerate(reference_files):
 
 
 
 
 
 
679
  try:
680
- ref_filename = os.path.basename(ref_file)
681
- logger.info(f"[{request_id}] πŸ”„ [{ref_idx+1}/{len(reference_files)}] Processing reference file: {ref_filename}")
682
-
683
- # Load reference audio using torchaudio instead of librosa
684
  ref_waveform, ref_sr = torchaudio.load(ref_file)
685
  if ref_sr != SAMPLE_RATE:
686
- logger.debug(f"[{request_id}] πŸ”„ Resampling reference audio from {ref_sr}Hz to {SAMPLE_RATE}Hz")
687
  ref_waveform = torchaudio.transforms.Resample(ref_sr, SAMPLE_RATE)(ref_waveform)
688
  ref_waveform = ref_waveform.squeeze().numpy()
689
- logger.debug(f"[{request_id}] βœ… Reference audio loaded: {len(ref_waveform)} samples")
690
-
691
  # Transcribe reference audio
692
- logger.debug(f"[{request_id}] πŸ”„ Transcribing reference audio: {ref_filename}")
693
  inputs = asr_processor(
694
  ref_waveform,
695
  sampling_rate=SAMPLE_RATE,
@@ -697,50 +676,61 @@ def evaluate_pronunciation():
697
  language=lang_code
698
  )
699
  inputs = {k: v.to(device) for k, v in inputs.items()}
700
-
701
  with torch.no_grad():
702
  logits = asr_model(**inputs).logits
703
  ids = torch.argmax(logits, dim=-1)[0]
704
  ref_transcription = asr_processor.decode(ids)
705
- logger.info(f"[{request_id}] βœ… Reference transcription for {ref_filename}: '{ref_transcription}'")
706
-
707
  # Calculate similarity
708
  similarity = calculate_similarity(user_transcription, ref_transcription)
709
- logger.info(f"[{request_id}] πŸ“Š Similarity with {ref_filename}: {similarity:.2f}%")
710
-
711
- results.append({
 
712
  "reference_file": ref_filename,
713
  "reference_text": ref_transcription,
714
  "similarity_score": similarity
715
- })
716
-
717
- if similarity > best_score:
718
- best_score = similarity
719
- best_reference = ref_filename
720
- best_transcription = ref_transcription
721
- logger.info(f"[{request_id}] πŸ“Š New best match: {best_reference} with score {best_score:.2f}%")
722
-
723
- # Add this early exit condition here
724
- if similarity > 80.0: # If we find a really good match
725
- logger.info(f"[{request_id}] 🏁 Found excellent match (>80%). Stopping evaluation early.")
726
- break # Exit the loop early
727
-
728
  except Exception as e:
729
- logger.error(f"[{request_id}] ❌ Error processing reference audio {ref_file}: {str(e)}")
730
- logger.debug(f"[{request_id}] Stack trace: {traceback.format_exc()}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
731
 
732
  # Clean up temp files
733
  try:
734
  import shutil
735
  shutil.rmtree(temp_dir)
736
- logger.debug(f"[{request_id}] 🧹 Cleaned up temporary directory: {temp_dir}")
737
  except Exception as e:
738
  logger.warning(f"[{request_id}] ⚠️ Failed to clean up temp files: {str(e)}")
739
 
740
- # Enhanced feedback based on score range
741
  is_correct = best_score >= 70.0
742
- feedback = ""
743
-
744
  if best_score >= 90.0:
745
  feedback = "Perfect pronunciation! Excellent job!"
746
  elif best_score >= 80.0:
@@ -772,8 +762,17 @@ def evaluate_pronunciation():
772
  except Exception as e:
773
  logger.error(f"[{request_id}] ❌ Unhandled exception in evaluation endpoint: {str(e)}")
774
  logger.debug(f"[{request_id}] Stack trace: {traceback.format_exc()}")
 
 
 
 
 
 
 
 
775
  return jsonify({"error": f"Internal server error: {str(e)}"}), 500
776
 
 
777
  @app.route("/upload_reference", methods=["POST"])
778
  def upload_reference_audio():
779
  try:
 
602
  temp_dir = os.path.join(OUTPUT_DIR, f"temp_{request_id}")
603
  os.makedirs(temp_dir, exist_ok=True)
604
 
605
+ # Process user audio
606
  user_audio_path = os.path.join(temp_dir, "user_audio_input.wav")
607
  with open(user_audio_path, 'wb') as f:
608
  f.write(audio_file.read())
609
+
 
 
610
  try:
611
  logger.info(f"[{request_id}] πŸ”„ Processing user audio file")
 
612
  audio = AudioSegment.from_file(user_audio_path)
613
  audio = audio.set_frame_rate(SAMPLE_RATE).set_channels(1)
614
+
 
615
  processed_path = os.path.join(temp_dir, "processed_user_audio.wav")
616
  audio.export(processed_path, format="wav")
617
+
 
 
618
  user_waveform, sr = torchaudio.load(processed_path)
619
  user_waveform = user_waveform.squeeze().numpy()
620
+ logger.info(f"[{request_id}] βœ… User audio processed: {sr}Hz, length: {len(user_waveform)} samples")
621
+
 
622
  user_audio_path = processed_path
623
  except Exception as e:
624
  logger.error(f"[{request_id}] ❌ Audio processing failed: {str(e)}")
 
 
 
 
 
 
 
625
  return jsonify({"error": f"Audio processing failed: {str(e)}"}), 500
626
 
627
  # Transcribe user audio
628
  try:
629
  logger.info(f"[{request_id}] πŸ”„ Transcribing user audio")
 
630
  inputs = asr_processor(
631
  user_waveform,
632
  sampling_rate=SAMPLE_RATE,
 
635
  )
636
  inputs = {k: v.to(device) for k, v in inputs.items()}
637
 
 
638
  with torch.no_grad():
639
  logits = asr_model(**inputs).logits
640
  ids = torch.argmax(logits, dim=-1)[0]
 
643
  logger.info(f"[{request_id}] βœ… User transcription: '{user_transcription}'")
644
  except Exception as e:
645
  logger.error(f"[{request_id}] ❌ ASR inference failed: {str(e)}")
 
 
 
 
 
 
646
  return jsonify({"error": f"ASR inference failed: {str(e)}"}), 500
647
 
648
+ # Process reference files in batches
649
+ batch_size = 2 # Process 2 files at a time - adjust based on your hardware
650
  results = []
651
  best_score = 0
652
  best_reference = None
653
  best_transcription = None
654
+
655
+ # Use this if you want to limit the number of files to process
656
+ max_files_to_check = min(5, len(reference_files)) # Check at most 5 files
657
+ reference_files = reference_files[:max_files_to_check]
658
+
659
+ logger.info(f"[{request_id}] πŸ”„ Processing {len(reference_files)} reference files in batches of {batch_size}")
660
+
661
+ # Function to process a single reference file
662
+ def process_reference_file(ref_file):
663
+ ref_filename = os.path.basename(ref_file)
664
  try:
665
+ # Load and resample reference audio
 
 
 
666
  ref_waveform, ref_sr = torchaudio.load(ref_file)
667
  if ref_sr != SAMPLE_RATE:
 
668
  ref_waveform = torchaudio.transforms.Resample(ref_sr, SAMPLE_RATE)(ref_waveform)
669
  ref_waveform = ref_waveform.squeeze().numpy()
670
+
 
671
  # Transcribe reference audio
 
672
  inputs = asr_processor(
673
  ref_waveform,
674
  sampling_rate=SAMPLE_RATE,
 
676
  language=lang_code
677
  )
678
  inputs = {k: v.to(device) for k, v in inputs.items()}
679
+
680
  with torch.no_grad():
681
  logits = asr_model(**inputs).logits
682
  ids = torch.argmax(logits, dim=-1)[0]
683
  ref_transcription = asr_processor.decode(ids)
684
+
 
685
  # Calculate similarity
686
  similarity = calculate_similarity(user_transcription, ref_transcription)
687
+
688
+ logger.info(f"[{request_id}] πŸ“Š Similarity with {ref_filename}: {similarity:.2f}%, transcription: '{ref_transcription}'")
689
+
690
+ return {
691
  "reference_file": ref_filename,
692
  "reference_text": ref_transcription,
693
  "similarity_score": similarity
694
+ }
 
 
 
 
 
 
 
 
 
 
 
 
695
  except Exception as e:
696
+ logger.error(f"[{request_id}] ❌ Error processing {ref_filename}: {str(e)}")
697
+ return {
698
+ "reference_file": ref_filename,
699
+ "reference_text": "Error",
700
+ "similarity_score": 0,
701
+ "error": str(e)
702
+ }
703
+
704
+ # Process files in batches using ThreadPoolExecutor
705
+ from concurrent.futures import ThreadPoolExecutor
706
+
707
+ with ThreadPoolExecutor(max_workers=batch_size) as executor:
708
+ batch_results = list(executor.map(process_reference_file, reference_files))
709
+ results.extend(batch_results)
710
+
711
+ # Find the best result
712
+ for result in batch_results:
713
+ if result["similarity_score"] > best_score:
714
+ best_score = result["similarity_score"]
715
+ best_reference = result["reference_file"]
716
+ best_transcription = result["reference_text"]
717
+
718
+ # Exit early if we found a very good match (optional)
719
+ if best_score > 80.0:
720
+ logger.info(f"[{request_id}] 🏁 Found excellent match: {best_score:.2f}%")
721
+ break
722
 
723
  # Clean up temp files
724
  try:
725
  import shutil
726
  shutil.rmtree(temp_dir)
727
+ logger.debug(f"[{request_id}] 🧹 Cleaned up temporary directory")
728
  except Exception as e:
729
  logger.warning(f"[{request_id}] ⚠️ Failed to clean up temp files: {str(e)}")
730
 
731
+ # Determine feedback based on score
732
  is_correct = best_score >= 70.0
733
+
 
734
  if best_score >= 90.0:
735
  feedback = "Perfect pronunciation! Excellent job!"
736
  elif best_score >= 80.0:
 
762
  except Exception as e:
763
  logger.error(f"[{request_id}] ❌ Unhandled exception in evaluation endpoint: {str(e)}")
764
  logger.debug(f"[{request_id}] Stack trace: {traceback.format_exc()}")
765
+
766
+ # Clean up on error
767
+ try:
768
+ import shutil
769
+ shutil.rmtree(temp_dir)
770
+ except:
771
+ pass
772
+
773
  return jsonify({"error": f"Internal server error: {str(e)}"}), 500
774
 
775
+
776
  @app.route("/upload_reference", methods=["POST"])
777
  def upload_reference_audio():
778
  try: