Kukedlc's picture
Update app.py
886c73a verified
raw
history blame
2.47 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
titulo = """# 🤖 Bienvenido al Chatbot con Yi-9B"""
descripcion = """Este chatbot utiliza el modelo Yi de 9B parámetros para generar respuestas.
Puedes mantener una conversación fluida y realizar preguntas sobre diversos temas."""
# Definir el dispositivo y la ruta del modelo
dispositivo = "cuda" if torch.cuda.is_available() else "cpu"
ruta_modelo = "01-ai/Yi-9B-Chat"
# Cargar el tokenizador y el modelo
tokenizador = AutoTokenizer.from_pretrained(ruta_modelo)
modelo = AutoModelForCausalLM.from_pretrained(ruta_modelo, device_map="auto").eval()
def generar_respuesta(historial, usuario_input, max_longitud):
mensajes = [
{"role": "system", "content": "Eres un asistente útil y amigable. Proporciona respuestas claras y concisas."}
]
for entrada in historial:
mensajes.append({"role": "user", "content": entrada[0]})
mensajes.append({"role": "assistant", "content": entrada[1]})
mensajes.append({"role": "user", "content": usuario_input})
texto = tokenizador.apply_chat_template(
mensajes,
tokenize=False,
add_generation_prompt=True
)
entradas_modelo = tokenizador([texto], return_tensors="pt").to(dispositivo)
ids_generados = modelo.generate(
entradas_modelo.input_ids,
max_new_tokens=max_longitud,
eos_token_id=tokenizador.eos_token_id
)
ids_generados = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(entradas_modelo.input_ids, ids_generados)
]
respuesta = tokenizador.batch_decode(ids_generados, skip_special_tokens=True)[0]
historial.append((usuario_input, respuesta))
return historial, ""
def interfaz_gradio():
with gr.Blocks() as interfaz:
gr.Markdown(titulo)
gr.Markdown(descripcion)
chatbot = gr.Chatbot(label="Historial de chat")
msg = gr.Textbox(label="Tu mensaje")
clear = gr.Button("Limpiar")
max_longitud_slider = gr.Slider(minimum=1, maximum=1000, value=500, label="Longitud máxima de la respuesta")
msg.submit(generar_respuesta, [chatbot, msg, max_longitud_slider], [chatbot, msg])
clear.click(lambda: None, None, chatbot, queue=False)
return interfaz
if __name__ == "__main__":
interfaz = interfaz_gradio()
interfaz.queue()
interfaz.launch()