Denliner commited on
Commit
5bf67a7
·
verified ·
1 Parent(s): 13b77ec

try to update it

Browse files
Files changed (1) hide show
  1. app.py +304 -242
app.py CHANGED
@@ -1,54 +1,66 @@
1
- from __future__ import annotations
2
-
3
  import argparse
4
- import functools
5
- import html
6
  import os
7
 
8
-
9
  import gradio as gr
10
  import huggingface_hub
11
  import numpy as np
12
  import onnxruntime as rt
13
  import pandas as pd
14
- import piexif
15
- import piexif.helper
16
-
17
- import PIL.Image
18
 
19
- from Utils import dbimutils
20
 
21
- TITLE = "WaifuDiffusion v1.4 Tags"
22
  DESCRIPTION = """
23
  This is an edited version of SmilingWolf's wd-1.4 taggs, which I have modified so that you don't have to remove the commas when you label an image for a booru website
24
 
25
  https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags
26
 
27
- Demo for:
28
- - [SmilingWolf/wd-v1-4-moat-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-moat-tagger-v2)
29
- - [SmilingWolf/wd-v1-4-swinv2-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger-v2)
30
- - [SmilingWolf/wd-v1-4-convnext-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger-v2)
31
- - [SmilingWolf/wd-v1-4-vit-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-vit-tagger-v2)
32
- - [SmilingWolf/wd-v1-4-convnextv2-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnextv2-tagger-v2)
33
- Includes "ready to copy" prompt and a prompt analyzer.
34
-
35
- Modified from [NoCrypt/DeepDanbooru_string](https://huggingface.co/spaces/NoCrypt/DeepDanbooru_string)
36
- Modified from [hysts/DeepDanbooru](https://huggingface.co/spaces/hysts/DeepDanbooru)
37
-
38
- PNG Info code forked from [AUTOMATIC1111/stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui)
39
 
40
  Example image by [ほし☆☆☆](https://www.pixiv.net/en/users/43565085)
41
  """
42
 
43
  HF_TOKEN = os.environ["HF_TOKEN"]
44
- MOAT_MODEL_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
45
- SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
46
- CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
47
- CONV2_MODEL_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
48
- VIT_MODEL_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
 
 
 
 
 
 
 
 
 
49
  MODEL_FILENAME = "model.onnx"
50
  LABEL_FILENAME = "selected_tags.csv"
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  def parse_args() -> argparse.Namespace:
54
  parser = argparse.ArgumentParser()
@@ -59,231 +71,281 @@ def parse_args() -> argparse.Namespace:
59
  return parser.parse_args()
60
 
61
 
62
- def load_model(model_repo: str, model_filename: str) -> rt.InferenceSession:
63
- path = huggingface_hub.hf_hub_download(
64
- model_repo, model_filename, use_auth_token=HF_TOKEN
65
- )
66
- model = rt.InferenceSession(path)
67
- return model
68
-
69
-
70
- def change_model(model_name):
71
- global loaded_models
72
- if model_name == "MOAT":
73
- model = load_model(MOAT_MODEL_REPO, MODEL_FILENAME)
74
- elif model_name == "SwinV2":
75
- model = load_model(SWIN_MODEL_REPO, MODEL_FILENAME)
76
- elif model_name == "ConvNext":
77
- model = load_model(CONV_MODEL_REPO, MODEL_FILENAME)
78
- elif model_name == "ViT":
79
- model = load_model(VIT_MODEL_REPO, MODEL_FILENAME)
80
- elif model_name == "ConvNextV2":
81
- model = load_model(CONV2_MODEL_REPO, MODEL_FILENAME)
82
-
83
- loaded_models[model_name] = model
84
- return loaded_models[model_name]
85
-
86
-
87
- def load_labels() -> list[str]:
88
- path = huggingface_hub.hf_hub_download(
89
- MOAT_MODEL_REPO, LABEL_FILENAME, use_auth_token=HF_TOKEN
90
- )
91
- df = pd.read_csv(path)
92
-
93
- tag_names = df["name"].tolist()
94
- rating_indexes = list(np.where(df["category"] == 9)[0])
95
- general_indexes = list(np.where(df["category"] == 0)[0])
96
- character_indexes = list(np.where(df["category"] == 4)[0])
97
  return tag_names, rating_indexes, general_indexes, character_indexes
98
 
99
 
100
- def plaintext_to_html(text):
101
- text = (
102
- "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split("\n")]) + "</p>"
103
- )
104
- return text
105
-
106
-
107
- def predict(
108
- image: PIL.Image.Image,
109
- model_name: str,
110
- general_threshold: float,
111
- character_threshold: float,
112
- tag_names: list[str],
113
- rating_indexes: list[np.int64],
114
- general_indexes: list[np.int64],
115
- character_indexes: list[np.int64],
116
- ):
117
- global loaded_models
118
-
119
- rawimage = image
120
-
121
- model = loaded_models[model_name]
122
- if model is None:
123
- model = change_model(model_name)
124
-
125
- _, height, width, _ = model.get_inputs()[0].shape
126
-
127
- # Alpha to white
128
- image = image.convert("RGBA")
129
- new_image = PIL.Image.new("RGBA", image.size, "WHITE")
130
- new_image.paste(image, mask=image)
131
- image = new_image.convert("RGB")
132
- image = np.asarray(image)
133
-
134
- # PIL RGB to OpenCV BGR
135
- image = image[:, :, ::-1]
136
-
137
- image = dbimutils.make_square(image, height)
138
- image = dbimutils.smart_resize(image, height)
139
- image = image.astype(np.float32)
140
- image = np.expand_dims(image, 0)
141
-
142
- input_name = model.get_inputs()[0].name
143
- label_name = model.get_outputs()[0].name
144
- probs = model.run([label_name], {input_name: image})[0]
145
-
146
- labels = list(zip(tag_names, probs[0].astype(float)))
147
-
148
- # First 4 labels are actually ratings: pick one with argmax
149
- ratings_names = [labels[i] for i in rating_indexes]
150
- rating = dict(ratings_names)
151
-
152
- # Then we have general tags: pick any where prediction confidence > threshold
153
- general_names = [labels[i] for i in general_indexes]
154
- general_res = [x for x in general_names if x[1] > general_threshold]
155
- general_res = dict(general_res)
156
-
157
- # Everything else is characters: pick any where prediction confidence > threshold
158
- character_names = [labels[i] for i in character_indexes]
159
- character_res = [x for x in character_names if x[1] > character_threshold]
160
- character_res = dict(character_res)
161
-
162
- b = dict(sorted(general_res.items(), key=lambda item: item[1], reverse=True))
163
- a = (
164
- ", ".join(list(b.keys()))
165
- .replace("_", " ")
166
- .replace("(", "\(")
167
- .replace(")", "\)")
168
- )
169
- c = ", ".join(list(b.keys()))
170
- d = " ".join(list(b.keys()))
171
- items = rawimage.info
172
- geninfo = ""
173
-
174
- if "exif" in rawimage.info:
175
- exif = piexif.load(rawimage.info["exif"])
176
- exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b"")
177
- try:
178
- exif_comment = piexif.helper.UserComment.load(exif_comment)
179
- except ValueError:
180
- exif_comment = exif_comment.decode("utf8", errors="ignore")
181
-
182
- items["exif comment"] = exif_comment
183
- geninfo = exif_comment
184
-
185
- for field in [
186
- "jfif",
187
- "jfif_version",
188
- "jfif_unit",
189
- "jfif_density",
190
- "dpi",
191
- "exif",
192
- "loop",
193
- "background",
194
- "timestamp",
195
- "duration",
196
- ]:
197
- items.pop(field, None)
198
-
199
- geninfo = items.get("parameters", geninfo)
200
-
201
- info = f"""
202
- <p><h4>PNG Info</h4></p>
203
- """
204
- for key, text in items.items():
205
- info += (
206
- f"""
207
- <div>
208
- <p><b>{plaintext_to_html(str(key))}</b></p>
209
- <p>{plaintext_to_html(str(text))}</p>
210
- </div>
211
- """.strip()
212
- + "\n"
213
  )
 
214
 
215
- if len(info) == 0:
216
- message = "Nothing found in the image."
217
- info = f"<div><p>{message}<p></div>"
218
 
219
- return (a, c,d, rating, character_res, general_res, info)
220
 
 
 
221
 
222
- def main():
223
- global loaded_models
224
- loaded_models = {
225
- "MOAT": None,
226
- "SwinV2": None,
227
- "ConvNext": None,
228
- "ConvNextV2": None,
229
- "ViT": None,
230
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  args = parse_args()
233
 
234
- change_model("MOAT")
235
-
236
- tag_names, rating_indexes, general_indexes, character_indexes = load_labels()
237
-
238
- func = functools.partial(
239
- predict,
240
- tag_names=tag_names,
241
- rating_indexes=rating_indexes,
242
- general_indexes=general_indexes,
243
- character_indexes=character_indexes,
244
- )
245
-
246
- gr.Interface(
247
- fn=func,
248
- inputs=[
249
- gr.Image(type="pil", label="Input"),
250
- gr.Radio(
251
- ["MOAT", "SwinV2", "ConvNext", "ConvNextV2", "ViT"],
252
- value="MOAT",
253
- label="Model",
254
- ),
255
- gr.Slider(
256
- 0,
257
- 1,
258
- step=args.score_slider_step,
259
- value=args.score_general_threshold,
260
- label="General Tags Threshold",
261
- ),
262
- gr.Slider(
263
- 0,
264
- 1,
265
- step=args.score_slider_step,
266
- value=args.score_character_threshold,
267
- label="Character Tags Threshold",
268
- ),
269
- ],
270
- outputs=[
271
- gr.Textbox(label="Output (string)"),
272
- gr.Textbox(label="Output (raw string)"),
273
- gr.Textbox(label="Output (booru string)"),
274
- gr.Label(label="Rating"),
275
- gr.Label(label="Output (characters)"),
276
- gr.Label(label="Output (tags)"),
277
- gr.HTML(),
278
- ],
279
- examples=[["power.jpg", "MOAT", 0.1, 0.85]],
280
- title=TITLE,
281
- description=DESCRIPTION,
282
- allow_flagging="never",
283
- ).launch(
284
- enable_queue=True,
285
- share=args.share,
286
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
 
288
 
289
  if __name__ == "__main__":
 
 
 
1
  import argparse
 
 
2
  import os
3
 
 
4
  import gradio as gr
5
  import huggingface_hub
6
  import numpy as np
7
  import onnxruntime as rt
8
  import pandas as pd
9
+ from PIL import Image
 
 
 
10
 
11
+ TITLE = "WaifuDiffusion Tagger"
12
 
 
13
  DESCRIPTION = """
14
  This is an edited version of SmilingWolf's wd-1.4 taggs, which I have modified so that you don't have to remove the commas when you label an image for a booru website
15
 
16
  https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags
17
 
18
+ Demo for the WaifuDiffusion tagger models
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  Example image by [ほし☆☆☆](https://www.pixiv.net/en/users/43565085)
21
  """
22
 
23
  HF_TOKEN = os.environ["HF_TOKEN"]
24
+
25
+ # Dataset v3 series of models:
26
+ SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
27
+ CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3"
28
+ VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
29
+
30
+ # Dataset v2 series of models:
31
+ MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
32
+ SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
33
+ CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
34
+ CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
35
+ VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
36
+
37
+ # Files to download from the repos
38
  MODEL_FILENAME = "model.onnx"
39
  LABEL_FILENAME = "selected_tags.csv"
40
 
41
+ # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/a9eacb1eff904552d3012babfa28b57e1d3e295c/tagger/ui.py#L368
42
+ kaomojis = [
43
+ "0_0",
44
+ "(o)_(o)",
45
+ "+_+",
46
+ "+_-",
47
+ "._.",
48
+ "<o>_<o>",
49
+ "<|>_<|>",
50
+ "=_=",
51
+ ">_<",
52
+ "3_3",
53
+ "6_9",
54
+ ">_o",
55
+ "@_@",
56
+ "^_^",
57
+ "o_o",
58
+ "u_u",
59
+ "x_x",
60
+ "|_|",
61
+ "||_||",
62
+ ]
63
+
64
 
65
  def parse_args() -> argparse.Namespace:
66
  parser = argparse.ArgumentParser()
 
71
  return parser.parse_args()
72
 
73
 
74
+ def load_labels(dataframe) -> list[str]:
75
+ name_series = dataframe["name"]
76
+ # name_series = name_series.map(
77
+ # lambda x: x.replace("_", " ") if x not in kaomojis else x
78
+ # )
79
+ tag_names = name_series.tolist()
80
+
81
+ rating_indexes = list(np.where(dataframe["category"] == 9)[0])
82
+ general_indexes = list(np.where(dataframe["category"] == 0)[0])
83
+ character_indexes = list(np.where(dataframe["category"] == 4)[0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  return tag_names, rating_indexes, general_indexes, character_indexes
85
 
86
 
87
+ def mcut_threshold(probs):
88
+ """
89
+ Maximum Cut Thresholding (MCut)
90
+ Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy
91
+ for Multi-label Classification. In 11th International Symposium, IDA 2012
92
+ (pp. 172-183).
93
+ """
94
+ sorted_probs = probs[probs.argsort()[::-1]]
95
+ difs = sorted_probs[:-1] - sorted_probs[1:]
96
+ t = difs.argmax()
97
+ thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2
98
+ return thresh
99
+
100
+
101
+ class Predictor:
102
+ def __init__(self):
103
+ self.model_target_size = None
104
+ self.last_loaded_repo = None
105
+
106
+ def download_model(self, model_repo):
107
+ csv_path = huggingface_hub.hf_hub_download(
108
+ model_repo,
109
+ LABEL_FILENAME,
110
+ use_auth_token=HF_TOKEN,
111
+ )
112
+ model_path = huggingface_hub.hf_hub_download(
113
+ model_repo,
114
+ MODEL_FILENAME,
115
+ use_auth_token=HF_TOKEN,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  )
117
+ return csv_path, model_path
118
 
119
+ def load_model(self, model_repo):
120
+ if model_repo == self.last_loaded_repo:
121
+ return
122
 
123
+ csv_path, model_path = self.download_model(model_repo)
124
 
125
+ tags_df = pd.read_csv(csv_path)
126
+ sep_tags = load_labels(tags_df)
127
 
128
+ self.tag_names = sep_tags[0]
129
+ self.rating_indexes = sep_tags[1]
130
+ self.general_indexes = sep_tags[2]
131
+ self.character_indexes = sep_tags[3]
132
+
133
+ model = rt.InferenceSession(model_path)
134
+ _, height, width, _ = model.get_inputs()[0].shape
135
+ self.model_target_size = height
136
+
137
+ self.last_loaded_repo = model_repo
138
+ self.model = model
139
+
140
+ def prepare_image(self, image):
141
+ target_size = self.model_target_size
142
+
143
+ canvas = Image.new("RGBA", image.size, (255, 255, 255))
144
+ canvas.alpha_composite(image)
145
+ image = canvas.convert("RGB")
146
+
147
+ # Pad image to square
148
+ image_shape = image.size
149
+ max_dim = max(image_shape)
150
+ pad_left = (max_dim - image_shape[0]) // 2
151
+ pad_top = (max_dim - image_shape[1]) // 2
152
+
153
+ padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
154
+ padded_image.paste(image, (pad_left, pad_top))
155
+
156
+ # Resize
157
+ if max_dim != target_size:
158
+ padded_image = padded_image.resize(
159
+ (target_size, target_size),
160
+ Image.BICUBIC,
161
+ )
162
+
163
+ # Convert to numpy array
164
+ image_array = np.asarray(padded_image, dtype=np.float32)
165
+
166
+ # Convert PIL-native RGB to BGR
167
+ image_array = image_array[:, :, ::-1]
168
+
169
+ return np.expand_dims(image_array, axis=0)
170
 
171
+ def predict(
172
+ self,
173
+ image,
174
+ model_repo,
175
+ general_thresh,
176
+ general_mcut_enabled,
177
+ character_thresh,
178
+ character_mcut_enabled,
179
+ ):
180
+ self.load_model(model_repo)
181
+
182
+ image = self.prepare_image(image)
183
+
184
+ input_name = self.model.get_inputs()[0].name
185
+ label_name = self.model.get_outputs()[0].name
186
+ preds = self.model.run([label_name], {input_name: image})[0]
187
+
188
+ labels = list(zip(self.tag_names, preds[0].astype(float)))
189
+
190
+ # First 4 labels are actually ratings: pick one with argmax
191
+ ratings_names = [labels[i] for i in self.rating_indexes]
192
+ rating = dict(ratings_names)
193
+
194
+ # Then we have general tags: pick any where prediction confidence > threshold
195
+ general_names = [labels[i] for i in self.general_indexes]
196
+
197
+ if general_mcut_enabled:
198
+ general_probs = np.array([x[1] for x in general_names])
199
+ general_thresh = mcut_threshold(general_probs)
200
+
201
+ general_res = [x for x in general_names if x[1] > general_thresh]
202
+ general_res = dict(general_res)
203
+
204
+ # Everything else is characters: pick any where prediction confidence > threshold
205
+ character_names = [labels[i] for i in self.character_indexes]
206
+
207
+ if character_mcut_enabled:
208
+ character_probs = np.array([x[1] for x in character_names])
209
+ character_thresh = mcut_threshold(character_probs)
210
+ character_thresh = max(0.15, character_thresh)
211
+
212
+ character_res = [x for x in character_names if x[1] > character_thresh]
213
+ character_res = dict(character_res)
214
+
215
+ sorted_general_strings = sorted(
216
+ general_res.items(),
217
+ key=lambda x: x[1],
218
+ reverse=True,
219
+ )
220
+ sorted_general_strings = [x[0] for x in sorted_general_strings]
221
+ sorted_booru_strings = (
222
+ " ".join(sorted_general_strings)
223
+ )
224
+ sorted_general_strings = (
225
+ ", ".join(sorted_general_strings).replace("(", "\(").replace(")", "\)")
226
+ )
227
+ sorted_general_strings = sorted_general_strings.map(
228
+ lambda x: x.replace("_", " ") if x not in kaomojis else x
229
+ )
230
+
231
+ return sorted_general_strings, sorted_booru_strings, rating, character_res, general_res
232
+
233
+
234
+ def main():
235
  args = parse_args()
236
 
237
+ predictor = Predictor()
238
+
239
+ dropdown_list = [
240
+ SWINV2_MODEL_DSV3_REPO,
241
+ CONV_MODEL_DSV3_REPO,
242
+ VIT_MODEL_DSV3_REPO,
243
+ MOAT_MODEL_DSV2_REPO,
244
+ SWIN_MODEL_DSV2_REPO,
245
+ CONV_MODEL_DSV2_REPO,
246
+ CONV2_MODEL_DSV2_REPO,
247
+ VIT_MODEL_DSV2_REPO,
248
+ ]
249
+
250
+ with gr.Blocks(title=TITLE) as demo:
251
+ with gr.Column():
252
+ gr.Markdown(
253
+ value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>"
254
+ )
255
+ gr.Markdown(value=DESCRIPTION)
256
+ with gr.Row():
257
+ with gr.Column(variant="panel"):
258
+ image = gr.Image(type="pil", image_mode="RGBA", label="Input")
259
+ model_repo = gr.Dropdown(
260
+ dropdown_list,
261
+ value=SWINV2_MODEL_DSV3_REPO,
262
+ label="Model",
263
+ )
264
+ with gr.Row():
265
+ general_thresh = gr.Slider(
266
+ 0,
267
+ 1,
268
+ step=args.score_slider_step,
269
+ value=args.score_general_threshold,
270
+ label="General Tags Threshold",
271
+ scale=3,
272
+ )
273
+ general_mcut_enabled = gr.Checkbox(
274
+ value=False,
275
+ label="Use MCut threshold",
276
+ scale=1,
277
+ )
278
+ with gr.Row():
279
+ character_thresh = gr.Slider(
280
+ 0,
281
+ 1,
282
+ step=args.score_slider_step,
283
+ value=args.score_character_threshold,
284
+ label="Character Tags Threshold",
285
+ scale=3,
286
+ )
287
+ character_mcut_enabled = gr.Checkbox(
288
+ value=False,
289
+ label="Use MCut threshold",
290
+ scale=1,
291
+ )
292
+ with gr.Row():
293
+ clear = gr.ClearButton(
294
+ components=[
295
+ image,
296
+ model_repo,
297
+ general_thresh,
298
+ general_mcut_enabled,
299
+ character_thresh,
300
+ character_mcut_enabled,
301
+ ],
302
+ variant="secondary",
303
+ size="lg",
304
+ )
305
+ submit = gr.Button(value="Submit", variant="primary", size="lg")
306
+ with gr.Column(variant="panel"):
307
+ sorted_general_strings = gr.Textbox(label="Output (string)")
308
+ sorted_booru_strings = gr.Textbox(label="Output (string)")
309
+ rating = gr.Label(label="Rating")
310
+ character_res = gr.Label(label="Output (characters)")
311
+ general_res = gr.Label(label="Output (tags)")
312
+ clear.add(
313
+ [
314
+ sorted_general_strings,
315
+ sorted_booru_strings,
316
+ rating,
317
+ character_res,
318
+ general_res,
319
+ ]
320
+ )
321
+
322
+ submit.click(
323
+ predictor.predict,
324
+ inputs=[
325
+ image,
326
+ model_repo,
327
+ general_thresh,
328
+ general_mcut_enabled,
329
+ character_thresh,
330
+ character_mcut_enabled,
331
+ ],
332
+ outputs=[sorted_general_strings,sorted_booru_strings, rating, character_res, general_res],
333
+ )
334
+
335
+ gr.Examples(
336
+ [["power.jpg", SWINV2_MODEL_DSV3_REPO, 0.35, False, 0.85, False]],
337
+ inputs=[
338
+ image,
339
+ model_repo,
340
+ general_thresh,
341
+ general_mcut_enabled,
342
+ character_thresh,
343
+ character_mcut_enabled,
344
+ ],
345
+ )
346
+
347
+ demo.queue(max_size=10)
348
+ demo.launch()
349
 
350
 
351
  if __name__ == "__main__":