John6666 commited on
Commit
1d7686b
·
verified ·
1 Parent(s): deba324

Delete v2.py

Browse files
Files changed (1) hide show
  1. v2.py +0 -260
v2.py DELETED
@@ -1,260 +0,0 @@
1
- import time
2
- import torch
3
- from typing import Callable
4
- from pathlib import Path
5
-
6
- from dartrs.v2 import (
7
- V2Model,
8
- MixtralModel,
9
- MistralModel,
10
- compose_prompt,
11
- LengthTag,
12
- AspectRatioTag,
13
- RatingTag,
14
- IdentityTag,
15
- )
16
- from dartrs.dartrs import DartTokenizer
17
- from dartrs.utils import get_generation_config
18
-
19
-
20
- import gradio as gr
21
- from gradio.components import Component
22
-
23
-
24
- try:
25
- from output import UpsamplingOutput
26
- except:
27
- from .output import UpsamplingOutput
28
-
29
-
30
- V2_ALL_MODELS = {
31
- "dart-v2-moe-sft": {
32
- "repo": "p1atdev/dart-v2-moe-sft",
33
- "type": "sft",
34
- "class": MixtralModel,
35
- },
36
- "dart-v2-sft": {
37
- "repo": "p1atdev/dart-v2-sft",
38
- "type": "sft",
39
- "class": MistralModel,
40
- },
41
- }
42
-
43
-
44
- def prepare_models(model_config: dict):
45
- model_name = model_config["repo"]
46
- tokenizer = DartTokenizer.from_pretrained(model_name)
47
- model = model_config["class"].from_pretrained(model_name)
48
-
49
- return {
50
- "tokenizer": tokenizer,
51
- "model": model,
52
- }
53
-
54
-
55
- def normalize_tags(tokenizer: DartTokenizer, tags: str):
56
- """Just remove unk tokens."""
57
- return ", ".join([tag for tag in tokenizer.tokenize(tags) if tag != "<|unk|>"])
58
-
59
-
60
- @torch.no_grad()
61
- def generate_tags(
62
- model: V2Model,
63
- tokenizer: DartTokenizer,
64
- prompt: str,
65
- ban_token_ids: list[int],
66
- ):
67
- output = model.generate(
68
- get_generation_config(
69
- prompt,
70
- tokenizer=tokenizer,
71
- temperature=1,
72
- top_p=0.9,
73
- top_k=100,
74
- max_new_tokens=256,
75
- ban_token_ids=ban_token_ids,
76
- ),
77
- )
78
-
79
- return output
80
-
81
-
82
- def _people_tag(noun: str, minimum: int = 1, maximum: int = 5):
83
- return (
84
- [f"1{noun}"]
85
- + [f"{num}{noun}s" for num in range(minimum + 1, maximum + 1)]
86
- + [f"{maximum+1}+{noun}s"]
87
- )
88
-
89
-
90
- PEOPLE_TAGS = (
91
- _people_tag("girl") + _people_tag("boy") + _people_tag("other") + ["no humans"]
92
- )
93
-
94
-
95
- def gen_prompt_text(output: UpsamplingOutput):
96
- # separate people tags (e.g. 1girl)
97
- people_tags = []
98
- other_general_tags = []
99
-
100
- for tag in output.general_tags.split(","):
101
- tag = tag.strip()
102
- if tag in PEOPLE_TAGS:
103
- people_tags.append(tag)
104
- else:
105
- other_general_tags.append(tag)
106
-
107
- return ", ".join(
108
- [
109
- part.strip()
110
- for part in [
111
- *people_tags,
112
- output.character_tags,
113
- output.copyright_tags,
114
- *other_general_tags,
115
- output.upsampled_tags,
116
- output.rating_tag,
117
- ]
118
- if part.strip() != ""
119
- ]
120
- )
121
-
122
-
123
- def elapsed_time_format(elapsed_time: float) -> str:
124
- return f"Elapsed: {elapsed_time:.2f} seconds"
125
-
126
-
127
- def parse_upsampling_output(
128
- upsampler: Callable[..., UpsamplingOutput],
129
- ):
130
- def _parse_upsampling_output(*args) -> tuple[str, str, dict]:
131
- output = upsampler(*args)
132
-
133
- return (
134
- gen_prompt_text(output),
135
- elapsed_time_format(output.elapsed_time),
136
- gr.update(interactive=True),
137
- gr.update(interactive=True),
138
- )
139
-
140
- return _parse_upsampling_output
141
-
142
-
143
- class V2UI:
144
- model_name: str | None = None
145
- model: V2Model
146
- tokenizer: DartTokenizer
147
-
148
- input_components: list[Component] = []
149
- generate_btn: gr.Button
150
-
151
- def on_generate(
152
- self,
153
- model_name: str,
154
- copyright_tags: str,
155
- character_tags: str,
156
- general_tags: str,
157
- rating_tag: RatingTag,
158
- aspect_ratio_tag: AspectRatioTag,
159
- length_tag: LengthTag,
160
- identity_tag: IdentityTag,
161
- ban_tags: str,
162
- *args,
163
- ) -> UpsamplingOutput:
164
- if self.model_name is None or self.model_name != model_name:
165
- models = prepare_models(V2_ALL_MODELS[model_name])
166
- self.model = models["model"]
167
- self.tokenizer = models["tokenizer"]
168
- self.model_name = model_name
169
-
170
- # normalize tags
171
- # copyright_tags = normalize_tags(self.tokenizer, copyright_tags)
172
- # character_tags = normalize_tags(self.tokenizer, character_tags)
173
- # general_tags = normalize_tags(self.tokenizer, general_tags)
174
-
175
- ban_token_ids = self.tokenizer.encode(ban_tags.strip())
176
-
177
- prompt = compose_prompt(
178
- prompt=general_tags,
179
- copyright=copyright_tags,
180
- character=character_tags,
181
- rating=rating_tag,
182
- aspect_ratio=aspect_ratio_tag,
183
- length=length_tag,
184
- identity=identity_tag,
185
- )
186
-
187
- start = time.time()
188
- upsampled_tags = generate_tags(
189
- self.model,
190
- self.tokenizer,
191
- prompt,
192
- ban_token_ids,
193
- )
194
- elapsed_time = time.time() - start
195
-
196
- return UpsamplingOutput(
197
- upsampled_tags=upsampled_tags,
198
- copyright_tags=copyright_tags,
199
- character_tags=character_tags,
200
- general_tags=general_tags,
201
- rating_tag=rating_tag,
202
- aspect_ratio_tag=aspect_ratio_tag,
203
- length_tag=length_tag,
204
- identity_tag=identity_tag,
205
- elapsed_time=elapsed_time,
206
- )
207
-
208
-
209
- def parse_upsampling_output_simple(upsampler: UpsamplingOutput):
210
- return gen_prompt_text(upsampler)
211
-
212
-
213
- v2 = V2UI()
214
-
215
-
216
- def v2_upsampling_prompt(model: str = "dart-v2-moe-sft", copyright: str = "", character: str = "",
217
- general_tags: str = "", rating: str = "nsfw", aspect_ratio: str = "square",
218
- length: str = "very_long", identity: str = "lax", ban_tags: str = "censored"):
219
- raw_prompt = parse_upsampling_output_simple(v2.on_generate(model, copyright, character, general_tags,
220
- rating, aspect_ratio, length, identity, ban_tags))
221
- return raw_prompt
222
-
223
-
224
- def load_dict_from_csv(filename):
225
- dict = {}
226
- if not Path(filename).exists():
227
- if Path('./tagger/', filename).exists(): filename = str(Path('./tagger/', filename))
228
- else: return dict
229
- try:
230
- with open(filename, 'r', encoding="utf-8") as f:
231
- lines = f.readlines()
232
- except Exception:
233
- print(f"Failed to open dictionary file: {filename}")
234
- return dict
235
- for line in lines:
236
- parts = line.strip().split(',')
237
- dict[parts[0]] = parts[1]
238
- return dict
239
-
240
-
241
- anime_series_dict = load_dict_from_csv('character_series_dict.csv')
242
-
243
-
244
- def select_random_character(series: str, character: str):
245
- from random import seed, randrange
246
- seed()
247
- character_list = list(anime_series_dict.keys())
248
- character = character_list[randrange(len(character_list) - 1)]
249
- series = anime_series_dict.get(character.split(",")[0].strip(), "")
250
- return series, character
251
-
252
-
253
- def v2_random_prompt(general_tags: str = "", copyright: str = "", character: str = "", rating: str = "nsfw",
254
- aspect_ratio: str = "square", length: str = "very_long", identity: str = "lax",
255
- ban_tags: str = "censored", model: str = "dart-v2-moe-sft"):
256
- if copyright == "" and character == "":
257
- copyright, character = select_random_character("", "")
258
- raw_prompt = v2_upsampling_prompt(model, copyright, character, general_tags, rating,
259
- aspect_ratio, length, identity, ban_tags)
260
- return raw_prompt, copyright, character