salomonsky commited on
Commit
a34685e
verified
1 Parent(s): 92cff17

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -39
app.py CHANGED
@@ -1,20 +1,12 @@
1
  import streamlit as st
2
  import base64
3
  import io
4
- import torch
5
- import os
6
  from huggingface_hub import InferenceClient
7
- from audiorecorder import audiorecorder
 
8
  import speech_recognition as sr
9
- from transformers import T5ForConditionalGeneration, T5Tokenizer
10
- import numpy as np
11
- from scipy.io.wavfile import write
12
- from pydub import AudioSegment
13
 
14
- model = T5ForConditionalGeneration.from_pretrained("facebook/mms-tts-spa")
15
- tokenizer = T5Tokenizer.from_pretrained("facebook/mms-tts-spa")
16
-
17
- pre_prompt_text = "Eres una IA conductual, tus respuestas deber谩n ser breves, est贸icas y humanistas."
18
 
19
  if "history" not in st.session_state:
20
  st.session_state.history = []
@@ -32,14 +24,14 @@ def recognize_speech(audio_data, show_messages=True):
32
  try:
33
  audio_text = recognizer.recognize_google(audio, language="es-ES")
34
  if show_messages:
35
- st.subheader("Texto Reconocido:")
36
  st.write(audio_text)
37
- st.success("Reconocimiento de voz completado.")
38
  except sr.UnknownValueError:
39
- st.warning("No se pudo reconocer el audio. 驴Intentaste grabar algo?")
40
  audio_text = ""
41
  except sr.RequestError:
42
- st.error("Hablame para comenzar!")
43
  audio_text = ""
44
 
45
  return audio_text
@@ -58,16 +50,41 @@ def format_prompt(message, history):
58
  prompt += f"[INST] {message} [/INST]"
59
  return prompt
60
 
61
- def generate(audio_text, history):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  formatted_prompt = format_prompt(audio_text, history)
63
- inputs = tokenizer(formatted_prompt, return_tensors="pt", padding=True, truncation=True)
64
- with torch.no_grad():
65
- output = model.generate(**inputs)
66
- audio = output['audio']
67
- return audio
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  def main():
70
- audio_data = audiorecorder("Presiona para hablar", "Deteniendo la grabaci贸n...")
71
 
72
  if not audio_data.empty():
73
  st.audio(audio_data.export().read(), format="audio/wav")
@@ -75,24 +92,12 @@ def main():
75
  audio_text = recognize_speech("audio.wav")
76
 
77
  if audio_text:
78
- audio_file = generate(audio_text, history=st.session_state.history)
79
 
80
  if audio_file is not None:
81
- # Guardar el archivo WAV
82
- write("output.wav", model.config.sampling_rate, audio_file)
83
-
84
- # Convertir el archivo WAV a MP3 utilizando pydub
85
- audio = AudioSegment.from_wav("output.wav")
86
- audio.export("output.mp3", format="mp3")
87
-
88
- # Leer el archivo MP3 y mostrarlo en Streamlit
89
- with open("output.mp3", "rb") as file:
90
- audio_bytes = file.read()
91
- st.audio(audio_bytes, format="audio/mp3")
92
-
93
- # Eliminar archivos temporales (opcional)
94
- os.remove("output.wav")
95
- os.remove("output.mp3")
96
 
97
  if __name__ == "__main__":
98
- main()
 
1
  import streamlit as st
2
  import base64
3
  import io
 
 
4
  from huggingface_hub import InferenceClient
5
+ from gtts import gTTS
6
+ import audiorecorder
7
  import speech_recognition as sr
 
 
 
 
8
 
9
+ pre_prompt_text = "You are a behavioral AI, your answers should be brief, stoic and humanistic."
 
 
 
10
 
11
  if "history" not in st.session_state:
12
  st.session_state.history = []
 
24
  try:
25
  audio_text = recognizer.recognize_google(audio, language="es-ES")
26
  if show_messages:
27
+ st.subheader("Recognized text:")
28
  st.write(audio_text)
29
+ st.success("Voice Recognized.")
30
  except sr.UnknownValueError:
31
+ st.warning("The audio could not be recognized. Did you try to record something?")
32
  audio_text = ""
33
  except sr.RequestError:
34
+ st.error("Push/Talk to start!")
35
  audio_text = ""
36
 
37
  return audio_text
 
50
  prompt += f"[INST] {message} [/INST]"
51
  return prompt
52
 
53
+ def generate(audio_text, history, temperature=None, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0):
54
+ client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
55
+
56
+ temperature = float(temperature) if temperature is not None else 0.9
57
+ temperature = max(temperature, 1e-2)
58
+ top_p = float(top_p)
59
+
60
+ generate_kwargs = dict(
61
+ temperature=temperature,
62
+ max_new_tokens=max_new_tokens,
63
+ top_p=top_p,
64
+ repetition_penalty=repetition_penalty,
65
+ do_sample=True,
66
+ seed=42)
67
+
68
  formatted_prompt = format_prompt(audio_text, history)
69
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=True)
70
+ response = ""
71
+
72
+ for response_token in stream:
73
+ response += response_token.token.text
74
+
75
+ response = ' '.join(response.split()).replace('</s>', '')
76
+ audio_file = text_to_speech(response)
77
+ return response, audio_file
78
+
79
+ def text_to_speech(text):
80
+ tts = gTTS(text=text, lang='es')
81
+ audio_fp = io.BytesIO()
82
+ tts.write_to_fp(audio_fp)
83
+ audio_fp.seek(0)
84
+ return audio_fp
85
 
86
  def main():
87
+ audio_data = audiorecorder.audiorecorder("Push to Talk", "Stop Recording...")
88
 
89
  if not audio_data.empty():
90
  st.audio(audio_data.export().read(), format="audio/wav")
 
92
  audio_text = recognize_speech("audio.wav")
93
 
94
  if audio_text:
95
+ output, audio_file = generate(audio_text, history=st.session_state.history)
96
 
97
  if audio_file is not None:
98
+ st.markdown(
99
+ f"""<audio autoplay="autoplay" controls="controls" src="data:audio/mp3;base64,{base64.b64encode(audio_file.read()).decode()}" type="audio/mp3" id="audio_player"></audio>""",
100
+ unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
  if __name__ == "__main__":
103
+ main()