File size: 5,426 Bytes
096fe3a
415afa1
 
872b08b
415afa1
 
872b08b
81a0ae4
535151e
81a0ae4
872b08b
096fe3a
872b08b
535151e
096fe3a
415afa1
096fe3a
415afa1
 
872b08b
 
 
81a0ae4
872b08b
9a08859
81a0ae4
 
9a08859
872b08b
9a08859
 
 
415afa1
9a08859
415afa1
9a08859
415afa1
9a08859
415afa1
9a08859
872b08b
9a08859
415afa1
9a08859
 
 
415afa1
9a08859
81a0ae4
096fe3a
535151e
096fe3a
 
 
 
8a2e372
872b08b
096fe3a
415afa1
096fe3a
 
785df91
 
 
 
 
872b08b
785df91
 
415afa1
 
096fe3a
872b08b
 
096fe3a
 
415afa1
096fe3a
 
 
415afa1
 
 
 
 
 
096fe3a
 
415afa1
096fe3a
81a0ae4
415afa1
 
096fe3a
415afa1
 
 
096fe3a
415afa1
 
 
872b08b
415afa1
 
 
096fe3a
415afa1
b60c583
415afa1
 
096fe3a
872b08b
096fe3a
872b08b
096fe3a
785df91
096fe3a
872b08b
415afa1
 
 
872b08b
 
535151e
 
415afa1
 
 
 
535151e
415afa1
 
 
872b08b
415afa1
872b08b
535151e
ea2994b
872b08b
ea2994b
81a0ae4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# app.py – encoder-only demo for bert-beatrix-2048
# ------------------------------------------------
# launch:  python app.py     β†’  http://localhost:7860

import json, re, sys
from pathlib import Path, PurePosixPath        # ← PurePosixPath import
import gradio as gr
import spaces
import torch
from huggingface_hub import snapshot_download

from bert_handler import create_handler_from_checkpoint


# ------------------------------------------------------------------
# 0.  Download & patch config.json  ---------------------------------
# ------------------------------------------------------------------
REPO_ID     = "AbstractPhil/bert-beatrix-2048"
LOCAL_CKPT  = "bert-beatrix-2048"              # cache dir name

snapshot_download(
    repo_id=REPO_ID,
    revision="main",
    local_dir=LOCAL_CKPT,
    local_dir_use_symlinks=False,
)

cfg_path = Path(LOCAL_CKPT) / "config.json"
with cfg_path.open() as f:
    cfg = json.load(f)

auto_map = cfg.get("auto_map", {})
patched  = False
for k, v in auto_map.items():
    if "--" in v:                              # strip   repo--module.Class
        auto_map[k] = PurePosixPath(v.split("--", 1)[1]).as_posix()
        patched = True

if patched:
    cfg["auto_map"] = auto_map
    with cfg_path.open("w") as f:
        json.dump(cfg, f, indent=2)
    print("πŸ› οΈ  Patched config.json β†’ auto_map now points to local modules")


# ------------------------------------------------------------------
# 1.  Load model / tokenizer  ---------------------------------------
# ------------------------------------------------------------------
handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_CKPT)
full_model = full_model.eval().cuda()

encoder     = full_model.bert.encoder
embeddings  = full_model.bert.embeddings
emb_ln      = full_model.bert.emb_ln
emb_drop    = full_model.bert.emb_drop


# ------------------------------------------------------------------
# 2.  Symbolic roles  ------------------------------------------------
# ------------------------------------------------------------------
SYMBOLIC_ROLES = [
    "<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
    "<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>",
    "<upper_body_clothing>", "<hair_style>", "<hair_length>", "<headwear>",
    "<texture>", "<pattern>", "<grid>", "<zone>", "<offset>",
    "<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
    "<fabric>", "<jewelry>",
]

missing = [t for t in SYMBOLIC_ROLES
           if tokenizer.convert_tokens_to_ids(t) == tokenizer.unk_token_id]
if missing:
    sys.exit(f"❌ Tokenizer is missing {missing}")


# ------------------------------------------------------------------
# 3.  Encoder-only helper  ------------------------------------------
# ------------------------------------------------------------------
@spaces.GPU
def encode_and_trace(text: str, selected_roles: list[str]):
    """
    Returns **exactly three scalars** matching the 3 gradio outputs:
        1) tokens_str  (comma-separated list found)
        2) norm_str    (mean L2-norm of those embeddings)
        3) count_int   (# tokens matched)
    """
    with torch.no_grad():
        batch = tokenizer(text, return_tensors="pt").to("cuda")
        ids, attn = batch.input_ids, batch.attention_mask

        x = emb_drop(emb_ln(embeddings(ids)))
        ext = full_model.bert.get_extended_attention_mask(attn, x.shape[:-1])
        enc = encoder(x, attention_mask=ext)                   # (1, S, H)

        role_ids = {tokenizer.convert_tokens_to_ids(t) for t in selected_roles}
        mask     = torch.tensor([tid in role_ids for tid in ids[0].tolist()],
                                device=enc.device, dtype=torch.bool)

        found = [tokenizer.convert_ids_to_tokens([tid])[0]
                 for tid in ids[0].tolist() if tid in role_ids]
        tokens_str = ", ".join(found) or "(none)"

        if mask.any():
            mean_vec = enc[0][mask].mean(0)
            norm_str = f"{mean_vec.norm().item():.4f}"
        else:
            norm_str = "0.0000"

        count_int = int(mask.sum().item())
        return tokens_str, norm_str, count_int       # ← three outputs!


# ------------------------------------------------------------------
# 4.  Gradio UI  ----------------------------------------------------
# ------------------------------------------------------------------
def build_interface():
    with gr.Blocks(title="🧠 Symbolic Encoder Inspector") as demo:
        gr.Markdown(
            "### 🧠 Symbolic Encoder Inspector\n"
            "Paste text that includes the `<role>` tokens and inspect their "
            "hidden-state statistics."
        )

        with gr.Row():
            with gr.Column():
                txt  = gr.Textbox(label="Input", lines=3,
                                  placeholder="A <subject> wearing <upper_body_clothing> …")
                chk  = gr.CheckboxGroup(SYMBOLIC_ROLES, label="Roles to trace")
                run  = gr.Button("Encode & Trace")
            with gr.Column():
                out_tok  = gr.Textbox(label="Tokens found")
                out_norm = gr.Textbox(label="Mean norm")
                out_cnt  = gr.Textbox(label="Token count")

        run.click(encode_and_trace, [txt, chk], [out_tok, out_norm, out_cnt])

    return demo


if __name__ == "__main__":
    build_interface().launch()