tori29umai commited on
Commit
98923c8
·
1 Parent(s): 14a932e
Files changed (1) hide show
  1. app.py +79 -86
app.py CHANGED
@@ -31,94 +31,91 @@ def preprocess_image(image):
31
  image = image.astype(np.float32)
32
  return image
33
 
34
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
 
37
  class webui:
38
  def __init__(self):
39
  self.demo = gr.Blocks()
40
 
41
- @spaces.GPU
42
- def process_image(self, image_path, input_name, ort_sess, rating_tags, character_tags, general_tags):
43
- thresh = 0.35
44
- try:
45
- image = Image.open(image_path)
46
- image = image.convert("RGB") if image.mode != "RGB" else image
47
- image = preprocess_image(image)
48
- except Exception as e:
49
- print(f"画像を読み込めません: {image_path}, エラー: {e}")
50
- return
51
-
52
- img = np.array([image])
53
- prob = ort_sess.run(None, {input_name: img})[0][0] # ONNXモデルからの出力
54
-
55
- # NSFW/SFW判定
56
- tag_confidences = {tag: prob[i] for i, tag in enumerate(rating_tags)}
57
- max_nsfw_score = max(tag_confidences.get("questionable", 0), tag_confidences.get("explicit", 0))
58
- max_sfw_score = tag_confidences.get("general", 0)
59
- NSFW_flag = None
60
-
61
- if max_nsfw_score > max_sfw_score:
62
- NSFW_flag = "NSFWの可能性が高いです"
63
- else:
64
- NSFW_flag = "SFWの可能性が高いです"
65
-
66
- # 版権キャラクターの可能性を評価
67
- character_tags_with_probs = []
68
- for i, p in enumerate(prob[4:]):
69
- if p >= thresh and i >= len(general_tags):
70
- tag_index = i - len(general_tags)
71
- if tag_index < len(character_tags):
72
- tag_name = character_tags[tag_index]
73
- prob_percent = round(p * 100, 2) # 確率をパーセンテージに変換
74
- character_tags_with_probs.append((tag_name, f"{prob_percent}%"))
75
-
76
- IP_flag = None
77
- if character_tags_with_probs:
78
- IP_flag = f"版権キャラクター: {character_tags_with_probs}の可能性があります"
79
- else:
80
- IP_flag = "版権キャラクターの可能性が低いと思われます"
81
-
82
- # タグを生成
83
- tag_freq = {}
84
- undesired_tags = []
85
- combined_tags = []
86
- general_tag_text = ""
87
- character_tag_text = ""
88
- remove_underscore = True
89
- caption_separator = ", "
90
- general_threshold = 0.35
91
- character_threshold = 0.35
92
-
93
- for i, p in enumerate(prob[4:]):
94
- if i < len(general_tags) and p >= general_threshold:
95
- tag_name = general_tags[i]
96
- if remove_underscore and len(tag_name) > 3: # ignore emoji tags like >_< and ^_^
97
- tag_name = tag_name.replace("_", " ")
98
-
99
- if tag_name not in undesired_tags:
100
- tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
101
- general_tag_text += caption_separator + tag_name
102
- combined_tags.append(tag_name)
103
- elif i >= len(general_tags) and p >= character_threshold:
104
- tag_name = character_tags[i - len(general_tags)]
105
- if remove_underscore and len(tag_name) > 3:
106
- tag_name = tag_name.replace("_", " ")
107
-
108
- if tag_name not in undesired_tags:
109
- tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
110
- character_tag_text += caption_separator + tag_name
111
- combined_tags.append(tag_name)
112
-
113
- # 先頭のカンマを取る
114
- if len(general_tag_text) > 0:
115
- general_tag_text = general_tag_text[len(caption_separator) :]
116
- if len(character_tag_text) > 0:
117
- character_tag_text = character_tag_text[len(caption_separator) :]
118
- tag_text = caption_separator.join(combined_tags)
119
-
120
- return NSFW_flag, IP_flag, tag_text
121
-
122
  @spaces.GPU
123
  def main(self, image_path, model_id):
124
  print("Hugging Faceからモデルをダウンロード中")
@@ -144,10 +141,6 @@ class webui:
144
  return NSFW_flag, IP_flag, tag_text
145
 
146
 
147
-
148
-
149
-
150
-
151
  def launch(self):
152
  with self.demo:
153
  with gr.Row():
 
31
  image = image.astype(np.float32)
32
  return image
33
 
34
+ def process_image(image_path, input_name, ort_sess, rating_tags, character_tags, general_tags):
35
+ thresh = 0.35
36
+ try:
37
+ image = Image.open(image_path)
38
+ image = image.convert("RGB") if image.mode != "RGB" else image
39
+ image = preprocess_image(image)
40
+ except Exception as e:
41
+ print(f"画像を読み込めません: {image_path}, エラー: {e}")
42
+ return
43
+
44
+ img = np.array([image])
45
+ prob = ort_sess.run(None, {input_name: img})[0][0] # ONNXモデルからの出力
46
+
47
+ # NSFW/SFW判定
48
+ tag_confidences = {tag: prob[i] for i, tag in enumerate(rating_tags)}
49
+ max_nsfw_score = max(tag_confidences.get("questionable", 0), tag_confidences.get("explicit", 0))
50
+ max_sfw_score = tag_confidences.get("general", 0)
51
+ NSFW_flag = None
52
+
53
+ if max_nsfw_score > max_sfw_score:
54
+ NSFW_flag = "NSFWの可能性が高いです"
55
+ else:
56
+ NSFW_flag = "SFWの可能性が高いです"
57
+
58
+ # 版権キャラクターの可能性を評価
59
+ character_tags_with_probs = []
60
+ for i, p in enumerate(prob[4:]):
61
+ if p >= thresh and i >= len(general_tags):
62
+ tag_index = i - len(general_tags)
63
+ if tag_index < len(character_tags):
64
+ tag_name = character_tags[tag_index]
65
+ prob_percent = round(p * 100, 2) # 確率をパーセンテージに変換
66
+ character_tags_with_probs.append((tag_name, f"{prob_percent}%"))
67
+
68
+ IP_flag = None
69
+ if character_tags_with_probs:
70
+ IP_flag = f"版権キャラクター: {character_tags_with_probs}の可能性があります"
71
+ else:
72
+ IP_flag = "版権キャラクターの可能性が低いと思われます"
73
+
74
+ # タグを生成
75
+ tag_freq = {}
76
+ undesired_tags = []
77
+ combined_tags = []
78
+ general_tag_text = ""
79
+ character_tag_text = ""
80
+ remove_underscore = True
81
+ caption_separator = ", "
82
+ general_threshold = 0.35
83
+ character_threshold = 0.35
84
+
85
+ for i, p in enumerate(prob[4:]):
86
+ if i < len(general_tags) and p >= general_threshold:
87
+ tag_name = general_tags[i]
88
+ if remove_underscore and len(tag_name) > 3: # ignore emoji tags like >_< and ^_^
89
+ tag_name = tag_name.replace("_", " ")
90
+
91
+ if tag_name not in undesired_tags:
92
+ tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
93
+ general_tag_text += caption_separator + tag_name
94
+ combined_tags.append(tag_name)
95
+ elif i >= len(general_tags) and p >= character_threshold:
96
+ tag_name = character_tags[i - len(general_tags)]
97
+ if remove_underscore and len(tag_name) > 3:
98
+ tag_name = tag_name.replace("_", " ")
99
+
100
+ if tag_name not in undesired_tags:
101
+ tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
102
+ character_tag_text += caption_separator + tag_name
103
+ combined_tags.append(tag_name)
104
+
105
+ # 先頭のカンマを取る
106
+ if len(general_tag_text) > 0:
107
+ general_tag_text = general_tag_text[len(caption_separator) :]
108
+ if len(character_tag_text) > 0:
109
+ character_tag_text = character_tag_text[len(caption_separator) :]
110
+ tag_text = caption_separator.join(combined_tags)
111
+
112
+ return NSFW_flag, IP_flag, tag_text
113
 
114
 
115
  class webui:
116
  def __init__(self):
117
  self.demo = gr.Blocks()
118
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  @spaces.GPU
120
  def main(self, image_path, model_id):
121
  print("Hugging Faceからモデルをダウンロード中")
 
141
  return NSFW_flag, IP_flag, tag_text
142
 
143
 
 
 
 
 
144
  def launch(self):
145
  with self.demo:
146
  with gr.Row():