Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,426 Bytes
096fe3a 415afa1 872b08b 415afa1 872b08b 81a0ae4 535151e 81a0ae4 872b08b 096fe3a 872b08b 535151e 096fe3a 415afa1 096fe3a 415afa1 872b08b 81a0ae4 872b08b 9a08859 81a0ae4 9a08859 872b08b 9a08859 415afa1 9a08859 415afa1 9a08859 415afa1 9a08859 415afa1 9a08859 872b08b 9a08859 415afa1 9a08859 415afa1 9a08859 81a0ae4 096fe3a 535151e 096fe3a 8a2e372 872b08b 096fe3a 415afa1 096fe3a 785df91 872b08b 785df91 415afa1 096fe3a 872b08b 096fe3a 415afa1 096fe3a 415afa1 096fe3a 415afa1 096fe3a 81a0ae4 415afa1 096fe3a 415afa1 096fe3a 415afa1 872b08b 415afa1 096fe3a 415afa1 b60c583 415afa1 096fe3a 872b08b 096fe3a 872b08b 096fe3a 785df91 096fe3a 872b08b 415afa1 872b08b 535151e 415afa1 535151e 415afa1 872b08b 415afa1 872b08b 535151e ea2994b 872b08b ea2994b 81a0ae4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
# app.py β encoder-only demo for bert-beatrix-2048
# ------------------------------------------------
# launch: python app.py β http://localhost:7860
import json, re, sys
from pathlib import Path, PurePosixPath # β PurePosixPath import
import gradio as gr
import spaces
import torch
from huggingface_hub import snapshot_download
from bert_handler import create_handler_from_checkpoint
# ------------------------------------------------------------------
# 0. Download & patch config.json ---------------------------------
# ------------------------------------------------------------------
REPO_ID = "AbstractPhil/bert-beatrix-2048"
LOCAL_CKPT = "bert-beatrix-2048" # cache dir name
snapshot_download(
repo_id=REPO_ID,
revision="main",
local_dir=LOCAL_CKPT,
local_dir_use_symlinks=False,
)
cfg_path = Path(LOCAL_CKPT) / "config.json"
with cfg_path.open() as f:
cfg = json.load(f)
auto_map = cfg.get("auto_map", {})
patched = False
for k, v in auto_map.items():
if "--" in v: # strip repo--module.Class
auto_map[k] = PurePosixPath(v.split("--", 1)[1]).as_posix()
patched = True
if patched:
cfg["auto_map"] = auto_map
with cfg_path.open("w") as f:
json.dump(cfg, f, indent=2)
print("π οΈ Patched config.json β auto_map now points to local modules")
# ------------------------------------------------------------------
# 1. Load model / tokenizer ---------------------------------------
# ------------------------------------------------------------------
handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_CKPT)
full_model = full_model.eval().cuda()
encoder = full_model.bert.encoder
embeddings = full_model.bert.embeddings
emb_ln = full_model.bert.emb_ln
emb_drop = full_model.bert.emb_drop
# ------------------------------------------------------------------
# 2. Symbolic roles ------------------------------------------------
# ------------------------------------------------------------------
SYMBOLIC_ROLES = [
"<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
"<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>",
"<upper_body_clothing>", "<hair_style>", "<hair_length>", "<headwear>",
"<texture>", "<pattern>", "<grid>", "<zone>", "<offset>",
"<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
"<fabric>", "<jewelry>",
]
missing = [t for t in SYMBOLIC_ROLES
if tokenizer.convert_tokens_to_ids(t) == tokenizer.unk_token_id]
if missing:
sys.exit(f"β Tokenizer is missing {missing}")
# ------------------------------------------------------------------
# 3. Encoder-only helper ------------------------------------------
# ------------------------------------------------------------------
@spaces.GPU
def encode_and_trace(text: str, selected_roles: list[str]):
"""
Returns **exactly three scalars** matching the 3 gradio outputs:
1) tokens_str (comma-separated list found)
2) norm_str (mean L2-norm of those embeddings)
3) count_int (# tokens matched)
"""
with torch.no_grad():
batch = tokenizer(text, return_tensors="pt").to("cuda")
ids, attn = batch.input_ids, batch.attention_mask
x = emb_drop(emb_ln(embeddings(ids)))
ext = full_model.bert.get_extended_attention_mask(attn, x.shape[:-1])
enc = encoder(x, attention_mask=ext) # (1, S, H)
role_ids = {tokenizer.convert_tokens_to_ids(t) for t in selected_roles}
mask = torch.tensor([tid in role_ids for tid in ids[0].tolist()],
device=enc.device, dtype=torch.bool)
found = [tokenizer.convert_ids_to_tokens([tid])[0]
for tid in ids[0].tolist() if tid in role_ids]
tokens_str = ", ".join(found) or "(none)"
if mask.any():
mean_vec = enc[0][mask].mean(0)
norm_str = f"{mean_vec.norm().item():.4f}"
else:
norm_str = "0.0000"
count_int = int(mask.sum().item())
return tokens_str, norm_str, count_int # β three outputs!
# ------------------------------------------------------------------
# 4. Gradio UI ----------------------------------------------------
# ------------------------------------------------------------------
def build_interface():
with gr.Blocks(title="π§ Symbolic Encoder Inspector") as demo:
gr.Markdown(
"### π§ Symbolic Encoder Inspector\n"
"Paste text that includes the `<role>` tokens and inspect their "
"hidden-state statistics."
)
with gr.Row():
with gr.Column():
txt = gr.Textbox(label="Input", lines=3,
placeholder="A <subject> wearing <upper_body_clothing> β¦")
chk = gr.CheckboxGroup(SYMBOLIC_ROLES, label="Roles to trace")
run = gr.Button("Encode & Trace")
with gr.Column():
out_tok = gr.Textbox(label="Tokens found")
out_norm = gr.Textbox(label="Mean norm")
out_cnt = gr.Textbox(label="Token count")
run.click(encode_and_trace, [txt, chk], [out_tok, out_norm, out_cnt])
return demo
if __name__ == "__main__":
build_interface().launch()
|