Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import cv2 | |
import speech_recognition as sr | |
from groq import Groq | |
import os | |
import time | |
import base64 | |
from io import BytesIO | |
from gtts import gTTS | |
import tempfile | |
# Set device | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
print(f"Using device: {device}") | |
# Clear GPU memory if using GPU | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
# Grok API client with API key (stored as environment variable for security) | |
GROQ_API_KEY = os.getenv("GROQ_API_KEY", "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") # Replace with your key or use env variable | |
try: | |
client = Groq(api_key=GROQ_API_KEY) | |
print("Grok client initialized successfully") | |
except Exception as e: | |
print(f"Error initializing Groq client: {str(e)}") | |
raise | |
# Functions | |
def predict_text_emotion(text): | |
prompt = f"The user has entered text '{text}' classify user's emotion as happy or sad or anxious or angry. Respond in only one word." | |
try: | |
completion = client.chat.completions.create( | |
model="llama-3.2-90b-vision-preview", | |
messages=[{"role": "user", "content": prompt}], | |
temperature=1, | |
max_completion_tokens=64, | |
top_p=1, | |
stream=False, | |
stop=None, | |
) | |
return completion.choices[0].message.content | |
except Exception as e: | |
return f"Error with Grok API: {str(e)}" | |
def transcribe_audio(audio_path): | |
r = sr.Recognizer() | |
with sr.AudioFile(audio_path) as source: | |
audio_text = r.listen(source) | |
try: | |
text = r.recognize_google(audio_text) | |
return text | |
except sr.UnknownValueError: | |
return "I didn’t catch that—could you try again?" | |
except sr.RequestError: | |
return "Speech recognition unavailable—try typing instead." | |
def capture_webcam_frame(): | |
cap = cv2.VideoCapture(0) | |
if not cap.isOpened(): | |
return None | |
start_time = time.time() | |
while time.time() - start_time < 2: | |
ret, frame = cap.read() | |
if ret: | |
_, buffer = cv2.imencode('.jpg', frame) | |
img_base64 = base64.b64encode(buffer).decode('utf-8') | |
img_url = f"data:image/jpeg;base64,{img_base64}" | |
cap.release() | |
return img_url | |
cap.release() | |
return None | |
def detect_facial_emotion(): | |
img_url = capture_webcam_frame() | |
if not img_url: | |
return "neutral" | |
try: | |
completion = client.chat.completions.create( | |
model="llama-3.2-90b-vision-preview", | |
messages=[ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "text", "text": "Identify user's facial emotion into happy or sad or anxious or angry. Respond in one word only"}, | |
{"type": "image_url", "image_url": {"url": img_url}} | |
] | |
} | |
], | |
temperature=1, | |
max_completion_tokens=20, | |
top_p=1, | |
stream=False, | |
stop=None, | |
) | |
emotion = completion.choices[0].message.content.strip().lower() | |
if emotion not in ["happy", "sad", "anxious", "angry"]: | |
return "neutral" | |
return emotion | |
except Exception as e: | |
print(f"Error with Grok facial detection: {str(e)}") | |
return "neutral" | |
def generate_response(user_input, emotion): | |
prompt = f"The user is feeling {emotion}. They said: '{user_input}'. Respond in a friendly caring manner with the user so the user feels being loved." | |
try: | |
completion = client.chat.completions.create( | |
model="llama-3.2-90b-vision-preview", | |
messages=[{"role": "user", "content": prompt}], | |
temperature=1, | |
max_completion_tokens=64, | |
top_p=1, | |
stream=False, | |
stop=None, | |
) | |
return completion.choices[0].message.content | |
except Exception as e: | |
return f"Error with Groq API: {str(e)}" | |
def text_to_speech(text): | |
try: | |
tts = gTTS(text=text, lang='en', slow=False) | |
# Create a temporary file to store the audio | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_audio: | |
tts.save(temp_audio.name) | |
return temp_audio.name | |
except Exception as e: | |
print(f"Error generating speech: {str(e)}") | |
return None | |
# Chat function for Gradio with voice output | |
def chat_function(input_type, text_input, audio_input, chat_history): | |
if input_type == "text" and text_input: | |
user_input = text_input | |
elif input_type == "voice" and audio_input: | |
user_input = transcribe_audio(audio_input) | |
else: | |
return chat_history, "Please provide text or voice input.", gr.update(value=text_input), None | |
text_emotion = predict_text_emotion(user_input) | |
if not chat_history: | |
gr.Info("Please look at the camera for emotion detection...") | |
facial_emotion = detect_facial_emotion() | |
else: | |
facial_emotion = "neutral" | |
emotions = [e for e in [text_emotion, facial_emotion] if e and e != "neutral"] | |
combined_emotion = emotions[0] if emotions else "neutral" | |
response = generate_response(user_input, combined_emotion) | |
chat_history.append({"role": "user", "content": user_input}) | |
chat_history.append({"role": "assistant", "content": response}) | |
audio_output = text_to_speech(response) | |
return chat_history, f"Detected Emotion: {combined_emotion}", "", audio_output | |
# Custom CSS for better styling | |
css = """ | |
<style> | |
.chatbot .message-user { | |
background-color: #e3f2fd; | |
border-radius: 10px; | |
padding: 10px; | |
margin: 5px 0; | |
} | |
.chatbot .message-assistant { | |
background-color: #c8e6c9; | |
border-radius: 10px; | |
padding: 10px; | |
margin: 5px 0; | |
} | |
.input-container { | |
padding: 10px; | |
background-color: #f9f9f9; | |
border-radius: 10px; | |
margin-top: 10px; | |
} | |
</style> | |
""" | |
# Build the Gradio interface | |
with gr.Blocks(theme=gr.themes.Soft(), css=css) as app: | |
gr.Markdown( | |
""" | |
# Multimodal Mental Health AI Agent | |
Chat with our empathetic AI designed to support you by understanding your emotions through text and facial expressions. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
emotion_display = gr.Textbox(label="Emotion", interactive=False, placeholder="Detected emotion will appear here") | |
with gr.Column(scale=3): | |
chatbot = gr.Chatbot(label="Conversation History", height=500, type="messages", elem_classes="chatbot") | |
with gr.Row(elem_classes="input-container"): | |
input_type = gr.Radio(["text", "voice"], label="Input Method", value="text") | |
text_input = gr.Textbox(label="Type Your Message", placeholder="How are you feeling today?", visible=True) | |
audio_input = gr.Audio(type="filepath", label="Record Your Message", visible=False) | |
submit_btn = gr.Button("Send", variant="primary") | |
clear_btn = gr.Button("Clear Chat", variant="secondary") | |
audio_output = gr.Audio(label="Assistant Response", type="filepath", interactive=False, autoplay=True) | |
# Dynamic visibility based on input type | |
def update_visibility(input_type): | |
return gr.update(visible=input_type == "text"), gr.update(visible=input_type == "voice") | |
input_type.change(fn=update_visibility, inputs=input_type, outputs=[text_input, audio_input]) | |
# Submit action with voice output | |
submit_btn.click( | |
fn=chat_function, | |
inputs=[input_type, text_input, audio_input, chatbot], | |
outputs=[chatbot, emotion_display, text_input, audio_output] | |
) | |
# Clear chat and audio | |
clear_btn.click( | |
lambda: ([], "", "", None), | |
inputs=None, | |
outputs=[chatbot, emotion_display, text_input, audio_output] | |
) | |
# Launch the app (for local testing) | |
if __name__ == "__main__": | |
app.launch(server_name="0.0.0.0", server_port=7860) |