Image_Inversion / app.py
IdlecloudX's picture
Update app.py
fcde2f2 verified
raw
history blame
6.18 kB
import os, json
import gradio as gr
import huggingface_hub, numpy as np, onnxruntime as rt, pandas as pd
from PIL import Image
from huggingface_hub import login
from translator import translate_texts
# ------------------------------------------------------------------
# 模型配置
# ------------------------------------------------------------------
MODEL_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
MODEL_FILENAME = "model.onnx"
LABEL_FILENAME = "selected_tags.csv"
HF_TOKEN = os.environ.get("HF_TOKEN", "")
if HF_TOKEN:
login(token=HF_TOKEN)
else:
print("⚠️ 未检测到 HF_TOKEN,私有模型可能下载失败")
# ------------------------------------------------------------------
# Tagger 类
# ------------------------------------------------------------------
class Tagger:
def __init__(self):
self.hf_token = HF_TOKEN
self._load_model_and_labels()
def _load_model_and_labels(self):
label_path = huggingface_hub.hf_hub_download(
MODEL_REPO, LABEL_FILENAME, token=self.hf_token
)
model_path = huggingface_hub.hf_hub_download(
MODEL_REPO, MODEL_FILENAME, token=self.hf_token
)
tags_df = pd.read_csv(label_path)
self.tag_names = tags_df["name"].tolist()
self.categories = {
"rating": np.where(tags_df["category"] == 9)[0],
"general": np.where(tags_df["category"] == 0)[0],
"character": np.where(tags_df["category"] == 4)[0],
}
self.model = rt.InferenceSession(model_path)
self.input_size = self.model.get_inputs()[0].shape[1]
# ------------------------- preprocess -------------------------
def _preprocess(self, img: Image.Image) -> np.ndarray:
if img.mode != "RGB":
img = img.convert("RGB")
size = max(img.size)
canvas = Image.new("RGB", (size, size), (255, 255, 255))
canvas.paste(img, ((size - img.width)//2, (size - img.height)//2))
if size != self.input_size:
canvas = canvas.resize((self.input_size, self.input_size), Image.BICUBIC)
return np.array(canvas)[:, :, ::-1].astype(np.float32) # to BGR
# --------------------------- predict --------------------------
def predict(self, img: Image.Image,
gen_th: float = 0.35,
char_th: float = 0.85):
inp_name = self.model.get_inputs()[0].name
outputs = self.model.run(None, {inp_name: self._preprocess(img)[None, ...]})[0][0]
res = {"ratings": {}, "general": {}, "characters": {}}
for idx in self.categories["rating"]:
res["ratings"][self.tag_names[idx].replace("_", " ")] = float(outputs[idx])
for idx in self.categories["general"]:
if outputs[idx] > gen_th:
res["general"][self.tag_names[idx].replace("_", " ")] = float(outputs[idx])
for idx in self.categories["character"]:
if outputs[idx] > char_th:
res["characters"][self.tag_names[idx].replace("_", " ")] = float(outputs[idx])
res["general"] = dict(sorted(res["general"].items(),
key=lambda kv: kv[1],
reverse=True))
return res
# ------------------------------------------------------------------
# Gradio UI
# ------------------------------------------------------------------
with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器 + 翻译") as demo:
gr.Markdown("# 🖼️ AI 图像标签分析器")
gr.Markdown("上传图片自动识别标签,并可一键翻译成中文")
with gr.Row():
with gr.Column(scale=1):
img_in = gr.Image(type="pil", label="上传图片")
with gr.Accordion("⚙️ 高级设置", open=False):
gen_slider = gr.Slider(0, 1, 0.35,
label="通用标签阈值", info="越高→标签更少更准")
char_slider = gr.Slider(0, 1, 0.85,
label="角色标签阈值", info="推荐保持较高阈值")
lang_drop = gr.Dropdown(["zh", "en"], value="zh",
label="翻译目标语言",
info="当前仅内置中 / 英")
btn = gr.Button("开始分析", variant="primary")
with gr.Column(scale=2):
with gr.Tabs():
with gr.TabItem("🏷️ 通用标签 (英文)"):
out_general = gr.Label(label="General Tags")
with gr.TabItem("👤 角色标签 (英文)"):
out_char = gr.Label(label="Character Tags")
with gr.TabItem("⭐ 评分标签 (英文)"):
out_rating = gr.Label(label="Rating Tags")
with gr.TabItem("🌐 翻译结果"):
out_trans = gr.Textbox(label="翻译后的标签",
placeholder="翻译结果显示在此处")
# ----------------- 处理回调 -----------------
def process(img, g_th, c_th, tgt_lang):
tagger = Tagger()
res = tagger.predict(img, g_th, c_th)
# =========== 组织翻译 ===========
tags_to_translate = list(res["general"].keys()) + list(res["characters"].keys())
translations = translate_texts(tags_to_translate, src_lang="auto", tgt_lang=tgt_lang)
# 拼接字符串
trans_str = ", ".join(translations)
return {
out_general: res["general"],
out_char: res["characters"],
out_rating: res["ratings"],
out_trans: trans_str
}
btn.click(
process,
inputs=[img_in, gen_slider, char_slider, lang_drop],
outputs=[out_general, out_char, out_rating, out_trans]
)
# ------------------------------------------------------------------
# 启动
# ------------------------------------------------------------------
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)