File size: 5,276 Bytes
096fe3a
 
 
 
535151e
 
ea2994b
096fe3a
 
535151e
096fe3a
 
 
 
 
 
 
535151e
096fe3a
535151e
096fe3a
 
535151e
096fe3a
 
 
 
 
8a2e372
096fe3a
 
 
 
785df91
 
 
 
 
 
 
 
096fe3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
785df91
096fe3a
 
 
 
 
 
535151e
 
096fe3a
 
 
 
 
 
 
 
535151e
8a2e372
535151e
096fe3a
 
 
535151e
096fe3a
 
 
 
 
535151e
 
ea2994b
096fe3a
ea2994b
535151e
ea2994b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# app.py – encoder-only demo for bert-beatrix-2048
# -----------------------------------------------
# launch:  python app.py
# (gradio UI appears at http://localhost:7860)

import torch
import gradio as gr
import spaces
from bert_handler import create_handler_from_checkpoint

# ------------------------------------------------------------------
# 1.  Model / tokenizer  -------------------------------------------------
# ------------------------------------------------------------------
#
# • We load one repo *once*, via its canonical name.
# • BERTHandler handles the VRAM-safe cleanup & guarantees that the
#   tokenizer already contains all special tokens saved in the checkpoint.

REPO_ID = "AbstractPhil/bert-beatrix-2048"

handler, full_model, tokenizer = create_handler_from_checkpoint(REPO_ID)
full_model = full_model.eval().cuda()

# Grab the encoder + embedding stack only
encoder     = full_model.bert.encoder
embeddings  = full_model.bert.embeddings
emb_ln      = full_model.bert.emb_ln
emb_drop    = full_model.bert.emb_drop

# ------------------------------------------------------------------
# 2.  Symbolic token set  -------------------------------------------
# ------------------------------------------------------------------
SYMBOLIC_ROLES = [
    "<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
    "<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>",
    "<upper_body_clothing>", "<hair_style>", "<hair_length>", "<headwear>",
    "<texture>", "<pattern>", "<grid>", "<zone>", "<offset>",
    "<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
    "<fabric>", "<jewelry>"
]

# Quick sanity check – should *never* be unk
missing = [tok for tok in SYMBOLIC_ROLES
           if tokenizer.convert_tokens_to_ids(tok) == tokenizer.unk_token_id]
if missing:
    raise RuntimeError(f"Tokenizer is missing {len(missing)} special tokens: {missing}")

# ------------------------------------------------------------------
# 3.  Encoder-only inference util  ----------------------------------
# ------------------------------------------------------------------
@spaces.GPU
def encode_and_trace(text: str, selected_roles: list[str]):
    """
    • encodes `text`
    • pulls out the hidden states for any of the `selected_roles`
    • returns some quick stats so we can verify everything’s wired up
    """
    with torch.no_grad():
        batch = tokenizer(text, return_tensors="pt").to("cuda")
        inp_ids, attn_mask = batch.input_ids, batch.attention_mask

        # --- embedding + LayerNorm/dropout ---
        x = embeddings(inp_ids)
        x = emb_drop(emb_ln(x))

        # --- proper *additive* attention mask ---
        ext_mask = full_model.bert.get_extended_attention_mask(
            attn_mask, x.shape[:-1]
        )

        encoded = encoder(x, attention_mask=ext_mask)          # (B, S, H)

        # --- pick out the positions that match selected_roles ---
        sel_ids = {tokenizer.convert_tokens_to_ids(t) for t in selected_roles}
        ids_list = inp_ids.squeeze(0).tolist()                 # python ints
        keep_mask = torch.tensor([tid in sel_ids for tid in ids_list],
                                 device=encoded.device)

        tokens_found = [tokenizer.convert_ids_to_tokens([tid])[0]
                        for tid in ids_list if tid in sel_ids]
        if keep_mask.any():
            repr_vec = encoded.squeeze(0)[keep_mask].mean(0)
            norm_val = f"{repr_vec.norm().item():.4f}"
        else:
            norm_val = "0.0000"

        return {
            "Symbolic Tokens": ", ".join(tokens_found) or "(none)",
            "Embedding Norm":  norm_val,
            "Symbolic Token Count": int(keep_mask.sum().item()),
        }

# ------------------------------------------------------------------
# 4.  Gradio UI  -----------------------------------------------------
# ------------------------------------------------------------------
def build_interface():
    with gr.Blocks(title="🧠 Symbolic Encoder Inspector") as demo:

        gr.Markdown("## 🧠 Symbolic Encoder Inspector\n"
                    "Paste some text containing the special `<role>` tokens and "
                    "inspect their encoder representations.")

        with gr.Row():
            with gr.Column():
                input_text = gr.Textbox(
                    label="Input with Symbolic Tokens",
                    placeholder="Example: A <subject> wearing <upper_body_clothing> …",
                    lines=3,
                )
                role_selector = gr.CheckboxGroup(
                    choices=SYMBOLIC_ROLES,
                    label="Trace these symbolic roles"
                )
                run_btn = gr.Button("Encode & Trace")
            with gr.Column():
                out_tokens  = gr.Textbox(label="Symbolic Tokens Found")
                out_norm    = gr.Textbox(label="Mean Norm")
                out_count   = gr.Textbox(label="Token Count")

        run_btn.click(
            fn=encode_and_trace,
            inputs=[input_text, role_selector],
            outputs=[out_tokens, out_norm, out_count],
        )

    return demo


if __name__ == "__main__":
    demo = build_interface()
    demo.launch()