Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,734 Bytes
5e20a2a fd4a12a aa28bbb 872b08b fd4a12a 0a14990 81a0ae4 5e20a2a 096fe3a 872b08b 5e20a2a 096fe3a fd4a12a 5e20a2a fd4a12a 5e20a2a fd4a12a 5e20a2a fd4a12a 5e20a2a fd4a12a 0a14990 fd4a12a 9a08859 fd4a12a 096fe3a 535151e fd4a12a 0a14990 5e20a2a fd4a12a 096fe3a 785df91 872b08b 785df91 fd4a12a 5e20a2a fd4a12a 5e20a2a 096fe3a fd4a12a 5e20a2a fd4a12a 5e20a2a fd4a12a 5e20a2a fd4a12a aaae56c f22905a aaae56c f22905a aaae56c f22905a aaae56c f22905a aaae56c f22905a aaae56c f22905a aaae56c fd4a12a f22905a aaae56c f22905a aaae56c f22905a aaae56c f22905a fd4a12a aaae56c fd4a12a 0a14990 da8b548 096fe3a fd4a12a 5e20a2a 096fe3a fd4a12a 5e20a2a 535151e fd4a12a 0a14990 fd4a12a 535151e fd4a12a 872b08b fd4a12a 535151e ea2994b 5e20a2a fd4a12a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
# app.py – encoder-only demo for bert-beatrix-2048
# launch: python app.py
# -----------------------------------------------
import json, re, sys, math
from pathlib import Path, PurePosixPath
import torch, torch.nn.functional as F
import gradio as gr
import spaces
from huggingface_hub import snapshot_download
from bert_handler import create_handler_from_checkpoint
# ------------------------------------------------------------------
# 0. Download & patch HF checkpoint --------------------------------
REPO_ID = "AbstractPhil/bert-beatrix-2048"
LOCAL_CKPT = "bert-beatrix-2048"
snapshot_download(
repo_id=REPO_ID,
revision="main",
local_dir=LOCAL_CKPT,
local_dir_use_symlinks=False,
)
# → strip repo prefix in auto_map (one-time)
cfg_path = Path(LOCAL_CKPT) / "config.json"
with cfg_path.open() as f: cfg = json.load(f)
amap = cfg.get("auto_map", {})
for k,v in amap.items():
if "--" in v:
amap[k] = PurePosixPath(v.split("--",1)[1]).as_posix()
cfg["auto_map"] = amap
with cfg_path.open("w") as f: json.dump(cfg,f,indent=2)
# ------------------------------------------------------------------
# 1. Load model & components --------------------------------------
handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_CKPT)
full_model = full_model.eval().cuda()
encoder = full_model.bert.encoder
embeddings = full_model.bert.embeddings
emb_ln = full_model.bert.emb_ln
emb_drop = full_model.bert.emb_drop
mlm_head = full_model.cls # prediction head
# ------------------------------------------------------------------
# 2. Symbolic roles -------------------------------------------------
SYMBOLIC_ROLES = [
"<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
"<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>",
"<upper_body_clothing>", "<hair_style>", "<hair_length>", "<headwear>",
"<texture>", "<pattern>", "<grid>", "<zone>", "<offset>",
"<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
"<fabric>", "<jewelry>",
]
if any(tokenizer.convert_tokens_to_ids(t)==tokenizer.unk_token_id
for t in SYMBOLIC_ROLES):
sys.exit("❌ tokenizer missing special tokens")
# Quick helpers
MASK = tokenizer.mask_token
# ------------------------------------------------------------------
# 3. Encoder-plus-MLM logic ---------------------------------------
def cosine(a,b):
return torch.nn.functional.cosine_similarity(a,b,dim=-1)
def pool_accuracy(ids, logits, pool_mask):
"""
ids : (S,) gold token ids
logits : (S,V) MLM logits
pool_mask : bool (S,) which tokens belong to the candidate pool
returns accuracy over masked positions only (if none, return 0)
"""
idx = pool_mask.nonzero(as_tuple=False).flatten()
if idx.numel()==0: return 0.0
preds = logits.argmax(-1)[idx]
gold = ids[idx]
return (preds==gold).float().mean().item()
@spaces.GPU
def encode_and_trace(text, selected_roles):
# if user unchecked everything we treat as "all"
if not selected_roles:
selected_roles = SYMBOLIC_ROLES
sel_ids = {tokenizer.convert_tokens_to_ids(t) for t in selected_roles}
# ---- Tokenise & encode once ----
batch = tokenizer(text, return_tensors="pt").to("cuda")
ids, att = batch.input_ids, batch.attention_mask
x = emb_drop(emb_ln(embeddings(ids)))
ext = full_model.bert.get_extended_attention_mask(att, x.shape[:-1])
enc = encoder(x, attention_mask=ext)[0, :, :] # (S,H)
# ---- compute max-cos per token (F-0/F-1) ----
role_mat = embeddings.word_embeddings(
torch.tensor(sorted(sel_ids), device=enc.device)
) # (R,H)
cos = cosine(enc.unsqueeze(1), role_mat.unsqueeze(0)) # (S,R)
maxcos, argrole = cos.max(-1) # (S,)
# ---- split tokens into High / Low half (F-2) ----
S = len(ids[0])
sort_idx = maxcos.argsort(descending=True)
hi_idx = sort_idx[: S//2]
lo_idx = sort_idx[S//2:]
# container for summary text
report_lines = []
# ------------------------------------------------------------------
# Greedy pool helper – tensor-safe version
# ------------------------------------------------------------------
def greedy_pool(index_tensor: torch.Tensor, which: str):
"""
index_tensor – 1-D tensor of token indices (already on CUDA)
which – "low" → walk upward
"high" → walk downward
Returns (best_pool:list[int], best_acc:float)
"""
# ---- make everything vanilla Python ints ---------------------
indices = index_tensor.tolist() # e.g. [7, 10, 13, …]
if which == "high":
indices = indices[::-1] # reverse for top-down
best_pool: list[int] = []
best_acc = 0.0
for i in range(0, len(indices), 2): # 2 at a time
cand = indices[i : i + 2] # plain list[int]
trial = best_pool + cand # grow pool
# ---- build masked input ----------------------------------
mask_flags = torch.ones_like(ids).bool() # mask everything
mask_flags[0, trial] = False # …except the pool
masked_ids = ids.where(~mask_flags, mask_token_id)
# ---- second forward-pass ---------------------------------
with torch.no_grad():
x_m = emb_drop(emb_ln(embeddings(masked_ids)))
ext_m = full_model.bert.get_extended_attention_mask(mask, x_m.shape[:-1])
enc_m = encoder(x_m, attention_mask=ext_m)
logits = mlm_head(enc_m)[0] # (S, V)
pred = logits.argmax(-1)
corr = (pred[mask_flags] == ids[mask_flags]).float().mean().item()
if corr > best_acc:
best_acc = corr
best_pool = trial # accept improvement
if best_acc >= 0.50:
break # early exit
return best_pool, best_acc
pool_lo, acc_lo = greedy_pool(lo_idx, "low")
pool_hi, acc_hi = greedy_pool(hi_idx, "high")
# ---- package textual result ----
res_json = {
"Low-pool tokens": tokenizer.decode(ids[0, pool_lo]),
"Low accuracy": f"{acc_lo:.2f}",
"High-pool tokens":tokenizer.decode(ids[0, pool_hi]),
"High accuracy": f"{acc_hi:.2f}",
"Trace": "\n".join(report_lines)
}
# three outputs expected by UI
return json.dumps(res_json, indent=2), f"{maxcos.max():.4f}", len(selected_roles)
# ------------------------------------------------------------------
# 4. Gradio UI -----------------------------------------------------
def build_interface():
with gr.Blocks(title="🧠 Symbolic Encoder Inspector") as demo:
gr.Markdown("## 🧠 Symbolic Encoder Inspector")
with gr.Row():
with gr.Column():
txt = gr.Textbox(label="Prompt", lines=3)
roles= gr.CheckboxGroup(
choices=SYMBOLIC_ROLES, label="Roles",
value=SYMBOLIC_ROLES # pre-checked
)
btn = gr.Button("Run")
with gr.Column():
out_json = gr.Textbox(label="Result JSON")
out_max = gr.Textbox(label="Max cos")
out_cnt = gr.Textbox(label="# roles")
btn.click(encode_and_trace, [txt,roles], [out_json,out_max,out_cnt])
return demo
if __name__=="__main__":
build_interface().launch() |