salomonsky commited on
Commit
0f213dd
verified
1 Parent(s): ad653b7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -28
app.py CHANGED
@@ -1,13 +1,10 @@
1
- from huggingface_hub import InferenceClient
2
  from audiorecorder import audiorecorder
3
- import speech_recognition as sr
4
- from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
5
- from datasets import load_dataset
6
  import streamlit as st
7
  import base64
8
- import io
9
  import torch
10
- import os
 
11
 
12
  model = Wav2Vec2ForCTC.from_pretrained("facebook/mms-tts-spa")
13
  pre_prompt_text = "You are a behavioral AI, your answers should be brief, stoic and humanistic."
@@ -17,6 +14,7 @@ if "history" not in st.session_state:
17
 
18
  if "pre_prompt_sent" not in st.session_state:
19
  st.session_state.pre_prompt_sent = False
 
20
  def recognize_speech(audio_data, show_messages=True):
21
  recognizer = sr.Recognizer()
22
  audio_recording = sr.AudioFile(audio_data)
@@ -39,20 +37,6 @@ def recognize_speech(audio_data, show_messages=True):
39
 
40
  return audio_text
41
 
42
- def format_prompt(message, history):
43
- prompt = "<s>"
44
-
45
- if not st.session_state.pre_prompt_sent:
46
- prompt += f"[INST] {pre_prompt_text} [/INST]"
47
- st.session_state.pre_prompt_sent = True
48
-
49
- for user_prompt, bot_response in history:
50
- prompt += f"[INST] {user_prompt} [/INST]"
51
- prompt += f" {bot_response}</s> "
52
-
53
- prompt += f"[INST] {message} [/INST]"
54
- return prompt
55
-
56
  def generate(audio_text, history, temperature=None, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0):
57
  client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
58
 
@@ -74,10 +58,22 @@ def generate(audio_text, history, temperature=None, max_new_tokens=512, top_p=0.
74
 
75
  for response_token in stream:
76
  response += response_token.token.text
77
-
78
- response = ' '.join(response.split()).replace('</s>', '')
79
- audio_file = text_to_speech(response)
80
- return response, audio_file
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  def text_to_speech(text):
83
  with torch.no_grad():
@@ -91,7 +87,7 @@ def text_to_speech(text):
91
  return encoded_audio
92
 
93
  def main():
94
- audio_data = audiorecorder("Push to Play", "Stop Recording...")
95
 
96
  if not audio_data.empty():
97
  st.audio(audio_data.export().read(), format="audio/wav")
@@ -99,11 +95,13 @@ def main():
99
  audio_text = recognize_speech("audio.wav")
100
 
101
  if audio_text:
102
- output, audio_file = generate(audio_text, history=st.session_state.history)
 
 
103
 
104
- if audio_file is not None:
105
  st.markdown(
106
- f"""<audio autoplay="autoplay" controls="controls" src="data:audio/mp3;base64,{base64.b64encode(audio_file.read()).decode()}" type="audio/mp3" id="audio_player"></audio>""",
107
  unsafe_allow_html=True)
108
 
109
  if __name__ == "__main__":
 
 
1
  from audiorecorder import audiorecorder
 
 
 
2
  import streamlit as st
3
  import base64
4
+ import soundfile as sf
5
  import torch
6
+ import speech_recognition as sr
7
+ from transformers import Wav2Vec2ForCTC
8
 
9
  model = Wav2Vec2ForCTC.from_pretrained("facebook/mms-tts-spa")
10
  pre_prompt_text = "You are a behavioral AI, your answers should be brief, stoic and humanistic."
 
14
 
15
  if "pre_prompt_sent" not in st.session_state:
16
  st.session_state.pre_prompt_sent = False
17
+
18
  def recognize_speech(audio_data, show_messages=True):
19
  recognizer = sr.Recognizer()
20
  audio_recording = sr.AudioFile(audio_data)
 
37
 
38
  return audio_text
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  def generate(audio_text, history, temperature=None, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0):
41
  client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
42
 
 
58
 
59
  for response_token in stream:
60
  response += response_token.token.text
61
+
62
+ return response
63
+
64
+ def format_prompt(message, history):
65
+ prompt = "<s>"
66
+
67
+ if not st.session_state.pre_prompt_sent:
68
+ prompt += f"[INST] {pre_prompt_text} [/INST]"
69
+ st.session_state.pre_prompt_sent = True
70
+
71
+ for user_prompt, bot_response in history:
72
+ prompt += f"[INST] {user_prompt} [/INST]"
73
+ prompt += f" {bot_response}</s> "
74
+
75
+ prompt += f"[INST] {message} [/INST]"
76
+ return prompt
77
 
78
  def text_to_speech(text):
79
  with torch.no_grad():
 
87
  return encoded_audio
88
 
89
  def main():
90
+ audio_data = st.audio_recorder("Push to Play", "Stop Recording...")
91
 
92
  if not audio_data.empty():
93
  st.audio(audio_data.export().read(), format="audio/wav")
 
95
  audio_text = recognize_speech("audio.wav")
96
 
97
  if audio_text:
98
+ # Llama a la funci贸n generate para obtener la respuesta generada
99
+ generated_response = generate(audio_text, history=st.session_state.history)
100
+ output = text_to_speech(generated_response)
101
 
102
+ if output is not None:
103
  st.markdown(
104
+ f"""<audio autoplay="autoplay" controls="controls" src="data:audio/mp3;base64,{output}" type="audio/mp3" id="audio_player"></audio>""",
105
  unsafe_allow_html=True)
106
 
107
  if __name__ == "__main__":