File size: 4,027 Bytes
3d57cbf
1760204
3d57cbf
7255a1d
d49accf
7255a1d
8842007
d6b9b98
 
 
 
 
 
84e2e9f
3d57cbf
 
 
6f460e4
 
 
d6b9b98
3d57cbf
 
d6b9b98
 
 
 
 
 
 
 
 
 
 
 
 
3d57cbf
 
 
0173625
3d57cbf
 
0173625
6f460e4
 
3d57cbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
574c2e1
 
 
 
 
 
 
 
 
badb078
 
 
 
713e319
badb078
d6b9b98
 
 
 
 
 
 
 
 
 
 
 
 
badb078
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import io
import base64
from gtts import gTTS
import streamlit as st
import speech_recognition as sr
from huggingface_hub import InferenceClient
from streamlit_mic_recorder import mic_recorder
import wave
import numpy as np
import os

pre_prompt_text = "eres una IA conductual, tus respuestas serán breves."
temp_audio_file_path = "./output.wav"

if "history" not in st.session_state:
    st.session_state.history = []

if "pre_prompt_sent" not in st.session_state:
    st.session_state.pre_prompt_sent = False

def recognize_speech(audio_data, show_messages=True):
    recognizer = sr.Recognizer()

    with io.BytesIO(audio_data) as audio_file:
        try:
            audio_text = recognizer.recognize_google(audio_file, language="es-ES")
            if show_messages:
                st.subheader("Texto Reconocido:")
                st.write(audio_text)
                st.success("Reconocimiento de voz completado.")
        except sr.UnknownValueError:
            st.warning("No se pudo reconocer el audio. ¿Intentaste grabar algo?")
            audio_text = ""
        except sr.RequestError:
            st.error("Hablame para comenzar!")
            audio_text = ""

    return audio_text

def format_prompt(message, history):
    prompt = "<s>"

    if not st.session_state.pre_prompt_sent:
        prompt += f"[INST]{pre_prompt_text}[/INST]"

    for user_prompt, bot_response in history:
        prompt += f"[INST] {user_prompt} [/INST]"
        prompt += f" {bot_response}</s> "

    prompt += f"[INST] {message} [/INST]"
    return prompt

def generate(audio_text, history, temperature=None, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0):
    client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")

    temperature = float(temperature) if temperature is not None else 0.9
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)

    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=42,)

    formatted_prompt = format_prompt(audio_text, history)
    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=True)
    response = ""

    for response_token in stream:
        response += response_token.token.text
    
    response = ' '.join(response.split()).replace('</s>', '')
    audio_file = text_to_speech(response, speed=1.3)
    return response, audio_file

def text_to_speech(text, speed=1.3):
    tts = gTTS(text=text, lang='es')
    audio_fp = io.BytesIO()
    tts.write_to_fp(audio_fp)
    audio_fp.seek(0)
    return audio_fp

def audio_play(audio_fp):
    st.audio(audio_fp.read(), format="audio/mp3", start_time=0)

def display_recognition_result(audio_text, output, audio_file):
    if audio_text:
        st.session_state.history.append((audio_text, output))  

    if audio_file is not None:
        st.markdown(
            f"""<audio autoplay="autoplay" controls="controls" src="data:audio/mp3;base64,{base64.b64encode(audio_file.read()).decode()}" type="audio/mp3" id="audio_player"></audio>""",
            unsafe_allow_html=True)

def main():
    if not st.session_state.pre_prompt_sent:
        st.session_state.pre_prompt_sent = True

    audio = mic_recorder(start_prompt="▶️", stop_prompt="🛑", key='recorder')

    if audio:       
        st.audio(audio['bytes'])
       
        audio_bytes = audio["bytes"]
        sample_width = audio["sample_width"]  # 2 bytes per sample for 16-bit PCM
        sample_rate = audio["sample_rate"]  # 44.1 kHz sample rate
        num_channels = 1  # 1 channel for mono, 2 for stereo
    
        with wave.open(temp_audio_file_path, 'w') as wave_file:
            wave_file.setnchannels(num_channels)
            wave_file.setsampwidth(sample_width)
            wave_file.setframerate(sample_rate)
            wave_file.writeframes(audio_bytes)    

if __name__ == "__main__":
    main()