p1atdev commited on
Commit
d457afd
Β·
1 Parent(s): b53691a
Files changed (3) hide show
  1. README.md +2 -2
  2. app.py +162 -4
  3. requirements.txt +5 -0
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
- title: Tag To Nl Test
3
- emoji: πŸ†
4
  colorFrom: blue
5
  colorTo: green
6
  sdk: gradio
 
1
  ---
2
+ title: Natural Language Text To Tag Test
3
+ emoji: πŸ”–
4
  colorFrom: blue
5
  colorTo: green
6
  sdk: gradio
app.py CHANGED
@@ -1,7 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+
4
+
5
+ import torch
6
+ from transformers import (
7
+ AutoModelForPreTraining,
8
+ AutoProcessor,
9
+ AutoConfig,
10
+ )
11
+ from huggingface_hub import hf_hub_download
12
+ from safetensors.torch import load_file
13
  import gradio as gr
14
 
 
 
15
 
16
+ MODEL_NAME = os.environ.get("MODEL_NAME", None)
17
+ assert MODEL_NAME is not None
18
+ MODEL_PATH = hf_hub_download(repo_id=MODEL_NAME, filename="model.safetensors")
19
+
20
+
21
+ def fix_compiled_state_dict(state_dict: dict):
22
+ return {k.replace("._orig_mod.", "."): v for k, v in state_dict.items()}
23
+
24
+
25
+ def prepare_models():
26
+ config = AutoConfig.from_pretrained(
27
+ MODEL_NAME, use_cache=True, trust_remote_code=True
28
+ )
29
+ model = AutoModelForPreTraining.from_config(
30
+ config, torch_dtype=torch.bfloat16, trust_remote_code=True
31
+ )
32
+ processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
33
+
34
+ state_dict = load_file(MODEL_PATH)
35
+ state_dict = {k.replace("._orig_mod.", "."): v for k, v in state_dict.items()}
36
+ model.load_state_dict(state_dict)
37
+
38
+ model.eval()
39
+ model = torch.compile(model)
40
+
41
+ return model, processor
42
+
43
+
44
+ def demo():
45
+ model, processor = prepare_models()
46
+
47
+ @torch.inference_mode()
48
+ def generate_tags(
49
+ text: str,
50
+ auto_detect: bool,
51
+ copyright_tags: str,
52
+ max_new_tokens: int = 128,
53
+ do_sample: bool = False,
54
+ temperature: float = 0.1,
55
+ top_k: int = 10,
56
+ top_p: float = 0.1,
57
+ ):
58
+ tag_text = (
59
+ "<|bos|>"
60
+ "<|aspect_ratio:tall|><|rating:general|><|length:long|>"
61
+ "<|reserved_2|><|reserved_3|><|reserved_4|>"
62
+ "<|translate:exact|><|input_end|>"
63
+ "<copyright>" + copyright_tags.strip()
64
+ )
65
+ if not auto_detect:
66
+ tag_text += "</copyright><character></character><general>"
67
+ inputs = processor(
68
+ encoder_text=text, decoder_text=tag_text, return_tensors="pt"
69
+ )
70
+
71
+ start_time = time.time()
72
+ outputs = model.generate(
73
+ input_ids=inputs["input_ids"].to("cuda"),
74
+ attention_mask=inputs["attention_mask"].to("cuda"),
75
+ encoder_input_ids=inputs["encoder_input_ids"].to("cuda"),
76
+ encoder_attention_mask=inputs["encoder_attention_mask"].to("cuda"),
77
+ max_new_tokens=max_new_tokens,
78
+ do_sample=do_sample,
79
+ temperature=temperature,
80
+ top_k=top_k,
81
+ top_p=top_p,
82
+ eos_token_id=processor.decoder_tokenizer.eos_token_id,
83
+ pad_token_id=processor.decoder_tokenizer.pad_token_id,
84
+ )
85
+ elapsed = time.time() - start_time
86
+
87
+ deocded = ", ".join(
88
+ [
89
+ tag
90
+ for tag in processor.batch_decode(outputs[0], skip_special_tokens=True)
91
+ if tag.strip() != ""
92
+ ]
93
+ )
94
+ return [deocded, f"Time elapsed: {elapsed:.2f} seconds"]
95
+
96
+ with gr.Blocks() as ui:
97
+ with gr.Row():
98
+ with gr.Column():
99
+ text = gr.Text(label="Text", lines=4)
100
+ auto_detect = gr.Checkbox(
101
+ label="Auto detect copyright tags.", value=False
102
+ )
103
+ copyright_tags = gr.Textbox(
104
+ label="Custom tags",
105
+ placeholder="Enter custom tags here. e.g.) hatsune miku",
106
+ )
107
+ translate_btn = gr.Button(value="Translate")
108
+
109
+ with gr.Accordion(label="Advanced", open=False):
110
+ max_new_tokens = gr.Number(label="Max new tokens", value=128)
111
+ do_sample = gr.Checkbox(label="Do sample", value=False)
112
+ temperature = gr.Slider(
113
+ label="Temperature",
114
+ minimum=0.1,
115
+ maximum=1.0,
116
+ value=0.1,
117
+ step=0.1,
118
+ )
119
+ top_k = gr.Number(
120
+ label="Top k",
121
+ value=10,
122
+ )
123
+ top_p = gr.Slider(
124
+ label="Top p",
125
+ minimum=0.1,
126
+ maximum=1.0,
127
+ value=0.1,
128
+ step=0.1,
129
+ )
130
+
131
+ with gr.Column():
132
+ output = gr.Textbox(label="Output", lines=4, interactive=False)
133
+ time_elapsed = gr.Markdown(value="")
134
+
135
+ gr.Examples(
136
+ examples=[["Miku is looking at viewer.", True]],
137
+ inputs=[text, auto_detect],
138
+ )
139
+
140
+ gr.on(
141
+ triggers=[
142
+ text.change,
143
+ auto_detect.change,
144
+ copyright_tags.change,
145
+ translate_btn.click,
146
+ ],
147
+ fn=generate_tags,
148
+ inputs=[
149
+ text,
150
+ auto_detect,
151
+ copyright_tags,
152
+ max_new_tokens,
153
+ do_sample,
154
+ temperature,
155
+ top_k,
156
+ top_p,
157
+ ],
158
+ outputs=[output, time_elapsed],
159
+ )
160
+
161
+ ui.launch()
162
+
163
+
164
+ if __name__ == "__main__":
165
+ demo()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ accelerate
4
+ safetensors
5
+ huggingface_hub