File size: 2,869 Bytes
7d2c473
35a3d64
8307fd0
35a3d64
1118eaf
8307fd0
5100159
f8a4c47
8307fd0
bf2cd44
b78309b
c501864
bf2cd44
a0a031a
c501864
bf2cd44
 
 
c501864
b5f7ced
 
 
c501864
a0a031a
 
8307fd0
 
 
 
 
 
 
 
 
 
 
3ab8e88
bf2cd44
3bbf5c2
e000aac
 
d9f7657
8307fd0
d9f7657
 
 
 
 
 
 
 
595f1c1
4a766c1
8307fd0
 
 
 
 
3a22490
8307fd0
 
 
 
 
328d058
8307fd0
89ff019
c015fb0
 
595f1c1
8307fd0
 
 
 
4e91bb7
8307fd0
 
35bdc4e
4e91bb7
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import streamlit as st
from huggingface_hub import InferenceClient
from gtts import gTTS
import base64
from pydub import AudioSegment
from pydub.playback import play

client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
pre_prompt = "tu nombre es Chaman 3.0 una IA conducual, tus principios son el trashumanísmo ecológico."
pre_prompt_sent = False

def format_prompt(message, history):
    global pre_prompt_sent
    prompt = "<s>"

    if not pre_prompt_sent and all(f"[INST] {pre_prompt} [/INST]" not in user_prompt for user_prompt, _ in history):
        prompt += f"[INST] {pre_prompt} [/INST]"
        pre_prompt_sent = True

    for user_prompt, bot_response in history:
        prompt += f"[INST] {user_prompt} [/INST]"
        prompt += f" {bot_response}</s> "

    prompt += f"[INST] {message} [/INST]"
    return prompt

def text_to_speech(text, speed=1.3):
    tts = gTTS(text=text, lang='es')
    audio_file_path = 'output.mp3'
    tts.save(audio_file_path)
    sound = AudioSegment.from_mp3(audio_file_path)
    sound = sound.speedup(playback_speed=speed)
    sound.export(audio_file_path, format="mp3")

    return audio_file_path

def generate(user_input, history, temperature=None, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0):
    global pre_prompt_sent
    temperature = float(temperature) if temperature is not None else 0.9
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)

    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=42,
    )

    formatted_prompt = format_prompt(user_input, history)
    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=True)
    response = ""

    for response_token in stream:
        response += response_token.token.text
    
    response = ' '.join(response.split()).replace('</s>', '')

    audio_file_path = text_to_speech(response, speed=1.3) 
    audio_file = open(audio_file_path, 'rb')
    audio_bytes = audio_file.read()

    return response, audio_bytes

if "history" not in st.session_state:
    st.session_state.history = []

with st.container():
    user_input = st.text_input(label="Usuario", value="")
    output, audio_bytes = generate(user_input, history=st.session_state.history)
    st.text_area("Respuesta", height=400, value=output, key="output_text", disabled=True)

    if user_input:
        st.session_state.history.append((user_input, output))

if audio_bytes is not None:
    st.markdown(
        f"""
        <audio autoplay="autoplay" controls="controls" src="data:audio/mp3;base64,{base64.b64encode(audio_bytes).decode()}" type="audio/mp3" id="audio_player"></audio>
        """,
        unsafe_allow_html=True
    )