Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
titulo = """# 🤖 Bienvenido al Chatbot con Yi-9B""" | |
descripcion = """Este chatbot utiliza el modelo Yi de 9B parámetros para generar respuestas. | |
Puedes mantener una conversación fluida y realizar preguntas sobre diversos temas.""" | |
# Definir el dispositivo y la ruta del modelo | |
dispositivo = "cuda" if torch.cuda.is_available() else "cpu" | |
ruta_modelo = "01-ai/Yi-9B-Chat" | |
# Cargar el tokenizador y el modelo | |
tokenizador = AutoTokenizer.from_pretrained(ruta_modelo) | |
modelo = AutoModelForCausalLM.from_pretrained(ruta_modelo, device_map="auto").eval() | |
def generar_respuesta(historial, usuario_input, max_longitud): | |
mensajes = [ | |
{"role": "system", "content": "Eres un asistente útil y amigable. Proporciona respuestas claras y concisas."} | |
] | |
for entrada in historial: | |
mensajes.append({"role": "user", "content": entrada[0]}) | |
mensajes.append({"role": "assistant", "content": entrada[1]}) | |
mensajes.append({"role": "user", "content": usuario_input}) | |
texto = tokenizador.apply_chat_template( | |
mensajes, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
entradas_modelo = tokenizador([texto], return_tensors="pt").to(dispositivo) | |
ids_generados = modelo.generate( | |
entradas_modelo.input_ids, | |
max_new_tokens=max_longitud, | |
eos_token_id=tokenizador.eos_token_id | |
) | |
ids_generados = [ | |
output_ids[len(input_ids):] for input_ids, output_ids in zip(entradas_modelo.input_ids, ids_generados) | |
] | |
respuesta = tokenizador.batch_decode(ids_generados, skip_special_tokens=True)[0] | |
historial.append((usuario_input, respuesta)) | |
return historial, "" | |
def interfaz_gradio(): | |
with gr.Blocks() as interfaz: | |
gr.Markdown(titulo) | |
gr.Markdown(descripcion) | |
chatbot = gr.Chatbot(label="Historial de chat") | |
msg = gr.Textbox(label="Tu mensaje") | |
clear = gr.Button("Limpiar") | |
max_longitud_slider = gr.Slider(minimum=1, maximum=1000, value=500, label="Longitud máxima de la respuesta") | |
msg.submit(generar_respuesta, [chatbot, msg, max_longitud_slider], [chatbot, msg]) | |
clear.click(lambda: None, None, chatbot, queue=False) | |
return interfaz | |
if __name__ == "__main__": | |
interfaz = interfaz_gradio() | |
interfaz.queue() | |
interfaz.launch() |