CyberWaifu commited on
Commit
d13085d
·
verified ·
1 Parent(s): 0f866b9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -10
app.py CHANGED
@@ -32,12 +32,13 @@ def tag_image(pil_image: Image.Image, output_format: str) -> str:
32
  idx_to_tag = metadata["idx_to_tag"]
33
  tag_to_category = metadata.get("tag_to_category", {})
34
  category_thresholds = metadata.get("category_thresholds", {})
35
- default_threshold = 0.325
36
  results_by_cat = {} # to store tags per category (for verbose output)
37
- prompt_tags = [] # to store tags for prompt-style output
38
- # Collect tags above thresholds
39
  artist_tags_with_probs = []
40
- non_artist_tags_with_probs = []
 
 
 
41
  for idx, prob in enumerate(probs):
42
  tag = idx_to_tag[str(idx)]
43
  cat = tag_to_category.get(tag, "unknown")
@@ -47,16 +48,21 @@ def tag_image(pil_image: Image.Image, output_format: str) -> str:
47
  results_by_cat.setdefault(cat, []).append((tag, float(prob)))
48
  if cat == 'artist':
49
  artist_tags_with_probs.append((tag, float(prob)))
50
- else:
51
- non_artist_tags_with_probs.append((tag, float(prob)))
 
 
52
 
53
  if output_format == "Prompt-style Tags":
54
  artist_tags_with_probs.sort(key=lambda x: x[1], reverse=True)
55
- non_artist_tags_with_probs.sort(key=lambda x: x[1], reverse=True)
 
 
 
 
 
56
 
57
- artist_prompt_tags = [tag.replace("_", " ") for tag, prob in artist_tags_with_probs]
58
- non_artist_prompt_tags = [tag.replace("_", " ") for tag, prob in non_artist_tags_with_probs]
59
- prompt_tags = artist_prompt_tags + non_artist_prompt_tags
60
 
61
  if not prompt_tags:
62
  return "No tags predicted."
 
32
  idx_to_tag = metadata["idx_to_tag"]
33
  tag_to_category = metadata.get("tag_to_category", {})
34
  category_thresholds = metadata.get("category_thresholds", {})
35
+ default_threshold = 0.35
36
  results_by_cat = {} # to store tags per category (for verbose output)
 
 
37
  artist_tags_with_probs = []
38
+ character_tags_with_probs = []
39
+ general_tags_with_probs = []
40
+
41
+ # Collect tags above thresholds
42
  for idx, prob in enumerate(probs):
43
  tag = idx_to_tag[str(idx)]
44
  cat = tag_to_category.get(tag, "unknown")
 
48
  results_by_cat.setdefault(cat, []).append((tag, float(prob)))
49
  if cat == 'artist':
50
  artist_tags_with_probs.append((tag, float(prob)))
51
+ elif cat == 'character':
52
+ character_tags_with_probs.append((tag, float(prob)))
53
+ elif cat == 'general':
54
+ general_tags_with_probs.append((tag, float(prob)))
55
 
56
  if output_format == "Prompt-style Tags":
57
  artist_tags_with_probs.sort(key=lambda x: x[1], reverse=True)
58
+ character_tags_with_probs.sort(key=lambda x: x[1], reverse=True)
59
+ general_tags_with_probs.sort(key=lambda x: x[1], reverse=True)
60
+
61
+ artist_prompt_tags = [tag.replace("_", " ").replace("(", "\\(").replace(")", "\\)") for tag, prob in artist_tags_with_probs]
62
+ character_prompt_tags = [tag.replace("_", " ").replace("(", "\\(").replace(")", "\\)") for tag, prob in character_tags_with_probs]
63
+ general_prompt_tags = [tag.replace("_", " ").replace("(", "\\(").replace(")", "\\)") for tag, prob in general_tags_with_probs]
64
 
65
+ prompt_tags = artist_prompt_tags + character_prompt_tags + general_prompt_tags
 
 
66
 
67
  if not prompt_tags:
68
  return "No tags predicted."