Spaces:
Running
on
Zero
Running
on
Zero
# app.py – encoder-only demo + pool-and-test prototype | |
# ---------------------------------------------------- | |
# launch: python app.py | |
# UI: http://localhost:7860 | |
import json, re, sys, math | |
from pathlib import Path, PurePosixPath | |
import torch, torch.nn.functional as F | |
import gradio as gr, spaces | |
from huggingface_hub import snapshot_download | |
from bert_handler import create_handler_from_checkpoint | |
# ------------------------------------------------------------------ | |
# 0. One-time patch of auto_map in config.json | |
# ------------------------------------------------------------------ | |
REPO_ID = "AbstractPhil/bert-beatrix-2048" | |
LOCAL_CKPT = "bert-beatrix-2048" | |
snapshot_download( | |
repo_id=REPO_ID, | |
revision="main", | |
local_dir=LOCAL_CKPT, | |
local_dir_use_symlinks=False, | |
) | |
cfg_path = Path(LOCAL_CKPT) / "config.json" | |
cfg = json.loads(cfg_path.read_text()) | |
auto_map = cfg.get("auto_map", {}) | |
changed = False | |
for k, v in auto_map.items(): | |
if "--" in v: # strip “repo--” | |
auto_map[k] = PurePosixPath(v.split("--", 1)[1]).as_posix() | |
changed = True | |
if changed: | |
cfg_path.write_text(json.dumps(cfg, indent=2)) | |
print("🛠️ Patched config.json → auto_map points to local modules") | |
# ------------------------------------------------------------------ | |
# 1. Load model + tokenizer with BERTHandler | |
# ------------------------------------------------------------------ | |
handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_CKPT) | |
full_model = full_model.eval().cuda() | |
# pull encoder & embedding stack | |
encoder = full_model.bert.encoder | |
embeddings = full_model.bert.embeddings | |
emb_weight = embeddings.word_embeddings.weight # <- correct tensor | |
emb_ln = full_model.bert.emb_ln | |
emb_drop = full_model.bert.emb_drop | |
# ------------------------------------------------------------------ | |
# 2. Symbolic roles | |
# ------------------------------------------------------------------ | |
SYMBOLIC_ROLES = [ | |
"<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>", | |
"<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>", | |
"<upper_body_clothing>", "<hair_style>", "<hair_length>", "<headwear>", | |
"<texture>", "<pattern>", "<grid>", "<zone>", "<offset>", | |
"<object_left>", "<object_right>", "<relation>", "<intent>", "<style>", | |
"<fabric>", "<jewelry>", | |
] | |
missing = [t for t in SYMBOLIC_ROLES | |
if tokenizer.convert_tokens_to_ids(t) == tokenizer.unk_token_id] | |
if missing: | |
sys.exit(f"❌ Tokenizer missing {missing}") | |
MASK_ID = tokenizer.mask_token_id | |
MASK_TOK = tokenizer.mask_token | |
# ------------------------------------------------------------------ | |
# helpers ----------------------------------------------------------- | |
def contextual_vectors(ids, mask): | |
"""run through embedding→encoder, return (S,H) hidden states""" | |
x = emb_drop(emb_ln(embeddings(ids))) # (1,S,H) | |
ext = full_model.bert.get_extended_attention_mask(mask, x.shape[:-1]) | |
return encoder(x, attention_mask=ext).squeeze(0) # (S,H) | |
def pool_accuracy(ids, mask, pool_positions): | |
"""mask positions in pool, predict, calc accuracy""" | |
masked = ids.clone() | |
masked[0, pool_positions] = MASK_ID | |
with torch.no_grad(): | |
logits = full_model(masked, attention_mask=mask).logits[0] | |
preds = logits.argmax(-1) | |
gold = ids.squeeze(0) | |
correct = (preds[pool_positions] == gold[pool_positions]).sum().item() | |
return correct / len(pool_positions) if pool_positions else 0.0 | |
# cosine utility | |
def cos(a, b): return F.cosine_similarity(a, b, dim=-1, eps=1e-8).item() | |
# ------------------------------------------------------------------ | |
# 3. Core routine --------------------------------------------------- | |
def encode_and_trace(text: str, picked_roles: list[str]): | |
# -------- tokenise ---------- | |
batch = tokenizer(text, return_tensors="pt").to("cuda") | |
ids, attn = batch.input_ids, batch.attention_mask | |
hid = contextual_vectors(ids, attn) # (S,H) | |
# -------- decide which roles we analyse ---------- | |
present = {tid: pos for pos, tid in enumerate(ids[0].tolist()) | |
if tid in {tokenizer.convert_tokens_to_ids(r) for r in SYMBOLIC_ROLES}} | |
if picked_roles: | |
present = {tid: pos for tid, pos in present.items() | |
if tokenizer.convert_ids_to_tokens([tid])[0] in picked_roles} | |
if not present: | |
return "No symbolic tokens in sentence", "", "" | |
# -------- similarity scores ---------- | |
sims = [] | |
for tid, pos in present.items(): | |
rvec = emb_weight[tid] # static embedding | |
cvec = hid[pos] # contextual | |
sims.append((cos(cvec, rvec), tid, pos)) | |
sims.sort() # low → high | |
# pools: bottom-2, top-2 (expand later) | |
low_pool, high_pool = sims[:2], sims[-2:] | |
accepted = [] | |
for grow in range(1 + math.ceil(len(sims)/2)): # ≤26 shots | |
for tag, pool in [("low", low_pool), ("high", high_pool)]: | |
pool_pos = [p for _,_,p in pool] | |
acc = pool_accuracy(ids, attn, pool_pos) | |
if acc >= 0.5: # category accepted | |
roles = [tokenizer.convert_ids_to_tokens([tid])[0] for _,tid,_ in pool] | |
accepted.append(f"{tag}:{roles} (acc {acc:.2f})") | |
if accepted: break # stop once something passed | |
# grow pools by two (if any left) | |
next_lo = sims[2+grow*2 : 4+grow*2] | |
next_hi = sims[-4-grow*2 : -2-grow*2] if 4+grow*2 <= len(sims) else [] | |
low_pool += next_lo | |
high_pool += next_hi | |
if not accepted: | |
accepted = ["(none hit 50 %)"] | |
return ", ".join(accepted), f"{len(present)} roles analysed", f"{text[:80]}…" | |
# ------------------------------------------------------------------ | |
# 4. UI ------------------------------------------------------------- | |
def build_ui(): | |
with gr.Blocks(title="🧠 Symbolic Encoder Inspector") as demo: | |
gr.Markdown( | |
"## 🧠 Symbolic Encoder Inspector \n" | |
"Select roles, paste text, and watch the pool-and-test prototype work." | |
) | |
with gr.Row(): | |
with gr.Column(): | |
txt = gr.Textbox(lines=3, label="Input") | |
roles = gr.CheckboxGroup( | |
SYMBOLIC_ROLES, | |
value=SYMBOLIC_ROLES, | |
label="Roles to consider (else all present)" | |
) | |
btn = gr.Button("Run") | |
with gr.Column(): | |
out_cat = gr.Textbox(label="Accepted categories") | |
out_info= gr.Textbox(label="Debug") | |
out_excerpt = gr.Textbox(label="Excerpt") | |
btn.click(encode_and_trace, [txt, roles], [out_cat, out_info, out_excerpt]) | |
return demo | |
if __name__ == "__main__": | |
build_ui().launch() | |