Update app.py
Browse files
app.py
CHANGED
@@ -602,46 +602,31 @@ def evaluate_pronunciation():
|
|
602 |
temp_dir = os.path.join(OUTPUT_DIR, f"temp_{request_id}")
|
603 |
os.makedirs(temp_dir, exist_ok=True)
|
604 |
|
605 |
-
#
|
606 |
user_audio_path = os.path.join(temp_dir, "user_audio_input.wav")
|
607 |
with open(user_audio_path, 'wb') as f:
|
608 |
f.write(audio_file.read())
|
609 |
-
|
610 |
-
|
611 |
-
# Convert to WAV if necessary and ensure correct format
|
612 |
try:
|
613 |
logger.info(f"[{request_id}] π Processing user audio file")
|
614 |
-
# First try using pydub for consistent processing
|
615 |
audio = AudioSegment.from_file(user_audio_path)
|
616 |
audio = audio.set_frame_rate(SAMPLE_RATE).set_channels(1)
|
617 |
-
|
618 |
-
# Save processed audio
|
619 |
processed_path = os.path.join(temp_dir, "processed_user_audio.wav")
|
620 |
audio.export(processed_path, format="wav")
|
621 |
-
|
622 |
-
|
623 |
-
# Load the processed audio for ASR
|
624 |
user_waveform, sr = torchaudio.load(processed_path)
|
625 |
user_waveform = user_waveform.squeeze().numpy()
|
626 |
-
logger.info(f"[{request_id}] β
User audio processed
|
627 |
-
|
628 |
-
# Update user_audio_path to processed file
|
629 |
user_audio_path = processed_path
|
630 |
except Exception as e:
|
631 |
logger.error(f"[{request_id}] β Audio processing failed: {str(e)}")
|
632 |
-
logger.debug(f"[{request_id}] Stack trace: {traceback.format_exc()}")
|
633 |
-
# Clean up temp directory
|
634 |
-
try:
|
635 |
-
import shutil
|
636 |
-
shutil.rmtree(temp_dir)
|
637 |
-
except:
|
638 |
-
pass
|
639 |
return jsonify({"error": f"Audio processing failed: {str(e)}"}), 500
|
640 |
|
641 |
# Transcribe user audio
|
642 |
try:
|
643 |
logger.info(f"[{request_id}] π Transcribing user audio")
|
644 |
-
# Process audio for ASR
|
645 |
inputs = asr_processor(
|
646 |
user_waveform,
|
647 |
sampling_rate=SAMPLE_RATE,
|
@@ -650,7 +635,6 @@ def evaluate_pronunciation():
|
|
650 |
)
|
651 |
inputs = {k: v.to(device) for k, v in inputs.items()}
|
652 |
|
653 |
-
# Perform ASR
|
654 |
with torch.no_grad():
|
655 |
logits = asr_model(**inputs).logits
|
656 |
ids = torch.argmax(logits, dim=-1)[0]
|
@@ -659,37 +643,32 @@ def evaluate_pronunciation():
|
|
659 |
logger.info(f"[{request_id}] β
User transcription: '{user_transcription}'")
|
660 |
except Exception as e:
|
661 |
logger.error(f"[{request_id}] β ASR inference failed: {str(e)}")
|
662 |
-
# Clean up temp directory
|
663 |
-
try:
|
664 |
-
import shutil
|
665 |
-
shutil.rmtree(temp_dir)
|
666 |
-
except:
|
667 |
-
pass
|
668 |
return jsonify({"error": f"ASR inference failed: {str(e)}"}), 500
|
669 |
|
670 |
-
#
|
|
|
671 |
results = []
|
672 |
best_score = 0
|
673 |
best_reference = None
|
674 |
best_transcription = None
|
675 |
-
|
676 |
-
|
677 |
-
|
678 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
679 |
try:
|
680 |
-
|
681 |
-
logger.info(f"[{request_id}] π [{ref_idx+1}/{len(reference_files)}] Processing reference file: {ref_filename}")
|
682 |
-
|
683 |
-
# Load reference audio using torchaudio instead of librosa
|
684 |
ref_waveform, ref_sr = torchaudio.load(ref_file)
|
685 |
if ref_sr != SAMPLE_RATE:
|
686 |
-
logger.debug(f"[{request_id}] π Resampling reference audio from {ref_sr}Hz to {SAMPLE_RATE}Hz")
|
687 |
ref_waveform = torchaudio.transforms.Resample(ref_sr, SAMPLE_RATE)(ref_waveform)
|
688 |
ref_waveform = ref_waveform.squeeze().numpy()
|
689 |
-
|
690 |
-
|
691 |
# Transcribe reference audio
|
692 |
-
logger.debug(f"[{request_id}] π Transcribing reference audio: {ref_filename}")
|
693 |
inputs = asr_processor(
|
694 |
ref_waveform,
|
695 |
sampling_rate=SAMPLE_RATE,
|
@@ -697,50 +676,61 @@ def evaluate_pronunciation():
|
|
697 |
language=lang_code
|
698 |
)
|
699 |
inputs = {k: v.to(device) for k, v in inputs.items()}
|
700 |
-
|
701 |
with torch.no_grad():
|
702 |
logits = asr_model(**inputs).logits
|
703 |
ids = torch.argmax(logits, dim=-1)[0]
|
704 |
ref_transcription = asr_processor.decode(ids)
|
705 |
-
|
706 |
-
|
707 |
# Calculate similarity
|
708 |
similarity = calculate_similarity(user_transcription, ref_transcription)
|
709 |
-
|
710 |
-
|
711 |
-
|
|
|
712 |
"reference_file": ref_filename,
|
713 |
"reference_text": ref_transcription,
|
714 |
"similarity_score": similarity
|
715 |
-
}
|
716 |
-
|
717 |
-
if similarity > best_score:
|
718 |
-
best_score = similarity
|
719 |
-
best_reference = ref_filename
|
720 |
-
best_transcription = ref_transcription
|
721 |
-
logger.info(f"[{request_id}] π New best match: {best_reference} with score {best_score:.2f}%")
|
722 |
-
|
723 |
-
# Add this early exit condition here
|
724 |
-
if similarity > 80.0: # If we find a really good match
|
725 |
-
logger.info(f"[{request_id}] π Found excellent match (>80%). Stopping evaluation early.")
|
726 |
-
break # Exit the loop early
|
727 |
-
|
728 |
except Exception as e:
|
729 |
-
logger.error(f"[{request_id}] β Error processing
|
730 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
731 |
|
732 |
# Clean up temp files
|
733 |
try:
|
734 |
import shutil
|
735 |
shutil.rmtree(temp_dir)
|
736 |
-
logger.debug(f"[{request_id}] π§Ή Cleaned up temporary directory
|
737 |
except Exception as e:
|
738 |
logger.warning(f"[{request_id}] β οΈ Failed to clean up temp files: {str(e)}")
|
739 |
|
740 |
-
#
|
741 |
is_correct = best_score >= 70.0
|
742 |
-
|
743 |
-
|
744 |
if best_score >= 90.0:
|
745 |
feedback = "Perfect pronunciation! Excellent job!"
|
746 |
elif best_score >= 80.0:
|
@@ -772,8 +762,17 @@ def evaluate_pronunciation():
|
|
772 |
except Exception as e:
|
773 |
logger.error(f"[{request_id}] β Unhandled exception in evaluation endpoint: {str(e)}")
|
774 |
logger.debug(f"[{request_id}] Stack trace: {traceback.format_exc()}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
775 |
return jsonify({"error": f"Internal server error: {str(e)}"}), 500
|
776 |
|
|
|
777 |
@app.route("/upload_reference", methods=["POST"])
|
778 |
def upload_reference_audio():
|
779 |
try:
|
|
|
602 |
temp_dir = os.path.join(OUTPUT_DIR, f"temp_{request_id}")
|
603 |
os.makedirs(temp_dir, exist_ok=True)
|
604 |
|
605 |
+
# Process user audio
|
606 |
user_audio_path = os.path.join(temp_dir, "user_audio_input.wav")
|
607 |
with open(user_audio_path, 'wb') as f:
|
608 |
f.write(audio_file.read())
|
609 |
+
|
|
|
|
|
610 |
try:
|
611 |
logger.info(f"[{request_id}] π Processing user audio file")
|
|
|
612 |
audio = AudioSegment.from_file(user_audio_path)
|
613 |
audio = audio.set_frame_rate(SAMPLE_RATE).set_channels(1)
|
614 |
+
|
|
|
615 |
processed_path = os.path.join(temp_dir, "processed_user_audio.wav")
|
616 |
audio.export(processed_path, format="wav")
|
617 |
+
|
|
|
|
|
618 |
user_waveform, sr = torchaudio.load(processed_path)
|
619 |
user_waveform = user_waveform.squeeze().numpy()
|
620 |
+
logger.info(f"[{request_id}] β
User audio processed: {sr}Hz, length: {len(user_waveform)} samples")
|
621 |
+
|
|
|
622 |
user_audio_path = processed_path
|
623 |
except Exception as e:
|
624 |
logger.error(f"[{request_id}] β Audio processing failed: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
625 |
return jsonify({"error": f"Audio processing failed: {str(e)}"}), 500
|
626 |
|
627 |
# Transcribe user audio
|
628 |
try:
|
629 |
logger.info(f"[{request_id}] π Transcribing user audio")
|
|
|
630 |
inputs = asr_processor(
|
631 |
user_waveform,
|
632 |
sampling_rate=SAMPLE_RATE,
|
|
|
635 |
)
|
636 |
inputs = {k: v.to(device) for k, v in inputs.items()}
|
637 |
|
|
|
638 |
with torch.no_grad():
|
639 |
logits = asr_model(**inputs).logits
|
640 |
ids = torch.argmax(logits, dim=-1)[0]
|
|
|
643 |
logger.info(f"[{request_id}] β
User transcription: '{user_transcription}'")
|
644 |
except Exception as e:
|
645 |
logger.error(f"[{request_id}] β ASR inference failed: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
646 |
return jsonify({"error": f"ASR inference failed: {str(e)}"}), 500
|
647 |
|
648 |
+
# Process reference files in batches
|
649 |
+
batch_size = 2 # Process 2 files at a time - adjust based on your hardware
|
650 |
results = []
|
651 |
best_score = 0
|
652 |
best_reference = None
|
653 |
best_transcription = None
|
654 |
+
|
655 |
+
# Use this if you want to limit the number of files to process
|
656 |
+
max_files_to_check = min(5, len(reference_files)) # Check at most 5 files
|
657 |
+
reference_files = reference_files[:max_files_to_check]
|
658 |
+
|
659 |
+
logger.info(f"[{request_id}] π Processing {len(reference_files)} reference files in batches of {batch_size}")
|
660 |
+
|
661 |
+
# Function to process a single reference file
|
662 |
+
def process_reference_file(ref_file):
|
663 |
+
ref_filename = os.path.basename(ref_file)
|
664 |
try:
|
665 |
+
# Load and resample reference audio
|
|
|
|
|
|
|
666 |
ref_waveform, ref_sr = torchaudio.load(ref_file)
|
667 |
if ref_sr != SAMPLE_RATE:
|
|
|
668 |
ref_waveform = torchaudio.transforms.Resample(ref_sr, SAMPLE_RATE)(ref_waveform)
|
669 |
ref_waveform = ref_waveform.squeeze().numpy()
|
670 |
+
|
|
|
671 |
# Transcribe reference audio
|
|
|
672 |
inputs = asr_processor(
|
673 |
ref_waveform,
|
674 |
sampling_rate=SAMPLE_RATE,
|
|
|
676 |
language=lang_code
|
677 |
)
|
678 |
inputs = {k: v.to(device) for k, v in inputs.items()}
|
679 |
+
|
680 |
with torch.no_grad():
|
681 |
logits = asr_model(**inputs).logits
|
682 |
ids = torch.argmax(logits, dim=-1)[0]
|
683 |
ref_transcription = asr_processor.decode(ids)
|
684 |
+
|
|
|
685 |
# Calculate similarity
|
686 |
similarity = calculate_similarity(user_transcription, ref_transcription)
|
687 |
+
|
688 |
+
logger.info(f"[{request_id}] π Similarity with {ref_filename}: {similarity:.2f}%, transcription: '{ref_transcription}'")
|
689 |
+
|
690 |
+
return {
|
691 |
"reference_file": ref_filename,
|
692 |
"reference_text": ref_transcription,
|
693 |
"similarity_score": similarity
|
694 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
695 |
except Exception as e:
|
696 |
+
logger.error(f"[{request_id}] β Error processing {ref_filename}: {str(e)}")
|
697 |
+
return {
|
698 |
+
"reference_file": ref_filename,
|
699 |
+
"reference_text": "Error",
|
700 |
+
"similarity_score": 0,
|
701 |
+
"error": str(e)
|
702 |
+
}
|
703 |
+
|
704 |
+
# Process files in batches using ThreadPoolExecutor
|
705 |
+
from concurrent.futures import ThreadPoolExecutor
|
706 |
+
|
707 |
+
with ThreadPoolExecutor(max_workers=batch_size) as executor:
|
708 |
+
batch_results = list(executor.map(process_reference_file, reference_files))
|
709 |
+
results.extend(batch_results)
|
710 |
+
|
711 |
+
# Find the best result
|
712 |
+
for result in batch_results:
|
713 |
+
if result["similarity_score"] > best_score:
|
714 |
+
best_score = result["similarity_score"]
|
715 |
+
best_reference = result["reference_file"]
|
716 |
+
best_transcription = result["reference_text"]
|
717 |
+
|
718 |
+
# Exit early if we found a very good match (optional)
|
719 |
+
if best_score > 80.0:
|
720 |
+
logger.info(f"[{request_id}] π Found excellent match: {best_score:.2f}%")
|
721 |
+
break
|
722 |
|
723 |
# Clean up temp files
|
724 |
try:
|
725 |
import shutil
|
726 |
shutil.rmtree(temp_dir)
|
727 |
+
logger.debug(f"[{request_id}] π§Ή Cleaned up temporary directory")
|
728 |
except Exception as e:
|
729 |
logger.warning(f"[{request_id}] β οΈ Failed to clean up temp files: {str(e)}")
|
730 |
|
731 |
+
# Determine feedback based on score
|
732 |
is_correct = best_score >= 70.0
|
733 |
+
|
|
|
734 |
if best_score >= 90.0:
|
735 |
feedback = "Perfect pronunciation! Excellent job!"
|
736 |
elif best_score >= 80.0:
|
|
|
762 |
except Exception as e:
|
763 |
logger.error(f"[{request_id}] β Unhandled exception in evaluation endpoint: {str(e)}")
|
764 |
logger.debug(f"[{request_id}] Stack trace: {traceback.format_exc()}")
|
765 |
+
|
766 |
+
# Clean up on error
|
767 |
+
try:
|
768 |
+
import shutil
|
769 |
+
shutil.rmtree(temp_dir)
|
770 |
+
except:
|
771 |
+
pass
|
772 |
+
|
773 |
return jsonify({"error": f"Internal server error: {str(e)}"}), 500
|
774 |
|
775 |
+
|
776 |
@app.route("/upload_reference", methods=["POST"])
|
777 |
def upload_reference_audio():
|
778 |
try:
|