File size: 4,075 Bytes
096fe3a
 
 
81a0ae4
535151e
ea2994b
81a0ae4
096fe3a
535151e
096fe3a
81a0ae4
096fe3a
81a0ae4
 
 
 
 
 
 
 
096fe3a
535151e
81a0ae4
096fe3a
 
 
 
8a2e372
096fe3a
81a0ae4
096fe3a
 
785df91
 
 
 
 
 
 
 
81a0ae4
 
 
096fe3a
81a0ae4
096fe3a
 
 
 
 
 
 
 
81a0ae4
096fe3a
81a0ae4
096fe3a
81a0ae4
 
096fe3a
81a0ae4
 
096fe3a
81a0ae4
 
 
 
096fe3a
81a0ae4
096fe3a
 
81a0ae4
 
 
096fe3a
 
 
 
 
785df91
096fe3a
81a0ae4
535151e
 
81a0ae4
 
 
535151e
81a0ae4
 
 
 
535151e
ea2994b
 
81a0ae4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# app.py – encoder-only demo for bert-beatrix-2048
# -----------------------------------------------
# launch:  python app.py
import spaces
import torch
import gradio as gr
from huggingface_hub import snapshot_download
from bert_handler import create_handler_from_checkpoint

# ------------------------------------------------------------------
# 1.  Download *once* and load locally  -----------------------------
# ------------------------------------------------------------------
LOCAL_CKPT = snapshot_download(
    repo_id="AbstractPhil/bert-beatrix-2048",
    revision="main",
    local_dir="bert-beatrix-2048",
    local_dir_use_symlinks=False
)

handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_CKPT)
full_model = full_model.eval().cuda()

# --- pull encoder & embeddings only --------------------------------
encoder     = full_model.bert.encoder
embeddings  = full_model.bert.embeddings
emb_ln      = full_model.bert.emb_ln
emb_drop    = full_model.bert.emb_drop

# ------------------------------------------------------------------
# 2.  Symbolic token list  ------------------------------------------
# ------------------------------------------------------------------
SYMBOLIC_ROLES = [
    "<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
    "<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>",
    "<upper_body_clothing>", "<hair_style>", "<hair_length>", "<headwear>",
    "<texture>", "<pattern>", "<grid>", "<zone>", "<offset>",
    "<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
    "<fabric>", "<jewelry>"
]

# Sanity-check: every role must be known by the tokenizer
missing = [t for t in SYMBOLIC_ROLES
           if tokenizer.convert_tokens_to_ids(t) == tokenizer.unk_token_id]
if missing:
    raise RuntimeError(f"Tokenizer is missing special tokens: {missing}")

# ------------------------------------------------------------------
# 3.  Encoder-only inference util  ----------------------------------
# ------------------------------------------------------------------
@spaces.GPU
def encode_and_trace(text: str, selected_roles: list[str]):
    with torch.no_grad():
        batch = tokenizer(text, return_tensors="pt").to("cuda")
        ids, mask = batch.input_ids, batch.attention_mask

        x = emb_drop(emb_ln(embeddings(ids)))

        ext_mask = full_model.bert.get_extended_attention_mask(mask, x.shape[:-1])
        enc = encoder(x, attention_mask=ext_mask)               # (1, S, H)

        want = {tokenizer.convert_tokens_to_ids(t) for t in selected_roles}
        keep = torch.tensor([tid in want for tid in ids[0]], device=enc.device)

        found = [tokenizer.convert_ids_to_tokens([tid])[0] for tid in ids[0] if tid in want]
        if keep.any():
            vec = enc[0][keep].mean(0)
            norm = f"{vec.norm().item():.4f}"
        else:
            norm = "0.0000"

        return {
            "Symbolic Tokens": ", ".join(found) or "(none)",
            "Mean Norm": norm,
            "Token Count": int(keep.sum().item()),
        }

# ------------------------------------------------------------------
# 4.  Gradio UI  -----------------------------------------------------
# ------------------------------------------------------------------
def build_interface():
    with gr.Blocks(title="🧠 Symbolic Encoder Inspector") as demo:
        gr.Markdown("## 🧠 Symbolic Encoder Inspector")
        with gr.Row():
            with gr.Column():
                txt = gr.Textbox(label="Input with Symbolic Tokens", lines=3)
                chk = gr.CheckboxGroup(choices=SYMBOLIC_ROLES, label="Trace these roles")
                btn = gr.Button("Encode & Trace")
            with gr.Column():
                out_tok  = gr.Textbox(label="Symbolic Tokens Found")
                out_norm = gr.Textbox(label="Mean Norm")
                out_cnt  = gr.Textbox(label="Token Count")
        btn.click(encode_and_trace, [txt, chk], [out_tok, out_norm, out_cnt])
    return demo

if __name__ == "__main__":
    build_interface().launch()