File size: 6,183 Bytes
fcde2f2
d5894b1
fcde2f2
d5894b1
4412065
d5894b1
fcde2f2
 
 
d5894b1
fcde2f2
 
 
 
d5894b1
fcde2f2
 
 
4412065
fcde2f2
d5894b1
fcde2f2
 
 
d5894b1
 
fcde2f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5894b1
 
fcde2f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5894b1
fcde2f2
 
d5894b1
fcde2f2
 
 
d5894b1
fcde2f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5894b1
 
fcde2f2
 
 
 
d5894b1
fcde2f2
 
 
 
 
 
d5894b1
 
 
fcde2f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5894b1
fcde2f2
 
 
 
d5894b1
 
fcde2f2
 
 
 
d5894b1
 
fcde2f2
 
 
d5894b1
fcde2f2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import os, json
import gradio as gr
import huggingface_hub, numpy as np, onnxruntime as rt, pandas as pd
from PIL import Image
from huggingface_hub import login

from translator import translate_texts

# ------------------------------------------------------------------
# 模型配置
# ------------------------------------------------------------------
MODEL_REPO      = "SmilingWolf/wd-swinv2-tagger-v3"
MODEL_FILENAME  = "model.onnx"
LABEL_FILENAME  = "selected_tags.csv"

HF_TOKEN = os.environ.get("HF_TOKEN", "")
if HF_TOKEN:
    login(token=HF_TOKEN)
else:
    print("⚠️ 未检测到 HF_TOKEN,私有模型可能下载失败")

# ------------------------------------------------------------------
# Tagger 类
# ------------------------------------------------------------------
class Tagger:
    def __init__(self):
        self.hf_token   = HF_TOKEN
        self._load_model_and_labels()

    def _load_model_and_labels(self):
        label_path = huggingface_hub.hf_hub_download(
            MODEL_REPO, LABEL_FILENAME, token=self.hf_token
        )
        model_path = huggingface_hub.hf_hub_download(
            MODEL_REPO, MODEL_FILENAME, token=self.hf_token
        )

        tags_df           = pd.read_csv(label_path)
        self.tag_names    = tags_df["name"].tolist()
        self.categories   = {
            "rating":    np.where(tags_df["category"] == 9)[0],
            "general":   np.where(tags_df["category"] == 0)[0],
            "character": np.where(tags_df["category"] == 4)[0],
        }
        self.model        = rt.InferenceSession(model_path)
        self.input_size   = self.model.get_inputs()[0].shape[1]

    # ------------------------- preprocess -------------------------
    def _preprocess(self, img: Image.Image) -> np.ndarray:
        if img.mode != "RGB":
            img = img.convert("RGB")
        size   = max(img.size)
        canvas = Image.new("RGB", (size, size), (255, 255, 255))
        canvas.paste(img, ((size - img.width)//2, (size - img.height)//2))
        if size != self.input_size:
            canvas = canvas.resize((self.input_size, self.input_size), Image.BICUBIC)
        return np.array(canvas)[:, :, ::-1].astype(np.float32)  # to BGR

    # --------------------------- predict --------------------------
    def predict(self, img: Image.Image,
                gen_th: float = 0.35,
                char_th: float = 0.85):
        inp_name  = self.model.get_inputs()[0].name
        outputs   = self.model.run(None, {inp_name: self._preprocess(img)[None, ...]})[0][0]

        res = {"ratings": {}, "general": {}, "characters": {}}

        for idx in self.categories["rating"]:
            res["ratings"][self.tag_names[idx].replace("_", " ")] = float(outputs[idx])

        for idx in self.categories["general"]:
            if outputs[idx] > gen_th:
                res["general"][self.tag_names[idx].replace("_", " ")] = float(outputs[idx])

        for idx in self.categories["character"]:
            if outputs[idx] > char_th:
                res["characters"][self.tag_names[idx].replace("_", " ")] = float(outputs[idx])

        res["general"] = dict(sorted(res["general"].items(),
                                     key=lambda kv: kv[1],
                                     reverse=True))
        return res

# ------------------------------------------------------------------
# Gradio UI
# ------------------------------------------------------------------
with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器 + 翻译") as demo:
    gr.Markdown("# 🖼️ AI 图像标签分析器")
    gr.Markdown("上传图片自动识别标签,并可一键翻译成中文")

    with gr.Row():
        with gr.Column(scale=1):
            img_in = gr.Image(type="pil", label="上传图片")
            with gr.Accordion("⚙️ 高级设置", open=False):
                gen_slider  = gr.Slider(0, 1, 0.35,
                                        label="通用标签阈值", info="越高→标签更少更准")
                char_slider = gr.Slider(0, 1, 0.85,
                                        label="角色标签阈值", info="推荐保持较高阈值")
                lang_drop   = gr.Dropdown(["zh", "en"], value="zh",
                                          label="翻译目标语言",
                                          info="当前仅内置中 / 英")

            btn = gr.Button("开始分析", variant="primary")

        with gr.Column(scale=2):
            with gr.Tabs():
                with gr.TabItem("🏷️ 通用标签 (英文)"):
                    out_general = gr.Label(label="General Tags")
                with gr.TabItem("👤 角色标签 (英文)"):
                    out_char    = gr.Label(label="Character Tags")
                with gr.TabItem("⭐ 评分标签 (英文)"):
                    out_rating  = gr.Label(label="Rating Tags")
                with gr.TabItem("🌐 翻译结果"):
                    out_trans   = gr.Textbox(label="翻译后的标签",
                                             placeholder="翻译结果显示在此处")

    # ----------------- 处理回调 -----------------
    def process(img, g_th, c_th, tgt_lang):
        tagger   = Tagger()
        res      = tagger.predict(img, g_th, c_th)

        # =========== 组织翻译 ===========
        tags_to_translate = list(res["general"].keys()) + list(res["characters"].keys())
        translations      = translate_texts(tags_to_translate, src_lang="auto", tgt_lang=tgt_lang)
        # 拼接字符串
        trans_str = ", ".join(translations)

        return {
            out_general: res["general"],
            out_char:    res["characters"],
            out_rating:  res["ratings"],
            out_trans:   trans_str
        }

    btn.click(
        process,
        inputs=[img_in, gen_slider, char_slider, lang_drop],
        outputs=[out_general, out_char, out_rating, out_trans]
    )

# ------------------------------------------------------------------
# 启动
# ------------------------------------------------------------------
if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)