Coco-18 commited on
Commit
da8916e
Β·
verified Β·
1 Parent(s): 0812080

Update evaluate.py

Browse files
Files changed (1) hide show
  1. evaluate.py +369 -340
evaluate.py CHANGED
@@ -1,341 +1,370 @@
1
- # evaluate.py - Handles evaluation and comparing tasks
2
-
3
- import os
4
- import glob
5
- import logging
6
- import traceback
7
- import tempfile
8
- import shutil
9
- from difflib import SequenceMatcher
10
- import torch
11
- import torchaudio
12
- from pydub import AudioSegment
13
- from flask import jsonify
14
- from werkzeug.utils import secure_filename
15
- from concurrent.futures import ThreadPoolExecutor
16
-
17
- # Import necessary functions from translator.py
18
- from translator import asr_model, asr_processor, LANGUAGE_CODES
19
-
20
- # Configure logging
21
- logger = logging.getLogger("speech_api")
22
-
23
- def calculate_similarity(text1, text2):
24
- """Calculate text similarity percentage."""
25
- def clean_text(text):
26
- return text.lower()
27
-
28
- clean1 = clean_text(text1)
29
- clean2 = clean_text(text2)
30
-
31
- matcher = SequenceMatcher(None, clean1, clean2)
32
- return matcher.ratio() * 100
33
-
34
- def init_reference_audio(reference_dir, output_dir):
35
- try:
36
- # Create the output directory first
37
- os.makedirs(output_dir, exist_ok=True)
38
- logger.info(f"πŸ“ Created output directory: {output_dir}")
39
-
40
- # Check if the reference audio directory exists in the repository
41
- if os.path.exists(reference_dir):
42
- logger.info(f"βœ… Found reference audio directory: {reference_dir}")
43
-
44
- # Log the contents to verify
45
- pattern_dirs = [d for d in os.listdir(reference_dir)
46
- if os.path.isdir(os.path.join(reference_dir, d))]
47
- logger.info(f"πŸ“ Found reference patterns: {pattern_dirs}")
48
-
49
- # Check each pattern directory for wav files
50
- for pattern_dir_name in pattern_dirs:
51
- pattern_path = os.path.join(reference_dir, pattern_dir_name)
52
- wav_files = glob.glob(os.path.join(pattern_path, "*.wav"))
53
- logger.info(f"πŸ“ Found {len(wav_files)} wav files in {pattern_dir_name}")
54
- else:
55
- logger.warning(f"⚠️ Reference audio directory not found: {reference_dir}")
56
- # Create the directory if it doesn't exist
57
- os.makedirs(reference_dir, exist_ok=True)
58
- logger.info(f"πŸ“ Created reference audio directory: {reference_dir}")
59
- except Exception as e:
60
- logger.error(f"❌ Failed to set up reference audio directory: {str(e)}")
61
-
62
- def handle_upload_reference(request, reference_dir, sample_rate):
63
- """Handle upload of reference audio files"""
64
- try:
65
- if "audio" not in request.files:
66
- logger.warning("⚠️ Reference upload missing audio file")
67
- return jsonify({"error": "No audio file uploaded"}), 400
68
-
69
- reference_word = request.form.get("reference_word", "").strip()
70
- if not reference_word:
71
- logger.warning("⚠️ Reference upload missing reference word")
72
- return jsonify({"error": "No reference word provided"}), 400
73
-
74
- # Validate reference word
75
- reference_patterns = [
76
- "mayap_a_abak", "mayap_a_ugtu", "mayap_a_gatpanapun", "mayap_a_bengi",
77
- "komusta_ka", "malaus_ko_pu", "malaus_kayu", "agaganaka_da_ka",
78
- "pagdulapan_da_ka", "kaluguran_da_ka", "dakal_a_salamat", "panapaya_mu_ku"
79
- ]
80
-
81
- if reference_word not in reference_patterns:
82
- logger.warning(f"⚠️ Invalid reference word: {reference_word}")
83
- return jsonify({"error": f"Invalid reference word. Available: {reference_patterns}"}), 400
84
-
85
- # Create directory for reference pattern if it doesn't exist
86
- pattern_dir = os.path.join(reference_dir, reference_word)
87
- os.makedirs(pattern_dir, exist_ok=True)
88
-
89
- # Save the reference audio file
90
- audio_file = request.files["audio"]
91
- file_path = os.path.join(pattern_dir, secure_filename(audio_file.filename))
92
- audio_file.save(file_path)
93
-
94
- # Convert to WAV if not already in that format
95
- if not file_path.lower().endswith('.wav'):
96
- base_path = os.path.splitext(file_path)[0]
97
- wav_path = f"{base_path}.wav"
98
- try:
99
- audio = AudioSegment.from_file(file_path)
100
- audio = audio.set_frame_rate(sample_rate).set_channels(1)
101
- audio.export(wav_path, format="wav")
102
- # Remove original file if conversion successful
103
- os.unlink(file_path)
104
- file_path = wav_path
105
- except Exception as e:
106
- logger.error(f"❌ Reference audio conversion failed: {str(e)}")
107
- return jsonify({"error": f"Audio conversion failed: {str(e)}"}), 500
108
-
109
- logger.info(f"βœ… Reference audio saved successfully for {reference_word}: {file_path}")
110
-
111
- # Count how many references we have now
112
- references = glob.glob(os.path.join(pattern_dir, "*.wav"))
113
- return jsonify({
114
- "message": "Reference audio uploaded successfully",
115
- "reference_word": reference_word,
116
- "file": os.path.basename(file_path),
117
- "total_references": len(references)
118
- })
119
-
120
- except Exception as e:
121
- logger.error(f"❌ Unhandled exception in reference upload: {str(e)}")
122
- logger.debug(f"Stack trace: {traceback.format_exc()}")
123
- return jsonify({"error": f"Internal server error: {str(e)}"}), 500
124
-
125
-
126
- def handle_evaluation_request(request, reference_dir, output_dir, sample_rate):
127
- """Handle pronunciation evaluation requests"""
128
- request_id = f"req-{id(request)}" # Create unique ID for this request
129
- logger.info(f"[{request_id}] πŸ†• Starting new pronunciation evaluation request")
130
-
131
- temp_dir = None
132
-
133
- if asr_model is None or asr_processor is None:
134
- logger.error(f"[{request_id}] ❌ Evaluation endpoint called but ASR models aren't loaded")
135
- return jsonify({"error": "ASR model not available"}), 503
136
-
137
- try:
138
- if "audio" not in request.files:
139
- logger.warning(f"[{request_id}] ⚠️ Evaluation request missing audio file")
140
- return jsonify({"error": "No audio file uploaded"}), 400
141
-
142
- audio_file = request.files["audio"]
143
- reference_locator = request.form.get("reference_locator", "").strip()
144
- language = request.form.get("language", "kapampangan").lower()
145
-
146
- # Validate reference locator
147
- if not reference_locator:
148
- logger.warning(f"[{request_id}] ⚠️ No reference locator provided")
149
- return jsonify({"error": "Reference locator is required"}), 400
150
-
151
- # Construct full reference directory path
152
- reference_dir_path = os.path.join(reference_dir, reference_locator)
153
- logger.info(f"[{request_id}] πŸ“ Reference directory path: {reference_dir_path}")
154
-
155
- if not os.path.exists(reference_dir_path):
156
- logger.warning(f"[{request_id}] ⚠️ Reference directory not found: {reference_dir_path}")
157
- return jsonify({"error": f"Reference audio directory not found: {reference_locator}"}), 404
158
-
159
- reference_files = glob.glob(os.path.join(reference_dir_path, "*.wav"))
160
- logger.info(f"[{request_id}] πŸ“ Found {len(reference_files)} reference files")
161
-
162
- if not reference_files:
163
- logger.warning(f"[{request_id}] ⚠️ No reference audio files found in {reference_dir_path}")
164
- return jsonify({"error": f"No reference audio found for {reference_locator}"}), 404
165
-
166
- lang_code = LANGUAGE_CODES.get(language, language)
167
- logger.info(
168
- f"[{request_id}] πŸ”„ Evaluating pronunciation for reference: {reference_locator} with language code: {lang_code}")
169
-
170
- # Create a request-specific temp directory to avoid conflicts
171
- temp_dir = os.path.join(output_dir, f"temp_{request_id}")
172
- os.makedirs(temp_dir, exist_ok=True)
173
-
174
- # Process user audio
175
- user_audio_path = os.path.join(temp_dir, "user_audio_input.wav")
176
- with open(user_audio_path, 'wb') as f:
177
- f.write(audio_file.read())
178
-
179
- try:
180
- logger.info(f"[{request_id}] πŸ”„ Processing user audio file")
181
- audio = AudioSegment.from_file(user_audio_path)
182
- audio = audio.set_frame_rate(sample_rate).set_channels(1)
183
-
184
- processed_path = os.path.join(temp_dir, "processed_user_audio.wav")
185
- audio.export(processed_path, format="wav")
186
-
187
- user_waveform, sr = torchaudio.load(processed_path)
188
- user_waveform = user_waveform.squeeze().numpy()
189
- logger.info(f"[{request_id}] βœ… User audio processed: {sr}Hz, length: {len(user_waveform)} samples")
190
-
191
- user_audio_path = processed_path
192
- except Exception as e:
193
- logger.error(f"[{request_id}] ❌ Audio processing failed: {str(e)}")
194
- return jsonify({"error": f"Audio processing failed: {str(e)}"}), 500
195
-
196
- # Transcribe user audio
197
- try:
198
- logger.info(f"[{request_id}] πŸ”„ Transcribing user audio")
199
- inputs = asr_processor(
200
- user_waveform,
201
- sampling_rate=sample_rate,
202
- return_tensors="pt",
203
- language=lang_code
204
- )
205
- inputs = {k: v.to(asr_model.device) for k, v in inputs.items()}
206
-
207
- with torch.no_grad():
208
- logits = asr_model(**inputs).logits
209
- ids = torch.argmax(logits, dim=-1)[0]
210
- user_transcription = asr_processor.decode(ids)
211
-
212
- logger.info(f"[{request_id}] βœ… User transcription: '{user_transcription}'")
213
- except Exception as e:
214
- logger.error(f"[{request_id}] ❌ ASR inference failed: {str(e)}")
215
- return jsonify({"error": f"ASR inference failed: {str(e)}"}), 500
216
-
217
- # Process reference files in batches
218
- batch_size = 2 # Process 2 files at a time - adjust based on your hardware
219
- results = []
220
- best_score = 0
221
- best_reference = None
222
- best_transcription = None
223
-
224
- # Use this if you want to limit the number of files to process
225
- max_files_to_check = min(5, len(reference_files)) # Check at most 5 files
226
- reference_files = reference_files[:max_files_to_check]
227
-
228
- logger.info(f"[{request_id}] πŸ”„ Processing {len(reference_files)} reference files in batches of {batch_size}")
229
-
230
- # Function to process a single reference file
231
- def process_reference_file(ref_file):
232
- ref_filename = os.path.basename(ref_file)
233
- try:
234
- # Load and resample reference audio
235
- ref_waveform, ref_sr = torchaudio.load(ref_file)
236
- if ref_sr != sample_rate:
237
- ref_waveform = torchaudio.transforms.Resample(ref_sr, sample_rate)(ref_waveform)
238
- ref_waveform = ref_waveform.squeeze().numpy()
239
-
240
- # Transcribe reference audio
241
- inputs = asr_processor(
242
- ref_waveform,
243
- sampling_rate=sample_rate,
244
- return_tensors="pt",
245
- language=lang_code
246
- )
247
- inputs = {k: v.to(asr_model.device) for k, v in inputs.items()}
248
-
249
- with torch.no_grad():
250
- logits = asr_model(**inputs).logits
251
- ids = torch.argmax(logits, dim=-1)[0]
252
- ref_transcription = asr_processor.decode(ids)
253
-
254
- # Calculate similarity
255
- similarity = calculate_similarity(user_transcription, ref_transcription)
256
-
257
- logger.info(
258
- f"[{request_id}] πŸ“Š Similarity with {ref_filename}: {similarity:.2f}%, transcription: '{ref_transcription}'")
259
-
260
- return {
261
- "reference_file": ref_filename,
262
- "reference_text": ref_transcription,
263
- "similarity_score": similarity
264
- }
265
- except Exception as e:
266
- logger.error(f"[{request_id}] ❌ Error processing {ref_filename}: {str(e)}")
267
- return {
268
- "reference_file": ref_filename,
269
- "reference_text": "Error",
270
- "similarity_score": 0,
271
- "error": str(e)
272
- }
273
-
274
- # Process files in batches using ThreadPoolExecutor
275
- with ThreadPoolExecutor(max_workers=batch_size) as executor:
276
- batch_results = list(executor.map(process_reference_file, reference_files))
277
- results.extend(batch_results)
278
-
279
- # Find the best result
280
- for result in batch_results:
281
- if result["similarity_score"] > best_score:
282
- best_score = result["similarity_score"]
283
- best_reference = result["reference_file"]
284
- best_transcription = result["reference_text"]
285
-
286
- # Exit early if we found a very good match (optional)
287
- if best_score > 80.0:
288
- logger.info(f"[{request_id}] 🏁 Found excellent match: {best_score:.2f}%")
289
- break
290
-
291
- # Clean up temp files
292
- try:
293
- if temp_dir and os.path.exists(temp_dir):
294
- shutil.rmtree(temp_dir)
295
- logger.debug(f"[{request_id}] 🧹 Cleaned up temporary directory")
296
- except Exception as e:
297
- logger.warning(f"[{request_id}] ⚠️ Failed to clean up temp files: {str(e)}")
298
-
299
- # Determine feedback based on score
300
- is_correct = best_score >= 70.0
301
-
302
- if best_score >= 90.0:
303
- feedback = "Perfect pronunciation! Excellent job!"
304
- elif best_score >= 80.0:
305
- feedback = "Great pronunciation! Your accent is very good."
306
- elif best_score >= 70.0:
307
- feedback = "Good pronunciation. Keep practicing!"
308
- elif best_score >= 50.0:
309
- feedback = "Fair attempt. Try focusing on the syllables that differ from the sample."
310
- else:
311
- feedback = "Try again. Listen carefully to the sample pronunciation."
312
-
313
- logger.info(f"[{request_id}] πŸ“Š Final evaluation results: score={best_score:.2f}%, is_correct={is_correct}")
314
- logger.info(f"[{request_id}] πŸ“ Feedback: '{feedback}'")
315
- logger.info(f"[{request_id}] βœ… Evaluation complete")
316
-
317
- # Sort results by score descending
318
- results.sort(key=lambda x: x["similarity_score"], reverse=True)
319
-
320
- return jsonify({
321
- "is_correct": is_correct,
322
- "score": best_score,
323
- "feedback": feedback,
324
- "user_transcription": user_transcription,
325
- "best_reference_transcription": best_transcription,
326
- "reference_locator": reference_locator,
327
- "details": results
328
- })
329
-
330
- except Exception as e:
331
- logger.error(f"[{request_id}] ❌ Unhandled exception in evaluation endpoint: {str(e)}")
332
- logger.debug(f"[{request_id}] Stack trace: {traceback.format_exc()}")
333
-
334
- # Clean up on error
335
- try:
336
- if temp_dir and os.path.exists(temp_dir):
337
- shutil.rmtree(temp_dir)
338
- except:
339
- pass
340
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
  return jsonify({"error": f"Internal server error: {str(e)}"}), 500
 
1
+ # evaluate.py - Handles evaluation and comparing tasks
2
+
3
+ import os
4
+ import glob
5
+ import logging
6
+ import traceback
7
+ import tempfile
8
+ import shutil
9
+ from difflib import SequenceMatcher
10
+ import torch
11
+ import torchaudio
12
+ from pydub import AudioSegment
13
+ from flask import jsonify
14
+ from werkzeug.utils import secure_filename
15
+ from concurrent.futures import ThreadPoolExecutor
16
+
17
+ # Import necessary functions from translator.py
18
+ from translator import asr_model, asr_processor, LANGUAGE_CODES
19
+
20
+ # Configure logging
21
+ logger = logging.getLogger("speech_api")
22
+
23
+ def calculate_similarity(text1, text2):
24
+ """Calculate text similarity percentage."""
25
+ def clean_text(text):
26
+ return text.lower()
27
+
28
+ clean1 = clean_text(text1)
29
+ clean2 = clean_text(text2)
30
+
31
+ matcher = SequenceMatcher(None, clean1, clean2)
32
+ return matcher.ratio() * 100
33
+
34
+ # In evaluate.py, modify the init_reference_audio function
35
+
36
+ def init_reference_audio(reference_dir, output_dir):
37
+ try:
38
+ # Create the output directory first
39
+ os.makedirs(output_dir, exist_ok=True)
40
+ logger.info(f"πŸ“ Created output directory: {output_dir}")
41
+
42
+ # Change reference_dir to be inside /tmp if the original location doesn't work
43
+ if not os.path.exists(reference_dir) or not os.access(os.path.dirname(reference_dir), os.W_OK):
44
+ # Use a directory in /tmp instead
45
+ new_reference_dir = os.path.join('/tmp', 'reference_audio')
46
+ logger.warning(f"⚠️ Changing reference directory from {reference_dir} to {new_reference_dir}")
47
+ reference_dir = new_reference_dir
48
+
49
+ # Check if the reference audio directory exists
50
+ if os.path.exists(reference_dir):
51
+ logger.info(f"βœ… Found reference audio directory: {reference_dir}")
52
+
53
+ # Log the contents to verify
54
+ pattern_dirs = [d for d in os.listdir(reference_dir)
55
+ if os.path.isdir(os.path.join(reference_dir, d))]
56
+ logger.info(f"πŸ“ Found reference patterns: {pattern_dirs}")
57
+
58
+ # Check each pattern directory for wav files
59
+ for pattern_dir_name in pattern_dirs:
60
+ pattern_path = os.path.join(reference_dir, pattern_dir_name)
61
+ wav_files = glob.glob(os.path.join(pattern_path, "*.wav"))
62
+ logger.info(f"πŸ“ Found {len(wav_files)} wav files in {pattern_dir_name}")
63
+ else:
64
+ logger.warning(f"⚠️ Reference audio directory not found: {reference_dir}")
65
+ # Create the directory if it doesn't exist
66
+ try:
67
+ os.makedirs(reference_dir, exist_ok=True)
68
+ logger.info(f"πŸ“ Created reference audio directory: {reference_dir}")
69
+ except PermissionError:
70
+ logger.error(f"❌ Permission denied when creating {reference_dir}")
71
+ # Try alternate location in /tmp
72
+ reference_dir = os.path.join('/tmp', 'reference_audio')
73
+ os.makedirs(reference_dir, exist_ok=True)
74
+ logger.info(f"πŸ“ Created reference audio directory in alternate location: {reference_dir}")
75
+
76
+ # Return the (possibly updated) reference directory path
77
+ return reference_dir
78
+
79
+ except Exception as e:
80
+ logger.error(f"❌ Failed to set up reference audio directory: {str(e)}")
81
+ # Use /tmp as fallback
82
+ fallback_dir = os.path.join('/tmp', 'reference_audio')
83
+ try:
84
+ os.makedirs(fallback_dir, exist_ok=True)
85
+ logger.info(f"πŸ“ Created fallback reference audio directory: {fallback_dir}")
86
+ return fallback_dir
87
+ except Exception as e2:
88
+ logger.critical(f"❌ Failed to create fallback directory: {str(e2)}")
89
+ return None
90
+
91
+ def handle_upload_reference(request, reference_dir, sample_rate):
92
+ """Handle upload of reference audio files"""
93
+ try:
94
+ if "audio" not in request.files:
95
+ logger.warning("⚠️ Reference upload missing audio file")
96
+ return jsonify({"error": "No audio file uploaded"}), 400
97
+
98
+ reference_word = request.form.get("reference_word", "").strip()
99
+ if not reference_word:
100
+ logger.warning("⚠️ Reference upload missing reference word")
101
+ return jsonify({"error": "No reference word provided"}), 400
102
+
103
+ # Validate reference word
104
+ reference_patterns = [
105
+ "mayap_a_abak", "mayap_a_ugtu", "mayap_a_gatpanapun", "mayap_a_bengi",
106
+ "komusta_ka", "malaus_ko_pu", "malaus_kayu", "agaganaka_da_ka",
107
+ "pagdulapan_da_ka", "kaluguran_da_ka", "dakal_a_salamat", "panapaya_mu_ku"
108
+ ]
109
+
110
+ if reference_word not in reference_patterns:
111
+ logger.warning(f"⚠️ Invalid reference word: {reference_word}")
112
+ return jsonify({"error": f"Invalid reference word. Available: {reference_patterns}"}), 400
113
+
114
+ # Create directory for reference pattern if it doesn't exist
115
+ pattern_dir = os.path.join(reference_dir, reference_word)
116
+ os.makedirs(pattern_dir, exist_ok=True)
117
+
118
+ # Save the reference audio file
119
+ audio_file = request.files["audio"]
120
+ file_path = os.path.join(pattern_dir, secure_filename(audio_file.filename))
121
+ audio_file.save(file_path)
122
+
123
+ # Convert to WAV if not already in that format
124
+ if not file_path.lower().endswith('.wav'):
125
+ base_path = os.path.splitext(file_path)[0]
126
+ wav_path = f"{base_path}.wav"
127
+ try:
128
+ audio = AudioSegment.from_file(file_path)
129
+ audio = audio.set_frame_rate(sample_rate).set_channels(1)
130
+ audio.export(wav_path, format="wav")
131
+ # Remove original file if conversion successful
132
+ os.unlink(file_path)
133
+ file_path = wav_path
134
+ except Exception as e:
135
+ logger.error(f"❌ Reference audio conversion failed: {str(e)}")
136
+ return jsonify({"error": f"Audio conversion failed: {str(e)}"}), 500
137
+
138
+ logger.info(f"βœ… Reference audio saved successfully for {reference_word}: {file_path}")
139
+
140
+ # Count how many references we have now
141
+ references = glob.glob(os.path.join(pattern_dir, "*.wav"))
142
+ return jsonify({
143
+ "message": "Reference audio uploaded successfully",
144
+ "reference_word": reference_word,
145
+ "file": os.path.basename(file_path),
146
+ "total_references": len(references)
147
+ })
148
+
149
+ except Exception as e:
150
+ logger.error(f"❌ Unhandled exception in reference upload: {str(e)}")
151
+ logger.debug(f"Stack trace: {traceback.format_exc()}")
152
+ return jsonify({"error": f"Internal server error: {str(e)}"}), 500
153
+
154
+
155
+ def handle_evaluation_request(request, reference_dir, output_dir, sample_rate):
156
+ """Handle pronunciation evaluation requests"""
157
+ request_id = f"req-{id(request)}" # Create unique ID for this request
158
+ logger.info(f"[{request_id}] πŸ†• Starting new pronunciation evaluation request")
159
+
160
+ temp_dir = None
161
+
162
+ if asr_model is None or asr_processor is None:
163
+ logger.error(f"[{request_id}] ❌ Evaluation endpoint called but ASR models aren't loaded")
164
+ return jsonify({"error": "ASR model not available"}), 503
165
+
166
+ try:
167
+ if "audio" not in request.files:
168
+ logger.warning(f"[{request_id}] ⚠️ Evaluation request missing audio file")
169
+ return jsonify({"error": "No audio file uploaded"}), 400
170
+
171
+ audio_file = request.files["audio"]
172
+ reference_locator = request.form.get("reference_locator", "").strip()
173
+ language = request.form.get("language", "kapampangan").lower()
174
+
175
+ # Validate reference locator
176
+ if not reference_locator:
177
+ logger.warning(f"[{request_id}] ⚠️ No reference locator provided")
178
+ return jsonify({"error": "Reference locator is required"}), 400
179
+
180
+ # Construct full reference directory path
181
+ reference_dir_path = os.path.join(reference_dir, reference_locator)
182
+ logger.info(f"[{request_id}] πŸ“ Reference directory path: {reference_dir_path}")
183
+
184
+ if not os.path.exists(reference_dir_path):
185
+ logger.warning(f"[{request_id}] ⚠️ Reference directory not found: {reference_dir_path}")
186
+ return jsonify({"error": f"Reference audio directory not found: {reference_locator}"}), 404
187
+
188
+ reference_files = glob.glob(os.path.join(reference_dir_path, "*.wav"))
189
+ logger.info(f"[{request_id}] πŸ“ Found {len(reference_files)} reference files")
190
+
191
+ if not reference_files:
192
+ logger.warning(f"[{request_id}] ⚠️ No reference audio files found in {reference_dir_path}")
193
+ return jsonify({"error": f"No reference audio found for {reference_locator}"}), 404
194
+
195
+ lang_code = LANGUAGE_CODES.get(language, language)
196
+ logger.info(
197
+ f"[{request_id}] πŸ”„ Evaluating pronunciation for reference: {reference_locator} with language code: {lang_code}")
198
+
199
+ # Create a request-specific temp directory to avoid conflicts
200
+ temp_dir = os.path.join(output_dir, f"temp_{request_id}")
201
+ os.makedirs(temp_dir, exist_ok=True)
202
+
203
+ # Process user audio
204
+ user_audio_path = os.path.join(temp_dir, "user_audio_input.wav")
205
+ with open(user_audio_path, 'wb') as f:
206
+ f.write(audio_file.read())
207
+
208
+ try:
209
+ logger.info(f"[{request_id}] πŸ”„ Processing user audio file")
210
+ audio = AudioSegment.from_file(user_audio_path)
211
+ audio = audio.set_frame_rate(sample_rate).set_channels(1)
212
+
213
+ processed_path = os.path.join(temp_dir, "processed_user_audio.wav")
214
+ audio.export(processed_path, format="wav")
215
+
216
+ user_waveform, sr = torchaudio.load(processed_path)
217
+ user_waveform = user_waveform.squeeze().numpy()
218
+ logger.info(f"[{request_id}] βœ… User audio processed: {sr}Hz, length: {len(user_waveform)} samples")
219
+
220
+ user_audio_path = processed_path
221
+ except Exception as e:
222
+ logger.error(f"[{request_id}] ❌ Audio processing failed: {str(e)}")
223
+ return jsonify({"error": f"Audio processing failed: {str(e)}"}), 500
224
+
225
+ # Transcribe user audio
226
+ try:
227
+ logger.info(f"[{request_id}] πŸ”„ Transcribing user audio")
228
+ inputs = asr_processor(
229
+ user_waveform,
230
+ sampling_rate=sample_rate,
231
+ return_tensors="pt",
232
+ language=lang_code
233
+ )
234
+ inputs = {k: v.to(asr_model.device) for k, v in inputs.items()}
235
+
236
+ with torch.no_grad():
237
+ logits = asr_model(**inputs).logits
238
+ ids = torch.argmax(logits, dim=-1)[0]
239
+ user_transcription = asr_processor.decode(ids)
240
+
241
+ logger.info(f"[{request_id}] βœ… User transcription: '{user_transcription}'")
242
+ except Exception as e:
243
+ logger.error(f"[{request_id}] ❌ ASR inference failed: {str(e)}")
244
+ return jsonify({"error": f"ASR inference failed: {str(e)}"}), 500
245
+
246
+ # Process reference files in batches
247
+ batch_size = 2 # Process 2 files at a time - adjust based on your hardware
248
+ results = []
249
+ best_score = 0
250
+ best_reference = None
251
+ best_transcription = None
252
+
253
+ # Use this if you want to limit the number of files to process
254
+ max_files_to_check = min(5, len(reference_files)) # Check at most 5 files
255
+ reference_files = reference_files[:max_files_to_check]
256
+
257
+ logger.info(f"[{request_id}] πŸ”„ Processing {len(reference_files)} reference files in batches of {batch_size}")
258
+
259
+ # Function to process a single reference file
260
+ def process_reference_file(ref_file):
261
+ ref_filename = os.path.basename(ref_file)
262
+ try:
263
+ # Load and resample reference audio
264
+ ref_waveform, ref_sr = torchaudio.load(ref_file)
265
+ if ref_sr != sample_rate:
266
+ ref_waveform = torchaudio.transforms.Resample(ref_sr, sample_rate)(ref_waveform)
267
+ ref_waveform = ref_waveform.squeeze().numpy()
268
+
269
+ # Transcribe reference audio
270
+ inputs = asr_processor(
271
+ ref_waveform,
272
+ sampling_rate=sample_rate,
273
+ return_tensors="pt",
274
+ language=lang_code
275
+ )
276
+ inputs = {k: v.to(asr_model.device) for k, v in inputs.items()}
277
+
278
+ with torch.no_grad():
279
+ logits = asr_model(**inputs).logits
280
+ ids = torch.argmax(logits, dim=-1)[0]
281
+ ref_transcription = asr_processor.decode(ids)
282
+
283
+ # Calculate similarity
284
+ similarity = calculate_similarity(user_transcription, ref_transcription)
285
+
286
+ logger.info(
287
+ f"[{request_id}] πŸ“Š Similarity with {ref_filename}: {similarity:.2f}%, transcription: '{ref_transcription}'")
288
+
289
+ return {
290
+ "reference_file": ref_filename,
291
+ "reference_text": ref_transcription,
292
+ "similarity_score": similarity
293
+ }
294
+ except Exception as e:
295
+ logger.error(f"[{request_id}] ❌ Error processing {ref_filename}: {str(e)}")
296
+ return {
297
+ "reference_file": ref_filename,
298
+ "reference_text": "Error",
299
+ "similarity_score": 0,
300
+ "error": str(e)
301
+ }
302
+
303
+ # Process files in batches using ThreadPoolExecutor
304
+ with ThreadPoolExecutor(max_workers=batch_size) as executor:
305
+ batch_results = list(executor.map(process_reference_file, reference_files))
306
+ results.extend(batch_results)
307
+
308
+ # Find the best result
309
+ for result in batch_results:
310
+ if result["similarity_score"] > best_score:
311
+ best_score = result["similarity_score"]
312
+ best_reference = result["reference_file"]
313
+ best_transcription = result["reference_text"]
314
+
315
+ # Exit early if we found a very good match (optional)
316
+ if best_score > 80.0:
317
+ logger.info(f"[{request_id}] 🏁 Found excellent match: {best_score:.2f}%")
318
+ break
319
+
320
+ # Clean up temp files
321
+ try:
322
+ if temp_dir and os.path.exists(temp_dir):
323
+ shutil.rmtree(temp_dir)
324
+ logger.debug(f"[{request_id}] 🧹 Cleaned up temporary directory")
325
+ except Exception as e:
326
+ logger.warning(f"[{request_id}] ⚠️ Failed to clean up temp files: {str(e)}")
327
+
328
+ # Determine feedback based on score
329
+ is_correct = best_score >= 70.0
330
+
331
+ if best_score >= 90.0:
332
+ feedback = "Perfect pronunciation! Excellent job!"
333
+ elif best_score >= 80.0:
334
+ feedback = "Great pronunciation! Your accent is very good."
335
+ elif best_score >= 70.0:
336
+ feedback = "Good pronunciation. Keep practicing!"
337
+ elif best_score >= 50.0:
338
+ feedback = "Fair attempt. Try focusing on the syllables that differ from the sample."
339
+ else:
340
+ feedback = "Try again. Listen carefully to the sample pronunciation."
341
+
342
+ logger.info(f"[{request_id}] πŸ“Š Final evaluation results: score={best_score:.2f}%, is_correct={is_correct}")
343
+ logger.info(f"[{request_id}] πŸ“ Feedback: '{feedback}'")
344
+ logger.info(f"[{request_id}] βœ… Evaluation complete")
345
+
346
+ # Sort results by score descending
347
+ results.sort(key=lambda x: x["similarity_score"], reverse=True)
348
+
349
+ return jsonify({
350
+ "is_correct": is_correct,
351
+ "score": best_score,
352
+ "feedback": feedback,
353
+ "user_transcription": user_transcription,
354
+ "best_reference_transcription": best_transcription,
355
+ "reference_locator": reference_locator,
356
+ "details": results
357
+ })
358
+
359
+ except Exception as e:
360
+ logger.error(f"[{request_id}] ❌ Unhandled exception in evaluation endpoint: {str(e)}")
361
+ logger.debug(f"[{request_id}] Stack trace: {traceback.format_exc()}")
362
+
363
+ # Clean up on error
364
+ try:
365
+ if temp_dir and os.path.exists(temp_dir):
366
+ shutil.rmtree(temp_dir)
367
+ except:
368
+ pass
369
+
370
  return jsonify({"error": f"Internal server error: {str(e)}"}), 500