salomonsky commited on
Commit
b7431cd
verified
1 Parent(s): 88f6f66

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -41
app.py CHANGED
@@ -1,18 +1,18 @@
1
- from huggingface_hub import InferenceClient
2
- from audiorecorder import audiorecorder
3
  import streamlit as st
4
  import base64
5
- import torch
 
 
 
6
  import speech_recognition as sr
7
- from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, Wav2Vec2CTCTokenizer, AutoTokenizer, AutoModelForSeq2SeqLM
 
 
8
 
9
- processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
10
- model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
11
- tts_model_name = "facebook/mms-tts-spa"
12
- tokenizer = AutoTokenizer.from_pretrained(tts_model_name)
13
- tts_model = AutoModelForSeq2SeqLM.from_pretrained(tts_model_name)
14
 
15
- pre_prompt_text = "You are a behavioral AI, your answers should be brief, stoic and humanistic."
16
 
17
  if "history" not in st.session_state:
18
  st.session_state.history = []
@@ -30,18 +30,32 @@ def recognize_speech(audio_data, show_messages=True):
30
  try:
31
  audio_text = recognizer.recognize_google(audio, language="es-ES")
32
  if show_messages:
33
- st.subheader("Recognized text:")
34
  st.write(audio_text)
35
- st.success("Completed.")
36
  except sr.UnknownValueError:
37
- st.warning("The audio could not be recognized. Did you try to record something?")
38
  audio_text = ""
39
  except sr.RequestError:
40
- st.error("Talk to me to get started!")
41
  audio_text = ""
42
 
43
  return audio_text
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  def generate(audio_text, history, temperature=None, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0):
46
  client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
47
 
@@ -63,33 +77,24 @@ def generate(audio_text, history, temperature=None, max_new_tokens=512, top_p=0.
63
 
64
  for response_token in stream:
65
  response += response_token.token.text
 
 
 
 
66
 
67
- return response
68
-
69
-
70
- def format_prompt(message, history):
71
- prompt = "<s>"
72
-
73
- if not st.session_state.pre_prompt_sent:
74
- prompt += f"[INST] {pre_prompt_text} [/INST]"
75
- st.session_state.pre_prompt_sent = True
76
-
77
- for user_prompt, bot_response in history:
78
- prompt += f"[INST] {user_prompt} [/INST]"
79
- prompt += f" {bot_response}</s> "
80
-
81
- prompt += f"[INST] {message} [/INST]"
82
- return prompt
83
 
84
- def text_to_speech(text):
85
- input_ids = tokenizer(text, return_tensors="pt").input_ids
86
  with torch.no_grad():
87
- logits = tts_model.generate(input_ids)
88
- audio = processor.decode(logits[0], skip_special_tokens=True)
89
- return audio
 
 
 
90
 
91
  def main():
92
- audio_data = st.audio_recorder("Push to Play", "Stop Recording...")
93
 
94
  if not audio_data.empty():
95
  st.audio(audio_data.export().read(), format="audio/wav")
@@ -97,13 +102,11 @@ def main():
97
  audio_text = recognize_speech("audio.wav")
98
 
99
  if audio_text:
100
- # Llama a la funci贸n generate para obtener la respuesta generada
101
- generated_response = generate(audio_text, history=st.session_state.history)
102
- output = text_to_speech(generated_response)
103
 
104
- if output is not None:
105
  st.markdown(
106
- f"""<audio autoplay="autoplay" controls="controls" src="data:audio/mp3;base64,{base64.b64encode(output.encode()).decode()}" type="audio/mp3" id="audio_player"></audio>""",
107
  unsafe_allow_html=True)
108
 
109
  if __name__ == "__main__":
 
 
 
1
  import streamlit as st
2
  import base64
3
+ import io
4
+ from huggingface_hub import InferenceClient
5
+ from gtts import gTTS
6
+ from audiorecorder import audiorecorder
7
  import speech_recognition as sr
8
+ from pydub import AudioSegment
9
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
10
+ import torch
11
 
12
+ tokenizer = AutoTokenizer.from_pretrained("facebook/mms-fastspeech2-spa")
13
+ model = AutoModelForSeq2SeqLM.from_pretrained("facebook/mms-fastspeech2-spa")
 
 
 
14
 
15
+ pre_prompt_text = "Eres una IA conductual, tus respuestas deber谩n ser breves, est贸icas y humanistas."
16
 
17
  if "history" not in st.session_state:
18
  st.session_state.history = []
 
30
  try:
31
  audio_text = recognizer.recognize_google(audio, language="es-ES")
32
  if show_messages:
33
+ st.subheader("Texto Reconocido:")
34
  st.write(audio_text)
35
+ st.success("Reconocimiento de voz completado.")
36
  except sr.UnknownValueError:
37
+ st.warning("No se pudo reconocer el audio. 驴Intentaste grabar algo?")
38
  audio_text = ""
39
  except sr.RequestError:
40
+ st.error("Hablame para comenzar!")
41
  audio_text = ""
42
 
43
  return audio_text
44
 
45
+ def format_prompt(message, history):
46
+ prompt = "<s>"
47
+
48
+ if not st.session_state.pre_prompt_sent:
49
+ prompt += f"[INST] {pre_prompt_text} [/INST]"
50
+ st.session_state.pre_prompt_sent = True
51
+
52
+ for user_prompt, bot_response in history:
53
+ prompt += f"[INST] {user_prompt} [/INST]"
54
+ prompt += f" {bot_response}</s> "
55
+
56
+ prompt += f"[INST] {message} [/INST]"
57
+ return prompt
58
+
59
  def generate(audio_text, history, temperature=None, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0):
60
  client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
61
 
 
77
 
78
  for response_token in stream:
79
  response += response_token.token.text
80
+
81
+ response = ' '.join(response.split()).replace('</s>', '')
82
+ audio_file = text_to_speech(response, speed=1.3)
83
+ return response, audio_file
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
+ def text_to_speech(text, speed=1.3):
87
+ inputs = tokenizer(text, return_tensors="pt", padding="longest", truncation=True)
88
  with torch.no_grad():
89
+ output = model(**inputs)
90
+ audio = output['logits']
91
+ audio_bytes = io.BytesIO()
92
+ torch.save(audio, audio_bytes)
93
+ audio_bytes.seek(0)
94
+ return base64.b64encode(audio_bytes.read()).decode()
95
 
96
  def main():
97
+ audio_data = audiorecorder("Presiona para hablar", "Deteniendo la grabaci贸n...")
98
 
99
  if not audio_data.empty():
100
  st.audio(audio_data.export().read(), format="audio/wav")
 
102
  audio_text = recognize_speech("audio.wav")
103
 
104
  if audio_text:
105
+ output, audio_file = generate(audio_text, history=st.session_state.history)
 
 
106
 
107
+ if audio_file is not None:
108
  st.markdown(
109
+ f"""<audio autoplay="autoplay" controls="controls" src="data:audio/mp3;base64,{base64.b64encode(audio_file.read()).decode()}" type="audio/mp3" id="audio_player"></audio>""",
110
  unsafe_allow_html=True)
111
 
112
  if __name__ == "__main__":