File size: 2,467 Bytes
173046e
 
 
 
886c73a
173046e
886c73a
 
173046e
886c73a
 
 
173046e
886c73a
 
 
 
 
 
 
173046e
886c73a
 
 
 
 
 
 
 
 
173046e
 
 
886c73a
 
 
 
 
 
173046e
886c73a
 
 
173046e
886c73a
 
 
 
173046e
886c73a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173046e
 
886c73a
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

titulo = """# 🤖 Bienvenido al Chatbot con Yi-9B"""

descripcion = """Este chatbot utiliza el modelo Yi de 9B parámetros para generar respuestas. 
Puedes mantener una conversación fluida y realizar preguntas sobre diversos temas."""

# Definir el dispositivo y la ruta del modelo
dispositivo = "cuda" if torch.cuda.is_available() else "cpu"
ruta_modelo = "01-ai/Yi-9B-Chat"

# Cargar el tokenizador y el modelo
tokenizador = AutoTokenizer.from_pretrained(ruta_modelo)
modelo = AutoModelForCausalLM.from_pretrained(ruta_modelo, device_map="auto").eval()

def generar_respuesta(historial, usuario_input, max_longitud):
    mensajes = [
        {"role": "system", "content": "Eres un asistente útil y amigable. Proporciona respuestas claras y concisas."}
    ]
    
    for entrada in historial:
        mensajes.append({"role": "user", "content": entrada[0]})
        mensajes.append({"role": "assistant", "content": entrada[1]})
    
    mensajes.append({"role": "user", "content": usuario_input})
    
    texto = tokenizador.apply_chat_template(
        mensajes,
        tokenize=False,
        add_generation_prompt=True
    )
    
    entradas_modelo = tokenizador([texto], return_tensors="pt").to(dispositivo)
    ids_generados = modelo.generate(
        entradas_modelo.input_ids,
        max_new_tokens=max_longitud,
        eos_token_id=tokenizador.eos_token_id
    )
    
    ids_generados = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(entradas_modelo.input_ids, ids_generados)
    ]
    
    respuesta = tokenizador.batch_decode(ids_generados, skip_special_tokens=True)[0]
    historial.append((usuario_input, respuesta))
    return historial, ""

def interfaz_gradio():
    with gr.Blocks() as interfaz:
        gr.Markdown(titulo)
        gr.Markdown(descripcion)
        
        chatbot = gr.Chatbot(label="Historial de chat")
        msg = gr.Textbox(label="Tu mensaje")
        clear = gr.Button("Limpiar")
        
        max_longitud_slider = gr.Slider(minimum=1, maximum=1000, value=500, label="Longitud máxima de la respuesta")
        
        msg.submit(generar_respuesta, [chatbot, msg, max_longitud_slider], [chatbot, msg])
        clear.click(lambda: None, None, chatbot, queue=False)
    
    return interfaz

if __name__ == "__main__":
    interfaz = interfaz_gradio()
    interfaz.queue()
    interfaz.launch()