welcometoFightclub commited on
Commit
8e51149
·
verified ·
1 Parent(s): d83d3e2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -110
app.py CHANGED
@@ -1,30 +1,31 @@
1
  import gradio as gr
2
  import torch
3
- import cv2
4
  import speech_recognition as sr
5
  from groq import Groq
6
  import os
7
- import time
8
- import base64
9
- from io import BytesIO
10
- from gtts import gTTS
11
  import tempfile
 
 
 
 
 
 
12
 
13
- # Set device
14
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
15
- print(f"Using device: {device}")
16
 
17
- # Clear GPU memory if using GPU
18
- if torch.cuda.is_available():
19
- torch.cuda.empty_cache()
 
 
20
 
21
- # Grok API client with API key (stored as environment variable for security)
22
- GROQ_API_KEY = os.getenv("GROQ_API_KEY", "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") # Replace with your key or use env variable
23
  try:
24
  client = Groq(api_key=GROQ_API_KEY)
25
- print("Grok client initialized successfully")
26
  except Exception as e:
27
- print(f"Error initializing Groq client: {str(e)}")
28
  raise
29
 
30
  # Functions
@@ -35,72 +36,30 @@ def predict_text_emotion(text):
35
  model="llama3-70b-8192",
36
  messages=[{"role": "user", "content": prompt}],
37
  temperature=1,
38
- max_completion_tokens=64,
39
  top_p=1,
40
  stream=False,
41
- stop=None,
42
  )
43
- return completion.choices[0].message.content
44
  except Exception as e:
45
- return f"Error with Grok API: {str(e)}"
 
46
 
47
  def transcribe_audio(audio_path):
48
  r = sr.Recognizer()
49
- with sr.AudioFile(audio_path) as source:
50
- audio_text = r.listen(source)
51
  try:
 
 
52
  text = r.recognize_google(audio_text)
53
  return text
54
  except sr.UnknownValueError:
55
  return "I didn’t catch that—could you try again?"
56
- except sr.RequestError:
 
57
  return "Speech recognition unavailable—try typing instead."
58
-
59
- def capture_webcam_frame():
60
- cap = cv2.VideoCapture(0)
61
- if not cap.isOpened():
62
- return None
63
- start_time = time.time()
64
- while time.time() - start_time < 2:
65
- ret, frame = cap.read()
66
- if ret:
67
- _, buffer = cv2.imencode('.jpg', frame)
68
- img_base64 = base64.b64encode(buffer).decode('utf-8')
69
- img_url = f"data:image/jpeg;base64,{img_base64}"
70
- cap.release()
71
- return img_url
72
- cap.release()
73
- return None
74
-
75
- def detect_facial_emotion():
76
- img_url = capture_webcam_frame()
77
- if not img_url:
78
- return "neutral"
79
- try:
80
- completion = client.chat.completions.create(
81
- model="llama3-70b-8192",
82
- messages=[
83
- {
84
- "role": "user",
85
- "content": [
86
- {"type": "text", "text": "Identify user's facial emotion into happy or sad or anxious or angry. Respond in one word only"},
87
- {"type": "image_url", "image_url": {"url": img_url}}
88
- ]
89
- }
90
- ],
91
- temperature=1,
92
- max_completion_tokens=20,
93
- top_p=1,
94
- stream=False,
95
- stop=None,
96
- )
97
- emotion = completion.choices[0].message.content.strip().lower()
98
- if emotion not in ["happy", "sad", "anxious", "angry"]:
99
- return "neutral"
100
- return emotion
101
  except Exception as e:
102
- print(f"Error with Grok facial detection: {str(e)}")
103
- return "neutral"
104
 
105
  def generate_response(user_input, emotion):
106
  prompt = f"The user is feeling {emotion}. They said: '{user_input}'. Respond in a friendly caring manner with the user so the user feels being loved."
@@ -109,24 +68,23 @@ def generate_response(user_input, emotion):
109
  model="llama3-70b-8192",
110
  messages=[{"role": "user", "content": prompt}],
111
  temperature=1,
112
- max_completion_tokens=64,
113
  top_p=1,
114
  stream=False,
115
- stop=None,
116
  )
117
  return completion.choices[0].message.content
118
  except Exception as e:
119
- return f"Error with Groq API: {str(e)}"
 
120
 
121
  def text_to_speech(text):
122
  try:
123
  tts = gTTS(text=text, lang='en', slow=False)
124
- # Create a temporary file to store the audio
125
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_audio:
126
  tts.save(temp_audio.name)
127
  return temp_audio.name
128
  except Exception as e:
129
- print(f"Error generating speech: {str(e)}")
130
  return None
131
 
132
  # Chat function for Gradio with voice output
@@ -138,45 +96,36 @@ def chat_function(input_type, text_input, audio_input, chat_history):
138
  else:
139
  return chat_history, "Please provide text or voice input.", gr.update(value=text_input), None
140
 
141
- text_emotion = predict_text_emotion(user_input)
142
- if not chat_history:
143
- gr.Info("Please look at the camera for emotion detection...")
144
- facial_emotion = detect_facial_emotion()
145
- else:
146
- facial_emotion = "neutral"
147
-
148
- emotions = [e for e in [text_emotion, facial_emotion] if e and e != "neutral"]
149
- combined_emotion = emotions[0] if emotions else "neutral"
150
-
151
- response = generate_response(user_input, combined_emotion)
152
  chat_history.append({"role": "user", "content": user_input})
153
  chat_history.append({"role": "assistant", "content": response})
154
 
155
  audio_output = text_to_speech(response)
156
- return chat_history, f"Detected Emotion: {combined_emotion}", "", audio_output
157
 
158
- # Custom CSS for better styling
159
  css = """
160
- <style>
161
- .chatbot .message-user {
162
- background-color: #e3f2fd;
163
- border-radius: 10px;
164
- padding: 10px;
165
- margin: 5px 0;
166
- }
167
- .chatbot .message-assistant {
168
- background-color: #c8e6c9;
169
- border-radius: 10px;
170
- padding: 10px;
171
- margin: 5px 0;
172
- }
173
- .input-container {
174
- padding: 10px;
175
- background-color: #f9f9f9;
176
- border-radius: 10px;
177
- margin-top: 10px;
178
- }
179
- </style>
180
  """
181
 
182
  # Build the Gradio interface
@@ -184,7 +133,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as app:
184
  gr.Markdown(
185
  """
186
  # Multimodal Mental Health AI Agent
187
- Chat with our empathetic AI designed to support you by understanding your emotions through text and facial expressions.
188
  """
189
  )
190
 
@@ -198,7 +147,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as app:
198
  with gr.Row(elem_classes="input-container"):
199
  input_type = gr.Radio(["text", "voice"], label="Input Method", value="text")
200
  text_input = gr.Textbox(label="Type Your Message", placeholder="How are you feeling today?", visible=True)
201
- audio_input = gr.Audio(type="filepath", label="Record Your Message", visible=False)
202
  submit_btn = gr.Button("Send", variant="primary")
203
  clear_btn = gr.Button("Clear Chat", variant="secondary")
204
  audio_output = gr.Audio(label="Assistant Response", type="filepath", interactive=False, autoplay=True)
@@ -223,6 +172,6 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as app:
223
  outputs=[chatbot, emotion_display, text_input, audio_output]
224
  )
225
 
226
- # Launch the app (for local testing)
227
- if __name__ == "__main__":
228
- app.launch(server_name="0.0.0.0", server_port=7860)
 
1
  import gradio as gr
2
  import torch
 
3
  import speech_recognition as sr
4
  from groq import Groq
5
  import os
 
 
 
 
6
  import tempfile
7
+ from gtts import gTTS
8
+ import logging
9
+
10
+ # Set up logging
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
 
14
+ # Set device (CPU only for Hugging Face Spaces free tier)
15
+ device = torch.device("cpu")
16
+ logger.info(f"Using device: {device}")
17
 
18
+ # Groq API client with API key from Hugging Face Secrets
19
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY")
20
+ if not GROQ_API_KEY:
21
+ logger.error("GROQ_API_KEY environment variable not set")
22
+ raise ValueError("GROQ_API_KEY environment variable not set")
23
 
 
 
24
  try:
25
  client = Groq(api_key=GROQ_API_KEY)
26
+ logger.info("Grok client initialized successfully")
27
  except Exception as e:
28
+ logger.error(f"Error initializing Groq client: {str(e)}")
29
  raise
30
 
31
  # Functions
 
36
  model="llama3-70b-8192",
37
  messages=[{"role": "user", "content": prompt}],
38
  temperature=1,
39
+ max_tokens=64,
40
  top_p=1,
41
  stream=False,
 
42
  )
43
+ return completion.choices[0].message.content.strip().lower()
44
  except Exception as e:
45
+ logger.error(f"Error with Groq API (text emotion): {str(e)}")
46
+ return "neutral"
47
 
48
  def transcribe_audio(audio_path):
49
  r = sr.Recognizer()
 
 
50
  try:
51
+ with sr.AudioFile(audio_path) as source:
52
+ audio_text = r.listen(source)
53
  text = r.recognize_google(audio_text)
54
  return text
55
  except sr.UnknownValueError:
56
  return "I didn’t catch that—could you try again?"
57
+ except sr.RequestError as e:
58
+ logger.error(f"Speech recognition error: {str(e)}")
59
  return "Speech recognition unavailable—try typing instead."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  except Exception as e:
61
+ logger.error(f"Unexpected error in audio transcription: {str(e)}")
62
+ return "Error processing audio."
63
 
64
  def generate_response(user_input, emotion):
65
  prompt = f"The user is feeling {emotion}. They said: '{user_input}'. Respond in a friendly caring manner with the user so the user feels being loved."
 
68
  model="llama3-70b-8192",
69
  messages=[{"role": "user", "content": prompt}],
70
  temperature=1,
71
+ max_tokens=64,
72
  top_p=1,
73
  stream=False,
 
74
  )
75
  return completion.choices[0].message.content
76
  except Exception as e:
77
+ logger.error(f"Error with Groq API (response generation): {str(e)}")
78
+ return "I'm here for you, but something went wrong. How can I help?"
79
 
80
  def text_to_speech(text):
81
  try:
82
  tts = gTTS(text=text, lang='en', slow=False)
 
83
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_audio:
84
  tts.save(temp_audio.name)
85
  return temp_audio.name
86
  except Exception as e:
87
+ logger.error(f"Error generating speech: {str(e)}")
88
  return None
89
 
90
  # Chat function for Gradio with voice output
 
96
  else:
97
  return chat_history, "Please provide text or voice input.", gr.update(value=text_input), None
98
 
99
+ emotion = predict_text_emotion(user_input)
100
+ response = generate_response(user_input, emotion)
101
+
102
+ chat_history = chat_history or []
 
 
 
 
 
 
 
103
  chat_history.append({"role": "user", "content": user_input})
104
  chat_history.append({"role": "assistant", "content": response})
105
 
106
  audio_output = text_to_speech(response)
107
+ return chat_history, f"Detected Emotion: {emotion}", "", audio_output
108
 
109
+ # Custom CSS for styling
110
  css = """
111
+ .chatbot .message-user {
112
+ background-color: #e3f2fd;
113
+ border-radius: 10px;
114
+ padding: 10px;
115
+ margin: 5px 0;
116
+ }
117
+ .chatbot .message-assistant {
118
+ background-color: #c8e6c9;
119
+ border-radius: 10px;
120
+ padding: 10px;
121
+ margin: 5px 0;
122
+ }
123
+ .input-container {
124
+ padding: 10px;
125
+ background-color: #f9f9f9;
126
+ border-radius: 10px;
127
+ margin-top: 10px;
128
+ }
 
 
129
  """
130
 
131
  # Build the Gradio interface
 
133
  gr.Markdown(
134
  """
135
  # Multimodal Mental Health AI Agent
136
+ Chat with our empathetic AI designed to support you by understanding your emotions through text and voice.
137
  """
138
  )
139
 
 
147
  with gr.Row(elem_classes="input-container"):
148
  input_type = gr.Radio(["text", "voice"], label="Input Method", value="text")
149
  text_input = gr.Textbox(label="Type Your Message", placeholder="How are you feeling today?", visible=True)
150
+ audio_input = gr.Audio(sources=["microphone"], type="filepath", label="Record Your Message", visible=False)
151
  submit_btn = gr.Button("Send", variant="primary")
152
  clear_btn = gr.Button("Clear Chat", variant="secondary")
153
  audio_output = gr.Audio(label="Assistant Response", type="filepath", interactive=False, autoplay=True)
 
172
  outputs=[chatbot, emotion_display, text_input, audio_output]
173
  )
174
 
175
+ # Launch the app (commented out for Hugging Face Spaces)
176
+ # if __name__ == "__main__":
177
+ # app.launch(server_name="0.0.0.0", server_port=7860)