import os, json import gradio as gr import huggingface_hub, numpy as np, onnxruntime as rt, pandas as pd from PIL import Image from huggingface_hub import login from translator import translate_texts # ------------------------------------------------------------------ # 模型配置 # ------------------------------------------------------------------ MODEL_REPO = "SmilingWolf/wd-swinv2-tagger-v3" MODEL_FILENAME = "model.onnx" LABEL_FILENAME = "selected_tags.csv" HF_TOKEN = os.environ.get("HF_TOKEN", "") if HF_TOKEN: login(token=HF_TOKEN) else: print("⚠️ 未检测到 HF_TOKEN,私有模型可能下载失败") # ------------------------------------------------------------------ # Tagger 类 # ------------------------------------------------------------------ class Tagger: def __init__(self): self.hf_token = HF_TOKEN self._load_model_and_labels() def _load_model_and_labels(self): label_path = huggingface_hub.hf_hub_download( MODEL_REPO, LABEL_FILENAME, token=self.hf_token ) model_path = huggingface_hub.hf_hub_download( MODEL_REPO, MODEL_FILENAME, token=self.hf_token ) tags_df = pd.read_csv(label_path) self.tag_names = tags_df["name"].tolist() self.categories = { "rating": np.where(tags_df["category"] == 9)[0], "general": np.where(tags_df["category"] == 0)[0], "character": np.where(tags_df["category"] == 4)[0], } self.model = rt.InferenceSession(model_path) self.input_size = self.model.get_inputs()[0].shape[1] # ------------------------- preprocess ------------------------- def _preprocess(self, img: Image.Image) -> np.ndarray: if img.mode != "RGB": img = img.convert("RGB") size = max(img.size) canvas = Image.new("RGB", (size, size), (255, 255, 255)) canvas.paste(img, ((size - img.width)//2, (size - img.height)//2)) if size != self.input_size: canvas = canvas.resize((self.input_size, self.input_size), Image.BICUBIC) return np.array(canvas)[:, :, ::-1].astype(np.float32) # to BGR # --------------------------- predict -------------------------- def predict(self, img: Image.Image, gen_th: float = 0.35, char_th: float = 0.85): inp_name = self.model.get_inputs()[0].name outputs = self.model.run(None, {inp_name: self._preprocess(img)[None, ...]})[0][0] res = {"ratings": {}, "general": {}, "characters": {}} for idx in self.categories["rating"]: res["ratings"][self.tag_names[idx].replace("_", " ")] = float(outputs[idx]) for idx in self.categories["general"]: if outputs[idx] > gen_th: res["general"][self.tag_names[idx].replace("_", " ")] = float(outputs[idx]) for idx in self.categories["character"]: if outputs[idx] > char_th: res["characters"][self.tag_names[idx].replace("_", " ")] = float(outputs[idx]) res["general"] = dict(sorted(res["general"].items(), key=lambda kv: kv[1], reverse=True)) return res # ------------------------------------------------------------------ # Gradio UI # ------------------------------------------------------------------ with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器 + 翻译") as demo: gr.Markdown("# 🖼️ AI 图像标签分析器") gr.Markdown("上传图片自动识别标签,并可一键翻译成中文") with gr.Row(): with gr.Column(scale=1): img_in = gr.Image(type="pil", label="上传图片") with gr.Accordion("⚙️ 高级设置", open=False): gen_slider = gr.Slider(0, 1, 0.35, label="通用标签阈值", info="越高→标签更少更准") char_slider = gr.Slider(0, 1, 0.85, label="角色标签阈值", info="推荐保持较高阈值") lang_drop = gr.Dropdown(["zh", "en"], value="zh", label="翻译目标语言", info="当前仅内置中 / 英") btn = gr.Button("开始分析", variant="primary") with gr.Column(scale=2): with gr.Tabs(): with gr.TabItem("🏷️ 通用标签 (英文)"): out_general = gr.Label(label="General Tags") with gr.TabItem("👤 角色标签 (英文)"): out_char = gr.Label(label="Character Tags") with gr.TabItem("⭐ 评分标签 (英文)"): out_rating = gr.Label(label="Rating Tags") with gr.TabItem("🌐 翻译结果"): out_trans = gr.Textbox(label="翻译后的标签", placeholder="翻译结果显示在此处") # ----------------- 处理回调 ----------------- def process(img, g_th, c_th, tgt_lang): tagger = Tagger() res = tagger.predict(img, g_th, c_th) # =========== 组织翻译 =========== tags_to_translate = list(res["general"].keys()) + list(res["characters"].keys()) translations = translate_texts(tags_to_translate, src_lang="auto", tgt_lang=tgt_lang) # 拼接字符串 trans_str = ", ".join(translations) return { out_general: res["general"], out_char: res["characters"], out_rating: res["ratings"], out_trans: trans_str } btn.click( process, inputs=[img_in, gen_slider, char_slider, lang_drop], outputs=[out_general, out_char, out_rating, out_trans] ) # ------------------------------------------------------------------ # 启动 # ------------------------------------------------------------------ if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)