salomonsky commited on
Commit
e89f269
·
verified ·
1 Parent(s): a8fb390

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -38
app.py CHANGED
@@ -1,11 +1,12 @@
1
  import streamlit as st
2
  import base64
3
  import io
4
- import audiorecorder
5
- import speech_recognition as sr
6
  from huggingface_hub import InferenceClient
 
 
 
7
 
8
- pre_prompt_text = "You are a behavioral AI, your answers should be brief, stoic and humanistic."
9
 
10
  if "history" not in st.session_state:
11
  st.session_state.history = []
@@ -49,7 +50,7 @@ def format_prompt(message, history):
49
  prompt += f"[INST] {message} [/INST]"
50
  return prompt
51
 
52
- def generate(audio_text, history, temperature=None, max_new_tokens=2048, top_p=0.95, repetition_penalty=1.0):
53
  client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
54
 
55
  temperature = float(temperature) if temperature is not None else 0.9
@@ -67,19 +68,12 @@ def generate(audio_text, history, temperature=None, max_new_tokens=2048, top_p=0
67
  formatted_prompt = format_prompt(audio_text, history)
68
  stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=True)
69
  response = ""
70
- response_tokens = []
71
- total_tokens = 0
72
-
73
- progress_bar = st.progress(0)
74
 
75
  for response_token in stream:
76
- total_tokens += len(response_token.token.text)
77
- response_tokens.append(response_token.token.text)
78
- response = ' '.join(response_tokens).replace('</s>', '')
79
- progress = min(total_tokens / max_new_tokens, 1.0) # Asegurar que el progreso esté en el rango [0.0, 1.0]
80
- progress_bar.progress(progress)
81
-
82
- audio_file = text_to_speech(response)
83
  return response, audio_file
84
 
85
  def text_to_speech(text):
@@ -90,28 +84,46 @@ def text_to_speech(text):
90
  return audio_fp
91
 
92
  def main():
93
- option = st.radio("Select Input Method:", ("Text", "Voice"))
94
-
95
- if option == "Text":
96
- prompt = st.text_area("Enter your prompt here:")
97
- else:
98
- st.write("Push and hold the button to record.")
99
- audio_data = audiorecorder.audiorecorder("Push to Talk", "Stop Recording...")
100
-
101
- if not audio_data.empty():
102
- st.audio(audio_data.export().read(), format="audio/wav")
103
- audio_data.export("audio.wav", format="wav")
104
- prompt = recognize_speech("audio.wav")
105
- st.text("Recognized prompt:")
106
- st.write(prompt)
107
-
108
- if prompt:
109
- output, audio_file = generate(prompt, history=st.session_state.history)
110
-
111
- if audio_file is not None:
112
- st.markdown(
113
- f"""<audio autoplay="autoplay" controls="controls" src="data:audio/mp3;base64,{base64.b64encode(audio_file.read()).decode()}" type="audio/mp3" id="audio_player"></audio>""",
114
- unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
  if __name__ == "__main__":
117
- main()
 
1
  import streamlit as st
2
  import base64
3
  import io
 
 
4
  from huggingface_hub import InferenceClient
5
+ from gtts import gTTS
6
+ from audiorecorder import audiorecorder
7
+ import speech_recognition as sr
8
 
9
+ pre_prompt_text = ""
10
 
11
  if "history" not in st.session_state:
12
  st.session_state.history = []
 
50
  prompt += f"[INST] {message} [/INST]"
51
  return prompt
52
 
53
+ def generate(audio_text, history, temperature=None, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0):
54
  client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
55
 
56
  temperature = float(temperature) if temperature is not None else 0.9
 
68
  formatted_prompt = format_prompt(audio_text, history)
69
  stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=True)
70
  response = ""
 
 
 
 
71
 
72
  for response_token in stream:
73
+ response += response_token.token.text
74
+
75
+ response = ' '.join(response.split()).replace('</s>', '')
76
+ audio_file = text_to_speech(response, speed=1.3)
 
 
 
77
  return response, audio_file
78
 
79
  def text_to_speech(text):
 
84
  return audio_fp
85
 
86
  def main():
87
+ audio_data = audiorecorder("Push to Talk", "Stop Recording...")
88
+
89
+ if not audio_data.empty():
90
+ st.audio(audio_data.export().read(), format="audio/wav")
91
+ audio_data.export("audio.wav", format="wav")
92
+ audio_text = recognize_speech("audio.wav")
93
+
94
+ if audio_text:
95
+ output, audio_file = generate(audio_text, history=st.session_state.history)
96
+
97
+ if audio_file is not None:
98
+ st.markdown(
99
+ f"""<audio autoplay="autoplay" controls="controls" src="data:audio/mp3;base64,{base64.b64encode(audio_file.read()).decode()}" type="audio/mp3" id="audio_player"></audio>""",
100
+ unsafe_allow_html=True)
101
+
102
+ # Voice Activity Detection (VAD) using speech_recognition
103
+ def vad_loop():
104
+ command = ""
105
+ while True:
106
+ with sr.Microphone() as source:
107
+ st.write("Listening...")
108
+ recognizer = sr.Recognizer()
109
+ recognizer.adjust_for_ambient_noise(source)
110
+ audio = recognizer.listen(source)
111
+
112
+ try:
113
+ command = recognizer.recognize_google(audio, language="en-US")
114
+ st.write(f"Command: {command}")
115
+ except sr.UnknownValueError:
116
+ st.write("Could not understand audio")
117
+ except sr.RequestError as e:
118
+ st.write(f"Error: {e}")
119
+
120
+ if command.lower() == "xaman":
121
+ st.write("Voice capture activated. Say 'Detente', 'Alto', or 'Basta' to stop.")
122
+ audio_data = audiorecorder("Push to Talk", "Stop Recording...")
123
+ main()
124
+ elif command.lower() in ["detente", "alto", "basta"]:
125
+ st.write("Voice capture stopped.")
126
+ break
127
 
128
  if __name__ == "__main__":
129
+ vad_loop()