Spaces:
Running
Running
Add MOAT
Browse files
app.py
CHANGED
@@ -23,6 +23,7 @@ This is an edited version of SmilingWolf's wd-1.4 taggs, which I have modified s
|
|
23 |
https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags
|
24 |
|
25 |
Demo for:
|
|
|
26 |
- [SmilingWolf/wd-v1-4-swinv2-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger-v2)
|
27 |
- [SmilingWolf/wd-v1-4-convnext-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger-v2)
|
28 |
- [SmilingWolf/wd-v1-4-vit-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-vit-tagger-v2)
|
@@ -38,6 +39,7 @@ Example image by [ほし☆☆☆](https://www.pixiv.net/en/users/43565085)
|
|
38 |
"""
|
39 |
|
40 |
HF_TOKEN = os.environ["HF_TOKEN"]
|
|
|
41 |
SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
|
42 |
CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
|
43 |
CONV2_MODEL_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
|
@@ -65,8 +67,9 @@ def load_model(model_repo: str, model_filename: str) -> rt.InferenceSession:
|
|
65 |
|
66 |
def change_model(model_name):
|
67 |
global loaded_models
|
68 |
-
|
69 |
-
|
|
|
70 |
model = load_model(SWIN_MODEL_REPO, MODEL_FILENAME)
|
71 |
elif model_name == "ConvNext":
|
72 |
model = load_model(CONV_MODEL_REPO, MODEL_FILENAME)
|
@@ -81,7 +84,7 @@ def change_model(model_name):
|
|
81 |
|
82 |
def load_labels() -> list[str]:
|
83 |
path = huggingface_hub.hf_hub_download(
|
84 |
-
|
85 |
)
|
86 |
df = pd.read_csv(path)
|
87 |
|
@@ -216,11 +219,17 @@ def predict(
|
|
216 |
|
217 |
def main():
|
218 |
global loaded_models
|
219 |
-
loaded_models = {
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
|
221 |
args = parse_args()
|
222 |
|
223 |
-
change_model("
|
224 |
|
225 |
tag_names, rating_indexes, general_indexes, character_indexes = load_labels()
|
226 |
|
@@ -236,7 +245,11 @@ def main():
|
|
236 |
fn=func,
|
237 |
inputs=[
|
238 |
gr.Image(type="pil", label="Input"),
|
239 |
-
gr.Radio(
|
|
|
|
|
|
|
|
|
240 |
gr.Slider(
|
241 |
0,
|
242 |
1,
|
@@ -261,7 +274,7 @@ def main():
|
|
261 |
gr.Label(label="Output (tags)"),
|
262 |
gr.HTML(),
|
263 |
],
|
264 |
-
examples=[["power.jpg", "
|
265 |
title=TITLE,
|
266 |
description=DESCRIPTION,
|
267 |
allow_flagging="never",
|
|
|
23 |
https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags
|
24 |
|
25 |
Demo for:
|
26 |
+
- [SmilingWolf/wd-v1-4-moat-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-moat-tagger-v2)
|
27 |
- [SmilingWolf/wd-v1-4-swinv2-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger-v2)
|
28 |
- [SmilingWolf/wd-v1-4-convnext-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger-v2)
|
29 |
- [SmilingWolf/wd-v1-4-vit-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-vit-tagger-v2)
|
|
|
39 |
"""
|
40 |
|
41 |
HF_TOKEN = os.environ["HF_TOKEN"]
|
42 |
+
MOAT_MODEL_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
|
43 |
SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
|
44 |
CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
|
45 |
CONV2_MODEL_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
|
|
|
67 |
|
68 |
def change_model(model_name):
|
69 |
global loaded_models
|
70 |
+
if model_name == "MOAT":
|
71 |
+
model = load_model(MOAT_MODEL_REPO, MODEL_FILENAME)
|
72 |
+
elif model_name == "SwinV2":
|
73 |
model = load_model(SWIN_MODEL_REPO, MODEL_FILENAME)
|
74 |
elif model_name == "ConvNext":
|
75 |
model = load_model(CONV_MODEL_REPO, MODEL_FILENAME)
|
|
|
84 |
|
85 |
def load_labels() -> list[str]:
|
86 |
path = huggingface_hub.hf_hub_download(
|
87 |
+
MOAT_MODEL_REPO, LABEL_FILENAME, use_auth_token=HF_TOKEN
|
88 |
)
|
89 |
df = pd.read_csv(path)
|
90 |
|
|
|
219 |
|
220 |
def main():
|
221 |
global loaded_models
|
222 |
+
loaded_models = {
|
223 |
+
"MOAT": None,
|
224 |
+
"SwinV2": None,
|
225 |
+
"ConvNext": None,
|
226 |
+
"ConvNextV2": None,
|
227 |
+
"ViT": None,
|
228 |
+
}
|
229 |
|
230 |
args = parse_args()
|
231 |
|
232 |
+
change_model("MOAT")
|
233 |
|
234 |
tag_names, rating_indexes, general_indexes, character_indexes = load_labels()
|
235 |
|
|
|
245 |
fn=func,
|
246 |
inputs=[
|
247 |
gr.Image(type="pil", label="Input"),
|
248 |
+
gr.Radio(
|
249 |
+
["MOAT", "SwinV2", "ConvNext", "ConvNextV2", "ViT"],
|
250 |
+
value="MOAT",
|
251 |
+
label="Model",
|
252 |
+
),
|
253 |
gr.Slider(
|
254 |
0,
|
255 |
1,
|
|
|
274 |
gr.Label(label="Output (tags)"),
|
275 |
gr.HTML(),
|
276 |
],
|
277 |
+
examples=[["power.jpg", "MOAT", 0.1, 0.85]],
|
278 |
title=TITLE,
|
279 |
description=DESCRIPTION,
|
280 |
allow_flagging="never",
|