File size: 3,984 Bytes
7d2c473
da45dce
 
 
 
88ca80b
0384131
ac7712e
88ca80b
882add6
88ca80b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43a9475
da45dce
 
c501864
da45dce
 
 
c501864
da45dce
 
8307fd0
882add6
3625f99
 
da45dce
 
 
 
8307fd0
da45dce
 
 
 
 
 
 
 
595f1c1
882add6
da45dce
 
8307fd0
da45dce
 
 
 
ca38fef
da45dce
89ff019
ca38fef
68ebbc0
 
 
 
3625f99
da45dce
8fb4803
 
 
 
 
 
 
 
 
68ebbc0
882add6
 
da45dce
8fb4803
88ca80b
 
 
644763d
88ca80b
 
 
 
 
8fb4803
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import streamlit as st
import base64
import io
from huggingface_hub import InferenceClient
from gtts import gTTS
from audiorecorder import audiorecorder
import streamlit_webrtc as webrtc

def recognize_speech(audio_data, show_messages=True):
    recognizer = sr.Recognizer()
    audio_recording = sr.AudioFile(audio_data)

    with audio_recording as source:
        audio = recognizer.record(source)

    try:
        audio_text = recognizer.recognize_google(audio, language="es-ES")
        if show_messages:
            st.subheader("Texto Reconocido:")
            st.write(audio_text)
            st.success("Reconocimiento de voz completado.")
    except sr.UnknownValueError:
        st.warning("No se pudo reconocer el audio. ¿Intentaste grabar algo?")
        audio_text = ""
    except sr.RequestError:
        st.error("No he recibido ningun audio. Por favor, inténtalo de nuevo.")
        audio_text = ""

    return audio_text

def format_prompt(message, history):
    prompt = "<s>"

    for user_prompt, bot_response in history:
        prompt += f"[INST] {user_prompt} [/INST]"
        prompt += f" {bot_response}</s> "

    prompt += f"[INST] {message} [/INST]"
    return prompt

def generate(audio_text, history, temperature=None, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0):
    client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")

    temperature = float(temperature) if temperature is not None else 0.9
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)

    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=42,
    )

    formatted_prompt = format_prompt(audio_text, history)
    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=True)
    response = ""

    for response_token in stream:
        response += response_token.token.text
    
    response = ' '.join(response.split()).replace('</s>', '')
    audio_file = text_to_speech(response, speed=1.3)
    return response, audio_file

def text_to_speech(text, speed=1.3):
    tts = gTTS(text=text, lang='es')
    audio_fp = io.BytesIO()
    tts.write_to_fp(audio_fp)
    audio_fp.seek(0)
    return audio_fp

def detect_vocal_activity(audio_data):
    y, sr = librosa.load(audio_data, sr=None)
    umbral_actividad_vocal = 0.01
    amplitud_media = librosa.feature.rms(y=y)
    actividad_vocal = amplitud_media > umbral_actividad_vocal

    return actividad_vocal


def main():
    if "history" not in st.session_state:
        st.session_state.history = []

    if not audio_data.empty():
        st.audio(audio_data.export().read(), format="audio/wav")
        audio_data.export("audio.wav", format="wav")
        audio_text = recognize_speech("audio.wav")

        if not st.session_state.history:
            pre_prompt = "Te Llamarás Chaman 4.0 y tus respuestas serán sumamente breves."
            output, _ = generate(pre_prompt, history=st.session_state.history)
            st.session_state.history.append((pre_prompt, output))

        if audio_text:
            actividad_vocal = detect_vocal_activity("audio.wav")

            if actividad_vocal.any():
                output, audio_file = generate(audio_text, history=st.session_state.history)

                if audio_text:
                    st.session_state.history.append((audio_text, output))

                if audio_file is not None:
                    st.markdown(
                        f"""
                        <audio autoplay="autoplay" controls="controls" src="data:audio/mp3;base64,{base64.b64encode(audio_file.read()).decode()}" type="audio/mp3" id="audio_player"></audio>
                        """,
                        unsafe_allow_html=True
                    )
            else:
                st.warning("No se detectó actividad vocal.")