nl-to-tag-test / app.py
p1atdev's picture
chore: use gpu
ff91c77
raw
history blame
5.71 kB
import os
import time
import spaces
import torch
from transformers import (
AutoModelForPreTraining,
AutoProcessor,
AutoConfig,
)
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import gradio as gr
MODEL_NAME = os.environ.get("MODEL_NAME", None)
assert MODEL_NAME is not None
MODEL_PATH = hf_hub_download(repo_id=MODEL_NAME, filename="model.safetensors")
DEVICE = torch.device("cuda")
def fix_compiled_state_dict(state_dict: dict):
return {k.replace("._orig_mod.", "."): v for k, v in state_dict.items()}
def prepare_models():
config = AutoConfig.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForPreTraining.from_config(
config, torch_dtype=torch.bfloat16, trust_remote_code=True
)
model.decoder_model.use_cache = True
processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
state_dict = load_file(MODEL_PATH)
state_dict = {k.replace("._orig_mod.", "."): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model.eval()
model = model.to(DEVICE)
# model = torch.compile(model)
return model, processor
def demo():
model, processor = prepare_models()
@spaces.GPU(duration=5)
@torch.inference_mode()
def generate_tags(
text: str,
auto_detect: bool,
copyright_tags: str = "",
max_new_tokens: int = 128,
do_sample: bool = False,
temperature: float = 0.1,
top_k: int = 10,
top_p: float = 0.1,
):
tag_text = (
"<|bos|>"
"<|aspect_ratio:tall|><|rating:general|><|length:long|>"
"<|reserved_2|><|reserved_3|><|reserved_4|>"
"<|translate:exact|><|input_end|>"
"<copyright>" + copyright_tags.strip()
)
if not auto_detect:
tag_text += "</copyright><character></character><general>"
inputs = processor(
encoder_text=text, decoder_text=tag_text, return_tensors="pt"
)
start_time = time.time()
outputs = model.generate(
input_ids=inputs["input_ids"].to(model.device),
attention_mask=inputs["attention_mask"].to(model.device),
encoder_input_ids=inputs["encoder_input_ids"].to(model.device),
encoder_attention_mask=inputs["encoder_attention_mask"].to(model.device),
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=temperature,
top_k=top_k,
top_p=top_p,
eos_token_id=processor.decoder_tokenizer.eos_token_id,
pad_token_id=processor.decoder_tokenizer.pad_token_id,
)
elapsed = time.time() - start_time
deocded = ", ".join(
[
tag
for tag in processor.batch_decode(outputs[0], skip_special_tokens=True)
if tag.strip() != ""
]
)
return [deocded, f"Time elapsed: {elapsed:.2f} seconds"]
# warmup
print("warming up...")
print(generate_tags("Miku is looking at viewer.", True))
print("done.")
with gr.Blocks() as ui:
with gr.Column():
with gr.Row():
with gr.Column():
text = gr.Text(label="Text", lines=4)
auto_detect = gr.Checkbox(
label="Auto detect copyright tags.", value=False
)
copyright_tags = gr.Textbox(
label="Copyright tags",
placeholder="Enter copyright tags here. e.g.) hatsune miku",
)
translate_btn = gr.Button(value="Translate")
with gr.Accordion(label="Advanced", open=False):
max_new_tokens = gr.Number(label="Max new tokens", value=128)
do_sample = gr.Checkbox(label="Do sample", value=False)
temperature = gr.Slider(
label="Temperature",
minimum=0.1,
maximum=1.0,
value=0.1,
step=0.1,
)
top_k = gr.Slider(
label="Top k",
minimum=1,
maximum=100,
value=10,
step=1,
)
top_p = gr.Slider(
label="Top p",
minimum=0.1,
maximum=1.0,
value=0.1,
step=0.1,
)
with gr.Column():
output = gr.Textbox(label="Output", lines=4, interactive=False)
time_elapsed = gr.Markdown(value="")
gr.Examples(
examples=[["Miku is looking at viewer.", True]],
inputs=[text, auto_detect],
)
gr.on(
triggers=[
# text.change,
# auto_detect.change,
# copyright_tags.change,
translate_btn.click,
],
fn=generate_tags,
inputs=[
text,
auto_detect,
copyright_tags,
max_new_tokens,
do_sample,
temperature,
top_k,
top_p,
],
outputs=[output, time_elapsed],
)
ui.launch()
if __name__ == "__main__":
demo()