Coco-18 commited on
Commit
95121a2
Β·
verified Β·
1 Parent(s): e13b2ed

Update evaluate.py

Browse files
Files changed (1) hide show
  1. evaluate.py +82 -24
evaluate.py CHANGED
@@ -20,6 +20,9 @@ from translator import get_asr_model, get_asr_processor, LANGUAGE_CODES
20
  # Configure logging
21
  logger = logging.getLogger("speech_api")
22
 
 
 
 
23
  def calculate_similarity(text1, text2):
24
  """Calculate text similarity percentage."""
25
  def clean_text(text):
@@ -309,9 +312,9 @@ def handle_upload_reference(request, reference_dir, sample_rate):
309
  return jsonify({"error": f"Internal server error: {str(e)}"}), 500
310
 
311
  def handle_evaluation_request(request, reference_dir, output_dir, sample_rate):
312
- """Handle pronunciation evaluation requests with optimized parallel comparison"""
313
- request_id = f"req-{id(request)}" # Create unique ID for this request
314
- logger.info(f"[{request_id}] πŸ†• Starting new pronunciation evaluation request")
315
 
316
  temp_dir = None
317
 
@@ -324,6 +327,7 @@ def handle_evaluation_request(request, reference_dir, output_dir, sample_rate):
324
  return jsonify({"error": "ASR model not available"}), 503
325
 
326
  try:
 
327
  if "audio" not in request.files:
328
  logger.warning(f"[{request_id}] ⚠️ Evaluation request missing audio file")
329
  return jsonify({"error": "No audio file uploaded"}), 400
@@ -337,6 +341,19 @@ def handle_evaluation_request(request, reference_dir, output_dir, sample_rate):
337
  logger.warning(f"[{request_id}] ⚠️ No reference locator provided")
338
  return jsonify({"error": "Reference locator is required"}), 400
339
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
  # Construct full reference directory path
341
  reference_dir_path = os.path.join(reference_dir, reference_locator)
342
  logger.info(f"[{request_id}] πŸ“ Reference directory path: {reference_dir_path}")
@@ -366,7 +383,7 @@ def handle_evaluation_request(request, reference_dir, output_dir, sample_rate):
366
  }), 404
367
 
368
  lang_code = LANGUAGE_CODES.get(language, language)
369
- logger.info(f"[{request_id}] πŸ”„ Evaluating pronunciation for reference: {reference_locator} with language code: {lang_code}")
370
 
371
  # Create a request-specific temp directory to avoid conflicts
372
  temp_dir = os.path.join(output_dir, f"temp_{request_id}")
@@ -375,7 +392,7 @@ def handle_evaluation_request(request, reference_dir, output_dir, sample_rate):
375
  # Process user audio
376
  user_audio_path = os.path.join(temp_dir, "user_audio_input.wav")
377
  with open(user_audio_path, 'wb') as f:
378
- f.write(audio_file.read())
379
 
380
  try:
381
  logger.info(f"[{request_id}] πŸ”„ Processing user audio file")
@@ -415,18 +432,23 @@ def handle_evaluation_request(request, reference_dir, output_dir, sample_rate):
415
  logger.error(f"[{request_id}] ❌ ASR inference failed: {str(e)}")
416
  return jsonify({"error": f"ASR inference failed: {str(e)}"}), 500
417
 
418
- # OPTIMIZATION: Process all reference files at once
419
  import multiprocessing
 
420
 
421
- # Determine optimal number of workers based on CPU count
422
- max_workers = min(multiprocessing.cpu_count(), len(reference_files))
423
- results = []
424
-
425
- # Use this if you want to limit the number of files to process
426
- max_files_to_check = min(len(reference_files), 10) # Increased from 5 to 10
427
- reference_files = reference_files[:max_files_to_check]
 
 
 
 
428
 
429
- logger.info(f"[{request_id}] πŸ”„ Processing {len(reference_files)} reference files in parallel with {max_workers} workers")
430
 
431
  # Function to process a single reference file
432
  def process_reference_file(ref_file):
@@ -438,8 +460,7 @@ def handle_evaluation_request(request, reference_dir, output_dir, sample_rate):
438
  ref_waveform = torchaudio.transforms.Resample(ref_sr, sample_rate)(ref_waveform)
439
  ref_waveform = ref_waveform.squeeze().numpy()
440
 
441
- # Transcribe reference audio - use the local asr_model and asr_processor
442
- # Remove language parameter if causing warnings
443
  inputs = asr_processor(
444
  ref_waveform,
445
  sampling_rate=sample_rate,
@@ -472,20 +493,40 @@ def handle_evaluation_request(request, reference_dir, output_dir, sample_rate):
472
  "error": str(e)
473
  }
474
 
475
- # OPTIMIZATION: Process all files simultaneously using ThreadPoolExecutor
476
  with ThreadPoolExecutor(max_workers=max_workers) as executor:
477
- results = list(executor.map(process_reference_file, reference_files))
478
 
479
- # Find the best result after all processing is complete
480
  best_score = 0
481
  best_reference = None
482
  best_transcription = None
483
 
484
- for result in results:
485
  if result["similarity_score"] > best_score:
486
  best_score = result["similarity_score"]
487
  best_reference = result["reference_file"]
488
  best_transcription = result["reference_text"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
489
 
490
  # Clean up temp files
491
  try:
@@ -514,17 +555,34 @@ def handle_evaluation_request(request, reference_dir, output_dir, sample_rate):
514
  logger.info(f"[{request_id}] βœ… Evaluation complete")
515
 
516
  # Sort results by score descending
517
- results.sort(key=lambda x: x["similarity_score"], reverse=True)
518
-
519
- return jsonify({
 
520
  "is_correct": is_correct,
521
  "score": best_score,
522
  "feedback": feedback,
523
  "user_transcription": user_transcription,
524
  "best_reference_transcription": best_transcription,
525
  "reference_locator": reference_locator,
526
- "details": results
 
 
 
527
  })
 
 
 
 
 
 
 
 
 
 
 
 
 
528
 
529
  except Exception as e:
530
  logger.error(f"[{request_id}] ❌ Unhandled exception in evaluation endpoint: {str(e)}")
 
20
  # Configure logging
21
  logger = logging.getLogger("speech_api")
22
 
23
+ if not hasattr(handle_evaluation_request, 'cache'):
24
+ handle_evaluation_request.cache = {}
25
+
26
  def calculate_similarity(text1, text2):
27
  """Calculate text similarity percentage."""
28
  def clean_text(text):
 
312
  return jsonify({"error": f"Internal server error: {str(e)}"}), 500
313
 
314
  def handle_evaluation_request(request, reference_dir, output_dir, sample_rate):
315
+ """Handle pronunciation evaluation requests with speed optimizations"""
316
+ request_id = f"req-{id(request)}"
317
+ logger.info(f"[{request_id}] πŸ†• Starting pronunciation evaluation request")
318
 
319
  temp_dir = None
320
 
 
327
  return jsonify({"error": "ASR model not available"}), 503
328
 
329
  try:
330
+ # OPTIMIZATION 1: Check cache first for identical audio
331
  if "audio" not in request.files:
332
  logger.warning(f"[{request_id}] ⚠️ Evaluation request missing audio file")
333
  return jsonify({"error": "No audio file uploaded"}), 400
 
341
  logger.warning(f"[{request_id}] ⚠️ No reference locator provided")
342
  return jsonify({"error": "Reference locator is required"}), 400
343
 
344
+ # OPTIMIZATION 2: Simple caching based on audio content hash + reference_locator
345
+ audio_content = audio_file.read()
346
+ audio_file.seek(0) # Reset file pointer after reading
347
+
348
+ import hashlib
349
+ audio_hash = hashlib.md5(audio_content).hexdigest()
350
+ cache_key = f"{audio_hash}_{reference_locator}_{language}"
351
+
352
+ # Check in-memory cache (define EVALUATION_CACHE at module level)
353
+ if hasattr(handle_evaluation_request, 'cache') and cache_key in handle_evaluation_request.cache:
354
+ logger.info(f"[{request_id}] βœ… Using cached evaluation result")
355
+ return handle_evaluation_request.cache[cache_key]
356
+
357
  # Construct full reference directory path
358
  reference_dir_path = os.path.join(reference_dir, reference_locator)
359
  logger.info(f"[{request_id}] πŸ“ Reference directory path: {reference_dir_path}")
 
383
  }), 404
384
 
385
  lang_code = LANGUAGE_CODES.get(language, language)
386
+ logger.info(f"[{request_id}] πŸ”„ Evaluating pronunciation for reference: {reference_locator}")
387
 
388
  # Create a request-specific temp directory to avoid conflicts
389
  temp_dir = os.path.join(output_dir, f"temp_{request_id}")
 
392
  # Process user audio
393
  user_audio_path = os.path.join(temp_dir, "user_audio_input.wav")
394
  with open(user_audio_path, 'wb') as f:
395
+ f.write(audio_content) # Use the content we already read
396
 
397
  try:
398
  logger.info(f"[{request_id}] πŸ”„ Processing user audio file")
 
432
  logger.error(f"[{request_id}] ❌ ASR inference failed: {str(e)}")
433
  return jsonify({"error": f"ASR inference failed: {str(e)}"}), 500
434
 
435
+ # OPTIMIZATION 3: Use a smaller sample of reference files
436
  import multiprocessing
437
+ import random
438
 
439
+ # OPTIMIZATION 4: Limit to just a few files for initial comparison
440
+ # If we have many reference files, randomly sample some for quick evaluation
441
+ if len(reference_files) > 3:
442
+ # Randomly select 3 files for faster comparison
443
+ reference_files_sample = random.sample(reference_files, 3)
444
+ else:
445
+ reference_files_sample = reference_files
446
+
447
+ # Determine optimal number of workers based on CPU count (but keep it small)
448
+ max_workers = min(multiprocessing.cpu_count(), len(reference_files_sample), 3)
449
+ initial_results = []
450
 
451
+ logger.info(f"[{request_id}] πŸ”„ Quick scan: processing {len(reference_files_sample)} reference files")
452
 
453
  # Function to process a single reference file
454
  def process_reference_file(ref_file):
 
460
  ref_waveform = torchaudio.transforms.Resample(ref_sr, sample_rate)(ref_waveform)
461
  ref_waveform = ref_waveform.squeeze().numpy()
462
 
463
+ # Transcribe reference audio
 
464
  inputs = asr_processor(
465
  ref_waveform,
466
  sampling_rate=sample_rate,
 
493
  "error": str(e)
494
  }
495
 
496
+ # Process the sample files in parallel
497
  with ThreadPoolExecutor(max_workers=max_workers) as executor:
498
+ initial_results = list(executor.map(process_reference_file, reference_files_sample))
499
 
500
+ # Find the best result from the initial sample
501
  best_score = 0
502
  best_reference = None
503
  best_transcription = None
504
 
505
+ for result in initial_results:
506
  if result["similarity_score"] > best_score:
507
  best_score = result["similarity_score"]
508
  best_reference = result["reference_file"]
509
  best_transcription = result["reference_text"]
510
+
511
+ # OPTIMIZATION 5: If we already found a very good match, don't process more files
512
+ all_results = initial_results.copy()
513
+ remaining_files = [f for f in reference_files if f not in reference_files_sample]
514
+
515
+ # Only process more files if our best score isn't already very good
516
+ if best_score < 80.0 and remaining_files:
517
+ logger.info(f"[{request_id}] πŸ”„ Score {best_score:.2f}% not high enough, checking {len(remaining_files)} more references")
518
+
519
+ # Process remaining files
520
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
521
+ additional_results = list(executor.map(process_reference_file, remaining_files[:5])) # Process max 5 more
522
+ all_results.extend(additional_results)
523
+
524
+ # Update best result if we found a better one
525
+ for result in additional_results:
526
+ if result["similarity_score"] > best_score:
527
+ best_score = result["similarity_score"]
528
+ best_reference = result["reference_file"]
529
+ best_transcription = result["reference_text"]
530
 
531
  # Clean up temp files
532
  try:
 
555
  logger.info(f"[{request_id}] βœ… Evaluation complete")
556
 
557
  # Sort results by score descending
558
+ all_results.sort(key=lambda x: x["similarity_score"], reverse=True)
559
+
560
+ # Create response
561
+ response = jsonify({
562
  "is_correct": is_correct,
563
  "score": best_score,
564
  "feedback": feedback,
565
  "user_transcription": user_transcription,
566
  "best_reference_transcription": best_transcription,
567
  "reference_locator": reference_locator,
568
+ "details": all_results,
569
+ "total_references_compared": len(all_results),
570
+ "total_available_references": len(reference_files),
571
+ "quick_evaluation": True
572
  })
573
+
574
+ # OPTIMIZATION 6: Cache the result for future requests
575
+ if not hasattr(handle_evaluation_request, 'cache'):
576
+ handle_evaluation_request.cache = {}
577
+
578
+ # Store in cache (limit cache size to avoid memory issues)
579
+ MAX_CACHE_SIZE = 50
580
+ handle_evaluation_request.cache[cache_key] = response
581
+ if len(handle_evaluation_request.cache) > MAX_CACHE_SIZE:
582
+ # Remove oldest entry (simplified approach)
583
+ handle_evaluation_request.cache.pop(next(iter(handle_evaluation_request.cache)))
584
+
585
+ return response
586
 
587
  except Exception as e:
588
  logger.error(f"[{request_id}] ❌ Unhandled exception in evaluation endpoint: {str(e)}")