Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
76ac8f3
1
Parent(s):
aa2fbfe
Update
Browse files
app.py
CHANGED
@@ -31,64 +31,69 @@ def preprocess_image(image):
|
|
31 |
image = image.astype(np.float32)
|
32 |
return image
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
class webui:
|
35 |
def __init__(self):
|
36 |
self.demo = gr.Blocks()
|
37 |
|
38 |
-
@spaces.GPU
|
39 |
-
def main(self, image_path, model_id):
|
40 |
-
print("Hugging Faceからモデルをダウンロード中")
|
41 |
-
onnx_path = hf_hub_download(model_id, "model.onnx")
|
42 |
-
csv_path = hf_hub_download(model_id, "selected_tags.csv")
|
43 |
-
|
44 |
-
# ONNXモデルとCSVファイルの読み込み
|
45 |
-
image = Image.open(image_path)
|
46 |
-
image = image.convert("RGB") if image.mode != "RGB" else image
|
47 |
-
image = preprocess_image(image)
|
48 |
-
img = np.array([image])
|
49 |
-
|
50 |
-
ort_sess = ort.InferenceSession(onnx_path) # セッションの生成をここで行う
|
51 |
-
prob = ort_sess.run(None, {ort_sess.get_inputs()[0].name: img})[0][0]
|
52 |
-
|
53 |
-
with open(csv_path, "r", encoding="utf-8") as f:
|
54 |
-
reader = csv.reader(f)
|
55 |
-
next(reader) # ヘッダーをスキップ
|
56 |
-
rows = list(reader)
|
57 |
-
|
58 |
-
rating_tags = [row[1] for row in rows if row[2] == "9"]
|
59 |
-
character_tags = [row[1] for row in rows if row[2] == "4"]
|
60 |
-
general_tags = [row[1] for row in rows if row[2] == "0"]
|
61 |
-
|
62 |
-
# タグと評価
|
63 |
-
NSFW_flag, IP_flag, tag_text = self.evaluate_tags(prob, rating_tags, character_tags, general_tags)
|
64 |
-
return NSFW_flag, IP_flag, tag_text
|
65 |
-
|
66 |
-
def evaluate_tags(self, prob, rating_tags, character_tags, general_tags):
|
67 |
-
thresh = 0.35
|
68 |
-
# NSFW/SFW判定
|
69 |
-
tag_confidences = {tag: prob[i] for i, tag in enumerate(rating_tags)}
|
70 |
-
max_nsfw_score = max(tag_confidences.get("questionable", 0), tag_confidences.get("explicit", 0))
|
71 |
-
max_sfw_score = tag_confidences.get("general", 0)
|
72 |
-
NSFW_flag = "NSFWの可能性が高いです" if max_nsfw_score > max_sfw_score else "SFWの可能性が高いです"
|
73 |
-
|
74 |
-
# 版権キャラクターの可能性を評価
|
75 |
-
character_tags_with_probs = []
|
76 |
-
for i, p in enumerate(prob[4:]):
|
77 |
-
if p >= thresh and i >= len(general_tags):
|
78 |
-
tag_index = i - len(general_tags)
|
79 |
-
if tag_index < len(character_tags):
|
80 |
-
tag_name = character_tags[tag_index]
|
81 |
-
prob_percent = round(p * 100, 2) # 確率をパーセンテージに変換
|
82 |
-
character_tags_with_probs.append((tag_name, f"{prob_percent}%"))
|
83 |
-
|
84 |
-
IP_flag = f"版権キャラクター: {character_tags_with_probs}の可能性があります" if character_tags_with_probs else "版権キャラクターの可能性が低いと思われます"
|
85 |
-
|
86 |
-
# タグを生成
|
87 |
-
general_tag_text = ", ".join([general_tags[i] for i in range(len(general_tags)) if prob[i] >= thresh])
|
88 |
-
character_tag_text = ", ".join([character_tags[i - len(general_tags)] for i in range(len(general_tags), len(prob)) if prob[i] >= thresh])
|
89 |
-
tag_text = f"{general_tag_text}, {character_tag_text}" if character_tag_text else general_tag_text
|
90 |
-
return NSFW_flag, IP_flag, tag_text
|
91 |
-
|
92 |
def launch(self):
|
93 |
with self.demo:
|
94 |
with gr.Row():
|
@@ -101,7 +106,7 @@ class webui:
|
|
101 |
submit = gr.Button(value="Start Analysis")
|
102 |
|
103 |
submit.click(
|
104 |
-
|
105 |
inputs=[input_image, model_id],
|
106 |
outputs=[output_0, output_1, output_2]
|
107 |
)
|
|
|
31 |
image = image.astype(np.float32)
|
32 |
return image
|
33 |
|
34 |
+
@spaces.GPU
|
35 |
+
def main(image_path, model_id):
|
36 |
+
print("Hugging Faceからモデルをダウンロード中")
|
37 |
+
onnx_path = hf_hub_download(model_id, "model.onnx")
|
38 |
+
csv_path = hf_hub_download(model_id, "selected_tags.csv")
|
39 |
+
|
40 |
+
# ONNXモデルとCSVファイルの読み込み
|
41 |
+
image = Image.open(image_path)
|
42 |
+
image = image.convert("RGB") if image.mode != "RGB" else image
|
43 |
+
image = preprocess_image(image)
|
44 |
+
img = np.array([image])
|
45 |
+
|
46 |
+
ort_sess = ort.InferenceSession(onnx_path) # セッションの生成をここで行う
|
47 |
+
prob = ort_sess.run(None, {ort_sess.get_inputs()[0].name: img})[0][0]
|
48 |
+
|
49 |
+
with open(csv_path, "r", encoding="utf-8") as f:
|
50 |
+
reader = csv.reader(f)
|
51 |
+
next(reader) # ヘッダーをスキップ
|
52 |
+
rows = list(reader)
|
53 |
+
|
54 |
+
rating_tags = [row[1] for row in rows if row[2] == "9"]
|
55 |
+
character_tags = [row[1] for row in rows if row[2] == "4"]
|
56 |
+
general_tags = [row[1] for row in rows if row[2] == "0"]
|
57 |
+
|
58 |
+
# タグと評価
|
59 |
+
NSFW_flag, IP_flag, tag_text = self.evaluate_tags(prob, rating_tags, character_tags, general_tags)
|
60 |
+
return NSFW_flag, IP_flag, tag_text
|
61 |
+
|
62 |
+
def evaluate_tags(self, prob, rating_tags, character_tags, general_tags):
|
63 |
+
thresh = 0.35
|
64 |
+
# NSFW/SFW判定
|
65 |
+
tag_confidences = {tag: prob[i] for i, tag in enumerate(rating_tags)}
|
66 |
+
max_nsfw_score = max(tag_confidences.get("questionable", 0), tag_confidences.get("explicit", 0))
|
67 |
+
max_sfw_score = tag_confidences.get("general", 0)
|
68 |
+
NSFW_flag = "NSFWの可能性が高いです" if max_nsfw_score > max_sfw_score else "SFWの可能性が高いです"
|
69 |
+
|
70 |
+
# 版権キャラクターの可能性を評価
|
71 |
+
character_tags_with_probs = []
|
72 |
+
for i, p in enumerate(prob[4:]):
|
73 |
+
if p >= thresh and i >= len(general_tags):
|
74 |
+
tag_index = i - len(general_tags)
|
75 |
+
if tag_index < len(character_tags):
|
76 |
+
tag_name = character_tags[tag_index]
|
77 |
+
prob_percent = round(p * 100, 2) # 確率をパーセンテージに変換
|
78 |
+
character_tags_with_probs.append((tag_name, f"{prob_percent}%"))
|
79 |
+
|
80 |
+
IP_flag = f"版権キャラクター: {character_tags_with_probs}の可能性があります" if character_tags_with_probs else "版権キャラクターの可能性が低いと思われます"
|
81 |
+
|
82 |
+
# タグを生成
|
83 |
+
general_tag_text = ", ".join([general_tags[i] for i in range(len(general_tags)) if prob[i] >= thresh])
|
84 |
+
character_tag_text = ", ".join([character_tags[i - len(general_tags)] for i in range(len(general_tags), len(prob)) if prob[i] >= thresh])
|
85 |
+
tag_text = f"{general_tag_text}, {character_tag_text}" if character_tag_text else general_tag_text
|
86 |
+
return NSFW_flag, IP_flag, tag_text
|
87 |
+
|
88 |
+
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
class webui:
|
94 |
def __init__(self):
|
95 |
self.demo = gr.Blocks()
|
96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
def launch(self):
|
98 |
with self.demo:
|
99 |
with gr.Row():
|
|
|
106 |
submit = gr.Button(value="Start Analysis")
|
107 |
|
108 |
submit.click(
|
109 |
+
main,
|
110 |
inputs=[input_image, model_id],
|
111 |
outputs=[output_0, output_1, output_2]
|
112 |
)
|