Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -32,12 +32,13 @@ def tag_image(pil_image: Image.Image, output_format: str) -> str:
|
|
32 |
idx_to_tag = metadata["idx_to_tag"]
|
33 |
tag_to_category = metadata.get("tag_to_category", {})
|
34 |
category_thresholds = metadata.get("category_thresholds", {})
|
35 |
-
default_threshold = 0.
|
36 |
results_by_cat = {} # to store tags per category (for verbose output)
|
37 |
-
prompt_tags = [] # to store tags for prompt-style output
|
38 |
-
# Collect tags above thresholds
|
39 |
artist_tags_with_probs = []
|
40 |
-
|
|
|
|
|
|
|
41 |
for idx, prob in enumerate(probs):
|
42 |
tag = idx_to_tag[str(idx)]
|
43 |
cat = tag_to_category.get(tag, "unknown")
|
@@ -47,16 +48,21 @@ def tag_image(pil_image: Image.Image, output_format: str) -> str:
|
|
47 |
results_by_cat.setdefault(cat, []).append((tag, float(prob)))
|
48 |
if cat == 'artist':
|
49 |
artist_tags_with_probs.append((tag, float(prob)))
|
50 |
-
|
51 |
-
|
|
|
|
|
52 |
|
53 |
if output_format == "Prompt-style Tags":
|
54 |
artist_tags_with_probs.sort(key=lambda x: x[1], reverse=True)
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
-
|
58 |
-
non_artist_prompt_tags = [tag.replace("_", " ") for tag, prob in non_artist_tags_with_probs]
|
59 |
-
prompt_tags = artist_prompt_tags + non_artist_prompt_tags
|
60 |
|
61 |
if not prompt_tags:
|
62 |
return "No tags predicted."
|
|
|
32 |
idx_to_tag = metadata["idx_to_tag"]
|
33 |
tag_to_category = metadata.get("tag_to_category", {})
|
34 |
category_thresholds = metadata.get("category_thresholds", {})
|
35 |
+
default_threshold = 0.35
|
36 |
results_by_cat = {} # to store tags per category (for verbose output)
|
|
|
|
|
37 |
artist_tags_with_probs = []
|
38 |
+
character_tags_with_probs = []
|
39 |
+
general_tags_with_probs = []
|
40 |
+
|
41 |
+
# Collect tags above thresholds
|
42 |
for idx, prob in enumerate(probs):
|
43 |
tag = idx_to_tag[str(idx)]
|
44 |
cat = tag_to_category.get(tag, "unknown")
|
|
|
48 |
results_by_cat.setdefault(cat, []).append((tag, float(prob)))
|
49 |
if cat == 'artist':
|
50 |
artist_tags_with_probs.append((tag, float(prob)))
|
51 |
+
elif cat == 'character':
|
52 |
+
character_tags_with_probs.append((tag, float(prob)))
|
53 |
+
elif cat == 'general':
|
54 |
+
general_tags_with_probs.append((tag, float(prob)))
|
55 |
|
56 |
if output_format == "Prompt-style Tags":
|
57 |
artist_tags_with_probs.sort(key=lambda x: x[1], reverse=True)
|
58 |
+
character_tags_with_probs.sort(key=lambda x: x[1], reverse=True)
|
59 |
+
general_tags_with_probs.sort(key=lambda x: x[1], reverse=True)
|
60 |
+
|
61 |
+
artist_prompt_tags = [tag.replace("_", " ").replace("(", "\\(").replace(")", "\\)") for tag, prob in artist_tags_with_probs]
|
62 |
+
character_prompt_tags = [tag.replace("_", " ").replace("(", "\\(").replace(")", "\\)") for tag, prob in character_tags_with_probs]
|
63 |
+
general_prompt_tags = [tag.replace("_", " ").replace("(", "\\(").replace(")", "\\)") for tag, prob in general_tags_with_probs]
|
64 |
|
65 |
+
prompt_tags = artist_prompt_tags + character_prompt_tags + general_prompt_tags
|
|
|
|
|
66 |
|
67 |
if not prompt_tags:
|
68 |
return "No tags predicted."
|