Spaces:
Running
on
Zero
Running
on
Zero
# app.py – encoder-only + masking accuracy demo for bert-beatrix-2048 | |
# ----------------------------------------------------------------- | |
# launch: python app.py (UI at http://localhost:7860) | |
import json, re, sys | |
from pathlib import Path, PurePosixPath | |
import gradio as gr | |
import spaces | |
import torch | |
from huggingface_hub import snapshot_download | |
from bert_handler import create_handler_from_checkpoint | |
# ------------------------------------------------------------------ | |
# 0. download repo + patch auto_map -------------------------------- | |
REPO_ID = "AbstractPhil/bert-beatrix-2048" | |
LOCAL_CK = "bert-beatrix-2048" | |
snapshot_download(repo_id=REPO_ID, local_dir=LOCAL_CK, local_dir_use_symlinks=False) | |
cfg_p = Path(LOCAL_CK) / "config.json" | |
with cfg_p.open() as f: | |
cfg = json.load(f) | |
for k, v in cfg.get("auto_map", {}).items(): | |
if "--" in v: | |
cfg["auto_map"][k] = PurePosixPath(v.split("--", 1)[1]).as_posix() | |
with cfg_p.open("w") as f: | |
json.dump(cfg, f, indent=2) | |
# ------------------------------------------------------------------ | |
# 1. load model / tokenizer --------------------------------------- | |
handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_CK) | |
full_model = full_model.eval().cuda() | |
encoder = full_model.bert.encoder | |
embeddings = full_model.bert.embeddings | |
emb_ln = full_model.bert.emb_ln | |
emb_drop = full_model.bert.emb_drop | |
MASK = tokenizer.mask_token or "[MASK]" | |
# ------------------------------------------------------------------ | |
# 2. symbolic role list ------------------------------------------- | |
SYMBOLIC_ROLES = [ | |
"<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>", | |
"<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>", | |
"<upper_body_clothing>", "<hair_style>", "<hair_length>", "<headwear>", | |
"<texture>", "<pattern>", "<grid>", "<zone>", "<offset>", | |
"<object_left>", "<object_right>", "<relation>", "<intent>", "<style>", | |
"<fabric>", "<jewelry>", | |
] | |
miss = [t for t in SYMBOLIC_ROLES | |
if tokenizer.convert_tokens_to_ids(t) == tokenizer.unk_token_id] | |
if miss: | |
sys.exit(f"❌ Tokenizer missing {miss}") | |
# ------------------------------------------------------------------ | |
# 3. inference util ---------------------------------------------- | |
def encode_and_trace(text: str, selected_roles: list[str]): | |
# ----- 3-A. build masked version & encode original -------------- | |
sel_ids = {tokenizer.convert_tokens_to_ids(t) for t in selected_roles} | |
# tokenised “plain” text | |
plain = tokenizer(text, return_tensors="pt").to("cuda") | |
ids_plain = plain.input_ids | |
# make masked string (regex to avoid partial hits) | |
masked_txt = text | |
for tok in selected_roles: | |
masked_txt = re.sub(re.escape(tok), MASK, masked_txt) | |
masked = tokenizer(masked_txt, return_tensors="pt").to("cuda") | |
ids_masked = masked.input_ids | |
# ----- 3-B. run model on masked text ---------------------------- | |
with torch.no_grad(): | |
logits = full_model(**masked).logits[0] # (S, V) | |
preds = logits.argmax(-1) # (S,) | |
# ----- 3-C. gather stats per masked role ------------------------ | |
found_tokens, correct = [], 0 | |
role_flags = [] | |
for i, (orig_id, pred_id) in enumerate(zip(ids_plain[0], preds)): | |
if orig_id.item() in sel_ids and ids_masked[0, i].item() == tokenizer.mask_token_id: | |
found_tokens.append(tokenizer.convert_ids_to_tokens([orig_id])[0]) | |
correct += int(orig_id.item() == pred_id.item()) | |
role_flags.append(i) | |
total = len(role_flags) | |
acc = correct / total if total else 0.0 | |
# ----- 3-D. encoder rep pooling for *all* selected roles -------- | |
with torch.no_grad(): | |
# embeddings -> normed reps | |
x = emb_drop(emb_ln(embeddings(ids_plain))) | |
attn = full_model.bert.get_extended_attention_mask( | |
plain.attention_mask, x.shape[:-1] | |
) | |
enc = encoder(x, attention_mask=attn) # (1,S,H) | |
mask_vec = torch.tensor( | |
[tid in sel_ids for tid in ids_plain[0].tolist()], device=enc.device | |
) | |
if mask_vec.any(): | |
pooled = enc[0][mask_vec].mean(0) | |
norm = f"{pooled.norm().item():.4f}" | |
else: | |
norm = "0.0000" | |
tokens_str = ", ".join(found_tokens) or "(none)" | |
return tokens_str, norm, f"{acc*100:.1f}%" | |
# ------------------------------------------------------------------ | |
# 4. gradio UI ---------------------------------------------------- | |
def app(): | |
with gr.Blocks(title="🧠 Symbolic Encoder Inspector") as demo: | |
gr.Markdown( | |
"## 🧠 Symbolic Encoder Inspector \n" | |
"1. Model side: we *mask* every chosen role token, run the LM, and report how often it recovers the original. \n" | |
"2. Encoder side: we also pool hidden-state vectors for those roles and give their mean L2-norm." | |
) | |
with gr.Row(): | |
with gr.Column(): | |
txt = gr.Textbox( | |
label="Input with Symbolic Tokens", | |
lines=3, | |
placeholder="Example: A <subject> wearing <upper_body_clothing> …", | |
) | |
roles = gr.CheckboxGroup( | |
choices=SYMBOLIC_ROLES, | |
value=SYMBOLIC_ROLES, # <- all pre-selected | |
label="Roles to mask & trace", | |
) | |
run = gr.Button("Run") | |
with gr.Column(): | |
o_tok = gr.Textbox(label="Masked-role tokens found") | |
o_norm = gr.Textbox(label="Mean hidden-state L2-norm") | |
o_acc = gr.Textbox(label="Recovery accuracy") | |
run.click(encode_and_trace, [txt, roles], [o_tok, o_norm, o_acc]) | |
return demo | |
if __name__ == "__main__": | |
app().launch() | |