bertin-gpt-j-6B / gradio_app.py
versae's picture
Testing gradio app
f89848b
raw
history blame
9.38 kB
import random
import os
import gradio as gr
import torch
from transformers import pipeline, set_seed
from transformers import AutoTokenizer, AutoModelForCausalLM
import logging
logger = logging.getLogger()
logger.addHandler(logging.StreamHandler())
HF_AUTH_TOKEN = os.environ.get("HF_AUTH_TOKEN", None)
DEVICE = os.environ.get("DEVICE", "cpu") # cuda:0
if DEVICE != "cpu" and not torch.cuda.is_available():
DEVICE = "cpu"
logger.info(f"DEVICE {DEVICE}")
DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16
MODEL_NAME = os.environ.get("MODEL_NAME", "bertin-project/bertin-gpt-j-6B")
MAX_LENGTH = int(os.environ.get("MAX_LENGTH", 1024))
HEADER_INFO = """
# BERTIN GPT-J-6B
Spanish BERTIN GPT-J-6B Model.
""".strip()
LOGO = "https://huggingface.co/bertin-project/bertin-roberta-base-spanish/resolve/main/images/bertin.png"
HEADER = f"""
<link href="https://fonts.googleapis.com/css2?family=Roboto:wght@300&display=swap%22%20rel=%22stylesheet%22" rel="stylesheet">
<style>
.ltr,
textarea {{
font-family: Roboto !important;
text-align: left;
direction: ltr !important;
}}
.ltr-box {{
border-bottom: 1px solid #ddd;
padding-bottom: 20px;
}}
.rtl {{
text-align: left;
direction: ltr !important;
}}
span.result-text {{
padding: 3px 3px;
line-height: 32px;
}}
span.generated-text {{
background-color: rgb(118 200 147 / 13%);
}}
</style>
<div align=center>
<img src="{LOGO}" width=150/>
# BERTIN GPT-J-6B
BERTIN proporciona una serie de modelos de lenguaje en Español entrenados en abierto.
Este modelo ha sido entrenado con [Mesh Transformer JAX](https://github.com/kingoflolz/mesh-transformer-jax) en TPUs proporcionadas por Google a través del programa Tensor Research Cloud, a partir del modelo [GPT-J de EleutherAI](https://huggingface.co/EleutherAI/gpt-j-6B) con el corpus [mC4-es-sampled (gaussian)](https://huggingface.co/datasets/bertin-project/mc4-es-sampled). Esta demo funciona sobre una GPU proporcionada por HuggingFace.
</div>
"""
FOOTER = """
Para más información, visite el [repositorio del modelo](https://huggingface.co/bertin-project/bertin-gpt-j-6B).
""".strip()
class Normalizer:
def remove_repetitions(self, text):
"""Remove repetitions"""
first_ocurrences = []
for sentence in text.split("."):
if sentence not in first_ocurrences:
first_ocurrences.append(sentence)
return '.'.join(first_ocurrences)
def trim_last_sentence(self, text):
"""Trim last sentence if incomplete"""
return text[:text.rfind(".") + 1]
def clean_txt(self, text):
return self.trim_last_sentence(self.remove_repetitions(text))
class TextGeneration:
def __init__(self):
self.tokenizer = None
self.generator = None
self.task = "text-generation"
self.model_name_or_path = MODEL_NAME
set_seed(42)
def load(self):
logger.info("Loading model...")
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name_or_path, use_auth_token=HF_AUTH_TOKEN if HF_AUTH_TOKEN else None,
)
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name_or_path, use_auth_token=HF_AUTH_TOKEN if HF_AUTH_TOKEN else None,
pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id,
torch_dtype=DTYPE, low_cpu_mem_usage=False if DEVICE == "cpu" else True
).to(device=DEVICE, non_blocking=False)
_ = self.model.eval()
device_number = -1 if DEVICE == "cpu" else int(DEVICE.split(":")[-1])
self.generator = pipeline(self.task, model=self.model, tokenizer=self.tokenizer, device=device_number)
logger.info("Loading model done.")
# with torch.no_grad():
# tokens = tokenizer.encode(prompt, return_tensors='pt').to(device=device, non_blocking=True)
# gen_tokens = self.model.generate(tokens, do_sample=True, temperature=0.8, max_length=128)
# generated = tokenizer.batch_decode(gen_tokens)[0]
# return generated
def generate(self, text, generation_kwargs):
max_length = len(self.tokenizer(text)["input_ids"]) + generation_kwargs["max_length"]
generation_kwargs["max_length"] = min(max_length, self.model.config.n_positions)
# generation_kwargs["num_return_sequences"] = 1
# generation_kwargs["return_full_text"] = False
generated_text = None
if text:
for _ in range(10):
generated_text = self.generator(
text,
**generation_kwargs,
)[0]["generated_text"]
if generation_kwargs["do_clean"]:
generated_text = cleaner.clean_txt(generated_text)
if generated_text.strip().startswith(text):
generated_text = generated_text.replace(text, "", 1).strip()
if generated_text:
return (
text + " " + generated_text,
[(text, None), (generated_text, "BERTIN")]
)
if not generated_text:
return (
"",
[("Tras 10 intentos BERTIN no generó nada. Pruebe cambiando las opciones", "ERROR")]
)
# return (text + " " + generated_text,
# f'<p class="ltr ltr-box">'
# f'<span class="result-text">{text} <span>'
# f'<span class="result-text generated-text">{generated_text}</span>'
# f'</p>'
# )
#@st.cache(hash_funcs={torch.nn.parameter.Parameter: lambda _: None})
#@st.cache(allow_output_mutation=True)
#@st.cache(allow_output_mutation=True, hash_funcs={TextGeneration: lambda _: None})
def load_text_generator():
text_generator = TextGeneration()
text_generator.load()
return text_generator
cleaner = Normalizer()
generator = load_text_generator()
def complete_with_gpt(text, max_length, top_k, top_p, temperature, do_sample, do_clean):
generation_kwargs = {
"max_length": max_length,
"top_k": top_k,
"top_p": top_p,
"temperature": temperature,
"do_sample": do_sample,
"do_clean": do_clean,
}
return generator.generate(text, generation_kwargs)
with gr.Blocks() as demo:
gr.Markdown(HEADER)
with gr.Row():
with gr.Group():
with gr.Box():
gr.Markdown("Opciones")
max_length = gr.Slider(
label='Longitud máxima',
# help="Número máximo (aproximado) de palabras a generar.",
minimum=1,
maximum=MAX_LENGTH,
value=50,
step=1
)
top_k = gr.Slider(
label='Top-k',
# help="Número de palabras con alta probabilidad a mantener para el filtrado `top-k`",
minimum=40,
maximum=80,
value=50,
step=1
)
top_p = gr.Slider(
label='Top-p',
# help="Solo las palabras más probables con probabilidades que sumen `top_p` o más se mantienen para la generación.",
minimum=0.0,
maximum=1.0,
value=0.95,
step=0.01
)
temperature = gr.Slider(
label='Temperatura',
# help="Valor utilizado para modular las probabilidades de las siguientes palabras generadas.",
minimum=0.1,
maximum=10.0,
value=0.8,
step=0.05
)
do_sample = gr.Checkbox(
label='¿Muestrear?',
value = True,
# options=(True, False),
# help="Si no se muestrea se usará una decodificación voraz (_greedy_).",
)
do_clean = gr.Checkbox(
label='¿Limpiar texto?',
value = True,
# options=(True, False),
# help="Si eliminar o no las palabras repetidas y recortar las últimas frases sin terminar.",
)
with gr.Column():
textbox = gr.Textbox(label="Texto",placeholder="Escriba algo y pulse 'Generar'...", lines=8)
hidden = gr.Textbox(visible=False, show_label=False)
with gr.Box():
# output = gr.Markdown()
output = gr.HighlightedText(label="Resultado", combine_adjacent=True, color_map={"BERTIN": "green", "ERROR": "red"})
with gr.Row():
btn = gr.Button("Generar")
btn.click(complete_with_gpt, inputs=[textbox, max_length, top_k, top_p, temperature, do_sample, do_clean], outputs=[hidden, output])
edit_btn = gr.Button("Editar", variant="secondary")
edit_btn.click(lambda x: (x, "", []), inputs=[hidden], outputs=[textbox, hidden, output])
clean_btn = gr.Button("Limpiar", variant="secondary")
clean_btn.click(lambda: ("", "", []), inputs=[], outputs=[textbox, hidden, output])
gr.Markdown(FOOTER)
demo.launch()
# gr.Interface(complete_with_gpt, inputs=[textbox, max_length, top_k, top_p, temperature, do_sample, do_clean], outputs=[hidden, output]).launch()