Denliner commited on
Commit
cc00f5e
·
1 Parent(s): b7ecbff
Files changed (1) hide show
  1. app.py +20 -7
app.py CHANGED
@@ -23,6 +23,7 @@ This is an edited version of SmilingWolf's wd-1.4 taggs, which I have modified s
23
  https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags
24
 
25
  Demo for:
 
26
  - [SmilingWolf/wd-v1-4-swinv2-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger-v2)
27
  - [SmilingWolf/wd-v1-4-convnext-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger-v2)
28
  - [SmilingWolf/wd-v1-4-vit-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-vit-tagger-v2)
@@ -38,6 +39,7 @@ Example image by [ほし☆☆☆](https://www.pixiv.net/en/users/43565085)
38
  """
39
 
40
  HF_TOKEN = os.environ["HF_TOKEN"]
 
41
  SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
42
  CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
43
  CONV2_MODEL_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
@@ -65,8 +67,9 @@ def load_model(model_repo: str, model_filename: str) -> rt.InferenceSession:
65
 
66
  def change_model(model_name):
67
  global loaded_models
68
-
69
- if model_name == "SwinV2":
 
70
  model = load_model(SWIN_MODEL_REPO, MODEL_FILENAME)
71
  elif model_name == "ConvNext":
72
  model = load_model(CONV_MODEL_REPO, MODEL_FILENAME)
@@ -81,7 +84,7 @@ def change_model(model_name):
81
 
82
  def load_labels() -> list[str]:
83
  path = huggingface_hub.hf_hub_download(
84
- CONV2_MODEL_REPO, LABEL_FILENAME, use_auth_token=HF_TOKEN
85
  )
86
  df = pd.read_csv(path)
87
 
@@ -216,11 +219,17 @@ def predict(
216
 
217
  def main():
218
  global loaded_models
219
- loaded_models = {"SwinV2": None, "ConvNext": None,"ConvNextV2": None, "ViT": None}
 
 
 
 
 
 
220
 
221
  args = parse_args()
222
 
223
- change_model("ConvNextV2")
224
 
225
  tag_names, rating_indexes, general_indexes, character_indexes = load_labels()
226
 
@@ -236,7 +245,11 @@ def main():
236
  fn=func,
237
  inputs=[
238
  gr.Image(type="pil", label="Input"),
239
- gr.Radio(["SwinV2", "ConvNext","ConvNextV2", "ViT"], value="SwinV2", label="Model"),
 
 
 
 
240
  gr.Slider(
241
  0,
242
  1,
@@ -261,7 +274,7 @@ def main():
261
  gr.Label(label="Output (tags)"),
262
  gr.HTML(),
263
  ],
264
- examples=[["power.jpg", "ConvNextV2", 0.35, 0.85]],
265
  title=TITLE,
266
  description=DESCRIPTION,
267
  allow_flagging="never",
 
23
  https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags
24
 
25
  Demo for:
26
+ - [SmilingWolf/wd-v1-4-moat-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-moat-tagger-v2)
27
  - [SmilingWolf/wd-v1-4-swinv2-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger-v2)
28
  - [SmilingWolf/wd-v1-4-convnext-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger-v2)
29
  - [SmilingWolf/wd-v1-4-vit-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-vit-tagger-v2)
 
39
  """
40
 
41
  HF_TOKEN = os.environ["HF_TOKEN"]
42
+ MOAT_MODEL_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
43
  SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
44
  CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
45
  CONV2_MODEL_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
 
67
 
68
  def change_model(model_name):
69
  global loaded_models
70
+ if model_name == "MOAT":
71
+ model = load_model(MOAT_MODEL_REPO, MODEL_FILENAME)
72
+ elif model_name == "SwinV2":
73
  model = load_model(SWIN_MODEL_REPO, MODEL_FILENAME)
74
  elif model_name == "ConvNext":
75
  model = load_model(CONV_MODEL_REPO, MODEL_FILENAME)
 
84
 
85
  def load_labels() -> list[str]:
86
  path = huggingface_hub.hf_hub_download(
87
+ MOAT_MODEL_REPO, LABEL_FILENAME, use_auth_token=HF_TOKEN
88
  )
89
  df = pd.read_csv(path)
90
 
 
219
 
220
  def main():
221
  global loaded_models
222
+ loaded_models = {
223
+ "MOAT": None,
224
+ "SwinV2": None,
225
+ "ConvNext": None,
226
+ "ConvNextV2": None,
227
+ "ViT": None,
228
+ }
229
 
230
  args = parse_args()
231
 
232
+ change_model("MOAT")
233
 
234
  tag_names, rating_indexes, general_indexes, character_indexes = load_labels()
235
 
 
245
  fn=func,
246
  inputs=[
247
  gr.Image(type="pil", label="Input"),
248
+ gr.Radio(
249
+ ["MOAT", "SwinV2", "ConvNext", "ConvNextV2", "ViT"],
250
+ value="MOAT",
251
+ label="Model",
252
+ ),
253
  gr.Slider(
254
  0,
255
  1,
 
274
  gr.Label(label="Output (tags)"),
275
  gr.HTML(),
276
  ],
277
+ examples=[["power.jpg", "MOAT", 0.1, 0.85]],
278
  title=TITLE,
279
  description=DESCRIPTION,
280
  allow_flagging="never",