File size: 6,213 Bytes
086766e
 
 
 
 
 
 
7ec5b17
086766e
7ec5b17
086766e
 
 
 
 
7ec5b17
086766e
 
7ec5b17
 
 
086766e
 
7ec5b17
 
 
086766e
 
 
7ec5b17
 
 
 
 
c24087d
7ec5b17
1676c6e
d13085d
 
c24087d
d13085d
 
086766e
 
 
c24087d
 
7ec5b17
 
 
 
1676c6e
 
d13085d
 
 
 
1676c6e
7ec5b17
1676c6e
d13085d
 
 
 
 
 
1676c6e
d13085d
1676c6e
c24087d
 
 
 
 
 
7ec5b17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b400e8
086766e
7ec5b17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1676c6e
086766e
7ec5b17
1676c6e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import gradio as gr
import onnxruntime as ort
import numpy as np
from PIL import Image
import json
from huggingface_hub import hf_hub_download

# Load model and metadata at startup (same as before)
MODEL_REPO = "AngelBottomless/camie-tagger-onnxruntime"
MODEL_FILE = "camie_tagger_initial.onnx"
META_FILE = "metadata.json"
model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE, cache_dir=".")
meta_path = hf_hub_download(repo_id=MODEL_REPO, filename=META_FILE, cache_dir=".")
session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
metadata = json.load(open(meta_path, "r", encoding="utf-8"))
# Preprocessing function (same as before)
def preprocess_image(pil_image: Image.Image) -> np.ndarray:
    img = pil_image.convert("RGB").resize((512, 512))
    arr = np.array(img).astype(np.float32) / 255.0
    arr = np.transpose(arr, (2, 0, 1))
    arr = np.expand_dims(arr, 0)
    return arr

# Inference function with output format option
def tag_image(pil_image: Image.Image, output_format: str) -> str:
    # Run model inference
    input_tensor = preprocess_image(pil_image)
    input_name = session.get_inputs()[0].name
    initial_logits, refined_logits = session.run(None, {input_name: input_tensor})
    probs = 1 / (1 + np.exp(-refined_logits))
    probs = probs[0]
    idx_to_tag = metadata["idx_to_tag"]
    tag_to_category = metadata.get("tag_to_category", {})
    category_thresholds = metadata.get("category_thresholds", {})
    default_threshold = 0.35
    results_by_cat = {}  # to store tags per category (for verbose output)
    artist_tags_with_probs = []
    character_tags_with_probs = []
    general_tags_with_probs = []
    all_artist_tags_probs = [] # Store all artist tags and their probabilities

    # Collect tags above thresholds
    for idx, prob in enumerate(probs):
        tag = idx_to_tag[str(idx)]
        cat = tag_to_category.get(tag, "unknown")
        if cat == 'artist':
            all_artist_tags_probs.append((tag, float(prob))) # Store all artist tags
        thresh = category_thresholds.get(cat, default_threshold)
        if float(prob) >= thresh:
            # add to category dictionary
            results_by_cat.setdefault(cat, []).append((tag, float(prob)))
            if cat == 'artist':
                artist_tags_with_probs.append((tag, float(prob)))
            elif cat == 'character':
                character_tags_with_probs.append((tag, float(prob)))
            elif cat == 'general':
                general_tags_with_probs.append((tag, float(prob)))

    if output_format == "Prompt-style Tags":
        artist_tags_with_probs.sort(key=lambda x: x[1], reverse=True)
        character_tags_with_probs.sort(key=lambda x: x[1], reverse=True)
        general_tags_with_probs.sort(key=lambda x: x[1], reverse=True)

        artist_prompt_tags = [tag.replace("_", " ").replace("(", "\\(").replace(")", "\\)") for tag, prob in artist_tags_with_probs]
        character_prompt_tags = [tag.replace("_", " ").replace("(", "\\(").replace(")", "\\)") for tag, prob in character_tags_with_probs]
        general_prompt_tags = [tag.replace("_", " ").replace("(", "\\(").replace(")", "\\)") for tag, prob in general_tags_with_probs]

        prompt_tags = artist_prompt_tags + character_prompt_tags + general_prompt_tags

        # Ensure at least one artist tag if any artist tags were predicted at all, even below threshold
        if not artist_prompt_tags and all_artist_tags_probs:
            best_artist_tag, best_artist_prob = max(all_artist_tags_probs, key=lambda item: item[1])
            prompt_tags = [best_artist_tag.replace("_", " ").replace("(", "\\(").replace(")", "\\)")] + prompt_tags


        if not prompt_tags:
            return "No tags predicted."
        return ", ".join(prompt_tags)
    else:  # Detailed output
        if not results_by_cat:
            return "No tags predicted for this image."
        lines = []
        lines.append("**Predicted Tags by Category:**  \n")  # (Markdown newline: two spaces + newline)
        for cat, tag_list in results_by_cat.items():
            # sort tags in this category by probability descending
            tag_list.sort(key=lambda x: x[1], reverse=True)
            lines.append(f"**Category: {cat}** – {len(tag_list)} tags")
            for tag, prob in tag_list:
                tag_pretty = tag.replace("_", " ")
                lines.append(f"- {tag_pretty} (Prob: {prob:.3f})")
            lines.append("")  # blank line between categories
        return "\n".join(lines)

# Build the Gradio Blocks UI
demo = gr.Blocks(theme="gradio/soft")  # using a built-in theme for nicer styling

with demo:
    # Header Section
    gr.Markdown("# 🏷️ Camie Tagger – Anime Image Tagging\nThis demo uses an ONNX model of Camie Tagger to label anime illustrations with tags. Upload an image and click **Tag Image** to see predictions.")
    gr.Markdown("*(Note: The model will predict a large number of tags across categories like character, general, artist, etc. You can choose a concise prompt-style output or a detailed category-wise breakdown.)*")
    # Input/Output Section
    with gr.Row():
        # Left column: Image input and format selection
        with gr.Column():
            image_in = gr.Image(type="pil", label="Input Image")
            format_choice = gr.Radio(choices=["Prompt-style Tags", "Detailed Output"], value="Prompt-style Tags", label="Output Format")
            tag_button = gr.Button("🔍 Tag Image")
        # Right column: Output display
        with gr.Column():
            output_box = gr.Markdown("")  # will display the result in Markdown (supports bold, lists, etc.)
    # Link the button click to the function
    tag_button.click(fn=tag_image, inputs=[image_in, format_choice], outputs=output_box)
    # Footer/Info
    gr.Markdown("----\n**Model:** [Camie Tagger ONNX](https://huggingface.co/AngelBottomless/camie-tagger-onnxruntime)   •   **Base Model:** Camais03/camie-tagger (61% F1 on 70k tags)   •   **ONNX Runtime:** for efficient CPU inference​:contentReference[oaicite:6]{index=6}   •   *Demo built with Gradio Blocks.*")

# Launch the app (automatically handled in Spaces)
demo.launch()