File size: 3,638 Bytes
7d2c473
da45dce
 
56f4168
093b41a
 
 
ac7712e
093b41a
5a44809
 
 
 
 
 
 
 
093b41a
 
 
 
5a44809
 
093b41a
5a44809
 
093b41a
 
 
 
 
 
 
 
 
 
 
 
 
8307fd0
882add6
3625f99
 
da45dce
 
 
 
8307fd0
da45dce
 
 
 
 
 
 
 
595f1c1
093b41a
da45dce
 
8307fd0
da45dce
 
 
 
ca38fef
da45dce
89ff019
ca38fef
68ebbc0
 
 
 
3625f99
da45dce
68ebbc0
882add6
 
da45dce
093b41a
d163c7a
093b41a
 
 
88ca80b
644763d
88ca80b
 
 
 
093b41a
8fb4803
093b41a
d163c7a
093b41a
 
d163c7a
 
 
 
 
 
 
 
5b6fd29
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import streamlit as st
import base64
import io
from huggingface_hub import InferenceClient
from gtts import gTTS
from audiorecorder import audiorecorder
import speech_recognition as sr

def recognize_speech(audio_data, show_messages=True):
    recognizer = sr.Recognizer()
    audio_recording = sr.AudioFile(audio_data)

    with audio_recording as source:
        audio = recognizer.record(source)

    try:
        audio_text = recognizer.recognize_google(audio, language="es-ES")
        if show_messages:
            st.subheader("Texto Reconocido:")
            st.write(audio_text)
            st.success("Reconocimiento de voz completado.")
    except sr.UnknownValueError:
        st.warning("No se pudo reconocer el audio. ¿Intentaste grabar algo?")
        audio_text = ""
    except sr.RequestError:
        st.error("No he recibido ningun audio. Por favor, inténtalo de nuevo.")
        audio_text = ""

    return audio_text

def format_prompt(message, history):
    prompt = "<s>"

    for user_prompt, bot_response in history:
        prompt += f"[INST] {user_prompt} [/INST]"
        prompt += f" {bot_response}</s> "

    prompt += f"[INST] {message} [/INST]"
    return prompt

def generate(audio_text, history, temperature=None, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0):
    client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")

    temperature = float(temperature) if temperature is not None else 0.9
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)

    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=42,
    )

    formatted_prompt = format_prompt(audio_text, history)
    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=True)
    response = ""

    for response_token in stream:
        response += response_token.token.text
    
    response = ' '.join(response.split()).replace('</s>', '')
    audio_file = text_to_speech(response, speed=1.3)
    return response, audio_file

def text_to_speech(text, speed=1.3):
    tts = gTTS(text=text, lang='es')
    audio_fp = io.BytesIO()
    tts.write_to_fp(audio_fp)
    audio_fp.seek(0)
    return audio_fp

def main():
    if "history" not in st.session_state:
        st.session_state.history = []

    audio_data = audiorecorder("Habla para grabar", "Deteniendo la grabación...")

    if not audio_data.empty():
        st.audio(audio_data.export().read(), format="audio/wav")
        audio_data.export("audio.wav", format="wav")
        audio_text = recognize_speech("audio.wav")

        if not st.session_state.history:
            pre_prompt = "Te Llamarás Chaman 4.0 y tus respuestas serán sumamente breves."
            output, _ = generate(pre_prompt, history=st.session_state.history)
            st.session_state.history.append((pre_prompt, output))
        
        if audio_text:
            output, audio_file = generate(audio_text, history=st.session_state.history)  

            if audio_text:  
                st.session_state.history.append((audio_text, output))  

            if audio_file is not None:
                st.markdown(
                    f"""
                    <audio autoplay="autoplay" controls="controls" src="data:audio/mp3;base64,{base64.b64encode(audio_file.read()).decode()}" type="audio/mp3" id="audio_player"></audio>
                    """,
                    unsafe_allow_html=True
                )

if __name__ == "__main__":
    main()