Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import time | |
import spaces | |
import torch | |
from transformers import ( | |
AutoModelForPreTraining, | |
AutoProcessor, | |
AutoConfig, | |
) | |
from huggingface_hub import hf_hub_download | |
from safetensors.torch import load_file | |
import gradio as gr | |
MODEL_NAME = os.environ.get("MODEL_NAME", None) | |
assert MODEL_NAME is not None | |
MODEL_PATH = hf_hub_download(repo_id=MODEL_NAME, filename="model.safetensors") | |
DEVICE = torch.device("cuda") | |
def fix_compiled_state_dict(state_dict: dict): | |
return {k.replace("._orig_mod.", "."): v for k, v in state_dict.items()} | |
def prepare_models(): | |
config = AutoConfig.from_pretrained(MODEL_NAME, trust_remote_code=True) | |
model = AutoModelForPreTraining.from_config( | |
config, torch_dtype=torch.bfloat16, trust_remote_code=True | |
) | |
model.decoder_model.use_cache = True | |
processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True) | |
state_dict = load_file(MODEL_PATH) | |
state_dict = {k.replace("._orig_mod.", "."): v for k, v in state_dict.items()} | |
model.load_state_dict(state_dict) | |
model.eval() | |
model = model.to(DEVICE) | |
# model = torch.compile(model) | |
return model, processor | |
def demo(): | |
model, processor = prepare_models() | |
def generate_tags( | |
text: str, | |
auto_detect: bool, | |
copyright_tags: str = "", | |
max_new_tokens: int = 128, | |
do_sample: bool = False, | |
temperature: float = 0.1, | |
top_k: int = 10, | |
top_p: float = 0.1, | |
): | |
tag_text = ( | |
"<|bos|>" | |
"<|aspect_ratio:tall|><|rating:general|><|length:long|>" | |
"<|reserved_2|><|reserved_3|><|reserved_4|>" | |
"<|translate:exact|><|input_end|>" | |
"<copyright>" + copyright_tags.strip() | |
) | |
if not auto_detect: | |
tag_text += "</copyright><character></character><general>" | |
inputs = processor( | |
encoder_text=text, decoder_text=tag_text, return_tensors="pt" | |
) | |
start_time = time.time() | |
outputs = model.generate( | |
input_ids=inputs["input_ids"].to(model.device), | |
attention_mask=inputs["attention_mask"].to(model.device), | |
encoder_input_ids=inputs["encoder_input_ids"].to(model.device), | |
encoder_attention_mask=inputs["encoder_attention_mask"].to(model.device), | |
max_new_tokens=max_new_tokens, | |
do_sample=do_sample, | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
eos_token_id=processor.decoder_tokenizer.eos_token_id, | |
pad_token_id=processor.decoder_tokenizer.pad_token_id, | |
) | |
elapsed = time.time() - start_time | |
deocded = ", ".join( | |
[ | |
tag | |
for tag in processor.batch_decode(outputs[0], skip_special_tokens=True) | |
if tag.strip() != "" | |
] | |
) | |
return [deocded, f"Time elapsed: {elapsed:.2f} seconds"] | |
# warmup | |
print("warming up...") | |
print(generate_tags("Miku is looking at viewer.", True)) | |
print("done.") | |
with gr.Blocks() as ui: | |
with gr.Column(): | |
with gr.Row(): | |
with gr.Column(): | |
text = gr.Text(label="Text", lines=4) | |
auto_detect = gr.Checkbox( | |
label="Auto detect copyright tags.", value=False | |
) | |
copyright_tags = gr.Textbox( | |
label="Copyright tags", | |
placeholder="Enter copyright tags here. e.g.) hatsune miku", | |
) | |
translate_btn = gr.Button(value="Translate") | |
with gr.Accordion(label="Advanced", open=False): | |
max_new_tokens = gr.Number(label="Max new tokens", value=128) | |
do_sample = gr.Checkbox(label="Do sample", value=False) | |
temperature = gr.Slider( | |
label="Temperature", | |
minimum=0.1, | |
maximum=1.0, | |
value=0.1, | |
step=0.1, | |
) | |
top_k = gr.Slider( | |
label="Top k", | |
minimum=1, | |
maximum=100, | |
value=10, | |
step=1, | |
) | |
top_p = gr.Slider( | |
label="Top p", | |
minimum=0.1, | |
maximum=1.0, | |
value=0.1, | |
step=0.1, | |
) | |
with gr.Column(): | |
output = gr.Textbox(label="Output", lines=4, interactive=False) | |
time_elapsed = gr.Markdown(value="") | |
gr.Examples( | |
examples=[["Miku is looking at viewer.", True]], | |
inputs=[text, auto_detect], | |
) | |
gr.on( | |
triggers=[ | |
# text.change, | |
# auto_detect.change, | |
# copyright_tags.change, | |
translate_btn.click, | |
], | |
fn=generate_tags, | |
inputs=[ | |
text, | |
auto_detect, | |
copyright_tags, | |
max_new_tokens, | |
do_sample, | |
temperature, | |
top_k, | |
top_p, | |
], | |
outputs=[output, time_elapsed], | |
) | |
ui.launch() | |
if __name__ == "__main__": | |
demo() | |