Spaces:
Running
on
Zero
Running
on
Zero
# app.py β encoder-only demo for bert-beatrix-2048 | |
# ------------------------------------------------ | |
# launch: python app.py β http://localhost:7860 | |
import json, re, sys | |
from pathlib import Path, PurePosixPath # β PurePosixPath import | |
import gradio as gr | |
import spaces | |
import torch | |
from huggingface_hub import snapshot_download | |
from bert_handler import create_handler_from_checkpoint | |
# ------------------------------------------------------------------ | |
# 0. Download & patch config.json --------------------------------- | |
# ------------------------------------------------------------------ | |
REPO_ID = "AbstractPhil/bert-beatrix-2048" | |
LOCAL_CKPT = "bert-beatrix-2048" # cache dir name | |
snapshot_download( | |
repo_id=REPO_ID, | |
revision="main", | |
local_dir=LOCAL_CKPT, | |
local_dir_use_symlinks=False, | |
) | |
cfg_path = Path(LOCAL_CKPT) / "config.json" | |
with cfg_path.open() as f: | |
cfg = json.load(f) | |
auto_map = cfg.get("auto_map", {}) | |
patched = False | |
for k, v in auto_map.items(): | |
if "--" in v: # strip repo--module.Class | |
auto_map[k] = PurePosixPath(v.split("--", 1)[1]).as_posix() | |
patched = True | |
if patched: | |
cfg["auto_map"] = auto_map | |
with cfg_path.open("w") as f: | |
json.dump(cfg, f, indent=2) | |
print("π οΈ Patched config.json β auto_map now points to local modules") | |
# ------------------------------------------------------------------ | |
# 1. Load model / tokenizer --------------------------------------- | |
# ------------------------------------------------------------------ | |
handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_CKPT) | |
full_model = full_model.eval().cuda() | |
encoder = full_model.bert.encoder | |
embeddings = full_model.bert.embeddings | |
emb_ln = full_model.bert.emb_ln | |
emb_drop = full_model.bert.emb_drop | |
# ------------------------------------------------------------------ | |
# 2. Symbolic roles ------------------------------------------------ | |
# ------------------------------------------------------------------ | |
SYMBOLIC_ROLES = [ | |
"<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>", | |
"<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>", | |
"<upper_body_clothing>", "<hair_style>", "<hair_length>", "<headwear>", | |
"<texture>", "<pattern>", "<grid>", "<zone>", "<offset>", | |
"<object_left>", "<object_right>", "<relation>", "<intent>", "<style>", | |
"<fabric>", "<jewelry>", | |
] | |
missing = [t for t in SYMBOLIC_ROLES | |
if tokenizer.convert_tokens_to_ids(t) == tokenizer.unk_token_id] | |
if missing: | |
sys.exit(f"β Tokenizer is missing {missing}") | |
# ------------------------------------------------------------------ | |
# 3. Encoder-only helper ------------------------------------------ | |
# ------------------------------------------------------------------ | |
def encode_and_trace(text: str, selected_roles: list[str]): | |
""" | |
Returns **exactly three scalars** matching the 3 gradio outputs: | |
1) tokens_str (comma-separated list found) | |
2) norm_str (mean L2-norm of those embeddings) | |
3) count_int (# tokens matched) | |
""" | |
with torch.no_grad(): | |
batch = tokenizer(text, return_tensors="pt").to("cuda") | |
ids, attn = batch.input_ids, batch.attention_mask | |
x = emb_drop(emb_ln(embeddings(ids))) | |
ext = full_model.bert.get_extended_attention_mask(attn, x.shape[:-1]) | |
enc = encoder(x, attention_mask=ext) # (1, S, H) | |
role_ids = {tokenizer.convert_tokens_to_ids(t) for t in selected_roles} | |
mask = torch.tensor([tid in role_ids for tid in ids[0].tolist()], | |
device=enc.device, dtype=torch.bool) | |
found = [tokenizer.convert_ids_to_tokens([tid])[0] | |
for tid in ids[0].tolist() if tid in role_ids] | |
tokens_str = ", ".join(found) or "(none)" | |
if mask.any(): | |
mean_vec = enc[0][mask].mean(0) | |
norm_str = f"{mean_vec.norm().item():.4f}" | |
else: | |
norm_str = "0.0000" | |
count_int = int(mask.sum().item()) | |
return tokens_str, norm_str, count_int # β three outputs! | |
# ------------------------------------------------------------------ | |
# 4. Gradio UI ---------------------------------------------------- | |
# ------------------------------------------------------------------ | |
def build_interface(): | |
with gr.Blocks(title="π§ Symbolic Encoder Inspector") as demo: | |
gr.Markdown( | |
"### π§ Symbolic Encoder Inspector\n" | |
"Paste text that includes the `<role>` tokens and inspect their " | |
"hidden-state statistics." | |
) | |
with gr.Row(): | |
with gr.Column(): | |
txt = gr.Textbox(label="Input", lines=3, | |
placeholder="A <subject> wearing <upper_body_clothing> β¦") | |
chk = gr.CheckboxGroup(SYMBOLIC_ROLES, label="Roles to trace") | |
run = gr.Button("Encode & Trace") | |
with gr.Column(): | |
out_tok = gr.Textbox(label="Tokens found") | |
out_norm = gr.Textbox(label="Mean norm") | |
out_cnt = gr.Textbox(label="Token count") | |
run.click(encode_and_trace, [txt, chk], [out_tok, out_norm, out_cnt]) | |
return demo | |
if __name__ == "__main__": | |
build_interface().launch() | |