Denliner commited on
Commit
b7ecbff
·
1 Parent(s): c6e87b3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -6
app.py CHANGED
@@ -26,7 +26,7 @@ Demo for:
26
  - [SmilingWolf/wd-v1-4-swinv2-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger-v2)
27
  - [SmilingWolf/wd-v1-4-convnext-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger-v2)
28
  - [SmilingWolf/wd-v1-4-vit-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-vit-tagger-v2)
29
-
30
  Includes "ready to copy" prompt and a prompt analyzer.
31
 
32
  Modified from [NoCrypt/DeepDanbooru_string](https://huggingface.co/spaces/NoCrypt/DeepDanbooru_string)
@@ -40,6 +40,7 @@ Example image by [ほし☆☆☆](https://www.pixiv.net/en/users/43565085)
40
  HF_TOKEN = os.environ["HF_TOKEN"]
41
  SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
42
  CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
 
43
  VIT_MODEL_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
44
  MODEL_FILENAME = "model.onnx"
45
  LABEL_FILENAME = "selected_tags.csv"
@@ -71,6 +72,8 @@ def change_model(model_name):
71
  model = load_model(CONV_MODEL_REPO, MODEL_FILENAME)
72
  elif model_name == "ViT":
73
  model = load_model(VIT_MODEL_REPO, MODEL_FILENAME)
 
 
74
 
75
  loaded_models[model_name] = model
76
  return loaded_models[model_name]
@@ -78,7 +81,7 @@ def change_model(model_name):
78
 
79
  def load_labels() -> list[str]:
80
  path = huggingface_hub.hf_hub_download(
81
- SWIN_MODEL_REPO, LABEL_FILENAME, use_auth_token=HF_TOKEN
82
  )
83
  df = pd.read_csv(path)
84
 
@@ -213,11 +216,11 @@ def predict(
213
 
214
  def main():
215
  global loaded_models
216
- loaded_models = {"SwinV2": None, "ConvNext": None, "ViT": None}
217
 
218
  args = parse_args()
219
 
220
- change_model("SwinV2")
221
 
222
  tag_names, rating_indexes, general_indexes, character_indexes = load_labels()
223
 
@@ -233,7 +236,7 @@ def main():
233
  fn=func,
234
  inputs=[
235
  gr.Image(type="pil", label="Input"),
236
- gr.Radio(["SwinV2", "ConvNext", "ViT"], value="SwinV2", label="Model"),
237
  gr.Slider(
238
  0,
239
  1,
@@ -258,7 +261,7 @@ def main():
258
  gr.Label(label="Output (tags)"),
259
  gr.HTML(),
260
  ],
261
- examples=[["power.jpg", "SwinV2", 0.35, 0.85]],
262
  title=TITLE,
263
  description=DESCRIPTION,
264
  allow_flagging="never",
 
26
  - [SmilingWolf/wd-v1-4-swinv2-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger-v2)
27
  - [SmilingWolf/wd-v1-4-convnext-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger-v2)
28
  - [SmilingWolf/wd-v1-4-vit-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-vit-tagger-v2)
29
+ - [SmilingWolf/wd-v1-4-convnextv2-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnextv2-tagger-v2)
30
  Includes "ready to copy" prompt and a prompt analyzer.
31
 
32
  Modified from [NoCrypt/DeepDanbooru_string](https://huggingface.co/spaces/NoCrypt/DeepDanbooru_string)
 
40
  HF_TOKEN = os.environ["HF_TOKEN"]
41
  SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
42
  CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
43
+ CONV2_MODEL_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
44
  VIT_MODEL_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
45
  MODEL_FILENAME = "model.onnx"
46
  LABEL_FILENAME = "selected_tags.csv"
 
72
  model = load_model(CONV_MODEL_REPO, MODEL_FILENAME)
73
  elif model_name == "ViT":
74
  model = load_model(VIT_MODEL_REPO, MODEL_FILENAME)
75
+ elif model_name == "ConvNextV2":
76
+ model = load_model(CONV2_MODEL_REPO, MODEL_FILENAME)
77
 
78
  loaded_models[model_name] = model
79
  return loaded_models[model_name]
 
81
 
82
  def load_labels() -> list[str]:
83
  path = huggingface_hub.hf_hub_download(
84
+ CONV2_MODEL_REPO, LABEL_FILENAME, use_auth_token=HF_TOKEN
85
  )
86
  df = pd.read_csv(path)
87
 
 
216
 
217
  def main():
218
  global loaded_models
219
+ loaded_models = {"SwinV2": None, "ConvNext": None,"ConvNextV2": None, "ViT": None}
220
 
221
  args = parse_args()
222
 
223
+ change_model("ConvNextV2")
224
 
225
  tag_names, rating_indexes, general_indexes, character_indexes = load_labels()
226
 
 
236
  fn=func,
237
  inputs=[
238
  gr.Image(type="pil", label="Input"),
239
+ gr.Radio(["SwinV2", "ConvNext","ConvNextV2", "ViT"], value="SwinV2", label="Model"),
240
  gr.Slider(
241
  0,
242
  1,
 
261
  gr.Label(label="Output (tags)"),
262
  gr.HTML(),
263
  ],
264
+ examples=[["power.jpg", "ConvNextV2", 0.35, 0.85]],
265
  title=TITLE,
266
  description=DESCRIPTION,
267
  allow_flagging="never",