Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -32,16 +32,19 @@ def tag_image(pil_image: Image.Image, output_format: str) -> str:
|
|
32 |
idx_to_tag = metadata["idx_to_tag"]
|
33 |
tag_to_category = metadata.get("tag_to_category", {})
|
34 |
category_thresholds = metadata.get("category_thresholds", {})
|
35 |
-
default_threshold = 0.
|
36 |
results_by_cat = {} # to store tags per category (for verbose output)
|
37 |
artist_tags_with_probs = []
|
38 |
character_tags_with_probs = []
|
39 |
general_tags_with_probs = []
|
|
|
40 |
|
41 |
# Collect tags above thresholds
|
42 |
for idx, prob in enumerate(probs):
|
43 |
tag = idx_to_tag[str(idx)]
|
44 |
cat = tag_to_category.get(tag, "unknown")
|
|
|
|
|
45 |
thresh = category_thresholds.get(cat, default_threshold)
|
46 |
if float(prob) >= thresh:
|
47 |
# add to category dictionary
|
@@ -64,6 +67,12 @@ def tag_image(pil_image: Image.Image, output_format: str) -> str:
|
|
64 |
|
65 |
prompt_tags = artist_prompt_tags + character_prompt_tags + general_prompt_tags
|
66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
if not prompt_tags:
|
68 |
return "No tags predicted."
|
69 |
return ", ".join(prompt_tags)
|
|
|
32 |
idx_to_tag = metadata["idx_to_tag"]
|
33 |
tag_to_category = metadata.get("tag_to_category", {})
|
34 |
category_thresholds = metadata.get("category_thresholds", {})
|
35 |
+
default_threshold = 0.35
|
36 |
results_by_cat = {} # to store tags per category (for verbose output)
|
37 |
artist_tags_with_probs = []
|
38 |
character_tags_with_probs = []
|
39 |
general_tags_with_probs = []
|
40 |
+
all_artist_tags_probs = [] # Store all artist tags and their probabilities
|
41 |
|
42 |
# Collect tags above thresholds
|
43 |
for idx, prob in enumerate(probs):
|
44 |
tag = idx_to_tag[str(idx)]
|
45 |
cat = tag_to_category.get(tag, "unknown")
|
46 |
+
if cat == 'artist':
|
47 |
+
all_artist_tags_probs.append((tag, float(prob))) # Store all artist tags
|
48 |
thresh = category_thresholds.get(cat, default_threshold)
|
49 |
if float(prob) >= thresh:
|
50 |
# add to category dictionary
|
|
|
67 |
|
68 |
prompt_tags = artist_prompt_tags + character_prompt_tags + general_prompt_tags
|
69 |
|
70 |
+
# Ensure at least one artist tag if any artist tags were predicted at all, even below threshold
|
71 |
+
if not artist_prompt_tags and all_artist_tags_probs:
|
72 |
+
best_artist_tag, best_artist_prob = max(all_artist_tags_probs, key=lambda item: item[1])
|
73 |
+
prompt_tags = [best_artist_tag.replace("_", " ").replace("(", "\\(").replace(")", "\\)")] + prompt_tags
|
74 |
+
|
75 |
+
|
76 |
if not prompt_tags:
|
77 |
return "No tags predicted."
|
78 |
return ", ".join(prompt_tags)
|