Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
ab6e44b
1
Parent(s):
a8a8e1b
Update
Browse files
app.py
CHANGED
@@ -31,18 +31,8 @@ def preprocess_image(image):
|
|
31 |
image = image.astype(np.float32)
|
32 |
return image
|
33 |
|
34 |
-
def process_image(
|
35 |
thresh = 0.35
|
36 |
-
try:
|
37 |
-
image = Image.open(image_path)
|
38 |
-
image = image.convert("RGB") if image.mode != "RGB" else image
|
39 |
-
image = preprocess_image(image)
|
40 |
-
except Exception as e:
|
41 |
-
print(f"画像を読み込めません: {image_path}, エラー: {e}")
|
42 |
-
return
|
43 |
-
|
44 |
-
img = np.array([image])
|
45 |
-
prob = ort_sess.run(None, {input_name: img})[0][0] # ONNXモデルからの出力
|
46 |
|
47 |
# NSFW/SFW判定
|
48 |
tag_confidences = {tag: prob[i] for i, tag in enumerate(rating_tags)}
|
@@ -138,7 +128,11 @@ class webui:
|
|
138 |
character_tags = [row[1] for row in rows if row[2] == "4"]
|
139 |
general_tags = [row[1] for row in rows if row[2] == "0"]
|
140 |
|
141 |
-
|
|
|
|
|
|
|
|
|
142 |
return NSFW_flag, IP_flag, tag_text
|
143 |
|
144 |
|
@@ -147,7 +141,7 @@ class webui:
|
|
147 |
with gr.Row():
|
148 |
with gr.Column():
|
149 |
input_image = gr.Image(type='filepath', label="Analysis Image")
|
150 |
-
model_id = gr.Textbox(label="
|
151 |
output_0 = gr.Textbox(label="NSFW Flag")
|
152 |
output_1 = gr.Textbox(label="IP Flag")
|
153 |
output_2 = gr.Textbox(label="Tags")
|
|
|
31 |
image = image.astype(np.float32)
|
32 |
return image
|
33 |
|
34 |
+
def process_image(prob, rating_tags, character_tags, general_tags):
|
35 |
thresh = 0.35
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
# NSFW/SFW判定
|
38 |
tag_confidences = {tag: prob[i] for i, tag in enumerate(rating_tags)}
|
|
|
128 |
character_tags = [row[1] for row in rows if row[2] == "4"]
|
129 |
general_tags = [row[1] for row in rows if row[2] == "0"]
|
130 |
|
131 |
+
|
132 |
+
img = np.array([image])
|
133 |
+
prob = ort_sess.run(None, {ort_sess.get_inputs()[0].name: img})[0][0] # ONNXモデルからの出力
|
134 |
+
|
135 |
+
NSFW_flag, IP_flag, tag_text = process_image(prob, rating_tags, character_tags, general_tags)
|
136 |
return NSFW_flag, IP_flag, tag_text
|
137 |
|
138 |
|
|
|
141 |
with gr.Row():
|
142 |
with gr.Column():
|
143 |
input_image = gr.Image(type='filepath', label="Analysis Image")
|
144 |
+
model_id = gr.Textbox(label="MODEL ID", value="SmilingWolf/wd-vit-tagger-v3")
|
145 |
output_0 = gr.Textbox(label="NSFW Flag")
|
146 |
output_1 = gr.Textbox(label="IP Flag")
|
147 |
output_2 = gr.Textbox(label="Tags")
|