Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,947 Bytes
c9dfb9e a8a8e1b c9dfb9e a8a8e1b c9dfb9e a8a8e1b c9dfb9e ab6e44b 607a0f5 c9dfb9e ab6e44b c9dfb9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import gradio as gr
import csv
import os
from pathlib import Path
import cv2
import numpy as np
from PIL import Image
import onnxruntime as ort
from huggingface_hub import hf_hub_download
import spaces
# 画像のサイズ設定
IMAGE_SIZE = 448
def preprocess_image(image):
image = np.array(image)
image = image[:, :, ::-1] # BGRからRGBへ変換
# 画像を正方形にするためのパディングを追加
size = max(image.shape[0:2])
pad_x = size - image.shape[1]
pad_y = size - image.shape[0]
pad_l = pad_x // 2
pad_t = pad_y // 2
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)
# サイズに合わせた補間方法を選択
interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
image = image.astype(np.float32)
return image
class webui:
def __init__(self):
self.demo = gr.Blocks()
@spaces.GPU
def main(self, image_path, model_id):
print("Hugging Faceからモデルをダウンロード中")
onnx_path = hf_hub_download(model_id, "model.onnx")
csv_path = hf_hub_download(model_id, "selected_tags.csv")
ort_sess = ort.InferenceSession(onnx_path)
print("ONNXモデルを実行中")
print(f"ONNXモデルのパス: {onnx_path}")
image = Image.open(image_path)
image = image.convert("RGB") if image.mode != "RGB" else image
image = preprocess_image(image)
with open(csv_path, "r", encoding="utf-8") as f:
reader = csv.reader(f)
header = next(reader)
rows = list(reader)
rating_tags = [row[1] for row in rows if row[2] == "9"]
character_tags = [row[1] for row in rows if row[2] == "4"]
general_tags = [row[1] for row in rows if row[2] == "0"]
img = np.array([image])
prob = ort_sess.run(None, {ort_sess.get_inputs()[0].name: img})[0][0] # ONNXモデルからの出力
thresh = 0.35
# NSFW/SFW判定
tag_confidences = {tag: prob[i] for i, tag in enumerate(rating_tags)}
max_nsfw_score = max(tag_confidences.get("questionable", 0), tag_confidences.get("explicit", 0))
max_sfw_score = tag_confidences.get("general", 0)
NSFW_flag = None
if max_nsfw_score > max_sfw_score:
NSFW_flag = "NSFWの可能性が高いです"
else:
NSFW_flag = "SFWの可能性が高いです"
# 版権キャラクターの可能性を評価
character_tags_with_probs = []
for i, p in enumerate(prob[4:]):
if p >= thresh and i >= len(general_tags):
tag_index = i - len(general_tags)
if tag_index < len(character_tags):
tag_name = character_tags[tag_index]
prob_percent = round(p * 100, 2) # 確率をパーセンテージに変換
character_tags_with_probs.append((tag_name, f"{prob_percent}%"))
IP_flag = None
if character_tags_with_probs:
IP_flag = f"版権キャラクター: {character_tags_with_probs}の可能性があります"
else:
IP_flag = "版権キャラクターの可能性が低いと思われます"
# タグを生成
tag_freq = {}
undesired_tags = []
combined_tags = []
general_tag_text = ""
character_tag_text = ""
remove_underscore = True
caption_separator = ", "
general_threshold = 0.35
character_threshold = 0.35
for i, p in enumerate(prob[4:]):
if i < len(general_tags) and p >= general_threshold:
tag_name = general_tags[i]
if remove_underscore and len(tag_name) > 3: # ignore emoji tags like >_< and ^_^
tag_name = tag_name.replace("_", " ")
if tag_name not in undesired_tags:
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
general_tag_text += caption_separator + tag_name
combined_tags.append(tag_name)
elif i >= len(general_tags) and p >= character_threshold:
tag_name = character_tags[i - len(general_tags)]
if remove_underscore and len(tag_name) > 3:
tag_name = tag_name.replace("_", " ")
if tag_name not in undesired_tags:
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
character_tag_text += caption_separator + tag_name
combined_tags.append(tag_name)
# 先頭のカンマを取る
if len(general_tag_text) > 0:
general_tag_text = general_tag_text[len(caption_separator) :]
if len(character_tag_text) > 0:
character_tag_text = character_tag_text[len(caption_separator) :]
tag_text = caption_separator.join(combined_tags)
return NSFW_flag, IP_flag, tag_text
def launch(self):
with self.demo:
with gr.Row():
with gr.Column():
input_image = gr.Image(type='filepath', label="Analysis Image")
model_id = gr.Textbox(label="MODEL ID", value="SmilingWolf/wd-vit-tagger-v3")
output_0 = gr.Textbox(label="NSFW Flag")
output_1 = gr.Textbox(label="IP Flag")
output_2 = gr.Textbox(label="Tags")
submit = gr.Button(value="Start Analysis")
submit.click(
self.main,
inputs=[input_image, model_id],
outputs=[output_0, output_1, output_2]
)
self.demo.launch()
if __name__ == "__main__":
ui = webui()
ui.launch() |