from fastapi import FastAPI, File, UploadFile, HTTPException, Form from fastapi.responses import JSONResponse from pydantic import BaseModel import torchaudio import numpy as np import tempfile import os import warnings warnings.filterwarnings("ignore") app = FastAPI() def extract_audio_features(audio_file_path): # Load the audio file using torchaudio waveform, sample_rate = torchaudio.load(audio_file_path) # Ensure waveform is mono by averaging channels if necessary if waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True) waveform = waveform.squeeze() # Remove channel dimension if it's 1 # Extract pitch (fundamental frequency) pitch_frequencies, voiced_flags, _ = torchaudio.functional.detect_pitch_frequency( waveform, sample_rate, frame_time=0.01, win_length=1024 ) f0 = pitch_frequencies[voiced_flags > 0] # Extract energy energy = waveform.pow(2).numpy() # Extract MFCCs mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate, n_mfcc=13) mfccs = mfcc_transform(waveform.unsqueeze(0)).squeeze(0).numpy() # Estimate speech rate (simplified) tempo = torchaudio.functional.estimate_tempo(waveform, sample_rate) speech_rate = tempo / 60 if tempo is not None else 0 return f0.numpy(), energy, speech_rate, mfccs, waveform.numpy(), sample_rate def analyze_voice_stress(audio_file_path): f0, energy, speech_rate, mfccs, waveform, sample_rate = extract_audio_features(audio_file_path) if len(f0) == 0: raise ValueError("Could not extract fundamental frequency from the audio.") mean_f0 = np.mean(f0) std_f0 = np.std(f0) mean_energy = np.mean(energy) std_energy = np.std(energy) gender = 'male' if mean_f0 < 165 else 'female' norm_mean_f0 = 110 if gender == 'male' else 220 norm_std_f0 = 20 norm_mean_energy = 0.02 norm_std_energy = 0.005 norm_speech_rate = 4.4 norm_std_speech_rate = 0.5 z_f0 = (mean_f0 - norm_mean_f0) / norm_std_f0 z_energy = (mean_energy - norm_mean_energy) / norm_std_energy z_speech_rate = (speech_rate - norm_speech_rate) / norm_std_speech_rate stress_score = (0.4 * z_f0) + (0.4 * z_speech_rate) + (0.2 * z_energy) stress_level = float(1 / (1 + np.exp(-stress_score)) * 100) categories = ["Very Low Stress", "Low Stress", "Moderate Stress", "High Stress", "Very High Stress"] category_idx = min(int(stress_level / 20), 4) stress_category = categories[category_idx] return {"stress_level": stress_level, "category": stress_category, "gender": gender} def analyze_text_stress(text: str): stress_keywords = ["anxious", "nervous", "stress", "panic", "tense"] stress_score = sum([1 for word in stress_keywords if word in text.lower()]) stress_level = min(stress_score * 20, 100) categories = ["Very Low Stress", "Low Stress", "Moderate Stress", "High Stress", "Very High Stress"] category_idx = min(int(stress_level / 20), 4) stress_category = categories[category_idx] return {"stress_level": stress_level, "category": stress_category} class StressResponse(BaseModel): stress_level: float category: str gender: str = None # Optional, only for audio analysis @app.post("/analyze-stress/", response_model=StressResponse) async def analyze_stress( file: UploadFile = File(None), file_path: str = Form(None), text: str = Form(None) ): if file is None and file_path is None and text is None: raise HTTPException(status_code=400, detail="Either a file, file path, or text input is required.") # Handle audio file analysis if file or file_path: if file: if not file.filename.endswith(".wav"): raise HTTPException(status_code=400, detail="Only .wav files are supported.") with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file: temp_file.write(await file.read()) temp_wav_path = temp_file.name else: if not file_path.endswith(".wav"): raise HTTPException(status_code=400, detail="Only .wav files are supported.") if not os.path.exists(file_path): raise HTTPException(status_code=400, detail="File path does not exist.") temp_wav_path = file_path try: result = analyze_voice_stress(temp_wav_path) return JSONResponse(content=result) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) finally: # Clean up temporary files if file: os.remove(temp_wav_path) # Handle text analysis elif text: result = analyze_text_stress(text) return JSONResponse(content=result) if __name__ == "__main__": import uvicorn port = int(os.getenv("PORT", 7860)) # Use the PORT environment variable if needed uvicorn.run("app:app", host="0.0.0.0", port=port, reload=True)