Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
a8a8e1b
1
Parent(s):
98923c8
Update
Browse files
app.py
CHANGED
@@ -31,7 +31,7 @@ def preprocess_image(image):
|
|
31 |
image = image.astype(np.float32)
|
32 |
return image
|
33 |
|
34 |
-
def process_image(
|
35 |
thresh = 0.35
|
36 |
try:
|
37 |
image = Image.open(image_path)
|
@@ -121,23 +121,24 @@ class webui:
|
|
121 |
print("Hugging Faceからモデルをダウンロード中")
|
122 |
onnx_path = hf_hub_download(model_id, "model.onnx")
|
123 |
csv_path = hf_hub_download(model_id, "selected_tags.csv")
|
|
|
124 |
|
125 |
print("ONNXモデルを実行中")
|
126 |
print(f"ONNXモデルのパス: {onnx_path}")
|
127 |
|
128 |
-
|
|
|
|
|
129 |
|
130 |
with open(csv_path, "r", encoding="utf-8") as f:
|
131 |
reader = csv.reader(f)
|
132 |
header = next(reader)
|
133 |
rows = list(reader)
|
134 |
-
assert header == ["tag_id", "name", "category", "count"], f"CSVフォーマットが期待と異なります: {header}"
|
135 |
-
|
136 |
rating_tags = [row[1] for row in rows if row[2] == "9"]
|
137 |
character_tags = [row[1] for row in rows if row[2] == "4"]
|
138 |
-
general_tags = [row[1] for row in rows
|
139 |
|
140 |
-
NSFW_flag, IP_flag, tag_text = process_image(
|
141 |
return NSFW_flag, IP_flag, tag_text
|
142 |
|
143 |
|
|
|
31 |
image = image.astype(np.float32)
|
32 |
return image
|
33 |
|
34 |
+
def process_image(image, ort_sess, input_name, rating_tags, character_tags, general_tags):
|
35 |
thresh = 0.35
|
36 |
try:
|
37 |
image = Image.open(image_path)
|
|
|
121 |
print("Hugging Faceからモデルをダウンロード中")
|
122 |
onnx_path = hf_hub_download(model_id, "model.onnx")
|
123 |
csv_path = hf_hub_download(model_id, "selected_tags.csv")
|
124 |
+
ort_sess = ort.InferenceSession(onnx_path)
|
125 |
|
126 |
print("ONNXモデルを実行中")
|
127 |
print(f"ONNXモデルのパス: {onnx_path}")
|
128 |
|
129 |
+
image = Image.open(image_path)
|
130 |
+
image = image.convert("RGB") if image.mode != "RGB" else image
|
131 |
+
image = preprocess_image(image)
|
132 |
|
133 |
with open(csv_path, "r", encoding="utf-8") as f:
|
134 |
reader = csv.reader(f)
|
135 |
header = next(reader)
|
136 |
rows = list(reader)
|
|
|
|
|
137 |
rating_tags = [row[1] for row in rows if row[2] == "9"]
|
138 |
character_tags = [row[1] for row in rows if row[2] == "4"]
|
139 |
+
general_tags = [row[1] for row in rows if row[2] == "0"]
|
140 |
|
141 |
+
NSFW_flag, IP_flag, tag_text = process_image(image, ort_sess, ort_sess.get_inputs()[0].name, rating_tags, character_tags, general_tags)
|
142 |
return NSFW_flag, IP_flag, tag_text
|
143 |
|
144 |
|