File size: 6,927 Bytes
da8b548
 
aa28bbb
da8b548
 
aa28bbb
872b08b
da8b548
 
81a0ae4
872b08b
096fe3a
872b08b
096fe3a
da8b548
096fe3a
da8b548
 
872b08b
 
 
81a0ae4
ed080e6
9a08859
81a0ae4
 
ed080e6
da8b548
 
 
9a08859
da8b548
9a08859
da8b548
 
 
 
9a08859
 
da8b548
9a08859
ed080e6
096fe3a
535151e
da8b548
096fe3a
 
da8b548
096fe3a
 
8a2e372
096fe3a
da8b548
096fe3a
 
785df91
 
 
 
 
872b08b
785df91
da8b548
 
096fe3a
da8b548
872b08b
da8b548
 
096fe3a
 
da8b548
 
 
 
 
 
 
 
 
 
 
096fe3a
da8b548
 
 
 
 
 
 
 
872b08b
096fe3a
da8b548
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
096fe3a
da8b548
 
096fe3a
872b08b
da8b548
 
872b08b
535151e
 
da8b548
aa28bbb
da8b548
 
 
aa28bbb
da8b548
535151e
da8b548
 
 
872b08b
da8b548
535151e
ea2994b
 
da8b548
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
# app.py – encoder-only demo + pool-and-test prototype
# ----------------------------------------------------
# launch:  python app.py
# UI: http://localhost:7860
import json, re, sys, math
from pathlib import Path, PurePosixPath

import torch, torch.nn.functional as F
import gradio as gr, spaces
from huggingface_hub import snapshot_download

from bert_handler import create_handler_from_checkpoint

# ------------------------------------------------------------------
# 0. One-time patch of auto_map in config.json
# ------------------------------------------------------------------
REPO_ID   = "AbstractPhil/bert-beatrix-2048"
LOCAL_CKPT = "bert-beatrix-2048"

snapshot_download(
    repo_id=REPO_ID,
    revision="main",
    local_dir=LOCAL_CKPT,
    local_dir_use_symlinks=False,
)

cfg_path = Path(LOCAL_CKPT) / "config.json"
cfg      = json.loads(cfg_path.read_text())
auto_map = cfg.get("auto_map", {})
changed  = False
for k, v in auto_map.items():
    if "--" in v:                         # strip “repo--”
        auto_map[k] = PurePosixPath(v.split("--", 1)[1]).as_posix()
        changed = True
if changed:
    cfg_path.write_text(json.dumps(cfg, indent=2))
    print("🛠️  Patched config.json → auto_map points to local modules")

# ------------------------------------------------------------------
# 1. Load model + tokenizer with BERTHandler
# ------------------------------------------------------------------
handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_CKPT)
full_model = full_model.eval().cuda()

# pull encoder & embedding stack
encoder     = full_model.bert.encoder
embeddings  = full_model.bert.embeddings
emb_weight  = embeddings.word_embeddings.weight  # <- correct tensor
emb_ln      = full_model.bert.emb_ln
emb_drop    = full_model.bert.emb_drop

# ------------------------------------------------------------------
# 2. Symbolic roles
# ------------------------------------------------------------------
SYMBOLIC_ROLES = [
    "<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
    "<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>",
    "<upper_body_clothing>", "<hair_style>", "<hair_length>", "<headwear>",
    "<texture>", "<pattern>", "<grid>", "<zone>", "<offset>",
    "<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
    "<fabric>", "<jewelry>",
]
missing = [t for t in SYMBOLIC_ROLES
           if tokenizer.convert_tokens_to_ids(t) == tokenizer.unk_token_id]
if missing:
    sys.exit(f"❌ Tokenizer missing {missing}")

MASK_ID   = tokenizer.mask_token_id
MASK_TOK  = tokenizer.mask_token

# ------------------------------------------------------------------
# helpers -----------------------------------------------------------
def contextual_vectors(ids, mask):
    """run through embedding→encoder, return (S,H) hidden states"""
    x = emb_drop(emb_ln(embeddings(ids)))                    # (1,S,H)
    ext = full_model.bert.get_extended_attention_mask(mask, x.shape[:-1])
    return encoder(x, attention_mask=ext).squeeze(0)         # (S,H)

def pool_accuracy(ids, mask, pool_positions):
    """mask positions in pool, predict, calc accuracy"""
    masked = ids.clone()
    masked[0, pool_positions] = MASK_ID
    with torch.no_grad():
        logits = full_model(masked, attention_mask=mask).logits[0]
    preds  = logits.argmax(-1)
    gold   = ids.squeeze(0)
    correct = (preds[pool_positions] == gold[pool_positions]).sum().item()
    return correct / len(pool_positions) if pool_positions else 0.0

# cosine utility
def cos(a, b): return F.cosine_similarity(a, b, dim=-1, eps=1e-8).item()

# ------------------------------------------------------------------
# 3. Core routine ---------------------------------------------------
@spaces.GPU
def encode_and_trace(text: str, picked_roles: list[str]):
    # -------- tokenise ----------
    batch = tokenizer(text, return_tensors="pt").to("cuda")
    ids, attn = batch.input_ids, batch.attention_mask
    hid = contextual_vectors(ids, attn)                      # (S,H)

    # -------- decide which roles we analyse ----------
    present = {tid: pos for pos, tid in enumerate(ids[0].tolist())
               if tid in {tokenizer.convert_tokens_to_ids(r) for r in SYMBOLIC_ROLES}}
    if picked_roles:
        present = {tid: pos for tid, pos in present.items()
                   if tokenizer.convert_ids_to_tokens([tid])[0] in picked_roles}
    if not present:
        return "No symbolic tokens in sentence", "", ""

    # -------- similarity scores ----------
    sims = []
    for tid, pos in present.items():
        rvec = emb_weight[tid]               # static embedding
        cvec = hid[pos]                      # contextual
        sims.append((cos(cvec, rvec), tid, pos))
    sims.sort()                              # low → high
    # pools: bottom-2, top-2  (expand later)
    low_pool, high_pool = sims[:2], sims[-2:]
    accepted = []

    for grow in range(1 + math.ceil(len(sims)/2)):   # ≤26 shots
        for tag, pool in [("low", low_pool), ("high", high_pool)]:
            pool_pos = [p for _,_,p in pool]
            acc = pool_accuracy(ids, attn, pool_pos)
            if acc >= 0.5:        # category accepted
                roles = [tokenizer.convert_ids_to_tokens([tid])[0] for _,tid,_ in pool]
                accepted.append(f"{tag}:{roles} (acc {acc:.2f})")
        if accepted: break        # stop once something passed
        # grow pools by two (if any left)
        next_lo = sims[2+grow*2 : 4+grow*2]
        next_hi = sims[-4-grow*2 : -2-grow*2] if 4+grow*2 <= len(sims) else []
        low_pool  += next_lo
        high_pool += next_hi

    if not accepted:
        accepted = ["(none hit 50 %)"]

    return ", ".join(accepted), f"{len(present)} roles analysed", f"{text[:80]}…"

# ------------------------------------------------------------------
# 4. UI -------------------------------------------------------------
def build_ui():
    with gr.Blocks(title="🧠 Symbolic Encoder Inspector") as demo:
        gr.Markdown(
            "## 🧠 Symbolic Encoder Inspector  \n"
            "Select roles, paste text, and watch the pool-and-test prototype work."
        )
        with gr.Row():
            with gr.Column():
                txt = gr.Textbox(lines=3, label="Input")
                roles = gr.CheckboxGroup(
                    SYMBOLIC_ROLES,
                    value=SYMBOLIC_ROLES,
                    label="Roles to consider (else all present)"
                )
                btn = gr.Button("Run")
            with gr.Column():
                out_cat = gr.Textbox(label="Accepted categories")
                out_info= gr.Textbox(label="Debug")
                out_excerpt = gr.Textbox(label="Excerpt")

        btn.click(encode_and_trace, [txt, roles], [out_cat, out_info, out_excerpt])
    return demo

if __name__ == "__main__":
    build_ui().launch()