Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,927 Bytes
da8b548 aa28bbb da8b548 aa28bbb 872b08b da8b548 81a0ae4 872b08b 096fe3a 872b08b 096fe3a da8b548 096fe3a da8b548 872b08b 81a0ae4 ed080e6 9a08859 81a0ae4 ed080e6 da8b548 9a08859 da8b548 9a08859 da8b548 9a08859 da8b548 9a08859 ed080e6 096fe3a 535151e da8b548 096fe3a da8b548 096fe3a 8a2e372 096fe3a da8b548 096fe3a 785df91 872b08b 785df91 da8b548 096fe3a da8b548 872b08b da8b548 096fe3a da8b548 096fe3a da8b548 872b08b 096fe3a da8b548 096fe3a da8b548 096fe3a 872b08b da8b548 872b08b 535151e da8b548 aa28bbb da8b548 aa28bbb da8b548 535151e da8b548 872b08b da8b548 535151e ea2994b da8b548 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
# app.py – encoder-only demo + pool-and-test prototype
# ----------------------------------------------------
# launch: python app.py
# UI: http://localhost:7860
import json, re, sys, math
from pathlib import Path, PurePosixPath
import torch, torch.nn.functional as F
import gradio as gr, spaces
from huggingface_hub import snapshot_download
from bert_handler import create_handler_from_checkpoint
# ------------------------------------------------------------------
# 0. One-time patch of auto_map in config.json
# ------------------------------------------------------------------
REPO_ID = "AbstractPhil/bert-beatrix-2048"
LOCAL_CKPT = "bert-beatrix-2048"
snapshot_download(
repo_id=REPO_ID,
revision="main",
local_dir=LOCAL_CKPT,
local_dir_use_symlinks=False,
)
cfg_path = Path(LOCAL_CKPT) / "config.json"
cfg = json.loads(cfg_path.read_text())
auto_map = cfg.get("auto_map", {})
changed = False
for k, v in auto_map.items():
if "--" in v: # strip “repo--”
auto_map[k] = PurePosixPath(v.split("--", 1)[1]).as_posix()
changed = True
if changed:
cfg_path.write_text(json.dumps(cfg, indent=2))
print("🛠️ Patched config.json → auto_map points to local modules")
# ------------------------------------------------------------------
# 1. Load model + tokenizer with BERTHandler
# ------------------------------------------------------------------
handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_CKPT)
full_model = full_model.eval().cuda()
# pull encoder & embedding stack
encoder = full_model.bert.encoder
embeddings = full_model.bert.embeddings
emb_weight = embeddings.word_embeddings.weight # <- correct tensor
emb_ln = full_model.bert.emb_ln
emb_drop = full_model.bert.emb_drop
# ------------------------------------------------------------------
# 2. Symbolic roles
# ------------------------------------------------------------------
SYMBOLIC_ROLES = [
"<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
"<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>",
"<upper_body_clothing>", "<hair_style>", "<hair_length>", "<headwear>",
"<texture>", "<pattern>", "<grid>", "<zone>", "<offset>",
"<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
"<fabric>", "<jewelry>",
]
missing = [t for t in SYMBOLIC_ROLES
if tokenizer.convert_tokens_to_ids(t) == tokenizer.unk_token_id]
if missing:
sys.exit(f"❌ Tokenizer missing {missing}")
MASK_ID = tokenizer.mask_token_id
MASK_TOK = tokenizer.mask_token
# ------------------------------------------------------------------
# helpers -----------------------------------------------------------
def contextual_vectors(ids, mask):
"""run through embedding→encoder, return (S,H) hidden states"""
x = emb_drop(emb_ln(embeddings(ids))) # (1,S,H)
ext = full_model.bert.get_extended_attention_mask(mask, x.shape[:-1])
return encoder(x, attention_mask=ext).squeeze(0) # (S,H)
def pool_accuracy(ids, mask, pool_positions):
"""mask positions in pool, predict, calc accuracy"""
masked = ids.clone()
masked[0, pool_positions] = MASK_ID
with torch.no_grad():
logits = full_model(masked, attention_mask=mask).logits[0]
preds = logits.argmax(-1)
gold = ids.squeeze(0)
correct = (preds[pool_positions] == gold[pool_positions]).sum().item()
return correct / len(pool_positions) if pool_positions else 0.0
# cosine utility
def cos(a, b): return F.cosine_similarity(a, b, dim=-1, eps=1e-8).item()
# ------------------------------------------------------------------
# 3. Core routine ---------------------------------------------------
@spaces.GPU
def encode_and_trace(text: str, picked_roles: list[str]):
# -------- tokenise ----------
batch = tokenizer(text, return_tensors="pt").to("cuda")
ids, attn = batch.input_ids, batch.attention_mask
hid = contextual_vectors(ids, attn) # (S,H)
# -------- decide which roles we analyse ----------
present = {tid: pos for pos, tid in enumerate(ids[0].tolist())
if tid in {tokenizer.convert_tokens_to_ids(r) for r in SYMBOLIC_ROLES}}
if picked_roles:
present = {tid: pos for tid, pos in present.items()
if tokenizer.convert_ids_to_tokens([tid])[0] in picked_roles}
if not present:
return "No symbolic tokens in sentence", "", ""
# -------- similarity scores ----------
sims = []
for tid, pos in present.items():
rvec = emb_weight[tid] # static embedding
cvec = hid[pos] # contextual
sims.append((cos(cvec, rvec), tid, pos))
sims.sort() # low → high
# pools: bottom-2, top-2 (expand later)
low_pool, high_pool = sims[:2], sims[-2:]
accepted = []
for grow in range(1 + math.ceil(len(sims)/2)): # ≤26 shots
for tag, pool in [("low", low_pool), ("high", high_pool)]:
pool_pos = [p for _,_,p in pool]
acc = pool_accuracy(ids, attn, pool_pos)
if acc >= 0.5: # category accepted
roles = [tokenizer.convert_ids_to_tokens([tid])[0] for _,tid,_ in pool]
accepted.append(f"{tag}:{roles} (acc {acc:.2f})")
if accepted: break # stop once something passed
# grow pools by two (if any left)
next_lo = sims[2+grow*2 : 4+grow*2]
next_hi = sims[-4-grow*2 : -2-grow*2] if 4+grow*2 <= len(sims) else []
low_pool += next_lo
high_pool += next_hi
if not accepted:
accepted = ["(none hit 50 %)"]
return ", ".join(accepted), f"{len(present)} roles analysed", f"{text[:80]}…"
# ------------------------------------------------------------------
# 4. UI -------------------------------------------------------------
def build_ui():
with gr.Blocks(title="🧠 Symbolic Encoder Inspector") as demo:
gr.Markdown(
"## 🧠 Symbolic Encoder Inspector \n"
"Select roles, paste text, and watch the pool-and-test prototype work."
)
with gr.Row():
with gr.Column():
txt = gr.Textbox(lines=3, label="Input")
roles = gr.CheckboxGroup(
SYMBOLIC_ROLES,
value=SYMBOLIC_ROLES,
label="Roles to consider (else all present)"
)
btn = gr.Button("Run")
with gr.Column():
out_cat = gr.Textbox(label="Accepted categories")
out_info= gr.Textbox(label="Debug")
out_excerpt = gr.Textbox(label="Excerpt")
btn.click(encode_and_trace, [txt, roles], [out_cat, out_info, out_excerpt])
return demo
if __name__ == "__main__":
build_ui().launch()
|