AbstractPhil's picture
Update app.py
aaae56c verified
raw
history blame
7.73 kB
# app.py – encoder-only demo for bert-beatrix-2048
# launch: python app.py
# -----------------------------------------------
import json, re, sys, math
from pathlib import Path, PurePosixPath
import torch, torch.nn.functional as F
import gradio as gr
import spaces
from huggingface_hub import snapshot_download
from bert_handler import create_handler_from_checkpoint
# ------------------------------------------------------------------
# 0. Download & patch HF checkpoint --------------------------------
REPO_ID = "AbstractPhil/bert-beatrix-2048"
LOCAL_CKPT = "bert-beatrix-2048"
snapshot_download(
repo_id=REPO_ID,
revision="main",
local_dir=LOCAL_CKPT,
local_dir_use_symlinks=False,
)
# → strip repo prefix in auto_map (one-time)
cfg_path = Path(LOCAL_CKPT) / "config.json"
with cfg_path.open() as f: cfg = json.load(f)
amap = cfg.get("auto_map", {})
for k,v in amap.items():
if "--" in v:
amap[k] = PurePosixPath(v.split("--",1)[1]).as_posix()
cfg["auto_map"] = amap
with cfg_path.open("w") as f: json.dump(cfg,f,indent=2)
# ------------------------------------------------------------------
# 1. Load model & components --------------------------------------
handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_CKPT)
full_model = full_model.eval().cuda()
encoder = full_model.bert.encoder
embeddings = full_model.bert.embeddings
emb_ln = full_model.bert.emb_ln
emb_drop = full_model.bert.emb_drop
mlm_head = full_model.cls # prediction head
# ------------------------------------------------------------------
# 2. Symbolic roles -------------------------------------------------
SYMBOLIC_ROLES = [
"<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
"<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>",
"<upper_body_clothing>", "<hair_style>", "<hair_length>", "<headwear>",
"<texture>", "<pattern>", "<grid>", "<zone>", "<offset>",
"<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
"<fabric>", "<jewelry>",
]
if any(tokenizer.convert_tokens_to_ids(t)==tokenizer.unk_token_id
for t in SYMBOLIC_ROLES):
sys.exit("❌ tokenizer missing special tokens")
# Quick helpers
MASK = tokenizer.mask_token
# ------------------------------------------------------------------
# 3. Encoder-plus-MLM logic ---------------------------------------
def cosine(a,b):
return torch.nn.functional.cosine_similarity(a,b,dim=-1)
def pool_accuracy(ids, logits, pool_mask):
"""
ids : (S,) gold token ids
logits : (S,V) MLM logits
pool_mask : bool (S,) which tokens belong to the candidate pool
returns accuracy over masked positions only (if none, return 0)
"""
idx = pool_mask.nonzero(as_tuple=False).flatten()
if idx.numel()==0: return 0.0
preds = logits.argmax(-1)[idx]
gold = ids[idx]
return (preds==gold).float().mean().item()
@spaces.GPU
def encode_and_trace(text, selected_roles):
# if user unchecked everything we treat as "all"
if not selected_roles:
selected_roles = SYMBOLIC_ROLES
sel_ids = {tokenizer.convert_tokens_to_ids(t) for t in selected_roles}
# ---- Tokenise & encode once ----
batch = tokenizer(text, return_tensors="pt").to("cuda")
ids, att = batch.input_ids, batch.attention_mask
x = emb_drop(emb_ln(embeddings(ids)))
ext = full_model.bert.get_extended_attention_mask(att, x.shape[:-1])
enc = encoder(x, attention_mask=ext)[0, :, :] # (S,H)
# ---- compute max-cos per token (F-0/F-1) ----
role_mat = embeddings.word_embeddings(
torch.tensor(sorted(sel_ids), device=enc.device)
) # (R,H)
cos = cosine(enc.unsqueeze(1), role_mat.unsqueeze(0)) # (S,R)
maxcos, argrole = cos.max(-1) # (S,)
# ---- split tokens into High / Low half (F-2) ----
S = len(ids[0])
sort_idx = maxcos.argsort(descending=True)
hi_idx = sort_idx[: S//2]
lo_idx = sort_idx[S//2:]
# container for summary text
report_lines = []
# ------------------------------------------------------------------
# Greedy pool helper – tensor-safe version
# ------------------------------------------------------------------
def greedy_pool(index_tensor: torch.Tensor, which: str):
"""
index_tensor – 1-D tensor of token indices (already on CUDA)
which – "low" → walk upward
"high" → walk downward
Returns (best_pool:list[int], best_acc:float)
"""
# ---- make everything vanilla Python ints ---------------------
indices = index_tensor.tolist() # e.g. [7, 10, 13, …]
if which == "high":
indices = indices[::-1] # reverse for top-down
best_pool: list[int] = []
best_acc = 0.0
for i in range(0, len(indices), 2): # 2 at a time
cand = indices[i : i + 2] # plain list[int]
trial = best_pool + cand # grow pool
# ---- build masked input ----------------------------------
mask_flags = torch.ones_like(ids).bool() # mask everything
mask_flags[0, trial] = False # …except the pool
masked_ids = ids.where(~mask_flags, mask_token_id)
# ---- second forward-pass ---------------------------------
with torch.no_grad():
x_m = emb_drop(emb_ln(embeddings(masked_ids)))
ext_m = full_model.bert.get_extended_attention_mask(mask, x_m.shape[:-1])
enc_m = encoder(x_m, attention_mask=ext_m)
logits = mlm_head(enc_m)[0] # (S, V)
pred = logits.argmax(-1)
corr = (pred[mask_flags] == ids[mask_flags]).float().mean().item()
if corr > best_acc:
best_acc = corr
best_pool = trial # accept improvement
if best_acc >= 0.50:
break # early exit
return best_pool, best_acc
pool_lo, acc_lo = greedy_pool(lo_idx, "low")
pool_hi, acc_hi = greedy_pool(hi_idx, "high")
# ---- package textual result ----
res_json = {
"Low-pool tokens": tokenizer.decode(ids[0, pool_lo]),
"Low accuracy": f"{acc_lo:.2f}",
"High-pool tokens":tokenizer.decode(ids[0, pool_hi]),
"High accuracy": f"{acc_hi:.2f}",
"Trace": "\n".join(report_lines)
}
# three outputs expected by UI
return json.dumps(res_json, indent=2), f"{maxcos.max():.4f}", len(selected_roles)
# ------------------------------------------------------------------
# 4. Gradio UI -----------------------------------------------------
def build_interface():
with gr.Blocks(title="🧠 Symbolic Encoder Inspector") as demo:
gr.Markdown("## 🧠 Symbolic Encoder Inspector")
with gr.Row():
with gr.Column():
txt = gr.Textbox(label="Prompt", lines=3)
roles= gr.CheckboxGroup(
choices=SYMBOLIC_ROLES, label="Roles",
value=SYMBOLIC_ROLES # pre-checked
)
btn = gr.Button("Run")
with gr.Column():
out_json = gr.Textbox(label="Result JSON")
out_max = gr.Textbox(label="Max cos")
out_cnt = gr.Textbox(label="# roles")
btn.click(encode_and_trace, [txt,roles], [out_json,out_max,out_cnt])
return demo
if __name__=="__main__":
build_interface().launch()