File size: 2,470 Bytes
173046e
 
 
c98587e
173046e
886c73a
173046e
886c73a
 
173046e
886c73a
 
 
173046e
886c73a
 
 
 
2cbff7e
c98587e
886c73a
c98587e
 
173046e
886c73a
 
173046e
 
 
886c73a
 
 
 
 
173046e
886c73a
 
173046e
886c73a
c98587e
173046e
886c73a
 
 
 
 
c98587e
 
 
 
 
 
 
886c73a
 
c98587e
 
 
 
 
 
886c73a
 
173046e
 
886c73a
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import spaces

titulo = """# 🤖 Bienvenido al Chatbot con Yi-9B"""

descripcion = """Este chatbot utiliza el modelo Yi de 9B parámetros para generar respuestas. 
Puedes mantener una conversación fluida y realizar preguntas sobre diversos temas."""

# Definir el dispositivo y la ruta del modelo
dispositivo = "cuda" if torch.cuda.is_available() else "cpu"
ruta_modelo = "01-ai/Yi-9B-Chat"

# Cargar el tokenizador y el modelo
tokenizador = AutoTokenizer.from_pretrained(ruta_modelo)
modelo = AutoModelForCausalLM.from_pretrained(ruta_modelo, device_map="auto").eval()

@spaces.GPU(duration=130)
def generar_respuesta(prompt_sistema, prompt_usuario, max_longitud):
    mensajes = [
        {"role": "system", "content": prompt_sistema},
        {"role": "user", "content": prompt_usuario}
    ]
    texto = tokenizador.apply_chat_template(
        mensajes,
        tokenize=False,
        add_generation_prompt=True
    )
    entradas_modelo = tokenizador([texto], return_tensors="pt").to(dispositivo)
    ids_generados = modelo.generate(
        entradas_modelo.input_ids,
        max_new_tokens=max_longitud,
        eos_token_id=tokenizador.eos_token_id
    )
    ids_generados = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(entradas_modelo.input_ids, ids_generados)
    ]
    respuesta = tokenizador.batch_decode(ids_generados, skip_special_tokens=True)[0]
    return respuesta

def interfaz_gradio():
    with gr.Blocks() as interfaz:
        gr.Markdown(titulo)
        gr.Markdown(descripcion)
        
        prompt_sistema = gr.Textbox(
            label="Instrucción del sistema:",
            value="Eres un asistente útil y amigable. Proporciona respuestas claras y concisas.",
            lines=2
        )
        prompt_usuario = gr.Textbox(label="Tu mensaje", lines=3)
        respuesta = gr.Textbox(label="Respuesta del asistente", lines=10)
        max_longitud_slider = gr.Slider(minimum=1, maximum=1000, value=500, label="Longitud máxima de la respuesta")
        
        boton_generar = gr.Button("Generar respuesta")
        boton_generar.click(
            generar_respuesta,
            inputs=[prompt_sistema, prompt_usuario, max_longitud_slider],
            outputs=respuesta
        )
    
    return interfaz

if __name__ == "__main__":
    interfaz = interfaz_gradio()
    interfaz.queue()
    interfaz.launch()