import os import time import torch from transformers import ( AutoModelForPreTraining, AutoProcessor, AutoConfig, ) from huggingface_hub import hf_hub_download from safetensors.torch import load_file import gradio as gr MODEL_NAME = os.environ.get("MODEL_NAME", None) assert MODEL_NAME is not None MODEL_PATH = hf_hub_download(repo_id=MODEL_NAME, filename="model.safetensors") def fix_compiled_state_dict(state_dict: dict): return {k.replace("._orig_mod.", "."): v for k, v in state_dict.items()} def prepare_models(): config = AutoConfig.from_pretrained( MODEL_NAME, use_cache=True, trust_remote_code=True ) model = AutoModelForPreTraining.from_config( config, torch_dtype=torch.bfloat16, trust_remote_code=True ) processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True) state_dict = load_file(MODEL_PATH) state_dict = {k.replace("._orig_mod.", "."): v for k, v in state_dict.items()} model.load_state_dict(state_dict) model.eval() model = torch.compile(model) return model, processor def demo(): model, processor = prepare_models() @torch.inference_mode() def generate_tags( text: str, auto_detect: bool, copyright_tags: str, max_new_tokens: int = 128, do_sample: bool = False, temperature: float = 0.1, top_k: int = 10, top_p: float = 0.1, ): tag_text = ( "<|bos|>" "<|aspect_ratio:tall|><|rating:general|><|length:long|>" "<|reserved_2|><|reserved_3|><|reserved_4|>" "<|translate:exact|><|input_end|>" "" + copyright_tags.strip() ) if not auto_detect: tag_text += "" inputs = processor( encoder_text=text, decoder_text=tag_text, return_tensors="pt" ) start_time = time.time() outputs = model.generate( input_ids=inputs["input_ids"].to("cuda"), attention_mask=inputs["attention_mask"].to("cuda"), encoder_input_ids=inputs["encoder_input_ids"].to("cuda"), encoder_attention_mask=inputs["encoder_attention_mask"].to("cuda"), max_new_tokens=max_new_tokens, do_sample=do_sample, temperature=temperature, top_k=top_k, top_p=top_p, eos_token_id=processor.decoder_tokenizer.eos_token_id, pad_token_id=processor.decoder_tokenizer.pad_token_id, ) elapsed = time.time() - start_time deocded = ", ".join( [ tag for tag in processor.batch_decode(outputs[0], skip_special_tokens=True) if tag.strip() != "" ] ) return [deocded, f"Time elapsed: {elapsed:.2f} seconds"] with gr.Blocks() as ui: with gr.Row(): with gr.Column(): text = gr.Text(label="Text", lines=4) auto_detect = gr.Checkbox( label="Auto detect copyright tags.", value=False ) copyright_tags = gr.Textbox( label="Custom tags", placeholder="Enter custom tags here. e.g.) hatsune miku", ) translate_btn = gr.Button(value="Translate") with gr.Accordion(label="Advanced", open=False): max_new_tokens = gr.Number(label="Max new tokens", value=128) do_sample = gr.Checkbox(label="Do sample", value=False) temperature = gr.Slider( label="Temperature", minimum=0.1, maximum=1.0, value=0.1, step=0.1, ) top_k = gr.Number( label="Top k", value=10, ) top_p = gr.Slider( label="Top p", minimum=0.1, maximum=1.0, value=0.1, step=0.1, ) with gr.Column(): output = gr.Textbox(label="Output", lines=4, interactive=False) time_elapsed = gr.Markdown(value="") gr.Examples( examples=[["Miku is looking at viewer.", True]], inputs=[text, auto_detect], ) gr.on( triggers=[ text.change, auto_detect.change, copyright_tags.change, translate_btn.click, ], fn=generate_tags, inputs=[ text, auto_detect, copyright_tags, max_new_tokens, do_sample, temperature, top_k, top_p, ], outputs=[output, time_elapsed], ) ui.launch() if __name__ == "__main__": demo()