File size: 2,439 Bytes
7d2c473
35a3d64
 
1118eaf
 
cff47c8
5100159
f8a4c47
ae9cdb3
bf2cd44
a0a031a
c501864
bf2cd44
a0a031a
c501864
bf2cd44
 
 
c501864
b5f7ced
 
 
c501864
a0a031a
 
 
cff47c8
 
 
 
 
 
3ab8e88
bf2cd44
3bbf5c2
e000aac
 
d9f7657
 
 
 
 
 
 
 
 
 
595f1c1
4a766c1
e000aac
f27b29a
b5f7ced
f27b29a
 
09559ae
b0b5cd6
ce0dbfb
37cbb3a
09559ae
403eeb0
595f1c1
c015fb0
89ff019
c015fb0
 
595f1c1
5da70fc
1723905
5da70fc
35bdc4e
213f390
 
1118eaf
213f390
 
da3ee41
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import streamlit as st
from huggingface_hub import InferenceClient
import base64
from pydub import AudioSegment
from pydub.playback import play
import pyttsx3

client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
pre_prompt = ""
pre_prompt_sent = False

def format_prompt(message, history):
    global pre_prompt_sent
    prompt = "<s>"

    if not pre_prompt_sent and all(f"[INST] {pre_prompt} [/INST]" not in user_prompt for user_prompt, _ in history):
        prompt += f"[INST] {pre_prompt} [/INST]"
        pre_prompt_sent = True

    for user_prompt, bot_response in history:
        prompt += f"[INST] {user_prompt} [/INST]"
        prompt += f" {bot_response}</s> "

    prompt += f"[INST] {message} [/INST]"
    return prompt

def text_to_speech(text):
    engine = pyttsx3.init()
    engine.save_to_file(text, "output_pyttsx3.mp3")
    engine.runAndWait()
    return "output_pyttsx3.mp3"
    
def generate(user_input, history, temperature=None, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0):
    global pre_prompt_sent
    temperature = float(temperature) if temperature is not None else 0.9
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)

    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=42,
    )

    formatted_prompt = format_prompt(user_input, history)
    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=True)
    response = ""

    for response_token in stream:
        response += response_token.token.text
    
    response = ' '.join(response.split()).replace('</s>', '')

    audio_file_path = text_to_speech(response) 
    audio_file = open(audio_file_path, 'rb')
    audio_bytes = audio_file.read()

    return response, audio_bytes

if "history" not in st.session_state:
    st.session_state.history = []

user_input = st.text_input(label="", value="")
output, audio_bytes = generate(user_input, history=st.session_state.history)
st.text_area("Respuesta", value=output, height=400, key="output_text", disabled=True)

st.markdown(
    f"""
    <audio autoplay="autoplay" controls="controls" src="data:audio/mp3;base64,{base64.b64encode(audio_bytes).decode()}" type="audio/mp3" id="audio_player"></audio>
    """,
    unsafe_allow_html=True
)