xaman4 / app.py
salomonsky's picture
Update app.py
eb77a73 verified
raw
history blame
3.35 kB
import streamlit as st
import torch
import torchaudio
import numpy as np
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
from gtts import gTTS
class VoiceAssistant:
def __init__(self):
self.processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-xlsr-53-spanish")
self.model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-xlsr-53-spanish")
self.sample_rate = 16000
self.chunk_size = 480
self.keyword_activation = "jarvis"
self.keyword_deactivation = "detente"
self.listening = False
def vad_collector(self):
audio_chunks, keyword_detected = [], False
with torchaudio.io.AudioStream(sample_rate=self.sample_rate, channels=1, format='wav') as stream:
while self.listening:
try:
data = stream.read(self.chunk_size)
audio_chunk = torch.from_numpy(np.frombuffer(data, dtype=np.float32))
if self.keyword_activation.lower() in str(audio_chunk).lower():
keyword_detected = True
break
if self.keyword_deactivation.lower() in str(audio_chunk).lower():
self.listening = False
break
audio_chunks.append(audio_chunk.numpy())
except Exception as e:
st.error(f"Audio capture error: {e}")
break
return audio_chunks, keyword_detected
def transcribe_audio(self, audio_chunks):
audio_data = np.concatenate(audio_chunks)
input_values = self.processor(audio_data, return_tensors="pt", sampling_rate=self.sample_rate).input_values
with torch.no_grad():
logits = self.model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = self.processor.decode(predicted_ids[0])
return transcription
def generate_response(self, text):
return "Respuesta generada para: " + text
def text_to_speech(self, text):
tts = gTTS(text=text, lang='es')
output_path = "response.mp3"
tts.save(output_path)
return output_path
def run(self):
st.title("Asistente de Voz JARVIS")
if st.button("Iniciar/Detener Escucha"):
self.listening = not self.listening
st.write("Escucha activada." if self.listening else "Escucha desactivada.")
if self.listening:
audio_chunks, keyword_detected = self.vad_collector()
if keyword_detected:
st.success("Palabra clave 'JARVIS' detectada. Procesando...")
transcribed_text = self.transcribe_audio(audio_chunks)
st.write(f"Texto transcrito: {transcribed_text}")
response = self.generate_response(transcribed_text)
st.write(f"Respuesta: {response}")
audio_path = self.text_to_speech(response)
st.audio(audio_path)
def main():
assistant = VoiceAssistant()
assistant.run()
if __name__ == "__main__":
main()