File size: 2,926 Bytes
7d2c473
35a3d64
 
 
1118eaf
 
5100159
f8a4c47
1118eaf
1fe1359
a0a031a
c501864
 
a0a031a
c501864
af78765
 
 
 
c501864
af78765
 
 
c501864
a0a031a
 
 
cb09d3d
3d98a19
09559ae
 
1118eaf
 
 
 
 
5d746cc
a79a56f
da3ee41
c501864
3bbf5c2
e000aac
 
d9f7657
 
 
 
 
 
 
 
 
 
595f1c1
4a766c1
e000aac
595f1c1
f27b29a
 
 
09559ae
b0b5cd6
ce0dbfb
a647243
09559ae
 
595f1c1
c015fb0
89ff019
c015fb0
 
595f1c1
12e5ecd
c015fb0
35bdc4e
 
 
595f1c1
da3ee41
213f390
 
 
1118eaf
213f390
 
da3ee41
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import streamlit as st
from huggingface_hub import InferenceClient
from gtts import gTTS
import base64
from pydub import AudioSegment
from pydub.playback import play

client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
system_prompt = "Tu nombre será Chaman 3.0 una IA conductual, hablaras español tu rol es la bioética y el existencialismo estóico, holísticamente."
system_prompt_sent = False

def format_prompt(message, history):
    global system_prompt_sent
    prompt = "<s>"

    if history is not None and isinstance(history, list):
        if not any(f"[INST] {system_prompt} [/INST]" in user_prompt for user_prompt, _ in history):
            prompt += f"[INST] {system_prompt} [/INST]"
            system_prompt_sent = True

        for user_prompt, bot_response in history:
            prompt += f"[INST] {user_prompt} [/INST]"
            prompt += f" {bot_response}</s> "

    prompt += f"[INST] {message} [/INST]"
    return prompt

def text_to_speech(text, speed=1.3):
    tts = gTTS(text=text, lang='es')
    audio_file_path = 'output.mp3'
    tts.save(audio_file_path)

    sound = AudioSegment.from_mp3(audio_file_path)
    sound = sound.speedup(playback_speed=speed)
    sound.export(audio_file_path, format="mp3")

    return audio_file_path

def generate(user_input, history, temperature=None, max_new_tokens=1024, top_p=0.95, repetition_penalty=1.0):
    global system_prompt_sent
    temperature = float(temperature) if temperature is not None else 0.9
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)

    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=42,
    )

    formatted_prompt = format_prompt(user_input, history)
    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=True)

    response = ""
    for response_token in stream:
        response += response_token.token.text
    
    response = ' '.join(response.split()).replace('</s>', '')

    audio_file_path = text_to_speech(response, speed=1.3) 
    audio_file = open(audio_file_path, 'rb')
    audio_bytes = audio_file.read()

    return response, audio_bytes

if "history" not in st.session_state:
    st.session_state.history = []

user_input = st.text_input(label="Usuario", value="")
output, audio_bytes = generate(user_input, history=st.session_state.history)

if user_input:
    st.session_state.history.append((user_input, output))

st.text_area("Salida del Chatbot", value=output, height=400, max_chars=900, key="output_text", disabled=True)

st.markdown(
    f"""
    <audio autoplay="autoplay" controls="controls" src="data:audio/mp3;base64,{base64.b64encode(audio_bytes).decode()}" type="audio/mp3" id="audio_player"></audio>
    """,
    unsafe_allow_html=True
)