Spaces:
Running
on
Zero
Running
on
Zero
# app.py β encoder-only demo for bert-beatrix-2048 | |
# launch: python app.py | |
import json | |
import sys | |
from pathlib import Path, PurePosixPath | |
from itertools import islice | |
import gradio as gr | |
import spaces | |
import torch | |
import torch.nn.functional as F | |
from huggingface_hub import snapshot_download | |
from bert_handler import create_handler_from_checkpoint | |
# ------------------------------------------------------------------ | |
# 0. Download & patch config.json -------------------------------- | |
# ------------------------------------------------------------------ | |
REPO_ID = "AbstractPhil/bert-beatrix-2048" | |
LOCAL_CKPT = "bert-beatrix-2048" | |
snapshot_download( | |
repo_id=REPO_ID, | |
revision="main", | |
local_dir=LOCAL_CKPT, | |
local_dir_use_symlinks=False, | |
) | |
cfg_path = Path(LOCAL_CKPT) / "config.json" | |
with cfg_path.open() as f: | |
cfg = json.load(f) | |
auto_map = cfg.get("auto_map", {}) | |
patched = False | |
for k, v in auto_map.items(): | |
if "--" in v: # "repo--module.Class" | |
auto_map[k] = PurePosixPath(v.split("--", 1)[1]).as_posix() | |
patched = True | |
if patched: | |
cfg["auto_map"] = auto_map | |
with cfg_path.open("w") as f: | |
json.dump(cfg, f, indent=2) | |
print("π οΈ Patched config.json β auto_map paths fixed") | |
# ------------------------------------------------------------------ | |
# 1. Model / tokenizer ------------------------------------------- | |
# ------------------------------------------------------------------ | |
handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_CKPT) | |
full_model = full_model.eval().cuda() | |
encoder = full_model.bert.encoder | |
embeddings = full_model.bert.embeddings | |
emb_ln = full_model.bert.emb_ln | |
emb_drop = full_model.bert.emb_drop | |
# ------------------------------------------------------------------ | |
# 2. Symbolic token set ------------------------------------------ | |
# ------------------------------------------------------------------ | |
SYMBOLIC_ROLES = [ | |
"<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>", | |
"<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>", | |
"<upper_body_clothing>", "<hair_style>", "<hair_length>", "<headwear>", | |
"<texture>", "<pattern>", "<grid>", "<zone>", "<offset>", | |
"<object_left>", "<object_right>", "<relation>", "<intent>", "<style>", | |
"<fabric>", "<jewelry>", | |
] | |
unk = tokenizer.unk_token_id | |
missing = [t for t in SYMBOLIC_ROLES if tokenizer.convert_tokens_to_ids(t) == unk] | |
if missing: | |
sys.exit(f"β Tokenizer is missing {missing}") | |
# ------------------------------------------------------------------ | |
# 3. helper: merge lowest + highest until 3 remain ---------------- | |
# ------------------------------------------------------------------ | |
def reduce_to_three(table): | |
""" | |
table : list of dicts {role, token, score} | |
repeatedly remove lowest and highest, | |
replace with their average, | |
until len(table)==3. | |
""" | |
working = table[:] | |
working.sort(key=lambda x: x["score"]) | |
while len(working) > 3: | |
low = working.pop(0) | |
high = working.pop(-1) | |
merged = { | |
"role": f"{high['role']}|{low['role']}", | |
"token": f"{high['token']}/{low['token']}", | |
"score": (high["score"] + low["score"]) / 2.0, | |
} | |
working.append(merged) | |
working.sort(key=lambda x: x["score"]) | |
# highest first for display | |
working.sort(key=lambda x: x["score"], reverse=True) | |
return working | |
# ------------------------------------------------------------------ | |
# 4. Encoder-only inference util --------------------------------- | |
# ------------------------------------------------------------------ | |
def encode_and_trace(text: str, selected_roles: list[str]): | |
with torch.no_grad(): | |
if not text.strip(): | |
return "(no input)","","" | |
batch = tokenizer(text, return_tensors="pt").to("cuda") | |
ids, attn = batch.input_ids, batch.attention_mask | |
# encoder forward | |
x = emb_drop(emb_ln(embeddings(ids))) | |
ext = full_model.bert.get_extended_attention_mask(attn, x.shape[:-1]) | |
hs = encoder(x, attention_mask=ext) # (1,S,H) | |
# token-level embeddings (before LN) for similarity calc | |
token_emb = embeddings(ids).squeeze(0) # (S,H) | |
rows = [] | |
for role in selected_roles: | |
rid = tokenizer.convert_tokens_to_ids(role) | |
rvec = embeddings.weight[rid] # (H,) | |
# cosine similarity to every *input* token embedding | |
sims = F.cosine_similarity(rvec.unsqueeze(0), token_emb, dim=-1) | |
best = torch.argmax(sims).item() | |
rows.append({ | |
"role" : role, | |
"token": tokenizer.convert_ids_to_tokens([ids[0, best].item()])[0], | |
"score": sims[best].item() | |
}) | |
if not rows: | |
return "(none selected)","","" | |
final3 = reduce_to_three(rows) | |
out_strs = [f"{r['role']} β {r['token']} ({r['score']:+.2f})" for r in final3] | |
# pad so we always return 3 strings | |
while len(out_strs) < 3: | |
out_strs.append("") | |
return out_strs[0], out_strs[1], out_strs[2] | |
# ------------------------------------------------------------------ | |
# 5. Gradio UI ---------------------------------------------------- | |
# ------------------------------------------------------------------ | |
def build_interface(): | |
with gr.Blocks(title="π§ Symbolic Encoder Inspector") as demo: | |
gr.Markdown( | |
"### π§ Symbolic Encoder Inspector\n" | |
"Paste text with `<role>` tokens, pick roles to track, then we\n" | |
"β’ compute role β token cosine scores\n" | |
"β’ iteratively merge low+high pairs until **3 composite buckets** remain." | |
) | |
with gr.Row(): | |
with gr.Column(): | |
txt = gr.Textbox( | |
label="Input with Symbolic Tokens", | |
placeholder="Example: A <subject> wearing <upper_body_clothing> β¦", | |
lines=3, | |
) | |
roles = gr.CheckboxGroup( | |
choices=SYMBOLIC_ROLES, | |
label="Trace these symbolic roles", | |
) | |
btn = gr.Button("Encode & Merge") | |
with gr.Column(): | |
cat1 = gr.Textbox(label="Category 1 (highest)") | |
cat2 = gr.Textbox(label="Category 2") | |
cat3 = gr.Textbox(label="Category 3 (lowest)") | |
btn.click(encode_and_trace, [txt, roles], [cat1, cat2, cat3]) | |
return demo | |
if __name__ == "__main__": | |
build_interface().launch() | |