Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import librosa | |
| import numpy as np | |
| from typing import List, Dict, Any, Optional | |
| from fastapi import FastAPI, UploadFile, File, HTTPException, BackgroundTasks | |
| from fastapi.responses import JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC | |
| from librosa.sequence import dtw | |
| import tempfile | |
| import uuid | |
| import shutil | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="Quran Recitation Comparison API", | |
| description="API for comparing similarity between Quran recitations using Wav2Vec2 embeddings", | |
| version="1.0.0" | |
| ) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Allows all origins | |
| allow_credentials=True, | |
| allow_methods=["*"], # Allows all methods | |
| allow_headers=["*"], # Allows all headers | |
| ) | |
| # Global variables | |
| MODEL = None | |
| PROCESSOR = None | |
| UPLOAD_DIR = os.path.join(tempfile.gettempdir(), "quran_comparison_uploads") | |
| # Ensure upload directory exists | |
| os.makedirs(UPLOAD_DIR, exist_ok=True) | |
| # Response models | |
| class SimilarityResponse(BaseModel): | |
| similarity_score: float | |
| interpretation: str | |
| class ErrorResponse(BaseModel): | |
| error: str | |
| # Initialize model from environment variable | |
| def initialize_model(): | |
| global MODEL, PROCESSOR | |
| # Get HF token from environment variable | |
| hf_token = os.environ.get("HF_TOKEN", None) | |
| model_name = os.environ.get("MODEL_NAME", "jonatasgrosman/wav2vec2-large-xlsr-53-arabic") | |
| try: | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Loading model on device: {device}") | |
| # Load model and processor | |
| if hf_token: | |
| PROCESSOR = Wav2Vec2Processor.from_pretrained(model_name, use_auth_token=hf_token) | |
| MODEL = Wav2Vec2ForCTC.from_pretrained(model_name, use_auth_token=hf_token) | |
| else: | |
| PROCESSOR = Wav2Vec2Processor.from_pretrained(model_name) | |
| MODEL = Wav2Vec2ForCTC.from_pretrained(model_name) | |
| MODEL = MODEL.to(device) | |
| MODEL.eval() | |
| print("Model loaded successfully") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| raise e | |
| # Load audio file | |
| def load_audio(file_path, target_sr=16000, trim_silence=True, normalize=True): | |
| """Load and preprocess an audio file.""" | |
| try: | |
| y, sr = librosa.load(file_path, sr=target_sr) | |
| if normalize: | |
| y = librosa.util.normalize(y) | |
| if trim_silence: | |
| y, _ = librosa.effects.trim(y, top_db=30) | |
| return y | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=f"Error loading audio: {e}") | |
| # Get deep embedding | |
| def get_deep_embedding(audio, sr=16000): | |
| """Extract frame-wise deep embeddings using the pretrained model.""" | |
| global MODEL, PROCESSOR | |
| if MODEL is None or PROCESSOR is None: | |
| raise HTTPException(status_code=500, detail="Model not initialized") | |
| try: | |
| device = next(MODEL.parameters()).device | |
| input_values = PROCESSOR( | |
| audio, | |
| sampling_rate=sr, | |
| return_tensors="pt" | |
| ).input_values.to(device) | |
| with torch.no_grad(): | |
| outputs = MODEL(input_values, output_hidden_states=True) | |
| hidden_states = outputs.hidden_states[-1] | |
| embedding_seq = hidden_states.squeeze(0).cpu().numpy() | |
| return embedding_seq | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error extracting embeddings: {e}") | |
| # Compute DTW distance | |
| def compute_dtw_distance(features1, features2): | |
| """Compute the DTW distance between two sequences of features.""" | |
| try: | |
| D, wp = dtw(X=features1, Y=features2, metric='euclidean') | |
| distance = D[-1, -1] | |
| normalized_distance = distance / len(wp) | |
| return normalized_distance | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error computing DTW distance: {e}") | |
| # Interpret similarity | |
| def interpret_similarity(norm_distance): | |
| """Interpret the normalized distance value.""" | |
| if norm_distance == 0: | |
| result = "The recitations are identical based on the deep embeddings." | |
| score = 100 | |
| elif norm_distance < 1: | |
| result = "The recitations are extremely similar." | |
| score = 95 | |
| elif norm_distance < 5: | |
| result = "The recitations are very similar with minor differences." | |
| score = 80 | |
| elif norm_distance < 10: | |
| result = "The recitations show moderate similarity." | |
| score = 60 | |
| elif norm_distance < 20: | |
| result = "The recitations show some noticeable differences." | |
| score = 40 | |
| else: | |
| result = "The recitations are quite different." | |
| score = max(0, 100 - norm_distance) | |
| return result, score | |
| # Clean up temporary files | |
| def cleanup_temp_files(file_paths): | |
| """Remove temporary files.""" | |
| for file_path in file_paths: | |
| if os.path.exists(file_path): | |
| try: | |
| os.remove(file_path) | |
| except Exception as e: | |
| print(f"Error removing temporary file {file_path}: {e}") | |
| # API endpoints | |
| async def compare_recitations( | |
| background_tasks: BackgroundTasks, | |
| file1: UploadFile = File(...), | |
| file2: UploadFile = File(...) | |
| ): | |
| """ | |
| Compare two Quran recitations and return similarity metrics. | |
| - **file1**: First audio file | |
| - **file2**: Second audio file | |
| Returns: | |
| - **similarity_score**: Score between 0-100 indicating similarity | |
| - **interpretation**: Text interpretation of the similarity | |
| """ | |
| # Check if model is initialized | |
| if MODEL is None or PROCESSOR is None: | |
| raise HTTPException(status_code=500, detail="Model not initialized") | |
| # Temporary file paths | |
| temp_file1 = os.path.join(UPLOAD_DIR, f"{uuid.uuid4()}.wav") | |
| temp_file2 = os.path.join(UPLOAD_DIR, f"{uuid.uuid4()}.wav") | |
| try: | |
| # Save uploaded files | |
| with open(temp_file1, "wb") as f: | |
| shutil.copyfileobj(file1.file, f) | |
| with open(temp_file2, "wb") as f: | |
| shutil.copyfileobj(file2.file, f) | |
| # Load audio files | |
| audio1 = load_audio(temp_file1) | |
| audio2 = load_audio(temp_file2) | |
| # Extract embeddings | |
| embedding1 = get_deep_embedding(audio1) | |
| embedding2 = get_deep_embedding(audio2) | |
| # Compute DTW distance | |
| norm_distance = compute_dtw_distance(embedding1.T, embedding2.T) | |
| # Interpret results | |
| interpretation, similarity_score = interpret_similarity(norm_distance) | |
| # Add cleanup task | |
| background_tasks.add_task(cleanup_temp_files, [temp_file1, temp_file2]) | |
| return { | |
| "similarity_score": similarity_score, | |
| "interpretation": interpretation | |
| } | |
| except Exception as e: | |
| # Ensure files are cleaned up even in case of error | |
| background_tasks.add_task(cleanup_temp_files, [temp_file1, temp_file2]) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def health_check(): | |
| """Health check endpoint.""" | |
| if MODEL is None or PROCESSOR is None: | |
| return JSONResponse( | |
| status_code=503, | |
| content={"status": "error", "message": "Model not initialized"} | |
| ) | |
| return {"status": "ok", "model_loaded": True} | |
| # Initialize model on startup | |
| async def startup_event(): | |
| initialize_model() | |
| # Run the FastAPI app | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.environ.get("PORT", 7860)) # Default to port 7860 for Hugging Face Spaces | |
| uvicorn.run("main:app", host="0.0.0.0", port=port, reload=False) |