File size: 5,909 Bytes
0a14990
 
 
 
 
aa28bbb
872b08b
0a14990
 
 
81a0ae4
096fe3a
872b08b
096fe3a
0a14990
 
 
 
 
 
 
 
 
 
 
 
 
9a08859
 
0a14990
 
096fe3a
535151e
0a14990
 
 
 
 
 
8a2e372
096fe3a
0a14990
096fe3a
785df91
 
 
 
 
872b08b
785df91
0a14990
 
 
 
096fe3a
 
0a14990
 
 
 
 
da8b548
0a14990
 
 
872b08b
0a14990
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da8b548
096fe3a
0a14990
 
096fe3a
872b08b
da8b548
0a14990
 
872b08b
535151e
 
0a14990
 
 
 
 
aa28bbb
0a14990
 
 
aa28bbb
0a14990
535151e
0a14990
 
 
872b08b
0a14990
535151e
ea2994b
 
0a14990
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# app.py – encoder-only + masking accuracy demo for bert-beatrix-2048
# -----------------------------------------------------------------
# launch:  python app.py      (UI at http://localhost:7860)

import json, re, sys
from pathlib import Path, PurePosixPath

import gradio as gr
import spaces
import torch
from huggingface_hub import snapshot_download
from bert_handler import create_handler_from_checkpoint

# ------------------------------------------------------------------
# 0.  download repo + patch auto_map --------------------------------
REPO_ID  = "AbstractPhil/bert-beatrix-2048"
LOCAL_CK = "bert-beatrix-2048"
snapshot_download(repo_id=REPO_ID, local_dir=LOCAL_CK, local_dir_use_symlinks=False)

cfg_p = Path(LOCAL_CK) / "config.json"
with cfg_p.open() as f:
    cfg = json.load(f)
for k, v in cfg.get("auto_map", {}).items():
    if "--" in v:
        cfg["auto_map"][k] = PurePosixPath(v.split("--", 1)[1]).as_posix()
        with cfg_p.open("w") as f:
            json.dump(cfg, f, indent=2)

# ------------------------------------------------------------------
# 1.  load model / tokenizer ---------------------------------------
handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_CK)
full_model = full_model.eval().cuda()

encoder    = full_model.bert.encoder
embeddings = full_model.bert.embeddings
emb_ln     = full_model.bert.emb_ln
emb_drop   = full_model.bert.emb_drop

MASK = tokenizer.mask_token or "[MASK]"

# ------------------------------------------------------------------
# 2.  symbolic role list -------------------------------------------
SYMBOLIC_ROLES = [
    "<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
    "<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>",
    "<upper_body_clothing>", "<hair_style>", "<hair_length>", "<headwear>",
    "<texture>", "<pattern>", "<grid>", "<zone>", "<offset>",
    "<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
    "<fabric>", "<jewelry>",
]
miss = [t for t in SYMBOLIC_ROLES
        if tokenizer.convert_tokens_to_ids(t) == tokenizer.unk_token_id]
if miss:
    sys.exit(f"❌ Tokenizer missing {miss}")

# ------------------------------------------------------------------
# 3.  inference util  ----------------------------------------------
@spaces.GPU
def encode_and_trace(text: str, selected_roles: list[str]):
    # ----- 3-A. build masked version & encode original --------------
    sel_ids = {tokenizer.convert_tokens_to_ids(t) for t in selected_roles}

    # tokenised “plain” text
    plain = tokenizer(text, return_tensors="pt").to("cuda")
    ids_plain = plain.input_ids

    # make masked string (regex to avoid partial hits)
    masked_txt = text
    for tok in selected_roles:
        masked_txt = re.sub(re.escape(tok), MASK, masked_txt)

    masked = tokenizer(masked_txt, return_tensors="pt").to("cuda")
    ids_masked = masked.input_ids

    # ----- 3-B. run model on masked text ----------------------------
    with torch.no_grad():
        logits = full_model(**masked).logits[0]          # (S, V)
        preds  = logits.argmax(-1)                       # (S,)

    # ----- 3-C. gather stats per masked role ------------------------
    found_tokens, correct = [], 0
    role_flags = []
    for i, (orig_id, pred_id) in enumerate(zip(ids_plain[0], preds)):
        if orig_id.item() in sel_ids and ids_masked[0, i].item() == tokenizer.mask_token_id:
            found_tokens.append(tokenizer.convert_ids_to_tokens([orig_id])[0])
            correct += int(orig_id.item() == pred_id.item())
            role_flags.append(i)

    total = len(role_flags)
    acc   = correct / total if total else 0.0

    # ----- 3-D. encoder rep pooling for *all* selected roles --------
    with torch.no_grad():
        # embeddings -> normed reps
        x = emb_drop(emb_ln(embeddings(ids_plain)))
        attn = full_model.bert.get_extended_attention_mask(
            plain.attention_mask, x.shape[:-1]
        )
        enc = encoder(x, attention_mask=attn)            # (1,S,H)
        mask_vec = torch.tensor(
            [tid in sel_ids for tid in ids_plain[0].tolist()], device=enc.device
        )
        if mask_vec.any():
            pooled = enc[0][mask_vec].mean(0)
            norm   = f"{pooled.norm().item():.4f}"
        else:
            norm   = "0.0000"

    tokens_str = ", ".join(found_tokens) or "(none)"
    return tokens_str, norm, f"{acc*100:.1f}%"

# ------------------------------------------------------------------
# 4.  gradio UI  ----------------------------------------------------
def app():
    with gr.Blocks(title="🧠 Symbolic Encoder Inspector") as demo:
        gr.Markdown(
            "## 🧠 Symbolic Encoder Inspector  \n"
            "1. Model side: we *mask* every chosen role token, run the LM, and report how often it recovers the original.  \n"
            "2. Encoder side: we also pool hidden-state vectors for those roles and give their mean L2-norm."
        )
        with gr.Row():
            with gr.Column():
                txt = gr.Textbox(
                    label="Input with Symbolic Tokens",
                    lines=3,
                    placeholder="Example: A <subject> wearing <upper_body_clothing> …",
                )
                roles = gr.CheckboxGroup(
                    choices=SYMBOLIC_ROLES,
                    value=SYMBOLIC_ROLES,            # <- all pre-selected
                    label="Roles to mask & trace",
                )
                run = gr.Button("Run")
            with gr.Column():
                o_tok  = gr.Textbox(label="Masked-role tokens found")
                o_norm = gr.Textbox(label="Mean hidden-state L2-norm")
                o_acc  = gr.Textbox(label="Recovery accuracy")

        run.click(encode_and_trace, [txt, roles], [o_tok, o_norm, o_acc])
    return demo

if __name__ == "__main__":
    app().launch()