Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,909 Bytes
0a14990 aa28bbb 872b08b 0a14990 81a0ae4 096fe3a 872b08b 096fe3a 0a14990 9a08859 0a14990 096fe3a 535151e 0a14990 8a2e372 096fe3a 0a14990 096fe3a 785df91 872b08b 785df91 0a14990 096fe3a 0a14990 da8b548 0a14990 872b08b 0a14990 da8b548 096fe3a 0a14990 096fe3a 872b08b da8b548 0a14990 872b08b 535151e 0a14990 aa28bbb 0a14990 aa28bbb 0a14990 535151e 0a14990 872b08b 0a14990 535151e ea2994b 0a14990 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
# app.py – encoder-only + masking accuracy demo for bert-beatrix-2048
# -----------------------------------------------------------------
# launch: python app.py (UI at http://localhost:7860)
import json, re, sys
from pathlib import Path, PurePosixPath
import gradio as gr
import spaces
import torch
from huggingface_hub import snapshot_download
from bert_handler import create_handler_from_checkpoint
# ------------------------------------------------------------------
# 0. download repo + patch auto_map --------------------------------
REPO_ID = "AbstractPhil/bert-beatrix-2048"
LOCAL_CK = "bert-beatrix-2048"
snapshot_download(repo_id=REPO_ID, local_dir=LOCAL_CK, local_dir_use_symlinks=False)
cfg_p = Path(LOCAL_CK) / "config.json"
with cfg_p.open() as f:
cfg = json.load(f)
for k, v in cfg.get("auto_map", {}).items():
if "--" in v:
cfg["auto_map"][k] = PurePosixPath(v.split("--", 1)[1]).as_posix()
with cfg_p.open("w") as f:
json.dump(cfg, f, indent=2)
# ------------------------------------------------------------------
# 1. load model / tokenizer ---------------------------------------
handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_CK)
full_model = full_model.eval().cuda()
encoder = full_model.bert.encoder
embeddings = full_model.bert.embeddings
emb_ln = full_model.bert.emb_ln
emb_drop = full_model.bert.emb_drop
MASK = tokenizer.mask_token or "[MASK]"
# ------------------------------------------------------------------
# 2. symbolic role list -------------------------------------------
SYMBOLIC_ROLES = [
"<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
"<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>",
"<upper_body_clothing>", "<hair_style>", "<hair_length>", "<headwear>",
"<texture>", "<pattern>", "<grid>", "<zone>", "<offset>",
"<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
"<fabric>", "<jewelry>",
]
miss = [t for t in SYMBOLIC_ROLES
if tokenizer.convert_tokens_to_ids(t) == tokenizer.unk_token_id]
if miss:
sys.exit(f"❌ Tokenizer missing {miss}")
# ------------------------------------------------------------------
# 3. inference util ----------------------------------------------
@spaces.GPU
def encode_and_trace(text: str, selected_roles: list[str]):
# ----- 3-A. build masked version & encode original --------------
sel_ids = {tokenizer.convert_tokens_to_ids(t) for t in selected_roles}
# tokenised “plain” text
plain = tokenizer(text, return_tensors="pt").to("cuda")
ids_plain = plain.input_ids
# make masked string (regex to avoid partial hits)
masked_txt = text
for tok in selected_roles:
masked_txt = re.sub(re.escape(tok), MASK, masked_txt)
masked = tokenizer(masked_txt, return_tensors="pt").to("cuda")
ids_masked = masked.input_ids
# ----- 3-B. run model on masked text ----------------------------
with torch.no_grad():
logits = full_model(**masked).logits[0] # (S, V)
preds = logits.argmax(-1) # (S,)
# ----- 3-C. gather stats per masked role ------------------------
found_tokens, correct = [], 0
role_flags = []
for i, (orig_id, pred_id) in enumerate(zip(ids_plain[0], preds)):
if orig_id.item() in sel_ids and ids_masked[0, i].item() == tokenizer.mask_token_id:
found_tokens.append(tokenizer.convert_ids_to_tokens([orig_id])[0])
correct += int(orig_id.item() == pred_id.item())
role_flags.append(i)
total = len(role_flags)
acc = correct / total if total else 0.0
# ----- 3-D. encoder rep pooling for *all* selected roles --------
with torch.no_grad():
# embeddings -> normed reps
x = emb_drop(emb_ln(embeddings(ids_plain)))
attn = full_model.bert.get_extended_attention_mask(
plain.attention_mask, x.shape[:-1]
)
enc = encoder(x, attention_mask=attn) # (1,S,H)
mask_vec = torch.tensor(
[tid in sel_ids for tid in ids_plain[0].tolist()], device=enc.device
)
if mask_vec.any():
pooled = enc[0][mask_vec].mean(0)
norm = f"{pooled.norm().item():.4f}"
else:
norm = "0.0000"
tokens_str = ", ".join(found_tokens) or "(none)"
return tokens_str, norm, f"{acc*100:.1f}%"
# ------------------------------------------------------------------
# 4. gradio UI ----------------------------------------------------
def app():
with gr.Blocks(title="🧠 Symbolic Encoder Inspector") as demo:
gr.Markdown(
"## 🧠 Symbolic Encoder Inspector \n"
"1. Model side: we *mask* every chosen role token, run the LM, and report how often it recovers the original. \n"
"2. Encoder side: we also pool hidden-state vectors for those roles and give their mean L2-norm."
)
with gr.Row():
with gr.Column():
txt = gr.Textbox(
label="Input with Symbolic Tokens",
lines=3,
placeholder="Example: A <subject> wearing <upper_body_clothing> …",
)
roles = gr.CheckboxGroup(
choices=SYMBOLIC_ROLES,
value=SYMBOLIC_ROLES, # <- all pre-selected
label="Roles to mask & trace",
)
run = gr.Button("Run")
with gr.Column():
o_tok = gr.Textbox(label="Masked-role tokens found")
o_norm = gr.Textbox(label="Mean hidden-state L2-norm")
o_acc = gr.Textbox(label="Recovery accuracy")
run.click(encode_and_trace, [txt, roles], [o_tok, o_norm, o_acc])
return demo
if __name__ == "__main__":
app().launch()
|