import os, json import gradio as gr import huggingface_hub, numpy as np, onnxruntime as rt, pandas as pd from PIL import Image from huggingface_hub import login from translator import translate_texts # ------------------------------------------------------------------ # 模型配置 # ------------------------------------------------------------------ MODEL_REPO = "SmilingWolf/wd-swinv2-tagger-v3" MODEL_FILENAME = "model.onnx" LABEL_FILENAME = "selected_tags.csv" HF_TOKEN = os.environ.get("HF_TOKEN", "") if HF_TOKEN: login(token=HF_TOKEN) else: print("⚠️ 未检测到 HF_TOKEN,私有模型可能下载失败") # ------------------------------------------------------------------ # Tagger 类 # ------------------------------------------------------------------ class Tagger: def __init__(self): self.hf_token = HF_TOKEN self._load_model_and_labels() def _load_model_and_labels(self): label_path = huggingface_hub.hf_hub_download( MODEL_REPO, LABEL_FILENAME, token=self.hf_token ) model_path = huggingface_hub.hf_hub_download( MODEL_REPO, MODEL_FILENAME, token=self.hf_token ) tags_df = pd.read_csv(label_path) self.tag_names = tags_df["name"].tolist() self.categories = { "rating": np.where(tags_df["category"] == 9)[0], "general": np.where(tags_df["category"] == 0)[0], "character": np.where(tags_df["category"] == 4)[0], } self.model = rt.InferenceSession(model_path) self.input_size = self.model.get_inputs()[0].shape[1] # ------------------------- preprocess ------------------------- def _preprocess(self, img: Image.Image) -> np.ndarray: if img.mode != "RGB": img = img.convert("RGB") size = max(img.size) canvas = Image.new("RGB", (size, size), (255, 255, 255)) canvas.paste(img, ((size - img.width)//2, (size - img.height)//2)) if size != self.input_size: canvas = canvas.resize((self.input_size, self.input_size), Image.BICUBIC) return np.array(canvas)[:, :, ::-1].astype(np.float32) # to BGR # --------------------------- predict -------------------------- def predict(self, img: Image.Image, gen_th: float = 0.35, char_th: float = 0.85): inp_name = self.model.get_inputs()[0].name outputs = self.model.run(None, {inp_name: self._preprocess(img)[None, ...]})[0][0] res = {"ratings": {}, "general": {}, "characters": {}} for idx in self.categories["rating"]: res["ratings"][self.tag_names[idx].replace("_", " ")] = float(outputs[idx]) for idx in self.categories["general"]: if outputs[idx] > gen_th: res["general"][self.tag_names[idx].replace("_", " ")] = float(outputs[idx]) for idx in self.categories["character"]: if outputs[idx] > char_th: res["characters"][self.tag_names[idx].replace("_", " ")] = float(outputs[idx]) res["general"] = dict(sorted(res["general"].items(), key=lambda kv: kv[1], reverse=True)) return res # ------------------------------------------------------------------ # Gradio UI # ------------------------------------------------------------------ custom_css = """ .label-container { max-height: 300px; overflow-y: auto; border: 1px solid #ddd; padding: 10px; border-radius: 5px; background-color: #f9f9f9; } .tag-item { display: flex; justify-content: space-between; align-items: center; margin: 2px 0; padding: 2px 5px; border-radius: 3px; background-color: #fff; cursor: pointer; transition: background-color 0.2s; } .tag-item:hover { background-color: #e8f4ff; } .tag-item:active { background-color: #bde0ff; } .tag-content { display: flex; align-items: center; gap: 10px; flex: 1; } .tag-text { font-weight: bold; color: #333; } .tag-score { color: #999; font-size: 0.9em; } .copy-container { position: relative; margin-bottom: 5px; } .copy-button { position: absolute; top: 5px; right: 5px; padding: 4px 8px; font-size: 12px; background-color: #f0f0f0; border: 1px solid #ddd; border-radius: 4px; cursor: pointer; transition: all 0.2s; } .copy-button:hover { background-color: #e0e0e0; } .copy-button:active { background-color: #d0d0d0; } .toast { position: fixed; top: 20px; right: 20px; padding: 10px 20px; background-color: #4CAF50; color: white; border-radius: 4px; opacity: 0; transition: opacity 0.3s; z-index: 1000; } .toast.show { opacity: 1; } """ with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器", css=custom_css) as demo: gr.Markdown("# 🖼️ AI 图像标签分析器") gr.Markdown("上传图片自动识别标签,并可一键翻译成中文") with gr.Row(): with gr.Column(scale=1): img_in = gr.Image(type="pil", label="上传图片") with gr.Accordion("⚙️ 高级设置", open=True): gen_slider = gr.Slider(0, 1, 0.35, label="通用标签阈值", info="越高→标签更少更准") char_slider = gr.Slider(0, 1, 0.85, label="角色标签阈值", info="推荐保持较高阈值") gr.Markdown("### 汇总设置") with gr.Row(): sum_general = gr.Checkbox(True, label="通用标签") sum_char = gr.Checkbox(True, label="角色标签") sum_rating = gr.Checkbox(False, label="评分标签") sum_sep = gr.Dropdown(["逗号", "换行", "空格"], value="逗号", label="分隔符") with gr.Column(scale=2): with gr.Tabs(): with gr.TabItem("🏷️ 通用标签"): out_general = gr.HTML(label="General Tags") with gr.TabItem("👤 角色标签"): out_char = gr.HTML(label="Character Tags") with gr.TabItem("⭐ 评分标签"): out_rating = gr.HTML(label="Rating Tags") gr.Markdown("### 标签汇总") with gr.Row(): lang_btn = gr.Button("中/EN", variant="secondary", scale=0) copy_btn = gr.Button("📋 复制", variant="secondary", scale=0) out_summary = gr.Textbox(label="标签汇总", placeholder="选择需要汇总的标签类别...", lines=3, interactive=False) with gr.Row(): processing_info = gr.Markdown("", visible=False) btn = gr.Button("开始分析", variant="primary", scale=0) # 存储状态的隐藏组件 lang_state = gr.State("en") # 默认显示英文 tags_data = gr.State({}) # 存储标签数据 translations_data = gr.State({}) # 存储翻译数据 # ----------------- 处理回调 ----------------- def format_tags_html(tags_dict, translations, category_key, current_lang): """格式化标签为HTML格式""" if not tags_dict: return "

暂无标签

" html = '
' for i, (tag, score) in enumerate(tags_dict.items()): display_text = translations[i] if current_lang == "zh" and i < len(translations) else tag tag_html = f'''
{display_text}
{score:.3f}
''' html += tag_html html += '
' # 添加复制函数的JavaScript copy_script = ''' ''' return html + copy_script def process(img, g_th, c_th, sum_gen, sum_char, sum_rat, sep_type, current_lang, prev_tags, prev_translations): # 开始处理 yield ( gr.update(interactive=False, value="处理中..."), gr.update(visible=True, value="🔄 正在分析图像..."), "", "", "", "", current_lang, {}, {} ) try: tagger = Tagger() res = tagger.predict(img, g_th, c_th) # 收集所有需要翻译的标签 all_tags = [] tag_categories = { "general": list(res["general"].keys()), "characters": list(res["characters"].keys()), "ratings": list(res["ratings"].keys()) } for tags in tag_categories.values(): all_tags.extend(tags) # 批量翻译 if all_tags: translations = translate_texts(all_tags, src_lang="auto", tgt_lang="zh") else: translations = [] # 分配翻译结果 translations_dict = {} offset = 0 for category, tags in tag_categories.items(): if tags: translations_dict[category] = translations[offset:offset+len(tags)] offset += len(tags) else: translations_dict[category] = [] # 生成HTML输出 general_html = format_tags_html(res["general"], translations_dict["general"], "general", current_lang) char_html = format_tags_html(res["characters"], translations_dict["characters"], "characters", current_lang) rating_html = format_tags_html(res["ratings"], translations_dict["ratings"], "ratings", current_lang) # 生成汇总文本 summary_tags = [] separators = {"逗号": ", ", "换行": "\n", "空格": " "} separator = separators[sep_type] # 按顺序:角色、通用、评分 if sum_char and res["characters"]: if current_lang == "zh" and translations_dict["characters"]: summary_tags.extend(translations_dict["characters"]) else: summary_tags.extend(list(res["characters"].keys())) if sum_gen and res["general"]: if current_lang == "zh" and translations_dict["general"]: summary_tags.extend(translations_dict["general"]) else: summary_tags.extend(list(res["general"].keys())) if sum_rat and res["ratings"]: if current_lang == "zh" and translations_dict["ratings"]: summary_tags.extend(translations_dict["ratings"]) else: summary_tags.extend(list(res["ratings"].keys())) summary_text = separator.join(summary_tags) if summary_tags else "请选择要汇总的标签类别" # 完成处理 yield ( gr.update(interactive=True, value="开始分析"), gr.update(visible=False), general_html, char_html, rating_html, summary_text, current_lang, res, translations_dict ) except Exception as e: # 出错处理 yield ( gr.update(interactive=True, value="开始分析"), gr.update(visible=True, value=f"❌ 处理失败: {str(e)}"), "", "", "", "", current_lang, {}, {} ) def toggle_language(current_lang, tags, translations): """切换语言显示""" new_lang = "zh" if current_lang == "en" else "en" # 重新生成HTML general_html = format_tags_html(tags.get("general", {}), translations.get("general", []), "general", new_lang) char_html = format_tags_html(tags.get("characters", {}), translations.get("characters", []), "characters", new_lang) rating_html = format_tags_html(tags.get("ratings", {}), translations.get("ratings", []), "ratings", new_lang) # 更新汇总文本 current_summary = out_summary.value if hasattr(out_summary, 'value') else "" if current_summary and current_summary != "请选择要汇总的标签类别": # 需要重新生成汇总文本 summary_tags = [] separator = ", " # 这里简化,实际应该记住用户选择的分隔符 # 检查选择的类别并生成汇总 # 注意:这里只是示例,实际需要传入选择状态 for category, category_tags in tags.items(): if category_tags: if new_lang == "zh" and translations.get(category): summary_tags.extend(translations[category]) else: summary_tags.extend(list(category_tags.keys())) summary_text = separator.join(summary_tags) if summary_tags else current_summary else: summary_text = current_summary return ( new_lang, general_html, char_html, rating_html, summary_text ) def copy_summary(text): """提示复制汇总文本""" # 使用JavaScript来复制文本 copy_js = f''' ''' return gr.update(value=copy_js) # 绑定事件 btn.click( process, inputs=[img_in, gen_slider, char_slider, sum_general, sum_char, sum_rating, sum_sep, lang_state, tags_data, translations_data], outputs=[btn, processing_info, out_general, out_char, out_rating, out_summary, lang_state, tags_data, translations_data], show_progress=True ) lang_btn.click( toggle_language, inputs=[lang_state, tags_data, translations_data], outputs=[lang_state, out_general, out_char, out_rating, out_summary] ) copy_btn.click( copy_summary, inputs=[out_summary], outputs=[gr.HTML(visible=False)] ) # ------------------------------------------------------------------ # 启动 # ------------------------------------------------------------------ if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)