Hack_the_spring / app.py
welcometoFightclub's picture
Update app.py
303fc54 verified
import gradio as gr
import torch
import cv2
import speech_recognition as sr
from groq import Groq
import os
import time
import base64
from io import BytesIO
from gtts import gTTS
import tempfile
# Set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Clear GPU memory if using GPU
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Grok API client with API key (stored as environment variable for security)
GROQ_API_KEY = os.getenv("GROQ_API_KEY", "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") # Replace with your key or use env variable
try:
client = Groq(api_key=GROQ_API_KEY)
print("Grok client initialized successfully")
except Exception as e:
print(f"Error initializing Groq client: {str(e)}")
raise
# Functions
def predict_text_emotion(text):
prompt = f"The user has entered text '{text}' classify user's emotion as happy or sad or anxious or angry. Respond in only one word."
try:
completion = client.chat.completions.create(
model="llama-3.2-90b-vision-preview",
messages=[{"role": "user", "content": prompt}],
temperature=1,
max_completion_tokens=64,
top_p=1,
stream=False,
stop=None,
)
return completion.choices[0].message.content
except Exception as e:
return f"Error with Grok API: {str(e)}"
def transcribe_audio(audio_path):
r = sr.Recognizer()
with sr.AudioFile(audio_path) as source:
audio_text = r.listen(source)
try:
text = r.recognize_google(audio_text)
return text
except sr.UnknownValueError:
return "I didn’t catch that—could you try again?"
except sr.RequestError:
return "Speech recognition unavailable—try typing instead."
def capture_webcam_frame():
cap = cv2.VideoCapture(0)
if not cap.isOpened():
return None
start_time = time.time()
while time.time() - start_time < 2:
ret, frame = cap.read()
if ret:
_, buffer = cv2.imencode('.jpg', frame)
img_base64 = base64.b64encode(buffer).decode('utf-8')
img_url = f"data:image/jpeg;base64,{img_base64}"
cap.release()
return img_url
cap.release()
return None
def detect_facial_emotion():
img_url = capture_webcam_frame()
if not img_url:
return "neutral"
try:
completion = client.chat.completions.create(
model="llama-3.2-90b-vision-preview",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "Identify user's facial emotion into happy or sad or anxious or angry. Respond in one word only"},
{"type": "image_url", "image_url": {"url": img_url}}
]
}
],
temperature=1,
max_completion_tokens=20,
top_p=1,
stream=False,
stop=None,
)
emotion = completion.choices[0].message.content.strip().lower()
if emotion not in ["happy", "sad", "anxious", "angry"]:
return "neutral"
return emotion
except Exception as e:
print(f"Error with Grok facial detection: {str(e)}")
return "neutral"
def generate_response(user_input, emotion):
prompt = f"The user is feeling {emotion}. They said: '{user_input}'. Respond in a friendly caring manner with the user so the user feels being loved."
try:
completion = client.chat.completions.create(
model="llama-3.2-90b-vision-preview",
messages=[{"role": "user", "content": prompt}],
temperature=1,
max_completion_tokens=64,
top_p=1,
stream=False,
stop=None,
)
return completion.choices[0].message.content
except Exception as e:
return f"Error with Groq API: {str(e)}"
def text_to_speech(text):
try:
tts = gTTS(text=text, lang='en', slow=False)
# Create a temporary file to store the audio
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_audio:
tts.save(temp_audio.name)
return temp_audio.name
except Exception as e:
print(f"Error generating speech: {str(e)}")
return None
# Chat function for Gradio with voice output
def chat_function(input_type, text_input, audio_input, chat_history):
if input_type == "text" and text_input:
user_input = text_input
elif input_type == "voice" and audio_input:
user_input = transcribe_audio(audio_input)
else:
return chat_history, "Please provide text or voice input.", gr.update(value=text_input), None
text_emotion = predict_text_emotion(user_input)
if not chat_history:
gr.Info("Please look at the camera for emotion detection...")
facial_emotion = detect_facial_emotion()
else:
facial_emotion = "neutral"
emotions = [e for e in [text_emotion, facial_emotion] if e and e != "neutral"]
combined_emotion = emotions[0] if emotions else "neutral"
response = generate_response(user_input, combined_emotion)
chat_history.append({"role": "user", "content": user_input})
chat_history.append({"role": "assistant", "content": response})
audio_output = text_to_speech(response)
return chat_history, f"Detected Emotion: {combined_emotion}", "", audio_output
# Custom CSS for better styling
css = """
<style>
.chatbot .message-user {
background-color: #e3f2fd;
border-radius: 10px;
padding: 10px;
margin: 5px 0;
}
.chatbot .message-assistant {
background-color: #c8e6c9;
border-radius: 10px;
padding: 10px;
margin: 5px 0;
}
.input-container {
padding: 10px;
background-color: #f9f9f9;
border-radius: 10px;
margin-top: 10px;
}
</style>
"""
# Build the Gradio interface
with gr.Blocks(theme=gr.themes.Soft(), css=css) as app:
gr.Markdown(
"""
# Multimodal Mental Health AI Agent
Chat with our empathetic AI designed to support you by understanding your emotions through text and facial expressions.
"""
)
with gr.Row():
with gr.Column(scale=1):
emotion_display = gr.Textbox(label="Emotion", interactive=False, placeholder="Detected emotion will appear here")
with gr.Column(scale=3):
chatbot = gr.Chatbot(label="Conversation History", height=500, type="messages", elem_classes="chatbot")
with gr.Row(elem_classes="input-container"):
input_type = gr.Radio(["text", "voice"], label="Input Method", value="text")
text_input = gr.Textbox(label="Type Your Message", placeholder="How are you feeling today?", visible=True)
audio_input = gr.Audio(type="filepath", label="Record Your Message", visible=False)
submit_btn = gr.Button("Send", variant="primary")
clear_btn = gr.Button("Clear Chat", variant="secondary")
audio_output = gr.Audio(label="Assistant Response", type="filepath", interactive=False, autoplay=True)
# Dynamic visibility based on input type
def update_visibility(input_type):
return gr.update(visible=input_type == "text"), gr.update(visible=input_type == "voice")
input_type.change(fn=update_visibility, inputs=input_type, outputs=[text_input, audio_input])
# Submit action with voice output
submit_btn.click(
fn=chat_function,
inputs=[input_type, text_input, audio_input, chatbot],
outputs=[chatbot, emotion_display, text_input, audio_output]
)
# Clear chat and audio
clear_btn.click(
lambda: ([], "", "", None),
inputs=None,
outputs=[chatbot, emotion_display, text_input, audio_output]
)
# Launch the app (for local testing)
if __name__ == "__main__":
app.launch(server_name="0.0.0.0", server_port=7860)