File size: 10,638 Bytes
72f3531
 
 
 
 
 
 
 
 
 
fcdc0cf
72f3531
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fcdc0cf
72f3531
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68390a5
72f3531
 
 
45a579f
72f3531
 
 
 
a6dea81
72f3531
 
 
 
 
 
a6dea81
72f3531
 
 
 
 
 
 
a6dea81
72f3531
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
410fd66
72f3531
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
410fd66
72f3531
98eae58
72f3531
 
 
98eae58
 
 
 
a937006
72f3531
98eae58
72f3531
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98eae58
 
 
 
 
 
 
 
 
 
 
72f3531
98eae58
72f3531
98eae58
 
 
 
 
 
 
 
 
72f3531
98eae58
72f3531
98eae58
 
 
 
 
 
 
72f3531
 
 
 
 
 
 
 
 
 
 
5cd70e9
410fd66
72f3531
 
 
 
 
 
 
98eae58
 
5cd70e9
98eae58
72f3531
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ade5a3
72f3531
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
import gradio as gr
import librosa
import numpy as np
import os
import hashlib
from datetime import datetime
import soundfile as sf
import torch
from tenacity import retry, stop_after_attempt, wait_fixed
from transformers import pipeline

# Initialize local models with retry logic
@retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
def load_whisper_model():
    try:
        model = pipeline(
            "automatic-speech-recognition",
            model="openai/whisper-tiny",  # Multilingual model
            device=-1,  # CPU; use device=0 for GPU if available
            model_kwargs={"use_safetensors": True}
        )
        print("Whisper model loaded successfully.")
        return model
    except Exception as e:
        print(f"Failed to load Whisper model: {str(e)}")
        raise

@retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
def load_symptom_model():
    try:
        model = pipeline(
            "text-classification",
            model="abhirajeshbhai/symptom-2-disease-net",
            device=-1,  # CPU
            model_kwargs={"use_safetensors": True}
        )
        print("Symptom-2-Disease model loaded successfully.")
        return model
    except Exception as e:
        print(f"Failed to load Symptom-2-Disease model: {str(e)}")
        # Fallback to a generic model
        try:
            model = pipeline(
                "text-classification",
                model="distilbert-base-uncased",
                device=-1
            )
            print("Fallback to distilbert-base-uncased model.")
            return model
        except Exception as fallback_e:
            print(f"Fallback model failed: {str(fallback_e)}")
            raise

whisper = None
symptom_classifier = None
is_fallback_model = False

try:
    whisper = load_whisper_model()
except Exception as e:
    print(f"Whisper model initialization failed after retries: {str(e)}")

try:
    symptom_classifier = load_symptom_model()
except Exception as e:
    print(f"Symptom model initialization failed after retries: {str(e)}")
    symptom_classifier = None
    is_fallback_model = True

def compute_file_hash(file_path):
    """Compute MD5 hash of a file to check uniqueness."""
    hash_md5 = hashlib.md5()
    with open(file_path, "rb") as f:
        for chunk in iter(lambda: f.read(4096), b""):
            hash_md5.update(chunk)
    return hash_md5.hexdigest()

def transcribe_audio(audio_file, language="en"):
    """Transcribe audio using local Whisper model."""
    if not whisper:
        return "Error: Whisper model not loaded. Check logs for details or ensure sufficient compute resources."
    try:
        # Load and validate audio
        audio, sr = librosa.load(audio_file, sr=16000)
        if len(audio) < 1600:  # Less than 0.1s
            return "Error: Audio too short. Please provide audio of at least 1 second."
        if np.max(np.abs(audio)) < 1e-4:  # Too quiet
            return "Error: Audio too quiet. Please provide clear audio describing symptoms."
        
        # Save as WAV for Whisper
        temp_wav = f"/tmp/{os.path.basename(audio_file)}.wav"
        sf.write(temp_wav, audio, sr)
        
        # Transcribe with beam search and language
        with torch.no_grad():
            result = whisper(temp_wav, generate_kwargs={"num_beams": 5, "language": language})
        transcription = result.get("text", "").strip()
        print(f"Transcription: {transcription}")
        
        # Clean up temp file
        try:
            os.remove(temp_wav)
        except Exception:
            pass
        
        if not transcription:
            return "Transcription empty. Please provide clear audio describing symptoms."
        # Check for repetitive transcription
        words = transcription.split()
        if len(words) > 5 and len(set(words)) < len(words) / 2:
            return "Error: Transcription appears repetitive. Please provide clear, non-repetitive audio describing symptoms."
        return transcription
    except Exception as e:
        return f"Error transcribing audio: {str(e)}"

def analyze_symptoms(text):
    """Analyze symptoms using local Symptom-2-Disease model."""
    if not symptom_classifier:
        return "Error: Symptom-2-Disease model not loaded. Check logs for details or ensure sufficient compute resources.", 0.0
    try:
        if not text or "Error transcribing" in text:
            return "No valid transcription for analysis.", 0.0
        with torch.no_grad():
            result = symptom_classifier(text)
        if result and isinstance(result, list) and len(result) > 0:
            prediction = result[0]["label"]
            score = result[0]["score"]
            if is_fallback_model:
                print("Warning: Using fallback model (distilbert-base-uncased). Results may be less accurate.")
                prediction = f"{prediction} (using fallback model)"
            print(f"Health Prediction: {prediction}, Score: {score:.4f}")
            return prediction, score
        return "No health condition predicted", 0.0
    except Exception as e:
        return f"Error analyzing symptoms: {str(e)}", 0.0

def handle_health_query(query, language="en"):
    """Handle health-related queries with a general response."""
    if not query:
        return "Please provide a valid health query."
    # Placeholder for Q&A logic (could integrate a model like BERT for Q&A)
    restricted_terms = ["medicine", "treatment", "drug", "prescription"]
    if any(term in query.lower() for term in restricted_terms):
        return "This tool does not provide medication or treatment advice. Please ask about symptoms or general health information (e.g., 'What are symptoms of asthma?')."
    return f"Response to query '{query}': For accurate health information, consult a healthcare provider."

def analyze_voice(audio_file, language="en"):
    """Analyze voice for health indicators and handle queries."""
    try:
        # Ensure unique file name
        unique_path = f"/tmp/gradio/{datetime.now().strftime('%Y%m%d%H%M%S%f')}_{os.path.basename(audio_file)}"
        os.rename(audio_file, unique_path)
        audio_file = unique_path
        
        # Log audio file info
        file_hash = compute_file_hash(audio_file)
        print(f"Processing audio file: {audio_file}, Hash: {file_hash}")
        
        # Load audio to verify format
        audio, sr = librosa.load(audio_file, sr=16000)
        print(f"Audio shape: {audio.shape}, Sampling rate: {sr}, Duration: {len(audio)/sr:.2f}s, Mean: {np.mean(audio):.4f}, Std: {np.std(audio):.4f}")
        
        # Transcribe audio
        transcription = transcribe_audio(audio_file, language)
        if "Error transcribing" in transcription:
            return transcription
        
        # Split transcription into symptom and query parts
        symptom_text = transcription
        query_text = None
        restricted_terms = ["medicine", "treatment", "drug", "prescription"]
        for term in restricted_terms:
            if term in transcription.lower():
                # Split at the first restricted term
                split_index = transcription.lower().find(term)
                symptom_text = transcription[:split_index].strip()
                query_text = transcription[split_index:].strip()
                break
        
        feedback = ""
        
        # Analyze symptoms if present
        if symptom_text:
            prediction, score = analyze_symptoms(symptom_text)
            if "Error analyzing" in prediction:
                feedback += prediction + "\n"
            elif prediction == "No health condition predicted":
                feedback += "No significant health indicators detected.\n"
            else:
                feedback += f"Possible health condition: {prediction} (confidence: {score:.4f}). Consult a doctor.\n"
        else:
            feedback += "No symptoms detected in the audio.\n"
        
        # Handle query if present
        if query_text:
            feedback += f"\nQuery detected: '{query_text}'\n"
            feedback += handle_health_query(query_text, language) + "\n"
        
        # Add debug info and disclaimer
        feedback += f"\n**Debug Info**: Transcription = '{transcription}', File Hash = {file_hash}"
        feedback += "\n**Disclaimer**: This is not a diagnostic tool. Consult a healthcare provider for medical advice."
        
        # Clean up temporary audio file
        try:
            os.remove(audio_file)
            print(f"Deleted temporary audio file: {audio_file}")
        except Exception as e:
            print(f"Failed to delete audio file: {str(e)}")
        
        return feedback
    except Exception as e:
        return f"Error processing audio: {str(e)}"

# Gradio interface
def create_gradio_interface():
    with gr.Blocks(theme=gr.themes.Soft()) as demo:
        gr.Markdown(
            """
            # Health Voice Analyzer
            Record or upload a voice sample describing symptoms in English, Spanish, Hindi, or Mandarin (e.g., 'I have a fever').  
            Ask health questions in the text box below (e.g., 'What are symptoms of asthma?').  
            **Note**: Do not ask for medication or treatment advice; focus on symptoms or general health questions.  
            **Disclaimer**: This is not a diagnostic tool. Consult a healthcare provider for medical advice.  
            **Text-to-Speech**: Available in the web frontend (Salesforce Sites) using the browser's Web Speech API.
            """
        )
        with gr.Row():
            language = gr.Dropdown(
                choices=["en", "es", "hi", "zh"],
                label="Select Language",
                value="en"
            )
        with gr.Row():
            audio_input = gr.Audio(type="filepath", label="Record or Upload Voice")
        with gr.Row():
            query_input = gr.Textbox(label="Ask a Health Question (e.g., 'What are symptoms of asthma?')")
        with gr.Row():
            output = gr.Textbox(label="Health Assessment Feedback")
        with gr.Row():
            analyze_button = gr.Button("Analyze Voice")
            query_button = gr.Button("Submit Query")
        
        analyze_button.click(
            fn=analyze_voice,
            inputs=[audio_input, language],
            outputs=output
        )
        query_button.click(
            fn=handle_health_query,
            inputs=[query_input, language],
            outputs=output
        )
    
    return demo

if __name__ == "__main__":
    demo = create_gradio_interface()
    demo.launch(server_name="0.0.0.0", server_port=7860)