File size: 2,841 Bytes
7d2c473
35a3d64
 
 
5100159
f8a4c47
595f1c1
1fe1359
a0a031a
c501864
 
a0a031a
c501864
af78765
 
 
 
c501864
af78765
 
 
c501864
a0a031a
 
 
9a784d5
3d98a19
09559ae
 
5d746cc
a79a56f
c015fb0
c501864
3bbf5c2
e000aac
 
d9f7657
 
 
 
 
 
 
 
 
 
595f1c1
4a766c1
e000aac
595f1c1
f27b29a
 
 
09559ae
b0b5cd6
ce0dbfb
09559ae
 
 
595f1c1
c015fb0
89ff019
c015fb0
 
595f1c1
c015fb0
 
 
595f1c1
c015fb0
 
5c113c1
c015fb0
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import streamlit as st
from huggingface_hub import InferenceClient
from gtts import gTTS
import base64

client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
system_prompt = "Tu nombre es Chaman 3.0 una IA conductual"
system_prompt_sent = False

def format_prompt(message, history):
    global system_prompt_sent
    prompt = "<s>"

    if history is not None and isinstance(history, list):
        if not any(f"[INST] {system_prompt} [/INST]" in user_prompt for user_prompt, _ in history):
            prompt += f"[INST] {system_prompt} [/INST]"
            system_prompt_sent = True

        for user_prompt, bot_response in history:
            prompt += f"[INST] {user_prompt} [/INST]"
            prompt += f" {bot_response}</s> "

    prompt += f"[INST] {message} [/INST]"
    return prompt

def text_to_speech(text, speed=2.0):
    tts = gTTS(text=text, lang='es')
    audio_file_path = 'output.mp3'
    tts.save(audio_file_path)
    return audio_file_path

def generate(user_input, history, temperature=None, max_new_tokens=2048, top_p=0.95, repetition_penalty=1.0):
    global system_prompt_sent
    temperature = float(temperature) if temperature is not None else 0.9
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)

    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=42,
    )

    formatted_prompt = format_prompt(user_input, history)
    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=True)

    response = ""
    for response_token in stream:
        response += response_token.token.text
    
    response = ' '.join(response.split()).replace('</s>', '')

    audio_file_path = text_to_speech(response)
    audio_file = open(audio_file_path, 'rb')
    audio_bytes = audio_file.read()

    return response, audio_bytes

if "history" not in st.session_state:
    st.session_state.history = []

user_input = st.text_input(label="Usuario", value="")
output, audio_bytes = generate(user_input, history=st.session_state.history)
st.session_state.history.append((user_input, output))

with st.container():
    st.text_area("Salida del Chatbot", value=output, height=200, max_chars=500, key="output_text", disabled=True)

    for i, (user_prompt, bot_response) in enumerate(st.session_state.history):
        st.write(f"Usuario {i + 1}: {user_prompt}")
        st.write(f"Respuesta {i + 1}: {bot_response}")
    
    st.markdown(
        f"""
        <audio autoplay="autoplay" controls="controls" src="data:audio/mp3;base64,{base64.b64encode(audio_bytes).decode()}" type="audio/mp3" speed="1.5" id="audio_player"></audio>
        """,
        unsafe_allow_html=True
    )