Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,105 Bytes
5e20a2a fd4a12a aa28bbb 872b08b fd4a12a 0a14990 81a0ae4 5e20a2a 096fe3a 872b08b 5e20a2a 096fe3a fd4a12a 5e20a2a fd4a12a 5e20a2a fd4a12a 5e20a2a fd4a12a 5e20a2a fd4a12a 0a14990 fd4a12a 9a08859 fd4a12a 096fe3a 535151e fd4a12a 0a14990 5e20a2a fd4a12a 096fe3a 785df91 872b08b 785df91 fd4a12a 5e20a2a fd4a12a 5e20a2a 096fe3a fd4a12a 5e20a2a fd4a12a 5e20a2a fd4a12a 5e20a2a fd4a12a 6cdb3e9 fd4a12a 20331bc fd4a12a 6cdb3e9 20331bc 6cdb3e9 20331bc 6cdb3e9 fd4a12a 6cdb3e9 20331bc e6ff6d3 6cdb3e9 e6ff6d3 6cdb3e9 e6ff6d3 6cdb3e9 20331bc 6cdb3e9 20331bc 6cdb3e9 e6ff6d3 6cdb3e9 e6ff6d3 6cdb3e9 f22905a fd4a12a 20331bc aaae56c 20331bc 6cdb3e9 fd4a12a 20331bc fd4a12a 6cdb3e9 fd4a12a 6cdb3e9 fd4a12a 6cdb3e9 fd4a12a 0a14990 da8b548 6cdb3e9 20331bc 096fe3a fd4a12a 5e20a2a 096fe3a fd4a12a 5e20a2a 535151e fd4a12a 0a14990 fd4a12a 535151e fd4a12a 872b08b fd4a12a 535151e ea2994b 5e20a2a fd4a12a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
# app.py – encoder-only demo for bert-beatrix-2048
# launch: python app.py
# -----------------------------------------------
import json, re, sys, math
from pathlib import Path, PurePosixPath
import torch, torch.nn.functional as F
import gradio as gr
import spaces
from huggingface_hub import snapshot_download
from bert_handler import create_handler_from_checkpoint
# ------------------------------------------------------------------
# 0. Download & patch HF checkpoint --------------------------------
REPO_ID = "AbstractPhil/bert-beatrix-2048"
LOCAL_CKPT = "bert-beatrix-2048"
snapshot_download(
repo_id=REPO_ID,
revision="main",
local_dir=LOCAL_CKPT,
local_dir_use_symlinks=False,
)
# → strip repo prefix in auto_map (one-time)
cfg_path = Path(LOCAL_CKPT) / "config.json"
with cfg_path.open() as f: cfg = json.load(f)
amap = cfg.get("auto_map", {})
for k,v in amap.items():
if "--" in v:
amap[k] = PurePosixPath(v.split("--",1)[1]).as_posix()
cfg["auto_map"] = amap
with cfg_path.open("w") as f: json.dump(cfg,f,indent=2)
# ------------------------------------------------------------------
# 1. Load model & components --------------------------------------
handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_CKPT)
full_model = full_model.eval().cuda()
encoder = full_model.bert.encoder
embeddings = full_model.bert.embeddings
emb_ln = full_model.bert.emb_ln
emb_drop = full_model.bert.emb_drop
mlm_head = full_model.cls # prediction head
# ------------------------------------------------------------------
# 2. Symbolic roles -------------------------------------------------
SYMBOLIC_ROLES = [
"<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
"<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>",
"<upper_body_clothing>", "<hair_style>", "<hair_length>", "<headwear>",
"<texture>", "<pattern>", "<grid>", "<zone>", "<offset>",
"<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
"<fabric>", "<jewelry>",
]
if any(tokenizer.convert_tokens_to_ids(t)==tokenizer.unk_token_id
for t in SYMBOLIC_ROLES):
sys.exit("❌ tokenizer missing special tokens")
# Quick helpers
MASK = tokenizer.mask_token
# ------------------------------------------------------------------
# 3. Encoder-plus-MLM logic ---------------------------------------
def cosine(a,b):
return torch.nn.functional.cosine_similarity(a,b,dim=-1)
def pool_accuracy(ids, logits, pool_mask):
"""
ids : (S,) gold token ids
logits : (S,V) MLM logits
pool_mask : bool (S,) which tokens belong to the candidate pool
returns accuracy over masked positions only (if none, return 0)
"""
idx = pool_mask.nonzero(as_tuple=False).flatten()
if idx.numel()==0: return 0.0
preds = logits.argmax(-1)[idx]
gold = ids[idx]
return (preds==gold).float().mean().item()
@spaces.GPU
def encode_and_trace(text, selected_roles):
if not selected_roles:
selected_roles = SYMBOLIC_ROLES
sel_ids = [tokenizer.convert_tokens_to_ids(t) for t in selected_roles]
sel_ids_tensor = torch.tensor(sel_ids, device="cuda")
# Tokenize input
batch = tokenizer(text, return_tensors="pt").to("cuda")
ids, attn = batch.input_ids, batch.attention_mask
S = ids.shape[1]
# Safe encoder forward
def encode(input_ids, attn_mask):
x = embeddings(input_ids)
if emb_ln: x = emb_ln(x)
if emb_drop: x = emb_drop(x)
ext = full_model.bert.get_extended_attention_mask(attn_mask, x.shape[:-1])
return encoder(x, attention_mask=ext)[0]
encoded = encode(ids, attn)
# Get raw symbolic token embeddings directly
symbolic_embeds = embeddings.word_embeddings(sel_ids_tensor) # ✅ FIXED
sim = cosine(encoded.unsqueeze(1), symbolic_embeds.unsqueeze(0)) # (S, R)
maxcos, argrole = sim.max(-1) # (S,)
top_roles = [selected_roles[i] for i in argrole.tolist()]
sort_idx = maxcos.argsort(descending=True)
hi_idx = sort_idx[:S // 2]
lo_idx = sort_idx[S // 2:]
MASK_ID = tokenizer.mask_token_id or tokenizer.convert_tokens_to_ids("[MASK]")
# 🔧 Pass ids into this function
def evaluate_pool(idx_order, label, ids):
best_pool, best_acc = [], 0.0
ptr = 0
while ptr < len(idx_order):
cand = idx_order[ptr:ptr + 2]
pool = best_pool + cand.tolist()
ptr += 2
mask_flags = torch.zeros_like(ids, dtype=torch.bool)
mask_flags[0, pool] = True
masked_input = ids.where(mask_flags, MASK_ID)
encoded_m = encode(masked_input, attn)
logits = mlm_head(encoded_m)[0] # ✅ FIXED — direct tensor
preds = logits.argmax(-1)
masked_positions = (~mask_flags[0]).nonzero(as_tuple=False).squeeze(-1)
if masked_positions.numel() == 0:
continue
# ✅ FIXED: indexing from explicitly passed ids
gold = ids[0][masked_positions]
correct = (preds[masked_positions] == gold).float()
acc = correct.mean().item()
if acc > best_acc:
best_pool, best_acc = pool, acc
if acc >= 0.5:
break
return best_pool, best_acc
# Run both pool evaluations
pool_hi, acc_hi = evaluate_pool(hi_idx, "high", ids)
pool_lo, acc_lo = evaluate_pool(lo_idx, "low", ids)
# Per-token symbolic trace
decoded_tokens = tokenizer.convert_ids_to_tokens(ids[0])
role_trace = [
f"{tok:<15} → {role} cos={score:.4f}"
for tok, role, score in zip(decoded_tokens, top_roles, maxcos.tolist())
]
# Output JSON
res_json = {
"High-pool tokens": tokenizer.decode(ids[0, pool_hi]),
"High accuracy": f"{acc_hi:.3f}",
"Low-pool tokens": tokenizer.decode(ids[0, pool_lo]),
"Low accuracy": f"{acc_lo:.3f}",
"Token–Symbolic Role Alignment": role_trace
}
return json.dumps(res_json, indent=2), f"{maxcos.max():.4f}", len(selected_roles)
# ------------------------------------------------------------------
# 4. Gradio UI -----------------------------------------------------
def build_interface():
with gr.Blocks(title="🧠 Symbolic Encoder Inspector") as demo:
gr.Markdown("## 🧠 Symbolic Encoder Inspector")
with gr.Row():
with gr.Column():
txt = gr.Textbox(label="Prompt", lines=3)
roles= gr.CheckboxGroup(
choices=SYMBOLIC_ROLES, label="Roles",
value=SYMBOLIC_ROLES # pre-checked
)
btn = gr.Button("Run")
with gr.Column():
out_json = gr.Textbox(label="Result JSON")
out_max = gr.Textbox(label="Max cos")
out_cnt = gr.Textbox(label="# roles")
btn.click(encode_and_trace, [txt,roles], [out_json,out_max,out_cnt])
return demo
if __name__=="__main__":
build_interface().launch() |