Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,276 Bytes
096fe3a 535151e ea2994b 096fe3a 535151e 096fe3a 535151e 096fe3a 535151e 096fe3a 535151e 096fe3a 8a2e372 096fe3a 785df91 096fe3a 785df91 096fe3a 535151e 096fe3a 535151e 8a2e372 535151e 096fe3a 535151e 096fe3a 535151e ea2994b 096fe3a ea2994b 535151e ea2994b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
# app.py – encoder-only demo for bert-beatrix-2048
# -----------------------------------------------
# launch: python app.py
# (gradio UI appears at http://localhost:7860)
import torch
import gradio as gr
import spaces
from bert_handler import create_handler_from_checkpoint
# ------------------------------------------------------------------
# 1. Model / tokenizer -------------------------------------------------
# ------------------------------------------------------------------
#
# • We load one repo *once*, via its canonical name.
# • BERTHandler handles the VRAM-safe cleanup & guarantees that the
# tokenizer already contains all special tokens saved in the checkpoint.
REPO_ID = "AbstractPhil/bert-beatrix-2048"
handler, full_model, tokenizer = create_handler_from_checkpoint(REPO_ID)
full_model = full_model.eval().cuda()
# Grab the encoder + embedding stack only
encoder = full_model.bert.encoder
embeddings = full_model.bert.embeddings
emb_ln = full_model.bert.emb_ln
emb_drop = full_model.bert.emb_drop
# ------------------------------------------------------------------
# 2. Symbolic token set -------------------------------------------
# ------------------------------------------------------------------
SYMBOLIC_ROLES = [
"<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
"<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>",
"<upper_body_clothing>", "<hair_style>", "<hair_length>", "<headwear>",
"<texture>", "<pattern>", "<grid>", "<zone>", "<offset>",
"<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
"<fabric>", "<jewelry>"
]
# Quick sanity check – should *never* be unk
missing = [tok for tok in SYMBOLIC_ROLES
if tokenizer.convert_tokens_to_ids(tok) == tokenizer.unk_token_id]
if missing:
raise RuntimeError(f"Tokenizer is missing {len(missing)} special tokens: {missing}")
# ------------------------------------------------------------------
# 3. Encoder-only inference util ----------------------------------
# ------------------------------------------------------------------
@spaces.GPU
def encode_and_trace(text: str, selected_roles: list[str]):
"""
• encodes `text`
• pulls out the hidden states for any of the `selected_roles`
• returns some quick stats so we can verify everything’s wired up
"""
with torch.no_grad():
batch = tokenizer(text, return_tensors="pt").to("cuda")
inp_ids, attn_mask = batch.input_ids, batch.attention_mask
# --- embedding + LayerNorm/dropout ---
x = embeddings(inp_ids)
x = emb_drop(emb_ln(x))
# --- proper *additive* attention mask ---
ext_mask = full_model.bert.get_extended_attention_mask(
attn_mask, x.shape[:-1]
)
encoded = encoder(x, attention_mask=ext_mask) # (B, S, H)
# --- pick out the positions that match selected_roles ---
sel_ids = {tokenizer.convert_tokens_to_ids(t) for t in selected_roles}
ids_list = inp_ids.squeeze(0).tolist() # python ints
keep_mask = torch.tensor([tid in sel_ids for tid in ids_list],
device=encoded.device)
tokens_found = [tokenizer.convert_ids_to_tokens([tid])[0]
for tid in ids_list if tid in sel_ids]
if keep_mask.any():
repr_vec = encoded.squeeze(0)[keep_mask].mean(0)
norm_val = f"{repr_vec.norm().item():.4f}"
else:
norm_val = "0.0000"
return {
"Symbolic Tokens": ", ".join(tokens_found) or "(none)",
"Embedding Norm": norm_val,
"Symbolic Token Count": int(keep_mask.sum().item()),
}
# ------------------------------------------------------------------
# 4. Gradio UI -----------------------------------------------------
# ------------------------------------------------------------------
def build_interface():
with gr.Blocks(title="🧠 Symbolic Encoder Inspector") as demo:
gr.Markdown("## 🧠 Symbolic Encoder Inspector\n"
"Paste some text containing the special `<role>` tokens and "
"inspect their encoder representations.")
with gr.Row():
with gr.Column():
input_text = gr.Textbox(
label="Input with Symbolic Tokens",
placeholder="Example: A <subject> wearing <upper_body_clothing> …",
lines=3,
)
role_selector = gr.CheckboxGroup(
choices=SYMBOLIC_ROLES,
label="Trace these symbolic roles"
)
run_btn = gr.Button("Encode & Trace")
with gr.Column():
out_tokens = gr.Textbox(label="Symbolic Tokens Found")
out_norm = gr.Textbox(label="Mean Norm")
out_count = gr.Textbox(label="Token Count")
run_btn.click(
fn=encode_and_trace,
inputs=[input_text, role_selector],
outputs=[out_tokens, out_norm, out_count],
)
return demo
if __name__ == "__main__":
demo = build_interface()
demo.launch()
|