Spaces:
Running
Running
File size: 6,183 Bytes
fcde2f2 d5894b1 fcde2f2 d5894b1 4412065 d5894b1 fcde2f2 d5894b1 fcde2f2 d5894b1 fcde2f2 4412065 fcde2f2 d5894b1 fcde2f2 d5894b1 fcde2f2 d5894b1 fcde2f2 d5894b1 fcde2f2 d5894b1 fcde2f2 d5894b1 fcde2f2 d5894b1 fcde2f2 d5894b1 fcde2f2 d5894b1 fcde2f2 d5894b1 fcde2f2 d5894b1 fcde2f2 d5894b1 fcde2f2 d5894b1 fcde2f2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import os, json
import gradio as gr
import huggingface_hub, numpy as np, onnxruntime as rt, pandas as pd
from PIL import Image
from huggingface_hub import login
from translator import translate_texts
# ------------------------------------------------------------------
# 模型配置
# ------------------------------------------------------------------
MODEL_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
MODEL_FILENAME = "model.onnx"
LABEL_FILENAME = "selected_tags.csv"
HF_TOKEN = os.environ.get("HF_TOKEN", "")
if HF_TOKEN:
login(token=HF_TOKEN)
else:
print("⚠️ 未检测到 HF_TOKEN,私有模型可能下载失败")
# ------------------------------------------------------------------
# Tagger 类
# ------------------------------------------------------------------
class Tagger:
def __init__(self):
self.hf_token = HF_TOKEN
self._load_model_and_labels()
def _load_model_and_labels(self):
label_path = huggingface_hub.hf_hub_download(
MODEL_REPO, LABEL_FILENAME, token=self.hf_token
)
model_path = huggingface_hub.hf_hub_download(
MODEL_REPO, MODEL_FILENAME, token=self.hf_token
)
tags_df = pd.read_csv(label_path)
self.tag_names = tags_df["name"].tolist()
self.categories = {
"rating": np.where(tags_df["category"] == 9)[0],
"general": np.where(tags_df["category"] == 0)[0],
"character": np.where(tags_df["category"] == 4)[0],
}
self.model = rt.InferenceSession(model_path)
self.input_size = self.model.get_inputs()[0].shape[1]
# ------------------------- preprocess -------------------------
def _preprocess(self, img: Image.Image) -> np.ndarray:
if img.mode != "RGB":
img = img.convert("RGB")
size = max(img.size)
canvas = Image.new("RGB", (size, size), (255, 255, 255))
canvas.paste(img, ((size - img.width)//2, (size - img.height)//2))
if size != self.input_size:
canvas = canvas.resize((self.input_size, self.input_size), Image.BICUBIC)
return np.array(canvas)[:, :, ::-1].astype(np.float32) # to BGR
# --------------------------- predict --------------------------
def predict(self, img: Image.Image,
gen_th: float = 0.35,
char_th: float = 0.85):
inp_name = self.model.get_inputs()[0].name
outputs = self.model.run(None, {inp_name: self._preprocess(img)[None, ...]})[0][0]
res = {"ratings": {}, "general": {}, "characters": {}}
for idx in self.categories["rating"]:
res["ratings"][self.tag_names[idx].replace("_", " ")] = float(outputs[idx])
for idx in self.categories["general"]:
if outputs[idx] > gen_th:
res["general"][self.tag_names[idx].replace("_", " ")] = float(outputs[idx])
for idx in self.categories["character"]:
if outputs[idx] > char_th:
res["characters"][self.tag_names[idx].replace("_", " ")] = float(outputs[idx])
res["general"] = dict(sorted(res["general"].items(),
key=lambda kv: kv[1],
reverse=True))
return res
# ------------------------------------------------------------------
# Gradio UI
# ------------------------------------------------------------------
with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器 + 翻译") as demo:
gr.Markdown("# 🖼️ AI 图像标签分析器")
gr.Markdown("上传图片自动识别标签,并可一键翻译成中文")
with gr.Row():
with gr.Column(scale=1):
img_in = gr.Image(type="pil", label="上传图片")
with gr.Accordion("⚙️ 高级设置", open=False):
gen_slider = gr.Slider(0, 1, 0.35,
label="通用标签阈值", info="越高→标签更少更准")
char_slider = gr.Slider(0, 1, 0.85,
label="角色标签阈值", info="推荐保持较高阈值")
lang_drop = gr.Dropdown(["zh", "en"], value="zh",
label="翻译目标语言",
info="当前仅内置中 / 英")
btn = gr.Button("开始分析", variant="primary")
with gr.Column(scale=2):
with gr.Tabs():
with gr.TabItem("🏷️ 通用标签 (英文)"):
out_general = gr.Label(label="General Tags")
with gr.TabItem("👤 角色标签 (英文)"):
out_char = gr.Label(label="Character Tags")
with gr.TabItem("⭐ 评分标签 (英文)"):
out_rating = gr.Label(label="Rating Tags")
with gr.TabItem("🌐 翻译结果"):
out_trans = gr.Textbox(label="翻译后的标签",
placeholder="翻译结果显示在此处")
# ----------------- 处理回调 -----------------
def process(img, g_th, c_th, tgt_lang):
tagger = Tagger()
res = tagger.predict(img, g_th, c_th)
# =========== 组织翻译 ===========
tags_to_translate = list(res["general"].keys()) + list(res["characters"].keys())
translations = translate_texts(tags_to_translate, src_lang="auto", tgt_lang=tgt_lang)
# 拼接字符串
trans_str = ", ".join(translations)
return {
out_general: res["general"],
out_char: res["characters"],
out_rating: res["ratings"],
out_trans: trans_str
}
btn.click(
process,
inputs=[img_in, gen_slider, char_slider, lang_drop],
outputs=[out_general, out_char, out_rating, out_trans]
)
# ------------------------------------------------------------------
# 启动
# ------------------------------------------------------------------
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)
|