File size: 5,517 Bytes
096fe3a
 
 
872b08b
 
 
 
 
 
 
 
81a0ae4
535151e
81a0ae4
872b08b
096fe3a
872b08b
535151e
096fe3a
872b08b
096fe3a
872b08b
 
 
 
 
81a0ae4
872b08b
9a08859
81a0ae4
 
872b08b
9a08859
872b08b
9a08859
 
 
 
 
872b08b
9a08859
 
 
 
 
872b08b
9a08859
872b08b
9a08859
 
 
872b08b
9a08859
81a0ae4
096fe3a
535151e
872b08b
096fe3a
 
 
 
8a2e372
872b08b
096fe3a
872b08b
096fe3a
 
785df91
 
 
 
 
872b08b
785df91
 
872b08b
 
 
096fe3a
872b08b
 
096fe3a
 
872b08b
096fe3a
 
 
 
 
81a0ae4
096fe3a
81a0ae4
096fe3a
81a0ae4
872b08b
096fe3a
872b08b
 
 
096fe3a
872b08b
 
 
 
 
81a0ae4
096fe3a
81a0ae4
096fe3a
 
81a0ae4
872b08b
 
096fe3a
 
872b08b
096fe3a
872b08b
096fe3a
785df91
096fe3a
872b08b
 
 
 
 
 
535151e
 
872b08b
 
 
 
 
 
 
 
 
81a0ae4
535151e
81a0ae4
 
 
872b08b
 
 
535151e
ea2994b
872b08b
ea2994b
81a0ae4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# app.py – encoder-only demo for bert-beatrix-2048
# -----------------------------------------------
# launch:  python app.py
# (gradio UI appears at http://localhost:7860)

import json
import re
import sys
from pathlib import Path, PurePosixPath   # ← PurePosixPath import added

import gradio as gr
import spaces
import torch
from huggingface_hub import snapshot_download

from bert_handler import create_handler_from_checkpoint


# ------------------------------------------------------------------
# 0.  Download & patch config.json  --------------------------------
# ------------------------------------------------------------------
REPO_ID = "AbstractPhil/bert-beatrix-2048"
LOCAL_CKPT = "bert-beatrix-2048"          # cached dir name

snapshot_download(
    repo_id=REPO_ID,
    revision="main",
    local_dir=LOCAL_CKPT,
    local_dir_use_symlinks=False,
)

# ── one-time patch: strip the β€œrepo--” prefix that confuses AutoModel ──
cfg_path = Path(LOCAL_CKPT) / "config.json"
with cfg_path.open() as f:
    cfg = json.load(f)

auto_map = cfg.get("auto_map", {})
changed = False
for k, v in auto_map.items():
    if "--" in v:                         # v looks like  "repo--module.Class"
        auto_map[k] = PurePosixPath(v.split("--", 1)[1]).as_posix()
        changed = True

if changed:
    cfg["auto_map"] = auto_map
    with cfg_path.open("w") as f:
        json.dump(cfg, f, indent=2)
    print("πŸ› οΈ  Patched config.json β†’ auto_map now points at local modules")


# ------------------------------------------------------------------
# 1.  Model / tokenizer  -------------------------------------------
# ------------------------------------------------------------------
handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_CKPT)
full_model = full_model.eval().cuda()

# Grab encoder + embedding stack only
encoder     = full_model.bert.encoder
embeddings  = full_model.bert.embeddings
emb_ln      = full_model.bert.emb_ln
emb_drop    = full_model.bert.emb_drop


# ------------------------------------------------------------------
# 2.  Symbolic token set  ------------------------------------------
# ------------------------------------------------------------------
SYMBOLIC_ROLES = [
    "<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
    "<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>",
    "<upper_body_clothing>", "<hair_style>", "<hair_length>", "<headwear>",
    "<texture>", "<pattern>", "<grid>", "<zone>", "<offset>",
    "<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
    "<fabric>", "<jewelry>",
]

# quick sanity check
missing = [tok for tok in SYMBOLIC_ROLES
           if tokenizer.convert_tokens_to_ids(tok) == tokenizer.unk_token_id]
if missing:
    sys.exit(f"❌ Tokenizer is missing {missing}")


# ------------------------------------------------------------------
# 3.  Encoder-only inference util  ---------------------------------
# ------------------------------------------------------------------
@spaces.GPU
def encode_and_trace(text: str, selected_roles: list[str]):
    with torch.no_grad():
        batch = tokenizer(text, return_tensors="pt").to("cuda")
        ids, mask = batch.input_ids, batch.attention_mask

        x = emb_drop(emb_ln(embeddings(ids)))

        ext_mask = full_model.bert.get_extended_attention_mask(mask, x.shape[:-1])
        enc = encoder(x, attention_mask=ext_mask)              # (1, S, H)

        sel_ids = {tokenizer.convert_tokens_to_ids(t) for t in selected_roles}
        flags   = torch.tensor([tid in sel_ids for tid in ids[0].tolist()],
                               device=enc.device)

        found = [tokenizer.convert_ids_to_tokens([tid])[0]
                 for tid in ids[0].tolist() if tid in sel_ids]

        if flags.any():
            vec = enc[0][flags].mean(0)
            norm = f"{vec.norm().item():.4f}"
        else:
            norm = "0.0000"

        return {
            "Symbolic Tokens": ", ".join(found) or "(none)",
            "Embedding Norm":  norm,
            "Symbolic Token Count": int(flags.sum().item()),
        }


# ------------------------------------------------------------------
# 4.  Gradio UI  ----------------------------------------------------
# ------------------------------------------------------------------
def build_interface():
    with gr.Blocks(title="🧠 Symbolic Encoder Inspector") as demo:
        gr.Markdown(
            "## 🧠 Symbolic Encoder Inspector\n"
            "Paste some text containing the special `<role>` tokens and "
            "inspect their encoder representations."
        )

        with gr.Row():
            with gr.Column():
                txt = gr.Textbox(
                    label="Input with Symbolic Tokens",
                    placeholder="Example: A <subject> wearing <upper_body_clothing> …",
                    lines=3,
                )
                roles = gr.CheckboxGroup(
                    choices=SYMBOLIC_ROLES,
                    label="Trace these symbolic roles",
                )
                btn = gr.Button("Encode & Trace")
            with gr.Column():
                out_tok  = gr.Textbox(label="Symbolic Tokens Found")
                out_norm = gr.Textbox(label="Mean Norm")
                out_cnt  = gr.Textbox(label="Token Count")

        btn.click(encode_and_trace, [txt, roles], [out_tok, out_norm, out_cnt])

    return demo


if __name__ == "__main__":
    build_interface().launch()