tori29umai commited on
Commit
607a0f5
·
1 Parent(s): ab6e44b
Files changed (1) hide show
  1. app.py +67 -72
app.py CHANGED
@@ -31,76 +31,6 @@ def preprocess_image(image):
31
  image = image.astype(np.float32)
32
  return image
33
 
34
- def process_image(prob, rating_tags, character_tags, general_tags):
35
- thresh = 0.35
36
-
37
- # NSFW/SFW判定
38
- tag_confidences = {tag: prob[i] for i, tag in enumerate(rating_tags)}
39
- max_nsfw_score = max(tag_confidences.get("questionable", 0), tag_confidences.get("explicit", 0))
40
- max_sfw_score = tag_confidences.get("general", 0)
41
- NSFW_flag = None
42
-
43
- if max_nsfw_score > max_sfw_score:
44
- NSFW_flag = "NSFWの可能性が高いです"
45
- else:
46
- NSFW_flag = "SFWの可能性が高いです"
47
-
48
- # 版権キャラクターの可能性を評価
49
- character_tags_with_probs = []
50
- for i, p in enumerate(prob[4:]):
51
- if p >= thresh and i >= len(general_tags):
52
- tag_index = i - len(general_tags)
53
- if tag_index < len(character_tags):
54
- tag_name = character_tags[tag_index]
55
- prob_percent = round(p * 100, 2) # 確率をパーセンテージに変換
56
- character_tags_with_probs.append((tag_name, f"{prob_percent}%"))
57
-
58
- IP_flag = None
59
- if character_tags_with_probs:
60
- IP_flag = f"版権キャラクター: {character_tags_with_probs}の可能性があります"
61
- else:
62
- IP_flag = "版権キャラクターの可能性が低いと思われます"
63
-
64
- # タグを生成
65
- tag_freq = {}
66
- undesired_tags = []
67
- combined_tags = []
68
- general_tag_text = ""
69
- character_tag_text = ""
70
- remove_underscore = True
71
- caption_separator = ", "
72
- general_threshold = 0.35
73
- character_threshold = 0.35
74
-
75
- for i, p in enumerate(prob[4:]):
76
- if i < len(general_tags) and p >= general_threshold:
77
- tag_name = general_tags[i]
78
- if remove_underscore and len(tag_name) > 3: # ignore emoji tags like >_< and ^_^
79
- tag_name = tag_name.replace("_", " ")
80
-
81
- if tag_name not in undesired_tags:
82
- tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
83
- general_tag_text += caption_separator + tag_name
84
- combined_tags.append(tag_name)
85
- elif i >= len(general_tags) and p >= character_threshold:
86
- tag_name = character_tags[i - len(general_tags)]
87
- if remove_underscore and len(tag_name) > 3:
88
- tag_name = tag_name.replace("_", " ")
89
-
90
- if tag_name not in undesired_tags:
91
- tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
92
- character_tag_text += caption_separator + tag_name
93
- combined_tags.append(tag_name)
94
-
95
- # 先頭のカンマを取る
96
- if len(general_tag_text) > 0:
97
- general_tag_text = general_tag_text[len(caption_separator) :]
98
- if len(character_tag_text) > 0:
99
- character_tag_text = character_tag_text[len(caption_separator) :]
100
- tag_text = caption_separator.join(combined_tags)
101
-
102
- return NSFW_flag, IP_flag, tag_text
103
-
104
 
105
  class webui:
106
  def __init__(self):
@@ -132,10 +62,75 @@ class webui:
132
  img = np.array([image])
133
  prob = ort_sess.run(None, {ort_sess.get_inputs()[0].name: img})[0][0] # ONNXモデルからの出力
134
 
135
- NSFW_flag, IP_flag, tag_text = process_image(prob, rating_tags, character_tags, general_tags)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  return NSFW_flag, IP_flag, tag_text
137
 
138
-
139
  def launch(self):
140
  with self.demo:
141
  with gr.Row():
 
31
  image = image.astype(np.float32)
32
  return image
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  class webui:
36
  def __init__(self):
 
62
  img = np.array([image])
63
  prob = ort_sess.run(None, {ort_sess.get_inputs()[0].name: img})[0][0] # ONNXモデルからの出力
64
 
65
+ thresh = 0.35
66
+
67
+ # NSFW/SFW判定
68
+ tag_confidences = {tag: prob[i] for i, tag in enumerate(rating_tags)}
69
+ max_nsfw_score = max(tag_confidences.get("questionable", 0), tag_confidences.get("explicit", 0))
70
+ max_sfw_score = tag_confidences.get("general", 0)
71
+ NSFW_flag = None
72
+
73
+ if max_nsfw_score > max_sfw_score:
74
+ NSFW_flag = "NSFWの可能性が高いです"
75
+ else:
76
+ NSFW_flag = "SFWの可能性が高いです"
77
+
78
+ # 版権キャラクターの可能性を評価
79
+ character_tags_with_probs = []
80
+ for i, p in enumerate(prob[4:]):
81
+ if p >= thresh and i >= len(general_tags):
82
+ tag_index = i - len(general_tags)
83
+ if tag_index < len(character_tags):
84
+ tag_name = character_tags[tag_index]
85
+ prob_percent = round(p * 100, 2) # 確率をパーセンテージに変換
86
+ character_tags_with_probs.append((tag_name, f"{prob_percent}%"))
87
+
88
+ IP_flag = None
89
+ if character_tags_with_probs:
90
+ IP_flag = f"版権キャラクター: {character_tags_with_probs}の可能性があります"
91
+ else:
92
+ IP_flag = "版権キャラクターの可能性が低いと思われます"
93
+
94
+ # タグを生成
95
+ tag_freq = {}
96
+ undesired_tags = []
97
+ combined_tags = []
98
+ general_tag_text = ""
99
+ character_tag_text = ""
100
+ remove_underscore = True
101
+ caption_separator = ", "
102
+ general_threshold = 0.35
103
+ character_threshold = 0.35
104
+
105
+ for i, p in enumerate(prob[4:]):
106
+ if i < len(general_tags) and p >= general_threshold:
107
+ tag_name = general_tags[i]
108
+ if remove_underscore and len(tag_name) > 3: # ignore emoji tags like >_< and ^_^
109
+ tag_name = tag_name.replace("_", " ")
110
+
111
+ if tag_name not in undesired_tags:
112
+ tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
113
+ general_tag_text += caption_separator + tag_name
114
+ combined_tags.append(tag_name)
115
+ elif i >= len(general_tags) and p >= character_threshold:
116
+ tag_name = character_tags[i - len(general_tags)]
117
+ if remove_underscore and len(tag_name) > 3:
118
+ tag_name = tag_name.replace("_", " ")
119
+
120
+ if tag_name not in undesired_tags:
121
+ tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
122
+ character_tag_text += caption_separator + tag_name
123
+ combined_tags.append(tag_name)
124
+
125
+ # 先頭のカンマを取る
126
+ if len(general_tag_text) > 0:
127
+ general_tag_text = general_tag_text[len(caption_separator) :]
128
+ if len(character_tag_text) > 0:
129
+ character_tag_text = character_tag_text[len(caption_separator) :]
130
+ tag_text = caption_separator.join(combined_tags)
131
+
132
  return NSFW_flag, IP_flag, tag_text
133
 
 
134
  def launch(self):
135
  with self.demo:
136
  with gr.Row():