Spaces:
Running
Running
File size: 6,213 Bytes
086766e 7ec5b17 086766e 7ec5b17 086766e 7ec5b17 086766e 7ec5b17 086766e 7ec5b17 086766e 7ec5b17 c24087d 7ec5b17 1676c6e d13085d c24087d d13085d 086766e c24087d 7ec5b17 1676c6e d13085d 1676c6e 7ec5b17 1676c6e d13085d 1676c6e d13085d 1676c6e c24087d 7ec5b17 5b400e8 086766e 7ec5b17 1676c6e 086766e 7ec5b17 1676c6e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import gradio as gr
import onnxruntime as ort
import numpy as np
from PIL import Image
import json
from huggingface_hub import hf_hub_download
# Load model and metadata at startup (same as before)
MODEL_REPO = "AngelBottomless/camie-tagger-onnxruntime"
MODEL_FILE = "camie_tagger_initial.onnx"
META_FILE = "metadata.json"
model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE, cache_dir=".")
meta_path = hf_hub_download(repo_id=MODEL_REPO, filename=META_FILE, cache_dir=".")
session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
metadata = json.load(open(meta_path, "r", encoding="utf-8"))
# Preprocessing function (same as before)
def preprocess_image(pil_image: Image.Image) -> np.ndarray:
img = pil_image.convert("RGB").resize((512, 512))
arr = np.array(img).astype(np.float32) / 255.0
arr = np.transpose(arr, (2, 0, 1))
arr = np.expand_dims(arr, 0)
return arr
# Inference function with output format option
def tag_image(pil_image: Image.Image, output_format: str) -> str:
# Run model inference
input_tensor = preprocess_image(pil_image)
input_name = session.get_inputs()[0].name
initial_logits, refined_logits = session.run(None, {input_name: input_tensor})
probs = 1 / (1 + np.exp(-refined_logits))
probs = probs[0]
idx_to_tag = metadata["idx_to_tag"]
tag_to_category = metadata.get("tag_to_category", {})
category_thresholds = metadata.get("category_thresholds", {})
default_threshold = 0.35
results_by_cat = {} # to store tags per category (for verbose output)
artist_tags_with_probs = []
character_tags_with_probs = []
general_tags_with_probs = []
all_artist_tags_probs = [] # Store all artist tags and their probabilities
# Collect tags above thresholds
for idx, prob in enumerate(probs):
tag = idx_to_tag[str(idx)]
cat = tag_to_category.get(tag, "unknown")
if cat == 'artist':
all_artist_tags_probs.append((tag, float(prob))) # Store all artist tags
thresh = category_thresholds.get(cat, default_threshold)
if float(prob) >= thresh:
# add to category dictionary
results_by_cat.setdefault(cat, []).append((tag, float(prob)))
if cat == 'artist':
artist_tags_with_probs.append((tag, float(prob)))
elif cat == 'character':
character_tags_with_probs.append((tag, float(prob)))
elif cat == 'general':
general_tags_with_probs.append((tag, float(prob)))
if output_format == "Prompt-style Tags":
artist_tags_with_probs.sort(key=lambda x: x[1], reverse=True)
character_tags_with_probs.sort(key=lambda x: x[1], reverse=True)
general_tags_with_probs.sort(key=lambda x: x[1], reverse=True)
artist_prompt_tags = [tag.replace("_", " ").replace("(", "\\(").replace(")", "\\)") for tag, prob in artist_tags_with_probs]
character_prompt_tags = [tag.replace("_", " ").replace("(", "\\(").replace(")", "\\)") for tag, prob in character_tags_with_probs]
general_prompt_tags = [tag.replace("_", " ").replace("(", "\\(").replace(")", "\\)") for tag, prob in general_tags_with_probs]
prompt_tags = artist_prompt_tags + character_prompt_tags + general_prompt_tags
# Ensure at least one artist tag if any artist tags were predicted at all, even below threshold
if not artist_prompt_tags and all_artist_tags_probs:
best_artist_tag, best_artist_prob = max(all_artist_tags_probs, key=lambda item: item[1])
prompt_tags = [best_artist_tag.replace("_", " ").replace("(", "\\(").replace(")", "\\)")] + prompt_tags
if not prompt_tags:
return "No tags predicted."
return ", ".join(prompt_tags)
else: # Detailed output
if not results_by_cat:
return "No tags predicted for this image."
lines = []
lines.append("**Predicted Tags by Category:** \n") # (Markdown newline: two spaces + newline)
for cat, tag_list in results_by_cat.items():
# sort tags in this category by probability descending
tag_list.sort(key=lambda x: x[1], reverse=True)
lines.append(f"**Category: {cat}** – {len(tag_list)} tags")
for tag, prob in tag_list:
tag_pretty = tag.replace("_", " ")
lines.append(f"- {tag_pretty} (Prob: {prob:.3f})")
lines.append("") # blank line between categories
return "\n".join(lines)
# Build the Gradio Blocks UI
demo = gr.Blocks(theme="gradio/soft") # using a built-in theme for nicer styling
with demo:
# Header Section
gr.Markdown("# 🏷️ Camie Tagger – Anime Image Tagging\nThis demo uses an ONNX model of Camie Tagger to label anime illustrations with tags. Upload an image and click **Tag Image** to see predictions.")
gr.Markdown("*(Note: The model will predict a large number of tags across categories like character, general, artist, etc. You can choose a concise prompt-style output or a detailed category-wise breakdown.)*")
# Input/Output Section
with gr.Row():
# Left column: Image input and format selection
with gr.Column():
image_in = gr.Image(type="pil", label="Input Image")
format_choice = gr.Radio(choices=["Prompt-style Tags", "Detailed Output"], value="Prompt-style Tags", label="Output Format")
tag_button = gr.Button("🔍 Tag Image")
# Right column: Output display
with gr.Column():
output_box = gr.Markdown("") # will display the result in Markdown (supports bold, lists, etc.)
# Link the button click to the function
tag_button.click(fn=tag_image, inputs=[image_in, format_choice], outputs=output_box)
# Footer/Info
gr.Markdown("----\n**Model:** [Camie Tagger ONNX](https://huggingface.co/AngelBottomless/camie-tagger-onnxruntime) • **Base Model:** Camais03/camie-tagger (61% F1 on 70k tags) • **ONNX Runtime:** for efficient CPU inference:contentReference[oaicite:6]{index=6} • *Demo built with Gradio Blocks.*")
# Launch the app (automatically handled in Spaces)
demo.launch() |