tori29umai commited on
Commit
76ac8f3
·
1 Parent(s): aa2fbfe
Files changed (1) hide show
  1. app.py +60 -55
app.py CHANGED
@@ -31,64 +31,69 @@ def preprocess_image(image):
31
  image = image.astype(np.float32)
32
  return image
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  class webui:
35
  def __init__(self):
36
  self.demo = gr.Blocks()
37
 
38
- @spaces.GPU
39
- def main(self, image_path, model_id):
40
- print("Hugging Faceからモデルをダウンロード中")
41
- onnx_path = hf_hub_download(model_id, "model.onnx")
42
- csv_path = hf_hub_download(model_id, "selected_tags.csv")
43
-
44
- # ONNXモデルとCSVファイルの読み込み
45
- image = Image.open(image_path)
46
- image = image.convert("RGB") if image.mode != "RGB" else image
47
- image = preprocess_image(image)
48
- img = np.array([image])
49
-
50
- ort_sess = ort.InferenceSession(onnx_path) # セッションの生成をここで行う
51
- prob = ort_sess.run(None, {ort_sess.get_inputs()[0].name: img})[0][0]
52
-
53
- with open(csv_path, "r", encoding="utf-8") as f:
54
- reader = csv.reader(f)
55
- next(reader) # ヘッダーをスキップ
56
- rows = list(reader)
57
-
58
- rating_tags = [row[1] for row in rows if row[2] == "9"]
59
- character_tags = [row[1] for row in rows if row[2] == "4"]
60
- general_tags = [row[1] for row in rows if row[2] == "0"]
61
-
62
- # タグと評価
63
- NSFW_flag, IP_flag, tag_text = self.evaluate_tags(prob, rating_tags, character_tags, general_tags)
64
- return NSFW_flag, IP_flag, tag_text
65
-
66
- def evaluate_tags(self, prob, rating_tags, character_tags, general_tags):
67
- thresh = 0.35
68
- # NSFW/SFW判定
69
- tag_confidences = {tag: prob[i] for i, tag in enumerate(rating_tags)}
70
- max_nsfw_score = max(tag_confidences.get("questionable", 0), tag_confidences.get("explicit", 0))
71
- max_sfw_score = tag_confidences.get("general", 0)
72
- NSFW_flag = "NSFWの可能性が高いです" if max_nsfw_score > max_sfw_score else "SFWの可能性が高いです"
73
-
74
- # 版権キャラクターの可能性を評価
75
- character_tags_with_probs = []
76
- for i, p in enumerate(prob[4:]):
77
- if p >= thresh and i >= len(general_tags):
78
- tag_index = i - len(general_tags)
79
- if tag_index < len(character_tags):
80
- tag_name = character_tags[tag_index]
81
- prob_percent = round(p * 100, 2) # 確率をパーセンテージに変換
82
- character_tags_with_probs.append((tag_name, f"{prob_percent}%"))
83
-
84
- IP_flag = f"版権キャラクター: {character_tags_with_probs}の可能性があります" if character_tags_with_probs else "版権キャラクターの可能性が低いと思われます"
85
-
86
- # タグを生成
87
- general_tag_text = ", ".join([general_tags[i] for i in range(len(general_tags)) if prob[i] >= thresh])
88
- character_tag_text = ", ".join([character_tags[i - len(general_tags)] for i in range(len(general_tags), len(prob)) if prob[i] >= thresh])
89
- tag_text = f"{general_tag_text}, {character_tag_text}" if character_tag_text else general_tag_text
90
- return NSFW_flag, IP_flag, tag_text
91
-
92
  def launch(self):
93
  with self.demo:
94
  with gr.Row():
@@ -101,7 +106,7 @@ class webui:
101
  submit = gr.Button(value="Start Analysis")
102
 
103
  submit.click(
104
- self.main,
105
  inputs=[input_image, model_id],
106
  outputs=[output_0, output_1, output_2]
107
  )
 
31
  image = image.astype(np.float32)
32
  return image
33
 
34
+ @spaces.GPU
35
+ def main(image_path, model_id):
36
+ print("Hugging Faceからモデルをダウンロード中")
37
+ onnx_path = hf_hub_download(model_id, "model.onnx")
38
+ csv_path = hf_hub_download(model_id, "selected_tags.csv")
39
+
40
+ # ONNXモデルとCSVファイルの読み込み
41
+ image = Image.open(image_path)
42
+ image = image.convert("RGB") if image.mode != "RGB" else image
43
+ image = preprocess_image(image)
44
+ img = np.array([image])
45
+
46
+ ort_sess = ort.InferenceSession(onnx_path) # セッションの生成をここで行う
47
+ prob = ort_sess.run(None, {ort_sess.get_inputs()[0].name: img})[0][0]
48
+
49
+ with open(csv_path, "r", encoding="utf-8") as f:
50
+ reader = csv.reader(f)
51
+ next(reader) # ヘッダーをスキップ
52
+ rows = list(reader)
53
+
54
+ rating_tags = [row[1] for row in rows if row[2] == "9"]
55
+ character_tags = [row[1] for row in rows if row[2] == "4"]
56
+ general_tags = [row[1] for row in rows if row[2] == "0"]
57
+
58
+ # タグと評価
59
+ NSFW_flag, IP_flag, tag_text = self.evaluate_tags(prob, rating_tags, character_tags, general_tags)
60
+ return NSFW_flag, IP_flag, tag_text
61
+
62
+ def evaluate_tags(self, prob, rating_tags, character_tags, general_tags):
63
+ thresh = 0.35
64
+ # NSFW/SFW判定
65
+ tag_confidences = {tag: prob[i] for i, tag in enumerate(rating_tags)}
66
+ max_nsfw_score = max(tag_confidences.get("questionable", 0), tag_confidences.get("explicit", 0))
67
+ max_sfw_score = tag_confidences.get("general", 0)
68
+ NSFW_flag = "NSFWの可能性が高いです" if max_nsfw_score > max_sfw_score else "SFWの可能性が高いです"
69
+
70
+ # 版権キャラクターの可能性を評価
71
+ character_tags_with_probs = []
72
+ for i, p in enumerate(prob[4:]):
73
+ if p >= thresh and i >= len(general_tags):
74
+ tag_index = i - len(general_tags)
75
+ if tag_index < len(character_tags):
76
+ tag_name = character_tags[tag_index]
77
+ prob_percent = round(p * 100, 2) # 確率をパーセンテージに変換
78
+ character_tags_with_probs.append((tag_name, f"{prob_percent}%"))
79
+
80
+ IP_flag = f"版権キャラクター: {character_tags_with_probs}の可能性があります" if character_tags_with_probs else "版権キャラクターの可能性が低いと思われます"
81
+
82
+ # タグを生成
83
+ general_tag_text = ", ".join([general_tags[i] for i in range(len(general_tags)) if prob[i] >= thresh])
84
+ character_tag_text = ", ".join([character_tags[i - len(general_tags)] for i in range(len(general_tags), len(prob)) if prob[i] >= thresh])
85
+ tag_text = f"{general_tag_text}, {character_tag_text}" if character_tag_text else general_tag_text
86
+ return NSFW_flag, IP_flag, tag_text
87
+
88
+
89
+
90
+
91
+
92
+
93
  class webui:
94
  def __init__(self):
95
  self.demo = gr.Blocks()
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  def launch(self):
98
  with self.demo:
99
  with gr.Row():
 
106
  submit = gr.Button(value="Start Analysis")
107
 
108
  submit.click(
109
+ main,
110
  inputs=[input_image, model_id],
111
  outputs=[output_0, output_1, output_2]
112
  )