File size: 6,957 Bytes
aa28bbb
 
 
 
 
 
 
872b08b
 
81a0ae4
535151e
aa28bbb
81a0ae4
872b08b
096fe3a
872b08b
535151e
096fe3a
aa28bbb
096fe3a
aa28bbb
 
872b08b
 
 
81a0ae4
aa28bbb
9a08859
81a0ae4
 
aa28bbb
872b08b
9a08859
 
 
415afa1
9a08859
aa28bbb
9a08859
415afa1
9a08859
415afa1
872b08b
9a08859
aa28bbb
9a08859
 
 
aa28bbb
9a08859
aa28bbb
096fe3a
535151e
096fe3a
 
 
 
8a2e372
872b08b
096fe3a
aa28bbb
096fe3a
 
785df91
 
 
 
 
872b08b
785df91
aa28bbb
 
096fe3a
872b08b
 
096fe3a
 
aa28bbb
096fe3a
 
 
415afa1
aa28bbb
 
 
 
 
 
415afa1
096fe3a
 
aa28bbb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
096fe3a
aa28bbb
b60c583
aa28bbb
096fe3a
872b08b
096fe3a
872b08b
096fe3a
785df91
096fe3a
872b08b
aa28bbb
 
 
 
 
872b08b
 
535151e
 
aa28bbb
 
 
 
 
 
 
 
 
 
535151e
aa28bbb
 
 
 
 
 
 
 
 
 
872b08b
535151e
ea2994b
872b08b
ea2994b
81a0ae4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
# app.py – encoder-only demo for bert-beatrix-2048 + role-probe
# ------------------------------------------------------------
# launch:  python app.py
# (gradio UI appears at http://localhost:7860)

import json, sys
from pathlib import Path, PurePosixPath

import gradio as gr
import spaces
import torch
import torch.nn.functional as F
from huggingface_hub import snapshot_download

from bert_handler import create_handler_from_checkpoint


# ------------------------------------------------------------------
# 0.  Download & patch config.json  --------------------------------
# ------------------------------------------------------------------
REPO_ID    = "AbstractPhil/bert-beatrix-2048"
LOCAL_DIR  = "bert-beatrix-2048"              # local cache dir

snapshot_download(
    repo_id=REPO_ID,
    revision="main",
    local_dir=LOCAL_DIR,
    local_dir_use_symlinks=False,
)

cfg_path = Path(LOCAL_DIR) / "config.json"
with cfg_path.open() as f:
    cfg = json.load(f)

auto_map = cfg.get("auto_map", {})
patched  = False
for k, v in auto_map.items():
    if "--" in v:                             # e.g.  "repo--module.Class"
        auto_map[k] = PurePosixPath(v.split("--", 1)[1]).as_posix()
        patched = True

if patched:
    with cfg_path.open("w") as f:
        json.dump(cfg, f, indent=2)
    print("🛠️  Patched config.json → auto_map fixed.")


# ------------------------------------------------------------------
# 1.  Model / tokenizer  -------------------------------------------
# ------------------------------------------------------------------
handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_DIR)
full_model = full_model.eval().cuda()

encoder     = full_model.bert.encoder
embeddings  = full_model.bert.embeddings
emb_ln      = full_model.bert.emb_ln
emb_drop    = full_model.bert.emb_drop


# ------------------------------------------------------------------
# 2.  Symbolic token set  ------------------------------------------
# ------------------------------------------------------------------
SYMBOLIC_ROLES = [
    "<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
    "<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>",
    "<upper_body_clothing>", "<hair_style>", "<hair_length>", "<headwear>",
    "<texture>", "<pattern>", "<grid>", "<zone>", "<offset>",
    "<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
    "<fabric>", "<jewelry>",
]
ROLE_ID = {tok: tokenizer.convert_tokens_to_ids(tok) for tok in SYMBOLIC_ROLES}
missing = [tok for tok, tid in ROLE_ID.items() if tid == tokenizer.unk_token_id]
if missing:
    sys.exit(f"❌ Tokenizer is missing {missing}")


# ------------------------------------------------------------------
# 3.  Encoder-only + role-similarity probe  ------------------------
# ------------------------------------------------------------------
@spaces.GPU
def encode_and_trace(text: str, selected_roles: list[str]):
    """
    For each *selected* role:
        • find the contextual token whose hidden state is most similar to that
          role’s own embedding (cosine similarity)
        • return “role → token (sim)”, using tokens even when the prompt
          contained no <role> markers at all.
    Also keeps the older diagnostics.
    """
    with torch.no_grad():
        batch = tokenizer(text, return_tensors="pt").to("cuda")
        ids, mask = batch.input_ids, batch.attention_mask            # (1, S)

        # ---------- encoder ----------
        x   = emb_drop(emb_ln(embeddings(ids)))
        msk = full_model.bert.get_extended_attention_mask(mask, x.shape[:-1])
        h   = encoder(x, attention_mask=msk).squeeze(0)              # (S, H)

        # L2-normalise hidden states once
        h_norm = F.normalize(h, dim=-1)                              # (S, H)

        # ---------- probe each selected role -----------------------
        matches = []
        for role in selected_roles:
            role_vec = embeddings.word_embeddings.weight[ROLE_ID[role]].to(h.device)
            role_vec = F.normalize(role_vec, dim=-1)                 # (H)

            sims = (h_norm @ role_vec)                               # (S)
            best_idx = int(sims.argmax().item())
            best_sim = float(sims[best_idx])

            match_tok = tokenizer.convert_ids_to_tokens(int(ids[0, best_idx]))
            matches.append(f"{role}{match_tok} ({best_sim:.2f})")

        match_str = ", ".join(matches) if matches else "(no roles selected)"

        # ---------- string-match diagnostics -----------------------
        present = [tok for tok_id, tok in zip(ids[0].tolist(),
                                              tokenizer.convert_ids_to_tokens(ids[0]))
                   if tok in selected_roles]
        present_str = ", ".join(present) or "(none)"
        count = len(present)

        # ---------- hidden-state norm of *explicit* role tokens ----
        if count:
            exp_mask = torch.tensor([tid in ROLE_ID.values() for tid in ids[0]], device=h.device)
            norm_val = f"{h[exp_mask].mean(0).norm().item():.4f}"
        else:
            norm_val = "0.0000"

        return present_str, match_str, norm_val, count


# ------------------------------------------------------------------
# 4.  Gradio UI  ----------------------------------------------------
# ------------------------------------------------------------------
def build_interface():
    with gr.Blocks(title="🧠 Symbolic Encoder Inspector") as demo:
        gr.Markdown(
            "## 🧠 Symbolic Encoder Inspector  \n"
            "Select one or more symbolic *roles* on the left.  "
            "The tool shows which regular tokens (if any) the model thinks "
            "best fit each role — even when your text doesn’t contain the "
            "explicit `<role>` marker."
        )

        with gr.Row():
            with gr.Column():
                txt = gr.Textbox(
                    label="Input text",
                    lines=3,
                    placeholder="Example: A small child in bright red boots jumps over a muddy puddle…",
                )
                roles = gr.CheckboxGroup(
                    choices=SYMBOLIC_ROLES,
                    label="Roles to probe",
                )
                btn = gr.Button("Run encoder probe")
            with gr.Column():
                out_present = gr.Textbox(label="Explicit role tokens found")
                out_match   = gr.Textbox(label="Role → Best-Match Token (cos θ)")
                out_norm    = gr.Textbox(label="Mean hidden-state norm (explicit)")
                out_count   = gr.Textbox(label="# explicit role tokens")

        btn.click(
            encode_and_trace,
            inputs=[txt, roles],
            outputs=[out_present, out_match, out_norm, out_count],
        )

    return demo


if __name__ == "__main__":
    build_interface().launch()