File size: 5,128 Bytes
096fe3a
 
 
81a0ae4
535151e
ea2994b
81a0ae4
096fe3a
a0fce12
535151e
096fe3a
81a0ae4
096fe3a
81a0ae4
 
 
 
9a08859
81a0ae4
 
9a08859
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81a0ae4
096fe3a
535151e
81a0ae4
096fe3a
 
 
 
8a2e372
096fe3a
81a0ae4
096fe3a
 
785df91
 
 
 
 
 
 
 
81a0ae4
 
 
096fe3a
81a0ae4
096fe3a
 
 
 
 
 
 
 
81a0ae4
096fe3a
81a0ae4
096fe3a
81a0ae4
 
096fe3a
81a0ae4
 
096fe3a
81a0ae4
 
 
 
096fe3a
81a0ae4
096fe3a
 
81a0ae4
 
 
096fe3a
 
 
 
 
785df91
096fe3a
81a0ae4
535151e
 
81a0ae4
 
 
535151e
81a0ae4
 
 
 
535151e
ea2994b
 
81a0ae4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# app.py – encoder-only demo for bert-beatrix-2048
# -----------------------------------------------
# launch:  python app.py
import spaces
import torch
import gradio as gr
from huggingface_hub import snapshot_download
from bert_handler import create_handler_from_checkpoint
from pathlib import Path

# ------------------------------------------------------------------
# 1.  Download *once* and load locally  -----------------------------
# ------------------------------------------------------------------
LOCAL_CKPT = snapshot_download(
    repo_id="AbstractPhil/bert-beatrix-2048",
    revision="main",
    local_dir="bert-beatrix-2048",
    local_dir_use_symlinks=False,
)

cfg_path = Path(LOCAL_CKPT) / "config.json"
with open(cfg_path) as f:
    cfg = json.load(f)

auto_map = cfg.get("auto_map", {})
changed = False
for k, v in auto_map.items():
    # v  looks like  "AbstractPhil/bert-beatrix-2048--modeling_hf_nomic_bert.…"
    if "--" in v:
        auto_map[k] = PurePosixPath(v.split("--", 1)[1]).as_posix()
        changed = True

if changed:
    cfg["auto_map"] = auto_map
    with open(cfg_path, "w") as f:
        json.dump(cfg, f, indent=2)
    print("🔧 Patched auto_map → now points to local modules only")

# also drop any *previously* imported remote modules in this session
for name in list(sys.modules):
    if name.startswith("transformers_modules.AbstractPhil.bert-beatrix-2048"):
        del sys.modules[name]

# ------------------------------------------------------------------
# 1.  normal load via BERTHandler  ---------------------------------
# ------------------------------------------------------------------
from bert_handler import create_handler_from_checkpoint
handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_CKPT)
full_model = full_model.eval().cuda()

# --- pull encoder & embeddings only --------------------------------
encoder     = full_model.bert.encoder
embeddings  = full_model.bert.embeddings
emb_ln      = full_model.bert.emb_ln
emb_drop    = full_model.bert.emb_drop

# ------------------------------------------------------------------
# 2.  Symbolic token list  ------------------------------------------
# ------------------------------------------------------------------
SYMBOLIC_ROLES = [
    "<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
    "<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>",
    "<upper_body_clothing>", "<hair_style>", "<hair_length>", "<headwear>",
    "<texture>", "<pattern>", "<grid>", "<zone>", "<offset>",
    "<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
    "<fabric>", "<jewelry>"
]

# Sanity-check: every role must be known by the tokenizer
missing = [t for t in SYMBOLIC_ROLES
           if tokenizer.convert_tokens_to_ids(t) == tokenizer.unk_token_id]
if missing:
    raise RuntimeError(f"Tokenizer is missing special tokens: {missing}")

# ------------------------------------------------------------------
# 3.  Encoder-only inference util  ----------------------------------
# ------------------------------------------------------------------
@spaces.GPU
def encode_and_trace(text: str, selected_roles: list[str]):
    with torch.no_grad():
        batch = tokenizer(text, return_tensors="pt").to("cuda")
        ids, mask = batch.input_ids, batch.attention_mask

        x = emb_drop(emb_ln(embeddings(ids)))

        ext_mask = full_model.bert.get_extended_attention_mask(mask, x.shape[:-1])
        enc = encoder(x, attention_mask=ext_mask)               # (1, S, H)

        want = {tokenizer.convert_tokens_to_ids(t) for t in selected_roles}
        keep = torch.tensor([tid in want for tid in ids[0]], device=enc.device)

        found = [tokenizer.convert_ids_to_tokens([tid])[0] for tid in ids[0] if tid in want]
        if keep.any():
            vec = enc[0][keep].mean(0)
            norm = f"{vec.norm().item():.4f}"
        else:
            norm = "0.0000"

        return {
            "Symbolic Tokens": ", ".join(found) or "(none)",
            "Mean Norm": norm,
            "Token Count": int(keep.sum().item()),
        }

# ------------------------------------------------------------------
# 4.  Gradio UI  -----------------------------------------------------
# ------------------------------------------------------------------
def build_interface():
    with gr.Blocks(title="🧠 Symbolic Encoder Inspector") as demo:
        gr.Markdown("## 🧠 Symbolic Encoder Inspector")
        with gr.Row():
            with gr.Column():
                txt = gr.Textbox(label="Input with Symbolic Tokens", lines=3)
                chk = gr.CheckboxGroup(choices=SYMBOLIC_ROLES, label="Trace these roles")
                btn = gr.Button("Encode & Trace")
            with gr.Column():
                out_tok  = gr.Textbox(label="Symbolic Tokens Found")
                out_norm = gr.Textbox(label="Mean Norm")
                out_cnt  = gr.Textbox(label="Token Count")
        btn.click(encode_and_trace, [txt, chk], [out_tok, out_norm, out_cnt])
    return demo

if __name__ == "__main__":
    build_interface().launch()