Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,212 Bytes
5e20a2a fd4a12a aa28bbb 872b08b fd4a12a 0a14990 81a0ae4 5e20a2a 096fe3a 872b08b 5e20a2a 096fe3a fd4a12a 5e20a2a fd4a12a 5e20a2a fd4a12a 5e20a2a fd4a12a 5e20a2a fd4a12a 0a14990 fd4a12a 9a08859 fd4a12a 096fe3a 535151e fd4a12a 0a14990 5e20a2a fd4a12a 096fe3a 785df91 872b08b 785df91 fd4a12a 5e20a2a fd4a12a 5e20a2a 096fe3a fd4a12a 5e20a2a fd4a12a 5e20a2a fd4a12a 5e20a2a fd4a12a b235205 6cdb3e9 b235205 fd4a12a b235205 fd4a12a b235205 6cdb3e9 b235205 6cdb3e9 b235205 6cdb3e9 b235205 6cdb3e9 b235205 6cdb3e9 fd4a12a b235205 fd4a12a b235205 fd4a12a 6cdb3e9 b235205 20331bc 45ddfa3 096fe3a fd4a12a 5e20a2a 096fe3a fd4a12a 5e20a2a 535151e fd4a12a 0a14990 fd4a12a 535151e fd4a12a 872b08b fd4a12a 535151e ea2994b 5e20a2a fd4a12a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
# app.py β encoder-only demo for bert-beatrix-2048
# launch: python app.py
# -----------------------------------------------
import json, re, sys, math
from pathlib import Path, PurePosixPath
import torch, torch.nn.functional as F
import gradio as gr
import spaces
from huggingface_hub import snapshot_download
from bert_handler import create_handler_from_checkpoint
# ------------------------------------------------------------------
# 0. Download & patch HF checkpoint --------------------------------
REPO_ID = "AbstractPhil/bert-beatrix-2048"
LOCAL_CKPT = "bert-beatrix-2048"
snapshot_download(
repo_id=REPO_ID,
revision="main",
local_dir=LOCAL_CKPT,
local_dir_use_symlinks=False,
)
# β strip repo prefix in auto_map (one-time)
cfg_path = Path(LOCAL_CKPT) / "config.json"
with cfg_path.open() as f: cfg = json.load(f)
amap = cfg.get("auto_map", {})
for k,v in amap.items():
if "--" in v:
amap[k] = PurePosixPath(v.split("--",1)[1]).as_posix()
cfg["auto_map"] = amap
with cfg_path.open("w") as f: json.dump(cfg,f,indent=2)
# ------------------------------------------------------------------
# 1. Load model & components --------------------------------------
handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_CKPT)
full_model = full_model.eval().cuda()
encoder = full_model.bert.encoder
embeddings = full_model.bert.embeddings
emb_ln = full_model.bert.emb_ln
emb_drop = full_model.bert.emb_drop
mlm_head = full_model.cls # prediction head
# ------------------------------------------------------------------
# 2. Symbolic roles -------------------------------------------------
SYMBOLIC_ROLES = [
"<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
"<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>",
"<upper_body_clothing>", "<hair_style>", "<hair_length>", "<headwear>",
"<texture>", "<pattern>", "<grid>", "<zone>", "<offset>",
"<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
"<fabric>", "<jewelry>",
]
if any(tokenizer.convert_tokens_to_ids(t)==tokenizer.unk_token_id
for t in SYMBOLIC_ROLES):
sys.exit("β tokenizer missing special tokens")
# Quick helpers
MASK = tokenizer.mask_token
# ------------------------------------------------------------------
# 3. Encoder-plus-MLM logic ---------------------------------------
def cosine(a,b):
return torch.nn.functional.cosine_similarity(a,b,dim=-1)
def pool_accuracy(ids, logits, pool_mask):
"""
ids : (S,) gold token ids
logits : (S,V) MLM logits
pool_mask : bool (S,) which tokens belong to the candidate pool
returns accuracy over masked positions only (if none, return 0)
"""
idx = pool_mask.nonzero(as_tuple=False).flatten()
if idx.numel()==0: return 0.0
preds = logits.argmax(-1)[idx]
gold = ids[idx]
return (preds==gold).float().mean().item()
@spaces.GPU
def encode_and_trace(text, selected_roles):
if not selected_roles:
selected_roles = SYMBOLIC_ROLES
# Convert symbolic role tokens to IDs
sel_ids = [tokenizer.convert_tokens_to_ids(t) for t in selected_roles]
sel_ids_tensor = torch.tensor(sel_ids, device="cuda").unsqueeze(0) # shape: (1, R)
# Tokenize user prompt
batch = tokenizer(text, return_tensors="pt").to("cuda")
input_ids, attention_mask = batch.input_ids, batch.attention_mask
S = input_ids.shape[1]
# === Shared encoder logic with RoPE ===
def encode(input_ids, attn_mask):
x = embeddings(input_ids) # (B, S, H)
if emb_ln: x = emb_ln(x)
if emb_drop: x = emb_drop(x)
ext = full_model.bert.get_extended_attention_mask(attn_mask, input_ids.shape)
return encoder(x, attention_mask=ext)[0] # (B, S, H)
# Encode prompt
encoded_prompt = encode(input_ids, attention_mask)[0] # (S, H)
# Encode symbolic roles through same pipeline
symbolic_attn = torch.ones_like(sel_ids_tensor)
encoded_roles = encode(sel_ids_tensor, symbolic_attn)[0] # (R, H)
# === Symbolic classification via cosine similarity ===
# Compare each token to each symbolic role β shape: (S, R)
token_exp = encoded_prompt.unsqueeze(1).expand(-1, encoded_roles.size(0), -1) # (S, R, H)
role_exp = encoded_roles.unsqueeze(0).expand(encoded_prompt.size(0), -1, -1) # (S, R, H)
sim = F.cosine_similarity(token_exp, role_exp, dim=-1) # β (S, R)
argmax_ids = sim.argmax(dim=-1) # (S,)
max_scores = sim.max(dim=-1).values # (S,)
predicted_roles = [selected_roles[i] for i in argmax_ids.tolist()]
decoded_tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
# === Build readable trace
role_trace = [
f"{tok:<15} β {role:<22} score={score:.4f}"
for tok, role, score in zip(decoded_tokens, predicted_roles, max_scores.tolist())
]
# === Final output
res_json = {
"Prompt": text,
"Predicted symbolic roles": predicted_roles,
"Max alignment score": f"{max_scores.max().item():.4f}",
"Per-token classification": role_trace
}
return json.dumps(res_json, indent=2), f"{max_scores.max().item():.4f}", len(selected_roles)
# ------------------------------------------------------------------
# 4. Gradio UI -----------------------------------------------------
def build_interface():
with gr.Blocks(title="π§ Symbolic Encoder Inspector") as demo:
gr.Markdown("## π§ Symbolic Encoder Inspector")
with gr.Row():
with gr.Column():
txt = gr.Textbox(label="Prompt", lines=3)
roles= gr.CheckboxGroup(
choices=SYMBOLIC_ROLES, label="Roles",
value=SYMBOLIC_ROLES # pre-checked
)
btn = gr.Button("Run")
with gr.Column():
out_json = gr.Textbox(label="Result JSON")
out_max = gr.Textbox(label="Max cos")
out_cnt = gr.Textbox(label="# roles")
btn.click(encode_and_trace, [txt,roles], [out_json,out_max,out_cnt])
return demo
if __name__=="__main__":
build_interface().launch() |