Spaces:
Running
on
Zero
Running
on
Zero
# app.py – encoder-only demo for bert-beatrix-2048 + role-probe | |
# ------------------------------------------------------------ | |
# launch: python app.py | |
# (gradio UI appears at http://localhost:7860) | |
import json, sys | |
from pathlib import Path, PurePosixPath | |
import gradio as gr | |
import spaces | |
import torch | |
import torch.nn.functional as F | |
from huggingface_hub import snapshot_download | |
from bert_handler import create_handler_from_checkpoint | |
# ------------------------------------------------------------------ | |
# 0. Download & patch config.json -------------------------------- | |
# ------------------------------------------------------------------ | |
REPO_ID = "AbstractPhil/bert-beatrix-2048" | |
LOCAL_DIR = "bert-beatrix-2048" # local cache dir | |
snapshot_download( | |
repo_id=REPO_ID, | |
revision="main", | |
local_dir=LOCAL_DIR, | |
local_dir_use_symlinks=False, | |
) | |
cfg_path = Path(LOCAL_DIR) / "config.json" | |
with cfg_path.open() as f: | |
cfg = json.load(f) | |
auto_map = cfg.get("auto_map", {}) | |
patched = False | |
for k, v in auto_map.items(): | |
if "--" in v: # e.g. "repo--module.Class" | |
auto_map[k] = PurePosixPath(v.split("--", 1)[1]).as_posix() | |
patched = True | |
if patched: | |
with cfg_path.open("w") as f: | |
json.dump(cfg, f, indent=2) | |
print("🛠️ Patched config.json → auto_map fixed.") | |
# ------------------------------------------------------------------ | |
# 1. Model / tokenizer ------------------------------------------- | |
# ------------------------------------------------------------------ | |
handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_DIR) | |
full_model = full_model.eval().cuda() | |
encoder = full_model.bert.encoder | |
embeddings = full_model.bert.embeddings | |
emb_ln = full_model.bert.emb_ln | |
emb_drop = full_model.bert.emb_drop | |
# ------------------------------------------------------------------ | |
# 2. Symbolic token set ------------------------------------------ | |
# ------------------------------------------------------------------ | |
SYMBOLIC_ROLES = [ | |
"<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>", | |
"<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>", | |
"<upper_body_clothing>", "<hair_style>", "<hair_length>", "<headwear>", | |
"<texture>", "<pattern>", "<grid>", "<zone>", "<offset>", | |
"<object_left>", "<object_right>", "<relation>", "<intent>", "<style>", | |
"<fabric>", "<jewelry>", | |
] | |
ROLE_ID = {tok: tokenizer.convert_tokens_to_ids(tok) for tok in SYMBOLIC_ROLES} | |
missing = [tok for tok, tid in ROLE_ID.items() if tid == tokenizer.unk_token_id] | |
if missing: | |
sys.exit(f"❌ Tokenizer is missing {missing}") | |
# ------------------------------------------------------------------ | |
# 3. Encoder-only + role-similarity probe ------------------------ | |
# ------------------------------------------------------------------ | |
def encode_and_trace(text: str, selected_roles: list[str]): | |
""" | |
For each *selected* role: | |
• find the contextual token whose hidden state is most similar to that | |
role’s own embedding (cosine similarity) | |
• return “role → token (sim)”, using tokens even when the prompt | |
contained no <role> markers at all. | |
Also keeps the older diagnostics. | |
""" | |
with torch.no_grad(): | |
batch = tokenizer(text, return_tensors="pt").to("cuda") | |
ids, mask = batch.input_ids, batch.attention_mask # (1, S) | |
# ---------- encoder ---------- | |
x = emb_drop(emb_ln(embeddings(ids))) | |
msk = full_model.bert.get_extended_attention_mask(mask, x.shape[:-1]) | |
h = encoder(x, attention_mask=msk).squeeze(0) # (S, H) | |
# L2-normalise hidden states once | |
h_norm = F.normalize(h, dim=-1) # (S, H) | |
# ---------- probe each selected role ----------------------- | |
matches = [] | |
for role in selected_roles: | |
role_vec = embeddings.word_embeddings.weight[ROLE_ID[role]].to(h.device) | |
role_vec = F.normalize(role_vec, dim=-1) # (H) | |
sims = (h_norm @ role_vec) # (S) | |
best_idx = int(sims.argmax().item()) | |
best_sim = float(sims[best_idx]) | |
match_tok = tokenizer.convert_ids_to_tokens(int(ids[0, best_idx])) | |
matches.append(f"{role} → {match_tok} ({best_sim:.2f})") | |
match_str = ", ".join(matches) if matches else "(no roles selected)" | |
# ---------- string-match diagnostics ----------------------- | |
present = [tok for tok_id, tok in zip(ids[0].tolist(), | |
tokenizer.convert_ids_to_tokens(ids[0])) | |
if tok in selected_roles] | |
present_str = ", ".join(present) or "(none)" | |
count = len(present) | |
# ---------- hidden-state norm of *explicit* role tokens ---- | |
if count: | |
exp_mask = torch.tensor([tid in ROLE_ID.values() for tid in ids[0]], device=h.device) | |
norm_val = f"{h[exp_mask].mean(0).norm().item():.4f}" | |
else: | |
norm_val = "0.0000" | |
return present_str, match_str, norm_val, count | |
# ------------------------------------------------------------------ | |
# 4. Gradio UI ---------------------------------------------------- | |
# ------------------------------------------------------------------ | |
def build_interface(): | |
with gr.Blocks(title="🧠 Symbolic Encoder Inspector") as demo: | |
gr.Markdown( | |
"## 🧠 Symbolic Encoder Inspector \n" | |
"Select one or more symbolic *roles* on the left. " | |
"The tool shows which regular tokens (if any) the model thinks " | |
"best fit each role — even when your text doesn’t contain the " | |
"explicit `<role>` marker." | |
) | |
with gr.Row(): | |
with gr.Column(): | |
txt = gr.Textbox( | |
label="Input text", | |
lines=3, | |
placeholder="Example: A small child in bright red boots jumps over a muddy puddle…", | |
) | |
roles = gr.CheckboxGroup( | |
choices=SYMBOLIC_ROLES, | |
label="Roles to probe", | |
) | |
btn = gr.Button("Run encoder probe") | |
with gr.Column(): | |
out_present = gr.Textbox(label="Explicit role tokens found") | |
out_match = gr.Textbox(label="Role → Best-Match Token (cos θ)") | |
out_norm = gr.Textbox(label="Mean hidden-state norm (explicit)") | |
out_count = gr.Textbox(label="# explicit role tokens") | |
btn.click( | |
encode_and_trace, | |
inputs=[txt, roles], | |
outputs=[out_present, out_match, out_norm, out_count], | |
) | |
return demo | |
if __name__ == "__main__": | |
build_interface().launch() | |