File size: 6,252 Bytes
c9dfb9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import gradio as gr
import csv
import os
from pathlib import Path

import cv2
import numpy as np
from PIL import Image
import onnxruntime as ort
from huggingface_hub import hf_hub_download
import spaces

# 画像のサイズ設定
IMAGE_SIZE = 448

def preprocess_image(image):
    image = np.array(image)
    image = image[:, :, ::-1]  # BGRからRGBへ変換

    # 画像を正方形にするためのパディングを追加
    size = max(image.shape[0:2])
    pad_x = size - image.shape[1]
    pad_y = size - image.shape[0]
    pad_l = pad_x // 2
    pad_t = pad_y // 2
    image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)

    # サイズに合わせた補間方法を選択
    interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
    image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
    image = image.astype(np.float32)
    return image

@spaces.GPU
def process_image(image_path, input_name, ort_sess, rating_tags, character_tags, general_tags):
    thresh = 0.35
    try:
        image = Image.open(image_path)
        image = image.convert("RGB") if image.mode != "RGB" else image
        image = preprocess_image(image)
    except Exception as e:
        print(f"画像を読み込めません: {image_path}, エラー: {e}")
        return

    img = np.array([image])
    prob = ort_sess.run(None, {input_name: img})[0][0]  # ONNXモデルからの出力

    # NSFW/SFW判定
    tag_confidences = {tag: prob[i] for i, tag in enumerate(rating_tags)}
    max_nsfw_score = max(tag_confidences.get("questionable", 0), tag_confidences.get("explicit", 0))
    max_sfw_score = tag_confidences.get("general", 0)
    NSFW_flag = None

    if max_nsfw_score > max_sfw_score:
        NSFW_flag = "NSFWの可能性が高いです"
    else:
        NSFW_flag = "SFWの可能性が高いです"

    # 版権キャラクターの可能性を評価
    character_tags_with_probs = []
    for i, p in enumerate(prob[4:]):
        if p >= thresh and i >= len(general_tags):
            tag_index = i - len(general_tags)
            if tag_index < len(character_tags):
                tag_name = character_tags[tag_index]
                prob_percent = round(p * 100, 2)  # 確率をパーセンテージに変換
                character_tags_with_probs.append((tag_name, f"{prob_percent}%"))

    IP_flag = None
    if character_tags_with_probs:
        IP_flag = f"版権キャラクター: {character_tags_with_probs}の可能性があります"
    else:
        IP_flag = "版権キャラクターの可能性が低いと思われます"

     # タグを生成
    tag_freq = {}
    undesired_tags = []     
    combined_tags = []
    general_tag_text = ""
    character_tag_text = ""
    remove_underscore = True
    caption_separator = ", "
    general_threshold = 0.35
    character_threshold = 0.35

    for i, p in enumerate(prob[4:]):
        if i < len(general_tags) and p >= general_threshold:
            tag_name = general_tags[i]
            if remove_underscore and len(tag_name) > 3:  # ignore emoji tags like >_< and ^_^
                tag_name = tag_name.replace("_", " ")

            if tag_name not in undesired_tags:
                tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
                general_tag_text += caption_separator + tag_name
                combined_tags.append(tag_name)
        elif i >= len(general_tags) and p >= character_threshold:
            tag_name = character_tags[i - len(general_tags)]
            if remove_underscore and len(tag_name) > 3:
                tag_name = tag_name.replace("_", " ")

            if tag_name not in undesired_tags:
                tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
                character_tag_text += caption_separator + tag_name
                combined_tags.append(tag_name)

    # 先頭のカンマを取る
    if len(general_tag_text) > 0:
        general_tag_text = general_tag_text[len(caption_separator) :]
    if len(character_tag_text) > 0:
        character_tag_text = character_tag_text[len(caption_separator) :]
    tag_text = caption_separator.join(combined_tags)
 
    return NSFW_flag, IP_flag, tag_text


class webui:
    def __init__(self):
        self.demo = gr.Blocks()

    @spaces.GPU
    def main(self, image_path, model_id):
        print("Hugging Faceからモデルをダウンロード中")
        onnx_path = hf_hub_download(model_id, "model.onnx")
        csv_path = hf_hub_download(model_id, "selected_tags.csv")

        print("ONNXモデルを実行中")
        print(f"ONNXモデルのパス: {onnx_path}")

        ort_sess = ort.InferenceSession(onnx_path)

        with open(csv_path, "r", encoding="utf-8") as f:
            reader = csv.reader(f)
            header = next(reader)
            rows = list(reader)
        assert header == ["tag_id", "name", "category", "count"], f"CSVフォーマットが期待と異なります: {header}"

        rating_tags = [row[1] for row in rows if row[2] == "9"]
        character_tags = [row[1] for row in rows if row[2] == "4"]
        general_tags = [row[1] for row in rows[1:] if row[2] == "0"]

        NSFW_flag, IP_flag, tag_text = process_image(image_path, ort_sess.get_inputs()[0].name, ort_sess, rating_tags, character_tags, general_tags)
        return NSFW_flag, IP_flag, tag_text


    def launch(self):
        with self.demo:
            with gr.Row():
                with gr.Column():
                    input_image = gr.Image(type='filepath', label="Analysis Image")
                    model_id = gr.Textbox(label="NSFW Flag", value="SmilingWolf/wd-vit-tagger-v3")
                    output_0 = gr.Textbox(label="NSFW Flag")
                    output_1 = gr.Textbox(label="IP Flag")
                    output_2 = gr.Textbox(label="Tags")
                    submit = gr.Button(value="Start Analysis")
                    
                    submit.click(
                        self.main, 
                        inputs=[input_image, model_id], 
                        outputs=[output_0, output_1, output_2]
                    )

        self.demo.launch()

if __name__ == "__main__":
    ui = webui()
    ui.launch()