File size: 3,984 Bytes
a202e44
 
 
e32117f
a34685e
376b54f
e32117f
88f6f66
a34685e
d67f0a9
2dc98a7
 
 
 
 
0f213dd
bdccd83
2a43b85
bdccd83
173c390
bdccd83
 
47759f3
2a43b85
bdccd83
 
a34685e
bdccd83
a34685e
b303f3d
a34685e
b303f3d
 
a34685e
bdccd83
 
 
 
b7431cd
 
 
 
 
 
 
 
 
 
 
 
 
 
d21ed65
a34685e
 
 
 
 
 
 
 
 
 
 
 
 
 
bdccd83
a34685e
 
b967b64
 
a34685e
fb054e7
 
a34685e
b967b64
 
 
d21ed65
fb054e7
b967b64
376b54f
a34685e
 
376b54f
1ba9aea
a34685e
 
 
376b54f
415313d
bdccd83
b967b64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bdccd83
 
a34685e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import streamlit as st
import base64
import io
from huggingface_hub import InferenceClient
from gtts import gTTS
import audiorecorder
import speech_recognition as sr

pre_prompt_text = "You are a behavioral AI, your answers should be brief, stoic and humanistic."

if "history" not in st.session_state:
    st.session_state.history = []

if "pre_prompt_sent" not in st.session_state:
    st.session_state.pre_prompt_sent = False

def recognize_speech(audio_data, show_messages=True):
    recognizer = sr.Recognizer()
    audio_recording = sr.AudioFile(audio_data)

    with audio_recording as source:
        audio = recognizer.record(source)

    try:
        audio_text = recognizer.recognize_google(audio, language="es-ES")
        if show_messages:
            st.subheader("Recognized text:")
            st.write(audio_text)
            st.success("Voice Recognized.")
    except sr.UnknownValueError:
        st.warning("The audio could not be recognized. Did you try to record something?")
        audio_text = ""
    except sr.RequestError:
        st.error("Push/Talk to start!")
        audio_text = ""

    return audio_text

def format_prompt(message, history):
    prompt = "<s>"

    if not st.session_state.pre_prompt_sent:
        prompt += f"[INST] {pre_prompt_text} [/INST]"
        st.session_state.pre_prompt_sent = True

    for user_prompt, bot_response in history:
        prompt += f"[INST] {user_prompt} [/INST]"
        prompt += f" {bot_response}</s> "

    prompt += f"[INST] {message} [/INST]"
    return prompt

def generate(audio_text, history, temperature=None, max_new_tokens=2048, top_p=0.95, repetition_penalty=1.0):
    client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")

    temperature = float(temperature) if temperature is not None else 0.9
    temperature = max(temperature, 1e-2)
    top_p = float(top_p)

    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=42)

    formatted_prompt = format_prompt(audio_text, history)
    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=True)
    response = ""
    response_tokens = []
    total_tokens = 0

    progress_bar = st.progress(0)

    for response_token in stream:
        total_tokens += len(response_token.token.text)
        response_tokens.append(response_token.token.text)
        response = ' '.join(response_tokens).replace('</s>', '')
        progress = min(total_tokens / max_new_tokens, 1.0)  # Asegurar que el progreso esté en el rango [0.0, 1.0]
        progress_bar.progress(progress)

    audio_file = text_to_speech(response)
    return response, audio_file

def text_to_speech(text):
    tts = gTTS(text=text, lang='es')
    audio_fp = io.BytesIO()
    tts.write_to_fp(audio_fp)
    audio_fp.seek(0)
    return audio_fp

def main():
    option = st.radio("Select Input Method:", ("Text", "Voice"))

    if option == "Text":
        prompt = st.text_area("Enter your prompt here:")
    else:
        st.write("Push and hold the button to record.")
        audio_data = audiorecorder.audiorecorder("Push to Talk", "Stop Recording...")

        if not audio_data.empty():
            st.audio(audio_data.export().read(), format="audio/wav")
            audio_data.export("audio.wav", format="wav")
            prompt = recognize_speech("audio.wav")
            st.text("Recognized prompt:")
            st.write(prompt)

    if prompt:
        output, audio_file = generate(prompt, history=st.session_state.history)  

        if audio_file is not None:
            st.markdown(
                f"""<audio autoplay="autoplay" controls="controls" src="data:audio/mp3;base64,{base64.b64encode(audio_file.read()).decode()}" type="audio/mp3" id="audio_player"></audio>""",
                unsafe_allow_html=True)

if __name__ == "__main__":
    main()