tori29umai commited on
Commit
a8a8e1b
·
1 Parent(s): 98923c8
Files changed (1) hide show
  1. app.py +7 -6
app.py CHANGED
@@ -31,7 +31,7 @@ def preprocess_image(image):
31
  image = image.astype(np.float32)
32
  return image
33
 
34
- def process_image(image_path, input_name, ort_sess, rating_tags, character_tags, general_tags):
35
  thresh = 0.35
36
  try:
37
  image = Image.open(image_path)
@@ -121,23 +121,24 @@ class webui:
121
  print("Hugging Faceからモデルをダウンロード中")
122
  onnx_path = hf_hub_download(model_id, "model.onnx")
123
  csv_path = hf_hub_download(model_id, "selected_tags.csv")
 
124
 
125
  print("ONNXモデルを実行中")
126
  print(f"ONNXモデルのパス: {onnx_path}")
127
 
128
- ort_sess = ort.InferenceSession(onnx_path)
 
 
129
 
130
  with open(csv_path, "r", encoding="utf-8") as f:
131
  reader = csv.reader(f)
132
  header = next(reader)
133
  rows = list(reader)
134
- assert header == ["tag_id", "name", "category", "count"], f"CSVフォーマットが期待と異なります: {header}"
135
-
136
  rating_tags = [row[1] for row in rows if row[2] == "9"]
137
  character_tags = [row[1] for row in rows if row[2] == "4"]
138
- general_tags = [row[1] for row in rows[1:] if row[2] == "0"]
139
 
140
- NSFW_flag, IP_flag, tag_text = process_image(image_path, ort_sess.get_inputs()[0].name, ort_sess, rating_tags, character_tags, general_tags)
141
  return NSFW_flag, IP_flag, tag_text
142
 
143
 
 
31
  image = image.astype(np.float32)
32
  return image
33
 
34
+ def process_image(image, ort_sess, input_name, rating_tags, character_tags, general_tags):
35
  thresh = 0.35
36
  try:
37
  image = Image.open(image_path)
 
121
  print("Hugging Faceからモデルをダウンロード中")
122
  onnx_path = hf_hub_download(model_id, "model.onnx")
123
  csv_path = hf_hub_download(model_id, "selected_tags.csv")
124
+ ort_sess = ort.InferenceSession(onnx_path)
125
 
126
  print("ONNXモデルを実行中")
127
  print(f"ONNXモデルのパス: {onnx_path}")
128
 
129
+ image = Image.open(image_path)
130
+ image = image.convert("RGB") if image.mode != "RGB" else image
131
+ image = preprocess_image(image)
132
 
133
  with open(csv_path, "r", encoding="utf-8") as f:
134
  reader = csv.reader(f)
135
  header = next(reader)
136
  rows = list(reader)
 
 
137
  rating_tags = [row[1] for row in rows if row[2] == "9"]
138
  character_tags = [row[1] for row in rows if row[2] == "4"]
139
+ general_tags = [row[1] for row in rows if row[2] == "0"]
140
 
141
+ NSFW_flag, IP_flag, tag_text = process_image(image, ort_sess, ort_sess.get_inputs()[0].name, rating_tags, character_tags, general_tags)
142
  return NSFW_flag, IP_flag, tag_text
143
 
144