tori29umai commited on
Commit
ab6e44b
·
1 Parent(s): a8a8e1b
Files changed (1) hide show
  1. app.py +7 -13
app.py CHANGED
@@ -31,18 +31,8 @@ def preprocess_image(image):
31
  image = image.astype(np.float32)
32
  return image
33
 
34
- def process_image(image, ort_sess, input_name, rating_tags, character_tags, general_tags):
35
  thresh = 0.35
36
- try:
37
- image = Image.open(image_path)
38
- image = image.convert("RGB") if image.mode != "RGB" else image
39
- image = preprocess_image(image)
40
- except Exception as e:
41
- print(f"画像を読み込めません: {image_path}, エラー: {e}")
42
- return
43
-
44
- img = np.array([image])
45
- prob = ort_sess.run(None, {input_name: img})[0][0] # ONNXモデルからの出力
46
 
47
  # NSFW/SFW判定
48
  tag_confidences = {tag: prob[i] for i, tag in enumerate(rating_tags)}
@@ -138,7 +128,11 @@ class webui:
138
  character_tags = [row[1] for row in rows if row[2] == "4"]
139
  general_tags = [row[1] for row in rows if row[2] == "0"]
140
 
141
- NSFW_flag, IP_flag, tag_text = process_image(image, ort_sess, ort_sess.get_inputs()[0].name, rating_tags, character_tags, general_tags)
 
 
 
 
142
  return NSFW_flag, IP_flag, tag_text
143
 
144
 
@@ -147,7 +141,7 @@ class webui:
147
  with gr.Row():
148
  with gr.Column():
149
  input_image = gr.Image(type='filepath', label="Analysis Image")
150
- model_id = gr.Textbox(label="NSFW Flag", value="SmilingWolf/wd-vit-tagger-v3")
151
  output_0 = gr.Textbox(label="NSFW Flag")
152
  output_1 = gr.Textbox(label="IP Flag")
153
  output_2 = gr.Textbox(label="Tags")
 
31
  image = image.astype(np.float32)
32
  return image
33
 
34
+ def process_image(prob, rating_tags, character_tags, general_tags):
35
  thresh = 0.35
 
 
 
 
 
 
 
 
 
 
36
 
37
  # NSFW/SFW判定
38
  tag_confidences = {tag: prob[i] for i, tag in enumerate(rating_tags)}
 
128
  character_tags = [row[1] for row in rows if row[2] == "4"]
129
  general_tags = [row[1] for row in rows if row[2] == "0"]
130
 
131
+
132
+ img = np.array([image])
133
+ prob = ort_sess.run(None, {ort_sess.get_inputs()[0].name: img})[0][0] # ONNXモデルからの出力
134
+
135
+ NSFW_flag, IP_flag, tag_text = process_image(prob, rating_tags, character_tags, general_tags)
136
  return NSFW_flag, IP_flag, tag_text
137
 
138
 
 
141
  with gr.Row():
142
  with gr.Column():
143
  input_image = gr.Image(type='filepath', label="Analysis Image")
144
+ model_id = gr.Textbox(label="MODEL ID", value="SmilingWolf/wd-vit-tagger-v3")
145
  output_0 = gr.Textbox(label="NSFW Flag")
146
  output_1 = gr.Textbox(label="IP Flag")
147
  output_2 = gr.Textbox(label="Tags")