CyberWaifu commited on
Commit
c24087d
·
verified ·
1 Parent(s): 6655490

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -1
app.py CHANGED
@@ -32,16 +32,19 @@ def tag_image(pil_image: Image.Image, output_format: str) -> str:
32
  idx_to_tag = metadata["idx_to_tag"]
33
  tag_to_category = metadata.get("tag_to_category", {})
34
  category_thresholds = metadata.get("category_thresholds", {})
35
- default_threshold = 0.325
36
  results_by_cat = {} # to store tags per category (for verbose output)
37
  artist_tags_with_probs = []
38
  character_tags_with_probs = []
39
  general_tags_with_probs = []
 
40
 
41
  # Collect tags above thresholds
42
  for idx, prob in enumerate(probs):
43
  tag = idx_to_tag[str(idx)]
44
  cat = tag_to_category.get(tag, "unknown")
 
 
45
  thresh = category_thresholds.get(cat, default_threshold)
46
  if float(prob) >= thresh:
47
  # add to category dictionary
@@ -64,6 +67,12 @@ def tag_image(pil_image: Image.Image, output_format: str) -> str:
64
 
65
  prompt_tags = artist_prompt_tags + character_prompt_tags + general_prompt_tags
66
 
 
 
 
 
 
 
67
  if not prompt_tags:
68
  return "No tags predicted."
69
  return ", ".join(prompt_tags)
 
32
  idx_to_tag = metadata["idx_to_tag"]
33
  tag_to_category = metadata.get("tag_to_category", {})
34
  category_thresholds = metadata.get("category_thresholds", {})
35
+ default_threshold = 0.35
36
  results_by_cat = {} # to store tags per category (for verbose output)
37
  artist_tags_with_probs = []
38
  character_tags_with_probs = []
39
  general_tags_with_probs = []
40
+ all_artist_tags_probs = [] # Store all artist tags and their probabilities
41
 
42
  # Collect tags above thresholds
43
  for idx, prob in enumerate(probs):
44
  tag = idx_to_tag[str(idx)]
45
  cat = tag_to_category.get(tag, "unknown")
46
+ if cat == 'artist':
47
+ all_artist_tags_probs.append((tag, float(prob))) # Store all artist tags
48
  thresh = category_thresholds.get(cat, default_threshold)
49
  if float(prob) >= thresh:
50
  # add to category dictionary
 
67
 
68
  prompt_tags = artist_prompt_tags + character_prompt_tags + general_prompt_tags
69
 
70
+ # Ensure at least one artist tag if any artist tags were predicted at all, even below threshold
71
+ if not artist_prompt_tags and all_artist_tags_probs:
72
+ best_artist_tag, best_artist_prob = max(all_artist_tags_probs, key=lambda item: item[1])
73
+ prompt_tags = [best_artist_tag.replace("_", " ").replace("(", "\\(").replace(")", "\\)")] + prompt_tags
74
+
75
+
76
  if not prompt_tags:
77
  return "No tags predicted."
78
  return ", ".join(prompt_tags)