AbstractPhil's picture
Update app.py
aa28bbb verified
raw
history blame
6.96 kB
# app.py – encoder-only demo for bert-beatrix-2048 + role-probe
# ------------------------------------------------------------
# launch: python app.py
# (gradio UI appears at http://localhost:7860)
import json, sys
from pathlib import Path, PurePosixPath
import gradio as gr
import spaces
import torch
import torch.nn.functional as F
from huggingface_hub import snapshot_download
from bert_handler import create_handler_from_checkpoint
# ------------------------------------------------------------------
# 0. Download & patch config.json --------------------------------
# ------------------------------------------------------------------
REPO_ID = "AbstractPhil/bert-beatrix-2048"
LOCAL_DIR = "bert-beatrix-2048" # local cache dir
snapshot_download(
repo_id=REPO_ID,
revision="main",
local_dir=LOCAL_DIR,
local_dir_use_symlinks=False,
)
cfg_path = Path(LOCAL_DIR) / "config.json"
with cfg_path.open() as f:
cfg = json.load(f)
auto_map = cfg.get("auto_map", {})
patched = False
for k, v in auto_map.items():
if "--" in v: # e.g. "repo--module.Class"
auto_map[k] = PurePosixPath(v.split("--", 1)[1]).as_posix()
patched = True
if patched:
with cfg_path.open("w") as f:
json.dump(cfg, f, indent=2)
print("🛠️ Patched config.json → auto_map fixed.")
# ------------------------------------------------------------------
# 1. Model / tokenizer -------------------------------------------
# ------------------------------------------------------------------
handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_DIR)
full_model = full_model.eval().cuda()
encoder = full_model.bert.encoder
embeddings = full_model.bert.embeddings
emb_ln = full_model.bert.emb_ln
emb_drop = full_model.bert.emb_drop
# ------------------------------------------------------------------
# 2. Symbolic token set ------------------------------------------
# ------------------------------------------------------------------
SYMBOLIC_ROLES = [
"<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
"<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>",
"<upper_body_clothing>", "<hair_style>", "<hair_length>", "<headwear>",
"<texture>", "<pattern>", "<grid>", "<zone>", "<offset>",
"<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
"<fabric>", "<jewelry>",
]
ROLE_ID = {tok: tokenizer.convert_tokens_to_ids(tok) for tok in SYMBOLIC_ROLES}
missing = [tok for tok, tid in ROLE_ID.items() if tid == tokenizer.unk_token_id]
if missing:
sys.exit(f"❌ Tokenizer is missing {missing}")
# ------------------------------------------------------------------
# 3. Encoder-only + role-similarity probe ------------------------
# ------------------------------------------------------------------
@spaces.GPU
def encode_and_trace(text: str, selected_roles: list[str]):
"""
For each *selected* role:
• find the contextual token whose hidden state is most similar to that
role’s own embedding (cosine similarity)
• return “role → token (sim)”, using tokens even when the prompt
contained no <role> markers at all.
Also keeps the older diagnostics.
"""
with torch.no_grad():
batch = tokenizer(text, return_tensors="pt").to("cuda")
ids, mask = batch.input_ids, batch.attention_mask # (1, S)
# ---------- encoder ----------
x = emb_drop(emb_ln(embeddings(ids)))
msk = full_model.bert.get_extended_attention_mask(mask, x.shape[:-1])
h = encoder(x, attention_mask=msk).squeeze(0) # (S, H)
# L2-normalise hidden states once
h_norm = F.normalize(h, dim=-1) # (S, H)
# ---------- probe each selected role -----------------------
matches = []
for role in selected_roles:
role_vec = embeddings.word_embeddings.weight[ROLE_ID[role]].to(h.device)
role_vec = F.normalize(role_vec, dim=-1) # (H)
sims = (h_norm @ role_vec) # (S)
best_idx = int(sims.argmax().item())
best_sim = float(sims[best_idx])
match_tok = tokenizer.convert_ids_to_tokens(int(ids[0, best_idx]))
matches.append(f"{role}{match_tok} ({best_sim:.2f})")
match_str = ", ".join(matches) if matches else "(no roles selected)"
# ---------- string-match diagnostics -----------------------
present = [tok for tok_id, tok in zip(ids[0].tolist(),
tokenizer.convert_ids_to_tokens(ids[0]))
if tok in selected_roles]
present_str = ", ".join(present) or "(none)"
count = len(present)
# ---------- hidden-state norm of *explicit* role tokens ----
if count:
exp_mask = torch.tensor([tid in ROLE_ID.values() for tid in ids[0]], device=h.device)
norm_val = f"{h[exp_mask].mean(0).norm().item():.4f}"
else:
norm_val = "0.0000"
return present_str, match_str, norm_val, count
# ------------------------------------------------------------------
# 4. Gradio UI ----------------------------------------------------
# ------------------------------------------------------------------
def build_interface():
with gr.Blocks(title="🧠 Symbolic Encoder Inspector") as demo:
gr.Markdown(
"## 🧠 Symbolic Encoder Inspector \n"
"Select one or more symbolic *roles* on the left. "
"The tool shows which regular tokens (if any) the model thinks "
"best fit each role — even when your text doesn’t contain the "
"explicit `<role>` marker."
)
with gr.Row():
with gr.Column():
txt = gr.Textbox(
label="Input text",
lines=3,
placeholder="Example: A small child in bright red boots jumps over a muddy puddle…",
)
roles = gr.CheckboxGroup(
choices=SYMBOLIC_ROLES,
label="Roles to probe",
)
btn = gr.Button("Run encoder probe")
with gr.Column():
out_present = gr.Textbox(label="Explicit role tokens found")
out_match = gr.Textbox(label="Role → Best-Match Token (cos θ)")
out_norm = gr.Textbox(label="Mean hidden-state norm (explicit)")
out_count = gr.Textbox(label="# explicit role tokens")
btn.click(
encode_and_trace,
inputs=[txt, roles],
outputs=[out_present, out_match, out_norm, out_count],
)
return demo
if __name__ == "__main__":
build_interface().launch()