Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import librosa | |
| import numpy as np | |
| import tempfile | |
| from fastapi import FastAPI, UploadFile, File, HTTPException | |
| from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC | |
| from librosa.sequence import dtw | |
| app = FastAPI(title="Quran Recitation Comparer API", description="Compares two Quran recitations using a deep wav2vec2 model.", version="1.0") | |
| # --- Core Class Definition --- | |
| class QuranRecitationComparer: | |
| def __init__(self, model_name="jonatasgrosman/wav2vec2-large-xlsr-53-arabic", auth_token=None): | |
| """ | |
| Initialize the Quran recitation comparer with a specific Wav2Vec2 model. | |
| """ | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load model and processor once during initialization | |
| if auth_token: | |
| self.processor = Wav2Vec2Processor.from_pretrained(model_name, token=auth_token) | |
| self.model = Wav2Vec2ForCTC.from_pretrained(model_name, token=auth_token) | |
| else: | |
| self.processor = Wav2Vec2Processor.from_pretrained(model_name) | |
| self.model = Wav2Vec2ForCTC.from_pretrained(model_name) | |
| self.model = self.model.to(self.device) | |
| self.model.eval() | |
| # Cache for embeddings to avoid recomputation | |
| self.embedding_cache = {} | |
| def load_audio(self, file_path, target_sr=16000, trim_silence=True, normalize=True): | |
| """Load and preprocess an audio file.""" | |
| if not os.path.exists(file_path): | |
| raise FileNotFoundError(f"Audio file not found: {file_path}") | |
| y, sr = librosa.load(file_path, sr=target_sr) | |
| if normalize: | |
| y = librosa.util.normalize(y) | |
| if trim_silence: | |
| y, _ = librosa.effects.trim(y, top_db=30) | |
| return y | |
| def get_deep_embedding(self, audio, sr=16000): | |
| """Extract frame-wise deep embeddings using the pretrained model.""" | |
| input_values = self.processor( | |
| audio, | |
| sampling_rate=sr, | |
| return_tensors="pt" | |
| ).input_values.to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model(input_values, output_hidden_states=True) | |
| hidden_states = outputs.hidden_states[-1] | |
| embedding_seq = hidden_states.squeeze(0).cpu().numpy() | |
| return embedding_seq | |
| def compute_dtw_distance(self, features1, features2): | |
| """Compute the DTW distance between two sequences of features.""" | |
| D, wp = dtw(X=features1, Y=features2, metric='euclidean') | |
| distance = D[-1, -1] | |
| normalized_distance = distance / len(wp) | |
| return normalized_distance | |
| def interpret_similarity(self, norm_distance): | |
| """Interpret the normalized distance value.""" | |
| if norm_distance == 0: | |
| result = "The recitations are identical based on the deep embeddings." | |
| score = 100 | |
| elif norm_distance < 1: | |
| result = "The recitations are extremely similar." | |
| score = 95 | |
| elif norm_distance < 5: | |
| result = "The recitations are very similar with minor differences." | |
| score = 80 | |
| elif norm_distance < 10: | |
| result = "The recitations show moderate similarity." | |
| score = 60 | |
| elif norm_distance < 20: | |
| result = "The recitations show some noticeable differences." | |
| score = 40 | |
| else: | |
| result = "The recitations are quite different." | |
| score = max(0, 100 - norm_distance) | |
| return result, score | |
| def get_embedding_for_file(self, file_path): | |
| """Get embedding for a file, using cache if available.""" | |
| if file_path in self.embedding_cache: | |
| return self.embedding_cache[file_path] | |
| audio = self.load_audio(file_path) | |
| embedding = self.get_deep_embedding(audio) | |
| # Store in cache for future use | |
| self.embedding_cache[file_path] = embedding | |
| return embedding | |
| def predict(self, file_path1, file_path2): | |
| """ | |
| Predict the similarity between two audio files. | |
| Args: | |
| file_path1 (str): Path to first audio file. | |
| file_path2 (str): Path to second audio file. | |
| Returns: | |
| (float, str): Similarity score and interpretation. | |
| """ | |
| embedding1 = self.get_embedding_for_file(file_path1) | |
| embedding2 = self.get_embedding_for_file(file_path2) | |
| norm_distance = self.compute_dtw_distance(embedding1.T, embedding2.T) | |
| interpretation, similarity_score = self.interpret_similarity(norm_distance) | |
| # Optionally log the results instead of printing in production | |
| print(f"Similarity Score: {similarity_score:.1f}/100") | |
| print(f"Interpretation: {interpretation}") | |
| return similarity_score, interpretation | |
| def clear_cache(self): | |
| """Clear the embedding cache to free memory.""" | |
| self.embedding_cache = {} | |
| # --- FastAPI Startup Event --- | |
| # In production, consider loading sensitive tokens from environment variables or configuration files. | |
| def startup_event(): | |
| global comparer | |
| # For production, do not hardcode tokens; use os.environ.get(...) or a configuration system. | |
| auth_token = os.environ.get("HF_TOKEN") | |
| comparer = QuranRecitationComparer( | |
| model_name="jonatasgrosman/wav2vec2-large-xlsr-53-arabic", | |
| auth_token=auth_token | |
| ) | |
| print("Model initialized and ready for predictions!") | |
| # --- API Endpoints --- | |
| async def root(): | |
| return {"message": "Quran Recitation Comparer API is up and running."} | |
| async def predict(file1: UploadFile = File(...), file2: UploadFile = File(...)): | |
| """ | |
| Compare two uploaded audio files and return a similarity score along with an interpretation. | |
| - **file1**: The first audio file. | |
| - **file2**: The second audio file. | |
| """ | |
| tmp1_path = None | |
| tmp2_path = None | |
| try: | |
| # Save first file to a temporary location | |
| suffix1 = os.path.splitext(file1.filename)[1] or ".wav" | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=suffix1) as tmp1: | |
| content1 = await file1.read() | |
| tmp1.write(content1) | |
| tmp1_path = tmp1.name | |
| # Save second file to a temporary location | |
| suffix2 = os.path.splitext(file2.filename)[1] or ".wav" | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=suffix2) as tmp2: | |
| content2 = await file2.read() | |
| tmp2.write(content2) | |
| tmp2_path = tmp2.name | |
| similarity_score, interpretation = comparer.predict(tmp1_path, tmp2_path) | |
| return {"similarity_score": similarity_score, "interpretation": interpretation} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| finally: | |
| # Clean up temporary files | |
| if tmp1_path and os.path.exists(tmp1_path): | |
| os.remove(tmp1_path) | |
| if tmp2_path and os.path.exists(tmp2_path): | |
| os.remove(tmp2_path) | |
| async def clear_cache(): | |
| """ | |
| Clear the embedding cache. This can help free memory if many comparisons have been made. | |
| """ | |
| comparer.clear_cache() | |
| return {"message": "Cache cleared."} | |