Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,634 Bytes
5e20a2a fd4a12a aa28bbb 872b08b fd4a12a 0a14990 81a0ae4 5e20a2a 096fe3a 872b08b 5e20a2a 096fe3a fd4a12a 5e20a2a fd4a12a 5e20a2a fd4a12a 5e20a2a fd4a12a 5e20a2a fd4a12a 0a14990 fd4a12a 9a08859 fd4a12a 096fe3a 535151e fd4a12a 0a14990 5e20a2a fd4a12a 096fe3a 785df91 872b08b 785df91 fd4a12a 5e20a2a fd4a12a 5e20a2a 096fe3a fd4a12a 5e20a2a fd4a12a 5e20a2a fd4a12a 5e20a2a fd4a12a aaae56c e6ff6d3 f22905a e6ff6d3 f22905a e6ff6d3 aaae56c e6ff6d3 f22905a e6ff6d3 f22905a e6ff6d3 f22905a e6ff6d3 f22905a fd4a12a aaae56c fd4a12a 0a14990 da8b548 096fe3a fd4a12a 5e20a2a 096fe3a fd4a12a 5e20a2a 535151e fd4a12a 0a14990 fd4a12a 535151e fd4a12a 872b08b fd4a12a 535151e ea2994b 5e20a2a fd4a12a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
# app.py β encoder-only demo for bert-beatrix-2048
# launch: python app.py
# -----------------------------------------------
import json, re, sys, math
from pathlib import Path, PurePosixPath
import torch, torch.nn.functional as F
import gradio as gr
import spaces
from huggingface_hub import snapshot_download
from bert_handler import create_handler_from_checkpoint
# ------------------------------------------------------------------
# 0. Download & patch HF checkpoint --------------------------------
REPO_ID = "AbstractPhil/bert-beatrix-2048"
LOCAL_CKPT = "bert-beatrix-2048"
snapshot_download(
repo_id=REPO_ID,
revision="main",
local_dir=LOCAL_CKPT,
local_dir_use_symlinks=False,
)
# β strip repo prefix in auto_map (one-time)
cfg_path = Path(LOCAL_CKPT) / "config.json"
with cfg_path.open() as f: cfg = json.load(f)
amap = cfg.get("auto_map", {})
for k,v in amap.items():
if "--" in v:
amap[k] = PurePosixPath(v.split("--",1)[1]).as_posix()
cfg["auto_map"] = amap
with cfg_path.open("w") as f: json.dump(cfg,f,indent=2)
# ------------------------------------------------------------------
# 1. Load model & components --------------------------------------
handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_CKPT)
full_model = full_model.eval().cuda()
encoder = full_model.bert.encoder
embeddings = full_model.bert.embeddings
emb_ln = full_model.bert.emb_ln
emb_drop = full_model.bert.emb_drop
mlm_head = full_model.cls # prediction head
# ------------------------------------------------------------------
# 2. Symbolic roles -------------------------------------------------
SYMBOLIC_ROLES = [
"<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
"<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>",
"<upper_body_clothing>", "<hair_style>", "<hair_length>", "<headwear>",
"<texture>", "<pattern>", "<grid>", "<zone>", "<offset>",
"<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
"<fabric>", "<jewelry>",
]
if any(tokenizer.convert_tokens_to_ids(t)==tokenizer.unk_token_id
for t in SYMBOLIC_ROLES):
sys.exit("β tokenizer missing special tokens")
# Quick helpers
MASK = tokenizer.mask_token
# ------------------------------------------------------------------
# 3. Encoder-plus-MLM logic ---------------------------------------
def cosine(a,b):
return torch.nn.functional.cosine_similarity(a,b,dim=-1)
def pool_accuracy(ids, logits, pool_mask):
"""
ids : (S,) gold token ids
logits : (S,V) MLM logits
pool_mask : bool (S,) which tokens belong to the candidate pool
returns accuracy over masked positions only (if none, return 0)
"""
idx = pool_mask.nonzero(as_tuple=False).flatten()
if idx.numel()==0: return 0.0
preds = logits.argmax(-1)[idx]
gold = ids[idx]
return (preds==gold).float().mean().item()
@spaces.GPU
def encode_and_trace(text, selected_roles):
# if user unchecked everything we treat as "all"
if not selected_roles:
selected_roles = SYMBOLIC_ROLES
sel_ids = {tokenizer.convert_tokens_to_ids(t) for t in selected_roles}
# ---- Tokenise & encode once ----
batch = tokenizer(text, return_tensors="pt").to("cuda")
ids, att = batch.input_ids, batch.attention_mask
x = emb_drop(emb_ln(embeddings(ids)))
ext = full_model.bert.get_extended_attention_mask(att, x.shape[:-1])
enc = encoder(x, attention_mask=ext)[0, :, :] # (S,H)
# ---- compute max-cos per token (F-0/F-1) ----
role_mat = embeddings.word_embeddings(
torch.tensor(sorted(sel_ids), device=enc.device)
) # (R,H)
cos = cosine(enc.unsqueeze(1), role_mat.unsqueeze(0)) # (S,R)
maxcos, argrole = cos.max(-1) # (S,)
# ---- split tokens into High / Low half (F-2) ----
S = len(ids[0])
sort_idx = maxcos.argsort(descending=True)
hi_idx = sort_idx[: S//2]
lo_idx = sort_idx[S//2:]
# container for summary text
report_lines = []
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 3. Encoder-only inference util (FIXED) β
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
MASK_ID = tokenizer.mask_token_id or tokenizer.convert_tokens_to_ids("[MASK]") # <- NEW
def greedy_pool(idx_order, tag):
"""
idx_order : tensor of token-indices sorted hiβlo or loβhi
tag : "high" | "low" (for the debug print)
returns : (best_pool_indices , best_accuracy)
"""
best_pool, best_acc = [], 0.0
ptr = 0
while ptr < len(idx_order):
cand = idx_order[ptr : ptr + 2] # 2-at-a-time
pool = best_pool + cand.tolist() # grow pool
ptr += 2
# --- build *mask* for βeverything NOT in poolβ ----------
mask_flags = torch.zeros_like(ids, dtype=torch.bool)
mask_flags[0, pool] = True # keep these un-masked
masked_ids = ids.where(mask_flags, MASK_ID) # <- uses the constant
# re-encode & score
enc_m = encode(masked_ids, mask) # helper already defined
logits = mlm_head(enc_m).logits[0] # (S, V)
preds = logits.argmax(-1)
acc = (preds[~mask_flags] == ids[0][~mask_flags]).float().mean().item()
if acc > best_acc: # accept pool only on gain
best_pool, best_acc = pool, acc
if acc >= 0.50: # early-stop rule
break
print(f"{tag:>4s}-pool {best_pool} acc={best_acc:.3f}")
return best_pool, best_acc
pool_lo, acc_lo = greedy_pool(lo_idx, "low")
pool_hi, acc_hi = greedy_pool(hi_idx, "high")
# ---- package textual result ----
res_json = {
"Low-pool tokens": tokenizer.decode(ids[0, pool_lo]),
"Low accuracy": f"{acc_lo:.2f}",
"High-pool tokens":tokenizer.decode(ids[0, pool_hi]),
"High accuracy": f"{acc_hi:.2f}",
"Trace": "\n".join(report_lines)
}
# three outputs expected by UI
return json.dumps(res_json, indent=2), f"{maxcos.max():.4f}", len(selected_roles)
# ------------------------------------------------------------------
# 4. Gradio UI -----------------------------------------------------
def build_interface():
with gr.Blocks(title="π§ Symbolic Encoder Inspector") as demo:
gr.Markdown("## π§ Symbolic Encoder Inspector")
with gr.Row():
with gr.Column():
txt = gr.Textbox(label="Prompt", lines=3)
roles= gr.CheckboxGroup(
choices=SYMBOLIC_ROLES, label="Roles",
value=SYMBOLIC_ROLES # pre-checked
)
btn = gr.Button("Run")
with gr.Column():
out_json = gr.Textbox(label="Result JSON")
out_max = gr.Textbox(label="Max cos")
out_cnt = gr.Textbox(label="# roles")
btn.click(encode_and_trace, [txt,roles], [out_json,out_max,out_cnt])
return demo
if __name__=="__main__":
build_interface().launch() |