Spaces:
Sleeping
Sleeping
import gradio as gr | |
import librosa | |
import numpy as np | |
import os | |
import hashlib | |
from datetime import datetime | |
from transformers import pipeline | |
import soundfile as sf | |
import torch | |
from tenacity import retry, stop_after_attempt, wait_fixed | |
from gtts import gTTS | |
# Initialize local models with retry logic | |
def load_whisper_model(): | |
try: | |
model = pipeline( | |
"automatic-speech-recognition", | |
model="openai/whisper-tiny.en", | |
device=-1, # CPU; use device=0 for GPU if available | |
model_kwargs={"use_safetensors": True} | |
) | |
print("Whisper model loaded successfully.") | |
return model | |
except Exception as e: | |
print(f"Failed to load Whisper model: {str(e)}") | |
raise | |
def load_symptom_model(): | |
try: | |
model = pipeline( | |
"text-classification", | |
model="abhirajeshbhai/symptom-2-disease-net", | |
device=-1, # CPU | |
model_kwargs={"use_safetensors": True} | |
) | |
print("Symptom-2-Disease model loaded successfully.") | |
return model | |
except Exception as e: | |
print(f"Failed to load Symptom-2-Disease model: {str(e)}") | |
# Fallback to a generic model | |
try: | |
model = pipeline( | |
"text-classification", | |
model="distilbert-base-uncased", | |
device=-1 | |
) | |
print("Fallback to distilbert-base-uncased model.") | |
return model | |
except Exception as fallback_e: | |
print(f"Fallback model failed: {str(fallback_e)}") | |
raise | |
whisper = None | |
symptom_classifier = None | |
is_fallback_model = False | |
try: | |
whisper = load_whisper_model() | |
except Exception as e: | |
print(f"Whisper model initialization failed after retries: {str(e)}") | |
try: | |
symptom_classifier = load_symptom_model() | |
except Exception as e: | |
print(f"Symptom model initialization failed after retries: {str(e)}") | |
symptom_classifier = None | |
is_fallback_model = True # Track if fallback model is used | |
def compute_file_hash(file_path): | |
"""Compute MD5 hash of a file to check uniqueness.""" | |
hash_md5 = hashlib.md5() | |
with open(file_path, "rb") as f: | |
for chunk in iter(lambda: f.read(4096), b""): | |
hash_md5.update(chunk) | |
return hash_md5.hexdigest() | |
def transcribe_audio(audio_file): | |
"""Transcribe audio using local Whisper model.""" | |
if not whisper: | |
return "Error: Whisper model not loaded. Check logs for details or ensure sufficient compute resources." | |
try: | |
# Load and validate audio | |
audio, sr = librosa.load(audio_file, sr=16000) | |
if len(audio) < 1600: # Less than 0.1s | |
return "Error: Audio too short. Please provide audio of at least 1 second." | |
if np.max(np.abs(audio)) < 1e-4: # Too quiet | |
return "Error: Audio too quiet. Please provide clear audio describing symptoms in English." | |
# Save as WAV for Whisper | |
temp_wav = f"/tmp/{os.path.basename(audio_file)}.wav" | |
sf.write(temp_wav, audio, sr) | |
# Transcribe with beam search | |
with torch.no_grad(): | |
result = whisper(temp_wav, generate_kwargs={"num_beams": 5}) | |
transcription = result.get("text", "").strip() | |
print(f"Transcription: {transcription}") | |
# Clean up temp file | |
try: | |
os.remove(temp_wav) | |
except Exception: | |
pass | |
if not transcription: | |
return "Transcription empty. Please provide clear audio describing symptoms in English." | |
# Check for repetitive transcription | |
words = transcription.split() | |
if len(words) > 5 and len(set(words)) < len(words) / 2: | |
return "Error: Transcription appears repetitive. Please provide clear, non-repetitive audio describing symptoms." | |
return transcription | |
except Exception as e: | |
return f"Error transcribing audio: {str(e)}" | |
def analyze_symptoms(text): | |
"""Analyze symptoms using local Symptom-2-Disease model.""" | |
if not symptom_classifier: | |
return "Error: Symptom-2-Disease model not loaded. Check logs for details or ensure sufficient compute resources.", 0.0 | |
try: | |
if not text or "Error transcribing" in text: | |
return "No valid transcription for analysis.", 0.0 | |
with torch.no_grad(): | |
result = symptom_classifier(text) | |
if result and isinstance(result, list) and len(result) > 0: | |
prediction = result[0]["label"] | |
score = result[0]["score"] | |
if is_fallback_model: | |
print("Warning: Using fallback model (distilbert-base-uncased). Results may be less accurate.") | |
prediction = f"{prediction} (using fallback model)" | |
print(f"Health Prediction: {prediction}, Score: {score:.4f}") | |
return prediction, score | |
return "No health condition predicted", 0.0 | |
except Exception as e: | |
return f"Error analyzing symptoms: {str(e)}", 0.0 | |
def generate_voice_feedback(text): | |
"""Generate voice feedback from text using gTTS.""" | |
try: | |
# Remove debug info and disclaimer for cleaner voice output | |
clean_text = text.split("\n\n**Debug Info**")[0] | |
tts = gTTS(text=clean_text, lang='en') | |
output_file = f"/tmp/feedback_{datetime.now().strftime('%Y%m%d%H%M%S%f')}.mp3" | |
tts.save(output_file) | |
return output_file | |
except Exception as e: | |
print(f"Error generating voice feedback: {str(e)}") | |
return None | |
def analyze_voice(audio_file): | |
"""Analyze voice for health indicators and provide text and voice feedback.""" | |
try: | |
# Check if audio_file is None | |
if audio_file is None: | |
feedback = "Error: No audio provided. Please record or upload a valid audio file." | |
voice_file = generate_voice_feedback(feedback) | |
return feedback, voice_file | |
# Ensure unique file name to avoid Gradio reuse | |
unique_path = f"/tmp/gradio/{datetime.now().strftime('%Y%m%d%H%M%S%f')}_{os.path.basename(audio_file)}" | |
os.rename(audio_file, unique_path) | |
audio_file = unique_path | |
# Log audio file info | |
file_hash = compute_file_hash(audio_file) | |
print(f"Processing audio file: {audio_file}, Hash: {file_hash}") | |
# Load audio to verify format | |
audio, sr = librosa.load(audio_file, sr=16000) | |
print(f"Audio shape: {audio.shape}, Sampling rate: {sr}, Duration: {len(audio)/sr:.2f}s, Mean: {np.mean(audio):.4f}, Std: {np.std(audio):.4f}") | |
# Transcribe audio | |
transcription = transcribe_audio(audio_file) | |
if "Error transcribing" in transcription: | |
voice_file = generate_voice_feedback(transcription) | |
return transcription, voice_file | |
# Check for medication-related queries | |
if "medicine" in transcription.lower() or "treatment" in transcription.lower(): | |
feedback = "Error: This tool does not provide medication or treatment advice. Please describe symptoms only (e.g., 'I have a fever')." | |
feedback += f"\n\n**Debug Info**: Transcription = '{transcription}', File Hash = {file_hash}" | |
feedback += "\n**Disclaimer**: This is not a diagnostic tool. Consult a healthcare provider for medical advice." | |
voice_file = generate_voice_feedback(feedback) | |
return feedback, voice_file | |
# Analyze symptoms | |
prediction, score = analyze_symptoms(transcription) | |
if "Error analyzing" in prediction: | |
voice_file = generate_voice_feedback(prediction) | |
return prediction, voice_file | |
# Generate feedback | |
if prediction == "No health condition predicted": | |
feedback = "No significant health indicators detected." | |
else: | |
feedback = f"Possible health condition: {prediction} (confidence: {score:.4f}). Consult a doctor." | |
feedback += f"\n\n**Debug Info**: Transcription = '{transcription}', Prediction = {prediction}, Confidence = {score:.4f}, File Hash = {file_hash}" | |
feedback += "\n**Disclaimer**: This is not a diagnostic tool. Consult a healthcare provider for medical advice." | |
# Generate voice feedback | |
voice_file = generate_voice_feedback(feedback) | |
# Clean up temporary audio file | |
try: | |
os.remove(audio_file) | |
print(f"Deleted temporary audio file: {audio_file}") | |
except Exception as e: | |
print(f"Failed to delete audio file: {str(e)}") | |
return feedback, voice_file | |
except Exception as e: | |
feedback = f"Error processing audio: {str(e)}" | |
voice_file = generate_voice_feedback(feedback) | |
return feedback, voice_file | |
def test_with_sample_audio(): | |
"""Test the app with sample audio files.""" | |
samples = ["audio_samples/sample.wav", "audio_samples/common_voice_en.wav"] | |
results = [] | |
for sample in samples: | |
if os.path.exists(sample): | |
text, voice = analyze_voice(sample) | |
results.append(f"Text: {text}\nVoice: {voice}") | |
else: | |
results.append(f"Sample not found: {sample}") | |
return "\n".join(results) | |
# Gradio interface | |
iface = gr.Interface( | |
fn=analyze_voice, | |
inputs=gr.Audio(type="filepath", label="Record or Upload Voice"), | |
outputs=[ | |
gr.Textbox(label="Health Assessment Feedback"), | |
gr.Audio(label="Voice Feedback", type="filepath") | |
], | |
title="Health Voice Analyzer", | |
description="Record or upload a voice sample describing symptoms (e.g., 'I have a fever') for preliminary health assessment. Supports English only. Use clear audio (WAV, 16kHz). Do not ask for medication or treatment advice." | |
) | |
if __name__ == "__main__": | |
print(test_with_sample_audio()) | |
iface.launch(server_name="0.0.0.0", server_port=7860) |