File size: 3,352 Bytes
57f1cca
870d938
 
c91e886
870d938
bf82696
d49accf
a512b51
902252f
0a6ca66
3d57cbf
 
 
0a6ca66
 
870d938
0146d6f
3d57cbf
9e3d7eb
35d0fde
39f14d4
 
e3f9866
7dd4127
c91e886
 
be363ff
c91e886
be363ff
c91e886
be363ff
c91e886
 
be363ff
c91e886
3d57cbf
 
 
0173625
3d57cbf
 
0a6ca66
 
 
 
3d57cbf
bf82696
 
3d57cbf
bf82696
39f14d4
3d57cbf
bf82696
3d57cbf
 
 
6036fba
3d57cbf
 
 
 
 
 
 
 
6036fba
3d57cbf
 
 
 
 
 
 
39f14d4
3d57cbf
580400e
3d57cbf
 
31a303b
c48dd31
7ed9a89
 
 
31a303b
6c2ddbb
bf82696
be363ff
ae120d2
bf82696
7ed9a89
 
 
39f14d4
c91e886
39f14d4
bf82696
 
 
7ed9a89
bf82696
badb078
0a7af93
1938cef
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import streamlit as st
import base64
import io
from huggingface_hub import InferenceClient
from gtts import gTTS
from audiorecorder import audiorecorder
import speech_recognition as sr

pre_prompt_text = ""

if "history" not in st.session_state:
    st.session_state.history = []

if "pre_prompt_sent" not in st.session_state:
    st.session_state.pre_prompt_sent = False

def recognize_speech(audio_data, show_messages=True):
    recognizer = sr.Recognizer()
    audio_recording = sr.AudioFile(audio_data)

    with audio_recording as source:
        audio = recognizer.record(source)

    try:
        audio_text = recognizer.recognize_google(audio, language="es-ES")
        if show_messages:
            st.subheader("Texto Reconocido:")
            st.write(audio_text)
            st.success("Voz reconocida.")
    except sr.UnknownValueError:
        st.warning("No se pudo reconocer el audio. ¿Intentaste grabar algo?")
        audio_text = ""
    except sr.RequestError:
        st.error("¡Presiona/Habla para comenzar!")
        audio_text = ""

    return audio_text

def format_prompt(message, history):
    prompt = "<s>"

    if not st.session_state.pre_prompt_sent:
        prompt += f"[INST] {pre_prompt_text} [/INST]"
        st.session_state.pre_prompt_sent = True

    for user_prompt, bot_response in history:
        prompt += f"[INST] {user_prompt} [/INST]"
        prompt += f" {bot_response}</s> "

    prompt += f"[INST] {message} [/INST]"
    return prompt

def generate(audio_text, history, temperature=None, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0):
    client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")

    temperature = float(temperature) if temperature is not None else 0.9
    temperature = max(temperature, 1e-2)
    top_p = float(top_p)

    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=42)

    formatted_prompt = format_prompt(audio_text, history)
    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=True)
    response = ""

    for response_token in stream:
        response += response_token.token.text
    
    response = ' '.join(response.split()).replace('</s>', '')
    audio_file = text_to_speech(response)
    return response, audio_file

def text_to_speech(text):
    tts = gTTS(text=text, lang='es')
    audio_fp = io.BytesIO()
    tts.write_to_fp(audio_fp)
    audio_fp.seek(0)
    return audio_fp

def main():
    audio_data = audiorecorder("Presiona para Hablar", "Detener Grabación...")

    if not audio_data.empty():
        st.audio(audio_data.export().read(), format="audio/wav")
        audio_data.export("audio.wav", format="wav")
        audio_text = recognize_speech("audio.wav")

        if audio_text:
            output, audio_file = generate(audio_text, history=st.session_state.history)  

            if audio_file is not None:
                st.markdown(
                    f"""<audio autoplay="autoplay" controls="controls" src="data:audio/mp3;base64,{base64.b64encode(audio_file.read()).decode()}" type="audio/mp3" id="audio_player"></audio>""",
                    unsafe_allow_html=True)

if __name__ == "__main__":
    main()