AbstractPhil's picture
Update app.py
0a14990 verified
raw
history blame
5.91 kB
# app.py – encoder-only + masking accuracy demo for bert-beatrix-2048
# -----------------------------------------------------------------
# launch: python app.py (UI at http://localhost:7860)
import json, re, sys
from pathlib import Path, PurePosixPath
import gradio as gr
import spaces
import torch
from huggingface_hub import snapshot_download
from bert_handler import create_handler_from_checkpoint
# ------------------------------------------------------------------
# 0. download repo + patch auto_map --------------------------------
REPO_ID = "AbstractPhil/bert-beatrix-2048"
LOCAL_CK = "bert-beatrix-2048"
snapshot_download(repo_id=REPO_ID, local_dir=LOCAL_CK, local_dir_use_symlinks=False)
cfg_p = Path(LOCAL_CK) / "config.json"
with cfg_p.open() as f:
cfg = json.load(f)
for k, v in cfg.get("auto_map", {}).items():
if "--" in v:
cfg["auto_map"][k] = PurePosixPath(v.split("--", 1)[1]).as_posix()
with cfg_p.open("w") as f:
json.dump(cfg, f, indent=2)
# ------------------------------------------------------------------
# 1. load model / tokenizer ---------------------------------------
handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_CK)
full_model = full_model.eval().cuda()
encoder = full_model.bert.encoder
embeddings = full_model.bert.embeddings
emb_ln = full_model.bert.emb_ln
emb_drop = full_model.bert.emb_drop
MASK = tokenizer.mask_token or "[MASK]"
# ------------------------------------------------------------------
# 2. symbolic role list -------------------------------------------
SYMBOLIC_ROLES = [
"<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
"<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>",
"<upper_body_clothing>", "<hair_style>", "<hair_length>", "<headwear>",
"<texture>", "<pattern>", "<grid>", "<zone>", "<offset>",
"<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
"<fabric>", "<jewelry>",
]
miss = [t for t in SYMBOLIC_ROLES
if tokenizer.convert_tokens_to_ids(t) == tokenizer.unk_token_id]
if miss:
sys.exit(f"❌ Tokenizer missing {miss}")
# ------------------------------------------------------------------
# 3. inference util ----------------------------------------------
@spaces.GPU
def encode_and_trace(text: str, selected_roles: list[str]):
# ----- 3-A. build masked version & encode original --------------
sel_ids = {tokenizer.convert_tokens_to_ids(t) for t in selected_roles}
# tokenised “plain” text
plain = tokenizer(text, return_tensors="pt").to("cuda")
ids_plain = plain.input_ids
# make masked string (regex to avoid partial hits)
masked_txt = text
for tok in selected_roles:
masked_txt = re.sub(re.escape(tok), MASK, masked_txt)
masked = tokenizer(masked_txt, return_tensors="pt").to("cuda")
ids_masked = masked.input_ids
# ----- 3-B. run model on masked text ----------------------------
with torch.no_grad():
logits = full_model(**masked).logits[0] # (S, V)
preds = logits.argmax(-1) # (S,)
# ----- 3-C. gather stats per masked role ------------------------
found_tokens, correct = [], 0
role_flags = []
for i, (orig_id, pred_id) in enumerate(zip(ids_plain[0], preds)):
if orig_id.item() in sel_ids and ids_masked[0, i].item() == tokenizer.mask_token_id:
found_tokens.append(tokenizer.convert_ids_to_tokens([orig_id])[0])
correct += int(orig_id.item() == pred_id.item())
role_flags.append(i)
total = len(role_flags)
acc = correct / total if total else 0.0
# ----- 3-D. encoder rep pooling for *all* selected roles --------
with torch.no_grad():
# embeddings -> normed reps
x = emb_drop(emb_ln(embeddings(ids_plain)))
attn = full_model.bert.get_extended_attention_mask(
plain.attention_mask, x.shape[:-1]
)
enc = encoder(x, attention_mask=attn) # (1,S,H)
mask_vec = torch.tensor(
[tid in sel_ids for tid in ids_plain[0].tolist()], device=enc.device
)
if mask_vec.any():
pooled = enc[0][mask_vec].mean(0)
norm = f"{pooled.norm().item():.4f}"
else:
norm = "0.0000"
tokens_str = ", ".join(found_tokens) or "(none)"
return tokens_str, norm, f"{acc*100:.1f}%"
# ------------------------------------------------------------------
# 4. gradio UI ----------------------------------------------------
def app():
with gr.Blocks(title="🧠 Symbolic Encoder Inspector") as demo:
gr.Markdown(
"## 🧠 Symbolic Encoder Inspector \n"
"1. Model side: we *mask* every chosen role token, run the LM, and report how often it recovers the original. \n"
"2. Encoder side: we also pool hidden-state vectors for those roles and give their mean L2-norm."
)
with gr.Row():
with gr.Column():
txt = gr.Textbox(
label="Input with Symbolic Tokens",
lines=3,
placeholder="Example: A <subject> wearing <upper_body_clothing> …",
)
roles = gr.CheckboxGroup(
choices=SYMBOLIC_ROLES,
value=SYMBOLIC_ROLES, # <- all pre-selected
label="Roles to mask & trace",
)
run = gr.Button("Run")
with gr.Column():
o_tok = gr.Textbox(label="Masked-role tokens found")
o_norm = gr.Textbox(label="Mean hidden-state L2-norm")
o_acc = gr.Textbox(label="Recovery accuracy")
run.click(encode_and_trace, [txt, roles], [o_tok, o_norm, o_acc])
return demo
if __name__ == "__main__":
app().launch()