tori29umai commited on
Commit
14a932e
·
1 Parent(s): c9dfb9e
Files changed (1) hide show
  1. app.py +86 -80
app.py CHANGED
@@ -31,92 +31,94 @@ def preprocess_image(image):
31
  image = image.astype(np.float32)
32
  return image
33
 
34
- @spaces.GPU
35
- def process_image(image_path, input_name, ort_sess, rating_tags, character_tags, general_tags):
36
- thresh = 0.35
37
- try:
38
- image = Image.open(image_path)
39
- image = image.convert("RGB") if image.mode != "RGB" else image
40
- image = preprocess_image(image)
41
- except Exception as e:
42
- print(f"画像を読み込めません: {image_path}, エラー: {e}")
43
- return
44
-
45
- img = np.array([image])
46
- prob = ort_sess.run(None, {input_name: img})[0][0] # ONNXモデルからの出力
47
-
48
- # NSFW/SFW判定
49
- tag_confidences = {tag: prob[i] for i, tag in enumerate(rating_tags)}
50
- max_nsfw_score = max(tag_confidences.get("questionable", 0), tag_confidences.get("explicit", 0))
51
- max_sfw_score = tag_confidences.get("general", 0)
52
- NSFW_flag = None
53
-
54
- if max_nsfw_score > max_sfw_score:
55
- NSFW_flag = "NSFWの可能性が高いです"
56
- else:
57
- NSFW_flag = "SFWの可能性が高いです"
58
-
59
- # 版権キャラクターの可能性を評価
60
- character_tags_with_probs = []
61
- for i, p in enumerate(prob[4:]):
62
- if p >= thresh and i >= len(general_tags):
63
- tag_index = i - len(general_tags)
64
- if tag_index < len(character_tags):
65
- tag_name = character_tags[tag_index]
66
- prob_percent = round(p * 100, 2) # 確率をパーセンテージに変換
67
- character_tags_with_probs.append((tag_name, f"{prob_percent}%"))
68
-
69
- IP_flag = None
70
- if character_tags_with_probs:
71
- IP_flag = f"版権キャラクター: {character_tags_with_probs}の可能性があります"
72
- else:
73
- IP_flag = "版権キャラクターの可能性が低いと思われます"
74
-
75
- # タグを生成
76
- tag_freq = {}
77
- undesired_tags = []
78
- combined_tags = []
79
- general_tag_text = ""
80
- character_tag_text = ""
81
- remove_underscore = True
82
- caption_separator = ", "
83
- general_threshold = 0.35
84
- character_threshold = 0.35
85
-
86
- for i, p in enumerate(prob[4:]):
87
- if i < len(general_tags) and p >= general_threshold:
88
- tag_name = general_tags[i]
89
- if remove_underscore and len(tag_name) > 3: # ignore emoji tags like >_< and ^_^
90
- tag_name = tag_name.replace("_", " ")
91
-
92
- if tag_name not in undesired_tags:
93
- tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
94
- general_tag_text += caption_separator + tag_name
95
- combined_tags.append(tag_name)
96
- elif i >= len(general_tags) and p >= character_threshold:
97
- tag_name = character_tags[i - len(general_tags)]
98
- if remove_underscore and len(tag_name) > 3:
99
- tag_name = tag_name.replace("_", " ")
100
-
101
- if tag_name not in undesired_tags:
102
- tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
103
- character_tag_text += caption_separator + tag_name
104
- combined_tags.append(tag_name)
105
-
106
- # 先頭のカンマを取る
107
- if len(general_tag_text) > 0:
108
- general_tag_text = general_tag_text[len(caption_separator) :]
109
- if len(character_tag_text) > 0:
110
- character_tag_text = character_tag_text[len(caption_separator) :]
111
- tag_text = caption_separator.join(combined_tags)
112
-
113
- return NSFW_flag, IP_flag, tag_text
114
 
115
 
116
  class webui:
117
  def __init__(self):
118
  self.demo = gr.Blocks()
119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  @spaces.GPU
121
  def main(self, image_path, model_id):
122
  print("Hugging Faceからモデルをダウンロード中")
@@ -142,6 +144,10 @@ class webui:
142
  return NSFW_flag, IP_flag, tag_text
143
 
144
 
 
 
 
 
145
  def launch(self):
146
  with self.demo:
147
  with gr.Row():
 
31
  image = image.astype(np.float32)
32
  return image
33
 
34
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
 
37
  class webui:
38
  def __init__(self):
39
  self.demo = gr.Blocks()
40
 
41
+ @spaces.GPU
42
+ def process_image(self, image_path, input_name, ort_sess, rating_tags, character_tags, general_tags):
43
+ thresh = 0.35
44
+ try:
45
+ image = Image.open(image_path)
46
+ image = image.convert("RGB") if image.mode != "RGB" else image
47
+ image = preprocess_image(image)
48
+ except Exception as e:
49
+ print(f"画像を読み込めません: {image_path}, エラー: {e}")
50
+ return
51
+
52
+ img = np.array([image])
53
+ prob = ort_sess.run(None, {input_name: img})[0][0] # ONNXモデルからの出力
54
+
55
+ # NSFW/SFW判定
56
+ tag_confidences = {tag: prob[i] for i, tag in enumerate(rating_tags)}
57
+ max_nsfw_score = max(tag_confidences.get("questionable", 0), tag_confidences.get("explicit", 0))
58
+ max_sfw_score = tag_confidences.get("general", 0)
59
+ NSFW_flag = None
60
+
61
+ if max_nsfw_score > max_sfw_score:
62
+ NSFW_flag = "NSFWの可能性が高いです"
63
+ else:
64
+ NSFW_flag = "SFWの可能性が高いです"
65
+
66
+ # 版権キャラクターの可能性を評価
67
+ character_tags_with_probs = []
68
+ for i, p in enumerate(prob[4:]):
69
+ if p >= thresh and i >= len(general_tags):
70
+ tag_index = i - len(general_tags)
71
+ if tag_index < len(character_tags):
72
+ tag_name = character_tags[tag_index]
73
+ prob_percent = round(p * 100, 2) # 確率をパーセンテージに変換
74
+ character_tags_with_probs.append((tag_name, f"{prob_percent}%"))
75
+
76
+ IP_flag = None
77
+ if character_tags_with_probs:
78
+ IP_flag = f"版権キャラクター: {character_tags_with_probs}の可能性があります"
79
+ else:
80
+ IP_flag = "版権キャラクターの可能性が低いと思われます"
81
+
82
+ # タグを生成
83
+ tag_freq = {}
84
+ undesired_tags = []
85
+ combined_tags = []
86
+ general_tag_text = ""
87
+ character_tag_text = ""
88
+ remove_underscore = True
89
+ caption_separator = ", "
90
+ general_threshold = 0.35
91
+ character_threshold = 0.35
92
+
93
+ for i, p in enumerate(prob[4:]):
94
+ if i < len(general_tags) and p >= general_threshold:
95
+ tag_name = general_tags[i]
96
+ if remove_underscore and len(tag_name) > 3: # ignore emoji tags like >_< and ^_^
97
+ tag_name = tag_name.replace("_", " ")
98
+
99
+ if tag_name not in undesired_tags:
100
+ tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
101
+ general_tag_text += caption_separator + tag_name
102
+ combined_tags.append(tag_name)
103
+ elif i >= len(general_tags) and p >= character_threshold:
104
+ tag_name = character_tags[i - len(general_tags)]
105
+ if remove_underscore and len(tag_name) > 3:
106
+ tag_name = tag_name.replace("_", " ")
107
+
108
+ if tag_name not in undesired_tags:
109
+ tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
110
+ character_tag_text += caption_separator + tag_name
111
+ combined_tags.append(tag_name)
112
+
113
+ # 先頭のカンマを取る
114
+ if len(general_tag_text) > 0:
115
+ general_tag_text = general_tag_text[len(caption_separator) :]
116
+ if len(character_tag_text) > 0:
117
+ character_tag_text = character_tag_text[len(caption_separator) :]
118
+ tag_text = caption_separator.join(combined_tags)
119
+
120
+ return NSFW_flag, IP_flag, tag_text
121
+
122
  @spaces.GPU
123
  def main(self, image_path, model_id):
124
  print("Hugging Faceからモデルをダウンロード中")
 
144
  return NSFW_flag, IP_flag, tag_text
145
 
146
 
147
+
148
+
149
+
150
+
151
  def launch(self):
152
  with self.demo:
153
  with gr.Row():