Spaces:
Sleeping
Sleeping
import gradio as gr | |
import librosa | |
import numpy as np | |
import os | |
import hashlib | |
from datetime import datetime | |
import soundfile as sf | |
import torch | |
from tenacity import retry, stop_after_attempt, wait_fixed | |
from transformers import pipeline | |
# Initialize local models with retry logic | |
def load_whisper_model(): | |
try: | |
model = pipeline( | |
"automatic-speech-recognition", | |
model="openai/whisper-tiny", # Multilingual model | |
device=-1, # CPU; use device=0 for GPU if available | |
model_kwargs={"use_safetensors": True} | |
) | |
print("Whisper model loaded successfully.") | |
return model | |
except Exception as e: | |
print(f"Failed to load Whisper model: {str(e)}") | |
raise | |
def load_symptom_model(): | |
try: | |
model = pipeline( | |
"text-classification", | |
model="abhirajeshbhai/symptom-2-disease-net", | |
device=-1, # CPU | |
model_kwargs={"use_safetensors": True} | |
) | |
print("Symptom-2-Disease model loaded successfully.") | |
return model | |
except Exception as e: | |
print(f"Failed to load Symptom-2-Disease model: {str(e)}") | |
# Fallback to a generic model | |
try: | |
model = pipeline( | |
"text-classification", | |
model="distilbert-base-uncased", | |
device=-1 | |
) | |
print("Fallback to distilbert-base-uncased model.") | |
return model | |
except Exception as fallback_e: | |
print(f"Fallback model failed: {str(fallback_e)}") | |
raise | |
whisper = None | |
symptom_classifier = None | |
is_fallback_model = False | |
try: | |
whisper = load_whisper_model() | |
except Exception as e: | |
print(f"Whisper model initialization failed after retries: {str(e)}") | |
try: | |
symptom_classifier = load_symptom_model() | |
except Exception as e: | |
print(f"Symptom model initialization failed after retries: {str(e)}") | |
symptom_classifier = None | |
is_fallback_model = True | |
def compute_file_hash(file_path): | |
"""Compute MD5 hash of a file to check uniqueness.""" | |
hash_md5 = hashlib.md5() | |
with open(file_path, "rb") as f: | |
for chunk in iter(lambda: f.read(4096), b""): | |
hash_md5.update(chunk) | |
return hash_md5.hexdigest() | |
def transcribe_audio(audio_file, language="en"): | |
"""Transcribe audio using local Whisper model.""" | |
if not whisper: | |
return "Error: Whisper model not loaded. Check logs for details or ensure sufficient compute resources." | |
try: | |
# Load and validate audio | |
audio, sr = librosa.load(audio_file, sr=16000) | |
if len(audio) < 1600: # Less than 0.1s | |
return "Error: Audio too short. Please provide audio of at least 1 second." | |
if np.max(np.abs(audio)) < 1e-4: # Too quiet | |
return "Error: Audio too quiet. Please provide clear audio describing symptoms." | |
# Save as WAV for Whisper | |
temp_wav = f"/tmp/{os.path.basename(audio_file)}.wav" | |
sf.write(temp_wav, audio, sr) | |
# Transcribe with beam search and language | |
with torch.no_grad(): | |
result = whisper(temp_wav, generate_kwargs={"num_beams": 5, "language": language}) | |
transcription = result.get("text", "").strip() | |
print(f"Transcription: {transcription}") | |
# Clean up temp file | |
try: | |
os.remove(temp_wav) | |
except Exception: | |
pass | |
if not transcription: | |
return "Transcription empty. Please provide clear audio describing symptoms." | |
# Check for repetitive transcription | |
words = transcription.split() | |
if len(words) > 5 and len(set(words)) < len(words) / 2: | |
return "Error: Transcription appears repetitive. Please provide clear, non-repetitive audio describing symptoms." | |
return transcription | |
except Exception as e: | |
return f"Error transcribing audio: {str(e)}" | |
def analyze_symptoms(text): | |
"""Analyze symptoms using local Symptom-2-Disease model.""" | |
if not symptom_classifier: | |
return "Error: Symptom-2-Disease model not loaded. Check logs for details or ensure sufficient compute resources.", 0.0 | |
try: | |
if not text or "Error transcribing" in text: | |
return "No valid transcription for analysis.", 0.0 | |
with torch.no_grad(): | |
result = symptom_classifier(text) | |
if result and isinstance(result, list) and len(result) > 0: | |
prediction = result[0]["label"] | |
score = result[0]["score"] | |
if is_fallback_model: | |
print("Warning: Using fallback model (distilbert-base-uncased). Results may be less accurate.") | |
prediction = f"{prediction} (using fallback model)" | |
print(f"Health Prediction: {prediction}, Score: {score:.4f}") | |
return prediction, score | |
return "No health condition predicted", 0.0 | |
except Exception as e: | |
return f"Error analyzing symptoms: {str(e)}", 0.0 | |
def handle_health_query(query, language="en"): | |
"""Handle health-related queries with a general response.""" | |
if not query: | |
return "Please provide a valid health query." | |
# Placeholder for Q&A logic (could integrate a model like BERT for Q&A) | |
restricted_terms = ["medicine", "treatment", "drug", "prescription"] | |
if any(term in query.lower() for term in restricted_terms): | |
return "This tool does not provide medication or treatment advice. Please ask about symptoms or general health information (e.g., 'What are symptoms of asthma?')." | |
return f"Response to query '{query}': For accurate health information, consult a healthcare provider." | |
def analyze_voice(audio_file, language="en"): | |
"""Analyze voice for health indicators and handle queries.""" | |
try: | |
# Ensure unique file name | |
unique_path = f"/tmp/gradio/{datetime.now().strftime('%Y%m%d%H%M%S%f')}_{os.path.basename(audio_file)}" | |
os.rename(audio_file, unique_path) | |
audio_file = unique_path | |
# Log audio file info | |
file_hash = compute_file_hash(audio_file) | |
print(f"Processing audio file: {audio_file}, Hash: {file_hash}") | |
# Load audio to verify format | |
audio, sr = librosa.load(audio_file, sr=16000) | |
print(f"Audio shape: {audio.shape}, Sampling rate: {sr}, Duration: {len(audio)/sr:.2f}s, Mean: {np.mean(audio):.4f}, Std: {np.std(audio):.4f}") | |
# Transcribe audio | |
transcription = transcribe_audio(audio_file, language) | |
if "Error transcribing" in transcription: | |
return transcription | |
# Split transcription into symptom and query parts | |
symptom_text = transcription | |
query_text = None | |
restricted_terms = ["medicine", "treatment", "drug", "prescription"] | |
for term in restricted_terms: | |
if term in transcription.lower(): | |
# Split at the first restricted term | |
split_index = transcription.lower().find(term) | |
symptom_text = transcription[:split_index].strip() | |
query_text = transcription[split_index:].strip() | |
break | |
feedback = "" | |
# Analyze symptoms if present | |
if symptom_text: | |
prediction, score = analyze_symptoms(symptom_text) | |
if "Error analyzing" in prediction: | |
feedback += prediction + "\n" | |
elif prediction == "No health condition predicted": | |
feedback += "No significant health indicators detected.\n" | |
else: | |
feedback += f"Possible health condition: {prediction} (confidence: {score:.4f}). Consult a doctor.\n" | |
else: | |
feedback += "No symptoms detected in the audio.\n" | |
# Handle query if present | |
if query_text: | |
feedback += f"\nQuery detected: '{query_text}'\n" | |
feedback += handle_health_query(query_text, language) + "\n" | |
# Add debug info and disclaimer | |
feedback += f"\n**Debug Info**: Transcription = '{transcription}', File Hash = {file_hash}" | |
feedback += "\n**Disclaimer**: This is not a diagnostic tool. Consult a healthcare provider for medical advice." | |
# Clean up temporary audio file | |
try: | |
os.remove(audio_file) | |
print(f"Deleted temporary audio file: {audio_file}") | |
except Exception as e: | |
print(f"Failed to delete audio file: {str(e)}") | |
return feedback | |
except Exception as e: | |
return f"Error processing audio: {str(e)}" | |
# Gradio interface | |
def create_gradio_interface(): | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown( | |
""" | |
# Health Voice Analyzer | |
Record or upload a voice sample describing symptoms in English, Spanish, Hindi, or Mandarin (e.g., 'I have a fever'). | |
Ask health questions in the text box below (e.g., 'What are symptoms of asthma?'). | |
**Note**: Do not ask for medication or treatment advice; focus on symptoms or general health questions. | |
**Disclaimer**: This is not a diagnostic tool. Consult a healthcare provider for medical advice. | |
**Text-to-Speech**: Available in the web frontend (Salesforce Sites) using the browser's Web Speech API. | |
""" | |
) | |
with gr.Row(): | |
language = gr.Dropdown( | |
choices=["en", "es", "hi", "zh"], | |
label="Select Language", | |
value="en" | |
) | |
with gr.Row(): | |
audio_input = gr.Audio(type="filepath", label="Record or Upload Voice") | |
with gr.Row(): | |
query_input = gr.Textbox(label="Ask a Health Question (e.g., 'What are symptoms of asthma?')") | |
with gr.Row(): | |
output = gr.Textbox(label="Health Assessment Feedback") | |
with gr.Row(): | |
analyze_button = gr.Button("Analyze Voice") | |
query_button = gr.Button("Submit Query") | |
analyze_button.click( | |
fn=analyze_voice, | |
inputs=[audio_input, language], | |
outputs=output | |
) | |
query_button.click( | |
fn=handle_health_query, | |
inputs=[query_input, language], | |
outputs=output | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = create_gradio_interface() | |
demo.launch(server_name="0.0.0.0", server_port=7860) |