Spaces:
Running
on
Zero
Running
on
Zero
# app.py – encoder-only demo for bert-beatrix-2048 | |
# launch: python app.py | |
# ----------------------------------------------- | |
import json, re, sys, math | |
from pathlib import Path, PurePosixPath | |
import torch, torch.nn.functional as F | |
import gradio as gr | |
import spaces | |
from huggingface_hub import snapshot_download | |
from bert_handler import create_handler_from_checkpoint | |
# ------------------------------------------------------------------ | |
# 0. Download & patch HF checkpoint -------------------------------- | |
REPO_ID = "AbstractPhil/bert-beatrix-2048" | |
LOCAL_CKPT = "bert-beatrix-2048" | |
snapshot_download( | |
repo_id=REPO_ID, | |
revision="main", | |
local_dir=LOCAL_CKPT, | |
local_dir_use_symlinks=False, | |
) | |
# → strip repo prefix in auto_map (one-time) | |
cfg_path = Path(LOCAL_CKPT) / "config.json" | |
with cfg_path.open() as f: cfg = json.load(f) | |
amap = cfg.get("auto_map", {}) | |
for k,v in amap.items(): | |
if "--" in v: | |
amap[k] = PurePosixPath(v.split("--",1)[1]).as_posix() | |
cfg["auto_map"] = amap | |
with cfg_path.open("w") as f: json.dump(cfg,f,indent=2) | |
# ------------------------------------------------------------------ | |
# 1. Load model & components -------------------------------------- | |
handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_CKPT) | |
full_model = full_model.eval().cuda() | |
encoder = full_model.bert.encoder | |
embeddings = full_model.bert.embeddings | |
emb_ln = full_model.bert.emb_ln | |
emb_drop = full_model.bert.emb_drop | |
mlm_head = full_model.cls # prediction head | |
# ------------------------------------------------------------------ | |
# 2. Symbolic roles ------------------------------------------------- | |
SYMBOLIC_ROLES = [ | |
"<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>", | |
"<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>", | |
"<upper_body_clothing>", "<hair_style>", "<hair_length>", "<headwear>", | |
"<texture>", "<pattern>", "<grid>", "<zone>", "<offset>", | |
"<object_left>", "<object_right>", "<relation>", "<intent>", "<style>", | |
"<fabric>", "<jewelry>", | |
] | |
if any(tokenizer.convert_tokens_to_ids(t)==tokenizer.unk_token_id | |
for t in SYMBOLIC_ROLES): | |
sys.exit("❌ tokenizer missing special tokens") | |
# Quick helpers | |
MASK = tokenizer.mask_token | |
# ------------------------------------------------------------------ | |
# 3. Encoder-plus-MLM logic --------------------------------------- | |
def cosine(a,b): | |
return torch.nn.functional.cosine_similarity(a,b,dim=-1) | |
def pool_accuracy(ids, logits, pool_mask): | |
""" | |
ids : (S,) gold token ids | |
logits : (S,V) MLM logits | |
pool_mask : bool (S,) which tokens belong to the candidate pool | |
returns accuracy over masked positions only (if none, return 0) | |
""" | |
idx = pool_mask.nonzero(as_tuple=False).flatten() | |
if idx.numel()==0: return 0.0 | |
preds = logits.argmax(-1)[idx] | |
gold = ids[idx] | |
return (preds==gold).float().mean().item() | |
def encode_and_trace(text, selected_roles): | |
# if user unchecked everything we treat as "all" | |
if not selected_roles: | |
selected_roles = SYMBOLIC_ROLES | |
sel_ids = {tokenizer.convert_tokens_to_ids(t) for t in selected_roles} | |
# ---- Tokenise & encode once ---- | |
batch = tokenizer(text, return_tensors="pt").to("cuda") | |
ids, att = batch.input_ids, batch.attention_mask | |
x = emb_drop(emb_ln(embeddings(ids))) | |
ext = full_model.bert.get_extended_attention_mask(att, x.shape[:-1]) | |
enc = encoder(x, attention_mask=ext)[0, :, :] # (S,H) | |
# ---- compute max-cos per token (F-0/F-1) ---- | |
role_mat = embeddings.word_embeddings( | |
torch.tensor(sorted(sel_ids), device=enc.device) | |
) # (R,H) | |
cos = cosine(enc.unsqueeze(1), role_mat.unsqueeze(0)) # (S,R) | |
maxcos, argrole = cos.max(-1) # (S,) | |
# ---- split tokens into High / Low half (F-2) ---- | |
S = len(ids[0]) | |
sort_idx = maxcos.argsort(descending=True) | |
hi_idx = sort_idx[: S//2] | |
lo_idx = sort_idx[S//2:] | |
# container for summary text | |
report_lines = [] | |
# ------------------------------------------------------------------ | |
# Greedy pool helper – tensor-safe version | |
# ------------------------------------------------------------------ | |
def greedy_pool(index_tensor: torch.Tensor, which: str): | |
""" | |
index_tensor – 1-D tensor of token indices (already on CUDA) | |
which – "low" → walk upward | |
"high" → walk downward | |
Returns (best_pool:list[int], best_acc:float) | |
""" | |
# ---- make everything vanilla Python ints --------------------- | |
indices = index_tensor.tolist() # e.g. [7, 10, 13, …] | |
if which == "high": | |
indices = indices[::-1] # reverse for top-down | |
best_pool: list[int] = [] | |
best_acc = 0.0 | |
for i in range(0, len(indices), 2): # 2 at a time | |
cand = indices[i : i + 2] # plain list[int] | |
trial = best_pool + cand # grow pool | |
# ---- build masked input ---------------------------------- | |
mask_flags = torch.ones_like(ids).bool() # mask everything | |
mask_flags[0, trial] = False # …except the pool | |
masked_ids = ids.where(~mask_flags, mask_token_id) | |
# ---- second forward-pass --------------------------------- | |
with torch.no_grad(): | |
x_m = emb_drop(emb_ln(embeddings(masked_ids))) | |
ext_m = full_model.bert.get_extended_attention_mask(mask, x_m.shape[:-1]) | |
enc_m = encoder(x_m, attention_mask=ext_m) | |
logits = mlm_head(enc_m)[0] # (S, V) | |
pred = logits.argmax(-1) | |
corr = (pred[mask_flags] == ids[mask_flags]).float().mean().item() | |
if corr > best_acc: | |
best_acc = corr | |
best_pool = trial # accept improvement | |
if best_acc >= 0.50: | |
break # early exit | |
return best_pool, best_acc | |
pool_lo, acc_lo = greedy_pool(lo_idx, "low") | |
pool_hi, acc_hi = greedy_pool(hi_idx, "high") | |
# ---- package textual result ---- | |
res_json = { | |
"Low-pool tokens": tokenizer.decode(ids[0, pool_lo]), | |
"Low accuracy": f"{acc_lo:.2f}", | |
"High-pool tokens":tokenizer.decode(ids[0, pool_hi]), | |
"High accuracy": f"{acc_hi:.2f}", | |
"Trace": "\n".join(report_lines) | |
} | |
# three outputs expected by UI | |
return json.dumps(res_json, indent=2), f"{maxcos.max():.4f}", len(selected_roles) | |
# ------------------------------------------------------------------ | |
# 4. Gradio UI ----------------------------------------------------- | |
def build_interface(): | |
with gr.Blocks(title="🧠 Symbolic Encoder Inspector") as demo: | |
gr.Markdown("## 🧠 Symbolic Encoder Inspector") | |
with gr.Row(): | |
with gr.Column(): | |
txt = gr.Textbox(label="Prompt", lines=3) | |
roles= gr.CheckboxGroup( | |
choices=SYMBOLIC_ROLES, label="Roles", | |
value=SYMBOLIC_ROLES # pre-checked | |
) | |
btn = gr.Button("Run") | |
with gr.Column(): | |
out_json = gr.Textbox(label="Result JSON") | |
out_max = gr.Textbox(label="Max cos") | |
out_cnt = gr.Textbox(label="# roles") | |
btn.click(encode_and_trace, [txt,roles], [out_json,out_max,out_cnt]) | |
return demo | |
if __name__=="__main__": | |
build_interface().launch() |