Spaces:
Sleeping
Sleeping
File size: 6,620 Bytes
086766e b403fe7 086766e 7ec5b17 086766e b403fe7 7ec5b17 086766e b403fe7 7ec5b17 b403fe7 086766e c24087d b403fe7 7ec5b17 b403fe7 7ec5b17 b403fe7 1676c6e b403fe7 1676c6e b403fe7 8da1280 b403fe7 8da1280 b403fe7 598cad3 c24087d b403fe7 eb1e40e b403fe7 eb1e40e b403fe7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import gradio as gr
import onnxruntime as ort
import numpy as np
from PIL import Image
import json
from huggingface_hub import hf_hub_download
# Load model and metadata at startup (same as before)
MODEL_REPO = "AngelBottomless/camie-tagger-onnxruntime"
MODEL_FILE = "camie_tagger_initial.onnx"
META_FILE = "metadata.json"
model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE, cache_dir=".")
meta_path = hf_hub_download(repo_id=MODEL_REPO, filename=META_FILE, cache_dir=".")
session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
metadata = json.load(open(meta_path, "r", encoding="utf-8"))
# Preprocessing function (same as before)
def preprocess_image(pil_image: Image.Image) -> np.ndarray:
img = pil_image.convert("RGB").resize((512, 512))
arr = np.array(img).astype(np.float32) / 255.0
arr = np.transpose(arr, (2, 0, 1))
arr = np.expand_dims(arr, 0)
return arr
# Inference function with output format option
def tag_image(pil_image: Image.Image, output_format: str) -> str:
# Run model inference
input_tensor = preprocess_image(pil_image)
input_name = session.get_inputs()[0].name
initial_logits, refined_logits = session.run(None, {input_name: input_tensor})
probs = 1 / (1 + np.exp(-refined_logits))
probs = probs[0]
idx_to_tag = metadata["idx_to_tag"]
tag_to_category = metadata.get("tag_to_category", {})
category_thresholds = metadata.get("category_thresholds", {})
default_threshold = 0.35
results_by_cat = {} # to store tags per category (for verbose output)
artist_tags_with_probs = []
character_tags_with_probs = []
general_tags_with_probs = []
all_artist_tags_probs = [] # Store all artist tags and their probabilities
# Collect tags above thresholds
for idx, prob in enumerate(probs):
tag = idx_to_tag[str(idx)]
cat = tag_to_category.get(tag, "unknown")
if cat == 'artist':
all_artist_tags_probs.append((tag, float(prob))) # Store all artist tags
thresh = category_thresholds.get(cat, default_threshold)
if float(prob) >= thresh:
# add to category dictionary
results_by_cat.setdefault(cat, []).append((tag, float(prob)))
if cat == 'artist':
artist_tags_with_probs.append((tag, float(prob)))
elif cat == 'character':
character_tags_with_probs.append((tag, float(prob)))
elif cat == 'general':
general_tags_with_probs.append((tag, float(prob)))
if output_format == "Prompt-style Tags":
artist_tags_with_probs.sort(key=lambda x: x[1], reverse=True)
character_tags_with_probs.sort(key=lambda x: x[1], reverse=True)
general_tags_with_probs.sort(key=lambda x: x[1], reverse=True)
artist_prompt_tags = [tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)") for tag, prob in artist_tags_with_probs]
character_prompt_tags = [tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)") for tag, prob in character_tags_with_probs]
general_prompt_tags = [tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)") for tag, prob in general_tags_with_probs]
prompt_tags = artist_prompt_tags + character_prompt_tags + general_prompt_tags
# Ensure at least one artist tag if any artist tags were predicted at all, even below threshold
if not artist_prompt_tags and all_artist_tags_probs:
best_artist_tag, best_artist_prob = max(all_artist_tags_probs, key=lambda item: item[1])
prompt_tags = [best_artist_tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)")] + prompt_tags
if not prompt_tags:
return "No tags predicted."
return ", ".join(prompt_tags)
else: # Detailed output
if not results_by_cat:
return "No tags predicted for this image."
# Ensure artist tag in detailed output even if below threshold
if 'artist' not in results_by_cat and all_artist_tags_probs:
best_artist_tag, best_artist_prob = max(all_artist_tags_probs, key=lambda item: item[1])
results_by_cat['artist'] = [(best_artist_tag, best_artist_prob)]
lines = []
lines.append("**Predicted Tags by Category:** \n") # (Markdown newline: two spaces + newline)
for cat, tag_list in results_by_cat.items():
# sort tags in this category by probability descending
tag_list.sort(key=lambda x: x[1], reverse=True)
lines.append(f"**Category: {cat}** – {len(tag_list)} tags")
for tag, prob in tag_list:
tag_pretty = tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)") # Escape parentheses here with raw string
lines.append(f"- {tag_pretty} (Prob: {prob:.3f})")
lines.append("") # blank line between categories
return "\n".join(lines)
# Build the Gradio Blocks UI
demo = gr.Blocks(theme="gradio/soft") # using a built-in theme for nicer styling
with demo:
# Header Section
gr.Markdown("# 🏷️ Camie Tagger – Anime Image Tagging\nThis demo uses an ONNX model of Camie Tagger to label anime illustrations with tags. Upload an image and click **Tag Image** to see predictions.")
gr.Markdown("*(Note: The model will predict a large number of tags across categories like character, general, artist, etc. You can choose a concise prompt-style output or a detailed category-wise breakdown.)*")
# Input/Output Section
with gr.Row():
# Left column: Image input and format selection
with gr.Column():
image_in = gr.Image(type="pil", label="Input Image")
format_choice = gr.Radio(choices=["Prompt-style Tags", "Detailed Output"], value="Prompt-style Tags", label="Output Format")
tag_button = gr.Button("🔍 Tag Image")
# Right column: Output display
with gr.Column():
output_box = gr.Markdown("") # will display the result in Markdown (supports bold, lists, etc.)
# Link the button click to the function
tag_button.click(fn=tag_image, inputs=[image_in, format_choice], outputs=output_box)
# Footer/Info
gr.Markdown("----\n**Model:** [Camie Tagger ONNX](https://huggingface.co/AngelBottomless/camie-tagger-onnxruntime) • **Base Model:** Camais03/camie-tagger (61% F1 on 70k tags) • **ONNX Runtime:** for efficient CPU inference:contentReference[oaicite:6]{index=6} • *Demo built with Gradio Blocks.*")
# Launch the app (automatically handled in Spaces)
demo.launch() |