Spaces:
Running
Running
import os, json | |
import gradio as gr | |
import huggingface_hub, numpy as np, onnxruntime as rt, pandas as pd | |
from PIL import Image | |
from huggingface_hub import login | |
from translator import translate_texts | |
# ------------------------------------------------------------------ | |
# 模型配置 | |
# ------------------------------------------------------------------ | |
MODEL_REPO = "SmilingWolf/wd-swinv2-tagger-v3" | |
MODEL_FILENAME = "model.onnx" | |
LABEL_FILENAME = "selected_tags.csv" | |
HF_TOKEN = os.environ.get("HF_TOKEN", "") | |
if HF_TOKEN: | |
login(token=HF_TOKEN) | |
else: | |
print("⚠️ 未检测到 HF_TOKEN,私有模型可能下载失败") | |
# ------------------------------------------------------------------ | |
# Tagger 类 | |
# ------------------------------------------------------------------ | |
class Tagger: | |
def __init__(self): | |
self.hf_token = HF_TOKEN | |
self._load_model_and_labels() | |
def _load_model_and_labels(self): | |
label_path = huggingface_hub.hf_hub_download( | |
MODEL_REPO, LABEL_FILENAME, token=self.hf_token | |
) | |
model_path = huggingface_hub.hf_hub_download( | |
MODEL_REPO, MODEL_FILENAME, token=self.hf_token | |
) | |
tags_df = pd.read_csv(label_path) | |
self.tag_names = tags_df["name"].tolist() | |
self.categories = { | |
"rating": np.where(tags_df["category"] == 9)[0], | |
"general": np.where(tags_df["category"] == 0)[0], | |
"character": np.where(tags_df["category"] == 4)[0], | |
} | |
self.model = rt.InferenceSession(model_path) | |
self.input_size = self.model.get_inputs()[0].shape[1] | |
# ------------------------- preprocess ------------------------- | |
def _preprocess(self, img: Image.Image) -> np.ndarray: | |
if img.mode != "RGB": | |
img = img.convert("RGB") | |
size = max(img.size) | |
canvas = Image.new("RGB", (size, size), (255, 255, 255)) | |
canvas.paste(img, ((size - img.width)//2, (size - img.height)//2)) | |
if size != self.input_size: | |
canvas = canvas.resize((self.input_size, self.input_size), Image.BICUBIC) | |
return np.array(canvas)[:, :, ::-1].astype(np.float32) # to BGR | |
# --------------------------- predict -------------------------- | |
def predict(self, img: Image.Image, | |
gen_th: float = 0.35, | |
char_th: float = 0.85): | |
inp_name = self.model.get_inputs()[0].name | |
outputs = self.model.run(None, {inp_name: self._preprocess(img)[None, ...]})[0][0] | |
res = {"ratings": {}, "general": {}, "characters": {}} | |
for idx in self.categories["rating"]: | |
res["ratings"][self.tag_names[idx].replace("_", " ")] = float(outputs[idx]) | |
for idx in self.categories["general"]: | |
if outputs[idx] > gen_th: | |
res["general"][self.tag_names[idx].replace("_", " ")] = float(outputs[idx]) | |
for idx in self.categories["character"]: | |
if outputs[idx] > char_th: | |
res["characters"][self.tag_names[idx].replace("_", " ")] = float(outputs[idx]) | |
res["general"] = dict(sorted(res["general"].items(), | |
key=lambda kv: kv[1], | |
reverse=True)) | |
return res | |
# ------------------------------------------------------------------ | |
# Gradio UI | |
# ------------------------------------------------------------------ | |
with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器 + 翻译") as demo: | |
gr.Markdown("# 🖼️ AI 图像标签分析器") | |
gr.Markdown("上传图片自动识别标签,并可一键翻译成中文") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
img_in = gr.Image(type="pil", label="上传图片") | |
with gr.Accordion("⚙️ 高级设置", open=False): | |
gen_slider = gr.Slider(0, 1, 0.35, | |
label="通用标签阈值", info="越高→标签更少更准") | |
char_slider = gr.Slider(0, 1, 0.85, | |
label="角色标签阈值", info="推荐保持较高阈值") | |
lang_drop = gr.Dropdown(["zh", "en"], value="zh", | |
label="翻译目标语言", | |
info="当前仅内置中 / 英") | |
btn = gr.Button("开始分析", variant="primary") | |
with gr.Column(scale=2): | |
with gr.Tabs(): | |
with gr.TabItem("🏷️ 通用标签 (英文)"): | |
out_general = gr.Label(label="General Tags") | |
with gr.TabItem("👤 角色标签 (英文)"): | |
out_char = gr.Label(label="Character Tags") | |
with gr.TabItem("⭐ 评分标签 (英文)"): | |
out_rating = gr.Label(label="Rating Tags") | |
with gr.TabItem("🌐 翻译结果"): | |
out_trans = gr.Textbox(label="翻译后的标签", | |
placeholder="翻译结果显示在此处") | |
# ----------------- 处理回调 ----------------- | |
def process(img, g_th, c_th, tgt_lang): | |
tagger = Tagger() | |
res = tagger.predict(img, g_th, c_th) | |
# =========== 组织翻译 =========== | |
tags_to_translate = list(res["general"].keys()) + list(res["characters"].keys()) | |
translations = translate_texts(tags_to_translate, src_lang="auto", tgt_lang=tgt_lang) | |
# 拼接字符串 | |
trans_str = ", ".join(translations) | |
return { | |
out_general: res["general"], | |
out_char: res["characters"], | |
out_rating: res["ratings"], | |
out_trans: trans_str | |
} | |
btn.click( | |
process, | |
inputs=[img_in, gen_slider, char_slider, lang_drop], | |
outputs=[out_general, out_char, out_rating, out_trans] | |
) | |
# ------------------------------------------------------------------ | |
# 启动 | |
# ------------------------------------------------------------------ | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0", server_port=7860) | |