Spaces:
Runtime error
Runtime error
import random | |
import os | |
import gradio as gr | |
import torch | |
from transformers import pipeline, set_seed | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import logging | |
logger = logging.getLogger() | |
logger.addHandler(logging.StreamHandler()) | |
HF_AUTH_TOKEN = os.environ.get("HF_AUTH_TOKEN", None) | |
DEVICE = os.environ.get("DEVICE", "cpu") # cuda:0 | |
if DEVICE != "cpu" and not torch.cuda.is_available(): | |
DEVICE = "cpu" | |
logger.info(f"DEVICE {DEVICE}") | |
DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16 | |
MODEL_NAME = os.environ.get("MODEL_NAME", "bertin-project/bertin-gpt-j-6B") | |
MAX_LENGTH = int(os.environ.get("MAX_LENGTH", 1024)) | |
HEADER_INFO = """ | |
# BERTIN GPT-J-6B | |
Spanish BERTIN GPT-J-6B Model. | |
""".strip() | |
LOGO = "https://huggingface.co/bertin-project/bertin-roberta-base-spanish/resolve/main/images/bertin.png" | |
HEADER = f""" | |
<link href="https://fonts.googleapis.com/css2?family=Roboto:wght@300&display=swap%22%20rel=%22stylesheet%22" rel="stylesheet"> | |
<style> | |
.ltr, | |
textarea {{ | |
font-family: Roboto !important; | |
text-align: left; | |
direction: ltr !important; | |
}} | |
.ltr-box {{ | |
border-bottom: 1px solid #ddd; | |
padding-bottom: 20px; | |
}} | |
.rtl {{ | |
text-align: left; | |
direction: ltr !important; | |
}} | |
span.result-text {{ | |
padding: 3px 3px; | |
line-height: 32px; | |
}} | |
span.generated-text {{ | |
background-color: rgb(118 200 147 / 13%); | |
}} | |
</style> | |
<div align=center> | |
<img src="{LOGO}" width=150/> | |
# BERTIN GPT-J-6B | |
BERTIN proporciona una serie de modelos de lenguaje en Español entrenados en abierto. | |
Este modelo ha sido entrenado con [Mesh Transformer JAX](https://github.com/kingoflolz/mesh-transformer-jax) en TPUs proporcionadas por Google a través del programa Tensor Research Cloud, a partir del modelo [GPT-J de EleutherAI](https://huggingface.co/EleutherAI/gpt-j-6B) con el corpus [mC4-es-sampled (gaussian)](https://huggingface.co/datasets/bertin-project/mc4-es-sampled). Esta demo funciona sobre una GPU proporcionada por HuggingFace. | |
</div> | |
""" | |
FOOTER = """ | |
Para más información, visite el [repositorio del modelo](https://huggingface.co/bertin-project/bertin-gpt-j-6B). | |
""".strip() | |
class Normalizer: | |
def remove_repetitions(self, text): | |
"""Remove repetitions""" | |
first_ocurrences = [] | |
for sentence in text.split("."): | |
if sentence not in first_ocurrences: | |
first_ocurrences.append(sentence) | |
return '.'.join(first_ocurrences) | |
def trim_last_sentence(self, text): | |
"""Trim last sentence if incomplete""" | |
return text[:text.rfind(".") + 1] | |
def clean_txt(self, text): | |
return self.trim_last_sentence(self.remove_repetitions(text)) | |
class TextGeneration: | |
def __init__(self): | |
self.tokenizer = None | |
self.generator = None | |
self.task = "text-generation" | |
self.model_name_or_path = MODEL_NAME | |
set_seed(42) | |
def load(self): | |
logger.info("Loading model...") | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
self.model_name_or_path, use_auth_token=HF_AUTH_TOKEN if HF_AUTH_TOKEN else None, | |
) | |
self.model = AutoModelForCausalLM.from_pretrained( | |
self.model_name_or_path, use_auth_token=HF_AUTH_TOKEN if HF_AUTH_TOKEN else None, | |
pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id, | |
torch_dtype=DTYPE, low_cpu_mem_usage=False if DEVICE == "cpu" else True | |
).to(device=DEVICE, non_blocking=False) | |
_ = self.model.eval() | |
device_number = -1 if DEVICE == "cpu" else int(DEVICE.split(":")[-1]) | |
self.generator = pipeline(self.task, model=self.model, tokenizer=self.tokenizer, device=device_number) | |
logger.info("Loading model done.") | |
# with torch.no_grad(): | |
# tokens = tokenizer.encode(prompt, return_tensors='pt').to(device=device, non_blocking=True) | |
# gen_tokens = self.model.generate(tokens, do_sample=True, temperature=0.8, max_length=128) | |
# generated = tokenizer.batch_decode(gen_tokens)[0] | |
# return generated | |
def generate(self, text, generation_kwargs): | |
max_length = len(self.tokenizer(text)["input_ids"]) + generation_kwargs["max_length"] | |
generation_kwargs["max_length"] = min(max_length, self.model.config.n_positions) | |
# generation_kwargs["num_return_sequences"] = 1 | |
# generation_kwargs["return_full_text"] = False | |
generated_text = None | |
if text: | |
for _ in range(10): | |
generated_text = self.generator( | |
text, | |
**generation_kwargs, | |
)[0]["generated_text"] | |
if generation_kwargs["do_clean"]: | |
generated_text = cleaner.clean_txt(generated_text) | |
if generated_text.strip().startswith(text): | |
generated_text = generated_text.replace(text, "", 1).strip() | |
if generated_text: | |
return ( | |
text + " " + generated_text, | |
[(text, None), (generated_text, "BERTIN")] | |
) | |
if not generated_text: | |
return ( | |
"", | |
[("Tras 10 intentos BERTIN no generó nada. Pruebe cambiando las opciones", "ERROR")] | |
) | |
# return (text + " " + generated_text, | |
# f'<p class="ltr ltr-box">' | |
# f'<span class="result-text">{text} <span>' | |
# f'<span class="result-text generated-text">{generated_text}</span>' | |
# f'</p>' | |
# ) | |
#@st.cache(hash_funcs={torch.nn.parameter.Parameter: lambda _: None}) | |
#@st.cache(allow_output_mutation=True) | |
#@st.cache(allow_output_mutation=True, hash_funcs={TextGeneration: lambda _: None}) | |
def load_text_generator(): | |
text_generator = TextGeneration() | |
text_generator.load() | |
return text_generator | |
cleaner = Normalizer() | |
generator = load_text_generator() | |
def complete_with_gpt(text, max_length, top_k, top_p, temperature, do_sample, do_clean): | |
generation_kwargs = { | |
"max_length": max_length, | |
"top_k": top_k, | |
"top_p": top_p, | |
"temperature": temperature, | |
"do_sample": do_sample, | |
"do_clean": do_clean, | |
} | |
return generator.generate(text, generation_kwargs) | |
with gr.Blocks() as demo: | |
gr.Markdown(HEADER) | |
with gr.Row(): | |
with gr.Group(): | |
with gr.Box(): | |
gr.Markdown("Opciones") | |
max_length = gr.Slider( | |
label='Longitud máxima', | |
# help="Número máximo (aproximado) de palabras a generar.", | |
minimum=1, | |
maximum=MAX_LENGTH, | |
value=50, | |
step=1 | |
) | |
top_k = gr.Slider( | |
label='Top-k', | |
# help="Número de palabras con alta probabilidad a mantener para el filtrado `top-k`", | |
minimum=40, | |
maximum=80, | |
value=50, | |
step=1 | |
) | |
top_p = gr.Slider( | |
label='Top-p', | |
# help="Solo las palabras más probables con probabilidades que sumen `top_p` o más se mantienen para la generación.", | |
minimum=0.0, | |
maximum=1.0, | |
value=0.95, | |
step=0.01 | |
) | |
temperature = gr.Slider( | |
label='Temperatura', | |
# help="Valor utilizado para modular las probabilidades de las siguientes palabras generadas.", | |
minimum=0.1, | |
maximum=10.0, | |
value=0.8, | |
step=0.05 | |
) | |
do_sample = gr.Checkbox( | |
label='¿Muestrear?', | |
value = True, | |
# options=(True, False), | |
# help="Si no se muestrea se usará una decodificación voraz (_greedy_).", | |
) | |
do_clean = gr.Checkbox( | |
label='¿Limpiar texto?', | |
value = True, | |
# options=(True, False), | |
# help="Si eliminar o no las palabras repetidas y recortar las últimas frases sin terminar.", | |
) | |
with gr.Column(): | |
textbox = gr.Textbox(label="Texto",placeholder="Escriba algo y pulse 'Generar'...", lines=8) | |
hidden = gr.Textbox(visible=False, show_label=False) | |
with gr.Box(): | |
# output = gr.Markdown() | |
output = gr.HighlightedText(label="Resultado", combine_adjacent=True, color_map={"BERTIN": "green", "ERROR": "red"}) | |
with gr.Row(): | |
btn = gr.Button("Generar") | |
btn.click(complete_with_gpt, inputs=[textbox, max_length, top_k, top_p, temperature, do_sample, do_clean], outputs=[hidden, output]) | |
edit_btn = gr.Button("Editar", variant="secondary") | |
edit_btn.click(lambda x: (x, "", []), inputs=[hidden], outputs=[textbox, hidden, output]) | |
clean_btn = gr.Button("Limpiar", variant="secondary") | |
clean_btn.click(lambda: ("", "", []), inputs=[], outputs=[textbox, hidden, output]) | |
gr.Markdown(FOOTER) | |
demo.launch() | |
# gr.Interface(complete_with_gpt, inputs=[textbox, max_length, top_k, top_p, temperature, do_sample, do_clean], outputs=[hidden, output]).launch() | |