AbstractPhil's picture
Update app.py
ed080e6 verified
raw
history blame
6.75 kB
# app.py – encoder-only demo for bert-beatrix-2048
# launch: python app.py
import json
import sys
from pathlib import Path, PurePosixPath
from itertools import islice
import gradio as gr
import spaces
import torch
import torch.nn.functional as F
from huggingface_hub import snapshot_download
from bert_handler import create_handler_from_checkpoint
# ------------------------------------------------------------------
# 0. Download & patch config.json --------------------------------
# ------------------------------------------------------------------
REPO_ID = "AbstractPhil/bert-beatrix-2048"
LOCAL_CKPT = "bert-beatrix-2048"
snapshot_download(
repo_id=REPO_ID,
revision="main",
local_dir=LOCAL_CKPT,
local_dir_use_symlinks=False,
)
cfg_path = Path(LOCAL_CKPT) / "config.json"
with cfg_path.open() as f:
cfg = json.load(f)
auto_map = cfg.get("auto_map", {})
patched = False
for k, v in auto_map.items():
if "--" in v: # "repo--module.Class"
auto_map[k] = PurePosixPath(v.split("--", 1)[1]).as_posix()
patched = True
if patched:
cfg["auto_map"] = auto_map
with cfg_path.open("w") as f:
json.dump(cfg, f, indent=2)
print("πŸ› οΈ Patched config.json β†’ auto_map paths fixed")
# ------------------------------------------------------------------
# 1. Model / tokenizer -------------------------------------------
# ------------------------------------------------------------------
handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_CKPT)
full_model = full_model.eval().cuda()
encoder = full_model.bert.encoder
embeddings = full_model.bert.embeddings
emb_ln = full_model.bert.emb_ln
emb_drop = full_model.bert.emb_drop
# ------------------------------------------------------------------
# 2. Symbolic token set ------------------------------------------
# ------------------------------------------------------------------
SYMBOLIC_ROLES = [
"<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
"<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>",
"<upper_body_clothing>", "<hair_style>", "<hair_length>", "<headwear>",
"<texture>", "<pattern>", "<grid>", "<zone>", "<offset>",
"<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
"<fabric>", "<jewelry>",
]
unk = tokenizer.unk_token_id
missing = [t for t in SYMBOLIC_ROLES if tokenizer.convert_tokens_to_ids(t) == unk]
if missing:
sys.exit(f"❌ Tokenizer is missing {missing}")
# ------------------------------------------------------------------
# 3. helper: merge lowest + highest until 3 remain ----------------
# ------------------------------------------------------------------
def reduce_to_three(table):
"""
table : list of dicts {role, token, score}
repeatedly remove lowest and highest,
replace with their average,
until len(table)==3.
"""
working = table[:]
working.sort(key=lambda x: x["score"])
while len(working) > 3:
low = working.pop(0)
high = working.pop(-1)
merged = {
"role": f"{high['role']}|{low['role']}",
"token": f"{high['token']}/{low['token']}",
"score": (high["score"] + low["score"]) / 2.0,
}
working.append(merged)
working.sort(key=lambda x: x["score"])
# highest first for display
working.sort(key=lambda x: x["score"], reverse=True)
return working
# ------------------------------------------------------------------
# 4. Encoder-only inference util ---------------------------------
# ------------------------------------------------------------------
@spaces.GPU
def encode_and_trace(text: str, selected_roles: list[str]):
with torch.no_grad():
if not text.strip():
return "(no input)","",""
batch = tokenizer(text, return_tensors="pt").to("cuda")
ids, attn = batch.input_ids, batch.attention_mask
# encoder forward
x = emb_drop(emb_ln(embeddings(ids)))
ext = full_model.bert.get_extended_attention_mask(attn, x.shape[:-1])
hs = encoder(x, attention_mask=ext) # (1,S,H)
# token-level embeddings (before LN) for similarity calc
token_emb = embeddings(ids).squeeze(0) # (S,H)
rows = []
for role in selected_roles:
rid = tokenizer.convert_tokens_to_ids(role)
rvec = embeddings.weight[rid] # (H,)
# cosine similarity to every *input* token embedding
sims = F.cosine_similarity(rvec.unsqueeze(0), token_emb, dim=-1)
best = torch.argmax(sims).item()
rows.append({
"role" : role,
"token": tokenizer.convert_ids_to_tokens([ids[0, best].item()])[0],
"score": sims[best].item()
})
if not rows:
return "(none selected)","",""
final3 = reduce_to_three(rows)
out_strs = [f"{r['role']} ↔ {r['token']} ({r['score']:+.2f})" for r in final3]
# pad so we always return 3 strings
while len(out_strs) < 3:
out_strs.append("")
return out_strs[0], out_strs[1], out_strs[2]
# ------------------------------------------------------------------
# 5. Gradio UI ----------------------------------------------------
# ------------------------------------------------------------------
def build_interface():
with gr.Blocks(title="🧠 Symbolic Encoder Inspector") as demo:
gr.Markdown(
"### 🧠 Symbolic Encoder Inspector\n"
"Paste text with `<role>` tokens, pick roles to track, then we\n"
"β€’ compute role ↔ token cosine scores\n"
"β€’ iteratively merge low+high pairs until **3 composite buckets** remain."
)
with gr.Row():
with gr.Column():
txt = gr.Textbox(
label="Input with Symbolic Tokens",
placeholder="Example: A <subject> wearing <upper_body_clothing> …",
lines=3,
)
roles = gr.CheckboxGroup(
choices=SYMBOLIC_ROLES,
label="Trace these symbolic roles",
)
btn = gr.Button("Encode & Merge")
with gr.Column():
cat1 = gr.Textbox(label="Category 1 (highest)")
cat2 = gr.Textbox(label="Category 2")
cat3 = gr.Textbox(label="Category 3 (lowest)")
btn.click(encode_and_trace, [txt, roles], [cat1, cat2, cat3])
return demo
if __name__ == "__main__":
build_interface().launch()