File size: 5,036 Bytes
1e1ecd3
 
 
0073001
1e1ecd3
 
 
 
23ba234
1e1ecd3
5630c13
1e1ecd3
 
 
23ba234
 
 
 
 
 
 
1e1ecd3
0073001
 
5630c13
0073001
 
 
 
 
 
 
5630c13
0073001
 
 
5630c13
0073001
1e1ecd3
 
5630c13
0073001
 
1e1ecd3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23ba234
 
 
1e1ecd3
23ba234
1e1ecd3
23ba234
 
1e1ecd3
 
23ba234
 
 
 
 
1e1ecd3
 
23ba234
1e1ecd3
 
 
 
50c246b
1e1ecd3
23ba234
1e1ecd3
 
 
 
 
 
 
 
5630c13
459fd87
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from fastapi import FastAPI, File, UploadFile, HTTPException, Form
from fastapi.responses import JSONResponse
from pydantic import BaseModel
import soundfile as sf
import numpy as np
import tempfile
import os
import warnings
from pydub import AudioSegment

warnings.filterwarnings("ignore")

app = FastAPI()

def convert_mp3_to_wav(mp3_path):
    # Convert MP3 to WAV
    sound = AudioSegment.from_mp3(mp3_path)
    wav_path = mp3_path.replace(".mp3", ".wav")
    sound.export(wav_path, format="wav")
    return wav_path

def extract_audio_features(audio_file_path):
    # Load the audio file using soundfile
    waveform, sample_rate = sf.read(audio_file_path)

    # Ensure waveform is a 1D array (mono audio)
    if waveform.ndim > 1:
        waveform = waveform.mean(axis=1)
    
    # Calculate basic features (pitch estimation requires a more complex algorithm, but we'll simplify)
    energy = np.mean(waveform ** 2)
    mfccs = np.mean(np.abs(np.fft.fft(waveform)[:13]), axis=0)  # Simplified MFCC-like features

    # Placeholder for speech rate and fundamental frequency
    speech_rate = 4.0  # Arbitrary placeholder value for speech rate
    f0 = np.mean(np.abs(np.diff(waveform))) * sample_rate / (2 * np.pi)  # Rough pitch estimate

    return f0, energy, speech_rate, mfccs, waveform, sample_rate

def analyze_voice_stress(audio_file_path):
    f0, energy, speech_rate, mfccs, waveform, sample_rate = extract_audio_features(audio_file_path)
    mean_f0 = f0
    mean_energy = energy
    gender = 'male' if mean_f0 < 165 else 'female'
    norm_mean_f0 = 110 if gender == 'male' else 220
    norm_std_f0 = 20
    norm_mean_energy = 0.02
    norm_std_energy = 0.005
    norm_speech_rate = 4.4
    norm_std_speech_rate = 0.5
    z_f0 = (mean_f0 - norm_mean_f0) / norm_std_f0
    z_energy = (mean_energy - norm_mean_energy) / norm_std_energy
    z_speech_rate = (speech_rate - norm_speech_rate) / norm_std_speech_rate
    stress_score = (0.4 * z_f0) + (0.4 * z_speech_rate) + (0.2 * z_energy)
    stress_level = float(1 / (1 + np.exp(-stress_score)) * 100)
    categories = ["Very Low Stress", "Low Stress", "Moderate Stress", "High Stress", "Very High Stress"]
    category_idx = min(int(stress_level / 20), 4)
    stress_category = categories[category_idx]
    return {"stress_level": stress_level, "category": stress_category, "gender": gender}

def analyze_text_stress(text: str):
    stress_keywords = ["anxious", "nervous", "stress", "panic", "tense"]
    stress_score = sum([1 for word in stress_keywords if word in text.lower()])
    stress_level = min(stress_score * 20, 100)
    categories = ["Very Low Stress", "Low Stress", "Moderate Stress", "High Stress", "Very High Stress"]
    category_idx = min(int(stress_level / 20), 4)
    stress_category = categories[category_idx]
    return {"stress_level": stress_level, "category": stress_category}

class StressResponse(BaseModel):
    stress_level: float
    category: str
    gender: str = None  # Optional, only for audio analysis

@app.post("/analyze-stress/", response_model=StressResponse)
async def analyze_stress(
    file: UploadFile = File(None), 
    file_path: str = Form(None),
    text: str = Form(None)
):
    if file is None and file_path is None and text is None:
        raise HTTPException(status_code=400, detail="Either a file, file path, or text input is required.")
    
    # Handle audio file analysis
    if file or file_path:
        if file:
            if not (file.filename.endswith(".wav") or file.filename.endswith(".mp3")):
                raise HTTPException(status_code=400, detail="Only .wav and .mp3 files are supported.")
            with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[-1]) as temp_file:
                temp_file.write(await file.read())
                temp_audio_path = temp_file.name
        else:
            if not (file_path.endswith(".wav") or file_path.endswith(".mp3")):
                raise HTTPException(status_code=400, detail="Only .wav and .mp3 files are supported.")
            if not os.path.exists(file_path):
                raise HTTPException(status_code=400, detail="File path does not exist.")
            temp_audio_path = file_path

        # Convert MP3 to WAV if needed
        if temp_audio_path.endswith(".mp3"):
            temp_audio_path = convert_mp3_to_wav(temp_audio_path)

        try:
            result = analyze_voice_stress(temp_audio_path)
            return JSONResponse(content=result)
        except Exception as e:
            raise HTTPException(status_code=500, detail=str(e))
        finally:
            # Clean up temporary files
            if file:
                os.remove(temp_audio_path)

    # Handle text analysis
    elif text:
        result = analyze_text_stress(text)
        return JSONResponse(content=result)

if __name__ == "__main__":
    import uvicorn
    port = int(os.getenv("PORT", 7860))  # Use the PORT environment variable if needed
    uvicorn.run("app:app", host="0.0.0.0", port=port, reload=True)