File size: 6,751 Bytes
ed080e6
aa28bbb
 
ed080e6
 
aa28bbb
ed080e6
872b08b
 
81a0ae4
535151e
aa28bbb
81a0ae4
872b08b
096fe3a
872b08b
096fe3a
aa28bbb
096fe3a
ed080e6
 
872b08b
 
 
81a0ae4
ed080e6
9a08859
81a0ae4
 
ed080e6
872b08b
9a08859
 
ed080e6
 
9a08859
ed080e6
9a08859
415afa1
9a08859
415afa1
ed080e6
872b08b
9a08859
ed080e6
9a08859
 
aa28bbb
9a08859
ed080e6
096fe3a
535151e
096fe3a
 
 
 
8a2e372
096fe3a
aa28bbb
096fe3a
 
785df91
 
 
 
 
872b08b
785df91
ed080e6
 
 
096fe3a
872b08b
 
ed080e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
096fe3a
 
ed080e6
096fe3a
 
 
 
ed080e6
 
 
096fe3a
ed080e6
aa28bbb
ed080e6
 
 
 
aa28bbb
ed080e6
 
aa28bbb
ed080e6
aa28bbb
ed080e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
872b08b
096fe3a
ed080e6
096fe3a
785df91
096fe3a
872b08b
ed080e6
 
 
 
872b08b
 
535151e
 
aa28bbb
ed080e6
 
aa28bbb
 
 
 
ed080e6
aa28bbb
ed080e6
535151e
ed080e6
 
 
 
 
872b08b
535151e
ea2994b
872b08b
ea2994b
81a0ae4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
# app.py – encoder-only demo for bert-beatrix-2048
# launch:  python app.py

import json
import sys
from pathlib import Path, PurePosixPath
from itertools import islice

import gradio as gr
import spaces
import torch
import torch.nn.functional as F
from huggingface_hub import snapshot_download

from bert_handler import create_handler_from_checkpoint

# ------------------------------------------------------------------
# 0.  Download & patch config.json  --------------------------------
# ------------------------------------------------------------------
REPO_ID     = "AbstractPhil/bert-beatrix-2048"
LOCAL_CKPT  = "bert-beatrix-2048"

snapshot_download(
    repo_id=REPO_ID,
    revision="main",
    local_dir=LOCAL_CKPT,
    local_dir_use_symlinks=False,
)

cfg_path = Path(LOCAL_CKPT) / "config.json"
with cfg_path.open() as f:
    cfg = json.load(f)

auto_map  = cfg.get("auto_map", {})
patched   = False
for k, v in auto_map.items():
    if "--" in v:                                  # "repo--module.Class"
        auto_map[k] = PurePosixPath(v.split("--", 1)[1]).as_posix()
        patched = True

if patched:
    cfg["auto_map"] = auto_map
    with cfg_path.open("w") as f:
        json.dump(cfg, f, indent=2)
    print("πŸ› οΈ  Patched config.json β†’ auto_map paths fixed")

# ------------------------------------------------------------------
# 1.  Model / tokenizer  -------------------------------------------
# ------------------------------------------------------------------
handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_CKPT)
full_model = full_model.eval().cuda()

encoder     = full_model.bert.encoder
embeddings  = full_model.bert.embeddings
emb_ln      = full_model.bert.emb_ln
emb_drop    = full_model.bert.emb_drop

# ------------------------------------------------------------------
# 2.  Symbolic token set  ------------------------------------------
# ------------------------------------------------------------------
SYMBOLIC_ROLES = [
    "<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
    "<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>",
    "<upper_body_clothing>", "<hair_style>", "<hair_length>", "<headwear>",
    "<texture>", "<pattern>", "<grid>", "<zone>", "<offset>",
    "<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
    "<fabric>", "<jewelry>",
]

unk = tokenizer.unk_token_id
missing = [t for t in SYMBOLIC_ROLES if tokenizer.convert_tokens_to_ids(t) == unk]
if missing:
    sys.exit(f"❌ Tokenizer is missing {missing}")

# ------------------------------------------------------------------
# 3.  helper: merge lowest + highest until 3 remain ----------------
# ------------------------------------------------------------------
def reduce_to_three(table):
    """
    table : list of dicts  {role, token, score}
    repeatedly remove   lowest and highest,
    replace with their average,
    until len(table)==3.
    """
    working = table[:]
    working.sort(key=lambda x: x["score"])
    while len(working) > 3:
        low  = working.pop(0)
        high = working.pop(-1)
        merged = {
            "role":  f"{high['role']}|{low['role']}",
            "token": f"{high['token']}/{low['token']}",
            "score": (high["score"] + low["score"]) / 2.0,
        }
        working.append(merged)
        working.sort(key=lambda x: x["score"])
    # highest first for display
    working.sort(key=lambda x: x["score"], reverse=True)
    return working

# ------------------------------------------------------------------
# 4.  Encoder-only inference util  ---------------------------------
# ------------------------------------------------------------------
@spaces.GPU
def encode_and_trace(text: str, selected_roles: list[str]):
    with torch.no_grad():
        if not text.strip():
            return "(no input)","",""

        batch = tokenizer(text, return_tensors="pt").to("cuda")
        ids, attn = batch.input_ids, batch.attention_mask

        # encoder forward
        x = emb_drop(emb_ln(embeddings(ids)))
        ext = full_model.bert.get_extended_attention_mask(attn, x.shape[:-1])
        hs  = encoder(x, attention_mask=ext)                     # (1,S,H)

        # token-level embeddings (before LN) for similarity calc
        token_emb = embeddings(ids).squeeze(0)                   # (S,H)

        rows = []
        for role in selected_roles:
            rid   = tokenizer.convert_tokens_to_ids(role)
            rvec  = embeddings.weight[rid]                       # (H,)
            # cosine similarity to every *input* token embedding
            sims  = F.cosine_similarity(rvec.unsqueeze(0), token_emb, dim=-1)
            best  = torch.argmax(sims).item()
            rows.append({
                "role" : role,
                "token": tokenizer.convert_ids_to_tokens([ids[0, best].item()])[0],
                "score": sims[best].item()
            })

        if not rows:
            return "(none selected)","",""

        final3 = reduce_to_three(rows)
        out_strs = [f"{r['role']} ↔ {r['token']}  ({r['score']:+.2f})" for r in final3]
        # pad so we always return 3 strings
        while len(out_strs) < 3:
            out_strs.append("")
        return out_strs[0], out_strs[1], out_strs[2]

# ------------------------------------------------------------------
# 5.  Gradio UI  ----------------------------------------------------
# ------------------------------------------------------------------
def build_interface():
    with gr.Blocks(title="🧠 Symbolic Encoder Inspector") as demo:
        gr.Markdown(
            "### 🧠 Symbolic Encoder Inspector\n"
            "Paste text with `<role>` tokens, pick roles to track, then we\n"
            "β€’ compute role ↔ token cosine scores\n"
            "β€’ iteratively merge low+high pairs until **3 composite buckets** remain."
        )

        with gr.Row():
            with gr.Column():
                txt = gr.Textbox(
                    label="Input with Symbolic Tokens",
                    placeholder="Example: A <subject> wearing <upper_body_clothing> …",
                    lines=3,
                )
                roles = gr.CheckboxGroup(
                    choices=SYMBOLIC_ROLES,
                    label="Trace these symbolic roles",
                )
                btn = gr.Button("Encode & Merge")
            with gr.Column():
                cat1 = gr.Textbox(label="Category 1 (highest)")
                cat2 = gr.Textbox(label="Category 2")
                cat3 = gr.Textbox(label="Category 3 (lowest)")

        btn.click(encode_and_trace, [txt, roles], [cat1, cat2, cat3])

    return demo


if __name__ == "__main__":
    build_interface().launch()