File size: 15,305 Bytes
0812080
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
# evaluate.py - Handles evaluation and comparing tasks

import os
import glob
import logging
import traceback
import tempfile
import shutil
from difflib import SequenceMatcher
import torch
import torchaudio
from pydub import AudioSegment
from flask import jsonify
from werkzeug.utils import secure_filename
from concurrent.futures import ThreadPoolExecutor

# Import necessary functions from translator.py
from translator import asr_model, asr_processor, LANGUAGE_CODES

# Configure logging
logger = logging.getLogger("speech_api")

def calculate_similarity(text1, text2):
    """Calculate text similarity percentage."""
    def clean_text(text):
        return text.lower()

    clean1 = clean_text(text1)
    clean2 = clean_text(text2)

    matcher = SequenceMatcher(None, clean1, clean2)
    return matcher.ratio() * 100

def init_reference_audio(reference_dir, output_dir):
    try:
        # Create the output directory first
        os.makedirs(output_dir, exist_ok=True)
        logger.info(f"πŸ“ Created output directory: {output_dir}")

        # Check if the reference audio directory exists in the repository
        if os.path.exists(reference_dir):
            logger.info(f"βœ… Found reference audio directory: {reference_dir}")

            # Log the contents to verify
            pattern_dirs = [d for d in os.listdir(reference_dir)
                            if os.path.isdir(os.path.join(reference_dir, d))]
            logger.info(f"πŸ“ Found reference patterns: {pattern_dirs}")

            # Check each pattern directory for wav files
            for pattern_dir_name in pattern_dirs:
                pattern_path = os.path.join(reference_dir, pattern_dir_name)
                wav_files = glob.glob(os.path.join(pattern_path, "*.wav"))
                logger.info(f"πŸ“ Found {len(wav_files)} wav files in {pattern_dir_name}")
        else:
            logger.warning(f"⚠️ Reference audio directory not found: {reference_dir}")
            # Create the directory if it doesn't exist
            os.makedirs(reference_dir, exist_ok=True)
            logger.info(f"πŸ“ Created reference audio directory: {reference_dir}")
    except Exception as e:
        logger.error(f"❌ Failed to set up reference audio directory: {str(e)}")

def handle_upload_reference(request, reference_dir, sample_rate):
    """Handle upload of reference audio files"""
    try:
        if "audio" not in request.files:
            logger.warning("⚠️ Reference upload missing audio file")
            return jsonify({"error": "No audio file uploaded"}), 400

        reference_word = request.form.get("reference_word", "").strip()
        if not reference_word:
            logger.warning("⚠️ Reference upload missing reference word")
            return jsonify({"error": "No reference word provided"}), 400

        # Validate reference word
        reference_patterns = [
            "mayap_a_abak", "mayap_a_ugtu", "mayap_a_gatpanapun", "mayap_a_bengi",
            "komusta_ka", "malaus_ko_pu", "malaus_kayu", "agaganaka_da_ka",
            "pagdulapan_da_ka", "kaluguran_da_ka", "dakal_a_salamat", "panapaya_mu_ku"
        ]

        if reference_word not in reference_patterns:
            logger.warning(f"⚠️ Invalid reference word: {reference_word}")
            return jsonify({"error": f"Invalid reference word. Available: {reference_patterns}"}), 400

        # Create directory for reference pattern if it doesn't exist
        pattern_dir = os.path.join(reference_dir, reference_word)
        os.makedirs(pattern_dir, exist_ok=True)

        # Save the reference audio file
        audio_file = request.files["audio"]
        file_path = os.path.join(pattern_dir, secure_filename(audio_file.filename))
        audio_file.save(file_path)

        # Convert to WAV if not already in that format
        if not file_path.lower().endswith('.wav'):
            base_path = os.path.splitext(file_path)[0]
            wav_path = f"{base_path}.wav"
            try:
                audio = AudioSegment.from_file(file_path)
                audio = audio.set_frame_rate(sample_rate).set_channels(1)
                audio.export(wav_path, format="wav")
                # Remove original file if conversion successful
                os.unlink(file_path)
                file_path = wav_path
            except Exception as e:
                logger.error(f"❌ Reference audio conversion failed: {str(e)}")
                return jsonify({"error": f"Audio conversion failed: {str(e)}"}), 500

        logger.info(f"βœ… Reference audio saved successfully for {reference_word}: {file_path}")

        # Count how many references we have now
        references = glob.glob(os.path.join(pattern_dir, "*.wav"))
        return jsonify({
            "message": "Reference audio uploaded successfully",
            "reference_word": reference_word,
            "file": os.path.basename(file_path),
            "total_references": len(references)
        })

    except Exception as e:
        logger.error(f"❌ Unhandled exception in reference upload: {str(e)}")
        logger.debug(f"Stack trace: {traceback.format_exc()}")
        return jsonify({"error": f"Internal server error: {str(e)}"}), 500


def handle_evaluation_request(request, reference_dir, output_dir, sample_rate):
    """Handle pronunciation evaluation requests"""
    request_id = f"req-{id(request)}"  # Create unique ID for this request
    logger.info(f"[{request_id}] πŸ†• Starting new pronunciation evaluation request")

    temp_dir = None

    if asr_model is None or asr_processor is None:
        logger.error(f"[{request_id}] ❌ Evaluation endpoint called but ASR models aren't loaded")
        return jsonify({"error": "ASR model not available"}), 503

    try:
        if "audio" not in request.files:
            logger.warning(f"[{request_id}] ⚠️ Evaluation request missing audio file")
            return jsonify({"error": "No audio file uploaded"}), 400

        audio_file = request.files["audio"]
        reference_locator = request.form.get("reference_locator", "").strip()
        language = request.form.get("language", "kapampangan").lower()

        # Validate reference locator
        if not reference_locator:
            logger.warning(f"[{request_id}] ⚠️ No reference locator provided")
            return jsonify({"error": "Reference locator is required"}), 400

        # Construct full reference directory path
        reference_dir_path = os.path.join(reference_dir, reference_locator)
        logger.info(f"[{request_id}] πŸ“ Reference directory path: {reference_dir_path}")

        if not os.path.exists(reference_dir_path):
            logger.warning(f"[{request_id}] ⚠️ Reference directory not found: {reference_dir_path}")
            return jsonify({"error": f"Reference audio directory not found: {reference_locator}"}), 404

        reference_files = glob.glob(os.path.join(reference_dir_path, "*.wav"))
        logger.info(f"[{request_id}] πŸ“ Found {len(reference_files)} reference files")

        if not reference_files:
            logger.warning(f"[{request_id}] ⚠️ No reference audio files found in {reference_dir_path}")
            return jsonify({"error": f"No reference audio found for {reference_locator}"}), 404

        lang_code = LANGUAGE_CODES.get(language, language)
        logger.info(
            f"[{request_id}] πŸ”„ Evaluating pronunciation for reference: {reference_locator} with language code: {lang_code}")

        # Create a request-specific temp directory to avoid conflicts
        temp_dir = os.path.join(output_dir, f"temp_{request_id}")
        os.makedirs(temp_dir, exist_ok=True)

        # Process user audio
        user_audio_path = os.path.join(temp_dir, "user_audio_input.wav")
        with open(user_audio_path, 'wb') as f:
            f.write(audio_file.read())

        try:
            logger.info(f"[{request_id}] πŸ”„ Processing user audio file")
            audio = AudioSegment.from_file(user_audio_path)
            audio = audio.set_frame_rate(sample_rate).set_channels(1)

            processed_path = os.path.join(temp_dir, "processed_user_audio.wav")
            audio.export(processed_path, format="wav")

            user_waveform, sr = torchaudio.load(processed_path)
            user_waveform = user_waveform.squeeze().numpy()
            logger.info(f"[{request_id}] βœ… User audio processed: {sr}Hz, length: {len(user_waveform)} samples")

            user_audio_path = processed_path
        except Exception as e:
            logger.error(f"[{request_id}] ❌ Audio processing failed: {str(e)}")
            return jsonify({"error": f"Audio processing failed: {str(e)}"}), 500

        # Transcribe user audio
        try:
            logger.info(f"[{request_id}] πŸ”„ Transcribing user audio")
            inputs = asr_processor(
                user_waveform,
                sampling_rate=sample_rate,
                return_tensors="pt",
                language=lang_code
            )
            inputs = {k: v.to(asr_model.device) for k, v in inputs.items()}

            with torch.no_grad():
                logits = asr_model(**inputs).logits
            ids = torch.argmax(logits, dim=-1)[0]
            user_transcription = asr_processor.decode(ids)

            logger.info(f"[{request_id}] βœ… User transcription: '{user_transcription}'")
        except Exception as e:
            logger.error(f"[{request_id}] ❌ ASR inference failed: {str(e)}")
            return jsonify({"error": f"ASR inference failed: {str(e)}"}), 500

        # Process reference files in batches
        batch_size = 2  # Process 2 files at a time - adjust based on your hardware
        results = []
        best_score = 0
        best_reference = None
        best_transcription = None

        # Use this if you want to limit the number of files to process
        max_files_to_check = min(5, len(reference_files))  # Check at most 5 files
        reference_files = reference_files[:max_files_to_check]

        logger.info(f"[{request_id}] πŸ”„ Processing {len(reference_files)} reference files in batches of {batch_size}")

        # Function to process a single reference file
        def process_reference_file(ref_file):
            ref_filename = os.path.basename(ref_file)
            try:
                # Load and resample reference audio
                ref_waveform, ref_sr = torchaudio.load(ref_file)
                if ref_sr != sample_rate:
                    ref_waveform = torchaudio.transforms.Resample(ref_sr, sample_rate)(ref_waveform)
                ref_waveform = ref_waveform.squeeze().numpy()

                # Transcribe reference audio
                inputs = asr_processor(
                    ref_waveform,
                    sampling_rate=sample_rate,
                    return_tensors="pt",
                    language=lang_code
                )
                inputs = {k: v.to(asr_model.device) for k, v in inputs.items()}

                with torch.no_grad():
                    logits = asr_model(**inputs).logits
                ids = torch.argmax(logits, dim=-1)[0]
                ref_transcription = asr_processor.decode(ids)

                # Calculate similarity
                similarity = calculate_similarity(user_transcription, ref_transcription)

                logger.info(
                    f"[{request_id}] πŸ“Š Similarity with {ref_filename}: {similarity:.2f}%, transcription: '{ref_transcription}'")

                return {
                    "reference_file": ref_filename,
                    "reference_text": ref_transcription,
                    "similarity_score": similarity
                }
            except Exception as e:
                logger.error(f"[{request_id}] ❌ Error processing {ref_filename}: {str(e)}")
                return {
                    "reference_file": ref_filename,
                    "reference_text": "Error",
                    "similarity_score": 0,
                    "error": str(e)
                }

        # Process files in batches using ThreadPoolExecutor
        with ThreadPoolExecutor(max_workers=batch_size) as executor:
            batch_results = list(executor.map(process_reference_file, reference_files))
            results.extend(batch_results)

            # Find the best result
            for result in batch_results:
                if result["similarity_score"] > best_score:
                    best_score = result["similarity_score"]
                    best_reference = result["reference_file"]
                    best_transcription = result["reference_text"]

                    # Exit early if we found a very good match (optional)
                    if best_score > 80.0:
                        logger.info(f"[{request_id}] 🏁 Found excellent match: {best_score:.2f}%")
                        break

        # Clean up temp files
        try:
            if temp_dir and os.path.exists(temp_dir):
                shutil.rmtree(temp_dir)
                logger.debug(f"[{request_id}] 🧹 Cleaned up temporary directory")
        except Exception as e:
            logger.warning(f"[{request_id}] ⚠️ Failed to clean up temp files: {str(e)}")

        # Determine feedback based on score
        is_correct = best_score >= 70.0

        if best_score >= 90.0:
            feedback = "Perfect pronunciation! Excellent job!"
        elif best_score >= 80.0:
            feedback = "Great pronunciation! Your accent is very good."
        elif best_score >= 70.0:
            feedback = "Good pronunciation. Keep practicing!"
        elif best_score >= 50.0:
            feedback = "Fair attempt. Try focusing on the syllables that differ from the sample."
        else:
            feedback = "Try again. Listen carefully to the sample pronunciation."

        logger.info(f"[{request_id}] πŸ“Š Final evaluation results: score={best_score:.2f}%, is_correct={is_correct}")
        logger.info(f"[{request_id}] πŸ“ Feedback: '{feedback}'")
        logger.info(f"[{request_id}] βœ… Evaluation complete")

        # Sort results by score descending
        results.sort(key=lambda x: x["similarity_score"], reverse=True)

        return jsonify({
            "is_correct": is_correct,
            "score": best_score,
            "feedback": feedback,
            "user_transcription": user_transcription,
            "best_reference_transcription": best_transcription,
            "reference_locator": reference_locator,
            "details": results
        })

    except Exception as e:
        logger.error(f"[{request_id}] ❌ Unhandled exception in evaluation endpoint: {str(e)}")
        logger.debug(f"[{request_id}] Stack trace: {traceback.format_exc()}")

        # Clean up on error
        try:
            if temp_dir and os.path.exists(temp_dir):
                shutil.rmtree(temp_dir)
        except:
            pass

        return jsonify({"error": f"Internal server error: {str(e)}"}), 500