Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,128 Bytes
096fe3a 81a0ae4 535151e ea2994b 81a0ae4 096fe3a a0fce12 535151e 096fe3a 81a0ae4 096fe3a 81a0ae4 9a08859 81a0ae4 9a08859 81a0ae4 096fe3a 535151e 81a0ae4 096fe3a 8a2e372 096fe3a 81a0ae4 096fe3a 785df91 81a0ae4 096fe3a 81a0ae4 096fe3a 81a0ae4 096fe3a 81a0ae4 096fe3a 81a0ae4 096fe3a 81a0ae4 096fe3a 81a0ae4 096fe3a 81a0ae4 096fe3a 81a0ae4 096fe3a 785df91 096fe3a 81a0ae4 535151e 81a0ae4 535151e 81a0ae4 535151e ea2994b 81a0ae4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
# app.py – encoder-only demo for bert-beatrix-2048
# -----------------------------------------------
# launch: python app.py
import spaces
import torch
import gradio as gr
from huggingface_hub import snapshot_download
from bert_handler import create_handler_from_checkpoint
from pathlib import Path
# ------------------------------------------------------------------
# 1. Download *once* and load locally -----------------------------
# ------------------------------------------------------------------
LOCAL_CKPT = snapshot_download(
repo_id="AbstractPhil/bert-beatrix-2048",
revision="main",
local_dir="bert-beatrix-2048",
local_dir_use_symlinks=False,
)
cfg_path = Path(LOCAL_CKPT) / "config.json"
with open(cfg_path) as f:
cfg = json.load(f)
auto_map = cfg.get("auto_map", {})
changed = False
for k, v in auto_map.items():
# v looks like "AbstractPhil/bert-beatrix-2048--modeling_hf_nomic_bert.…"
if "--" in v:
auto_map[k] = PurePosixPath(v.split("--", 1)[1]).as_posix()
changed = True
if changed:
cfg["auto_map"] = auto_map
with open(cfg_path, "w") as f:
json.dump(cfg, f, indent=2)
print("🔧 Patched auto_map → now points to local modules only")
# also drop any *previously* imported remote modules in this session
for name in list(sys.modules):
if name.startswith("transformers_modules.AbstractPhil.bert-beatrix-2048"):
del sys.modules[name]
# ------------------------------------------------------------------
# 1. normal load via BERTHandler ---------------------------------
# ------------------------------------------------------------------
from bert_handler import create_handler_from_checkpoint
handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_CKPT)
full_model = full_model.eval().cuda()
# --- pull encoder & embeddings only --------------------------------
encoder = full_model.bert.encoder
embeddings = full_model.bert.embeddings
emb_ln = full_model.bert.emb_ln
emb_drop = full_model.bert.emb_drop
# ------------------------------------------------------------------
# 2. Symbolic token list ------------------------------------------
# ------------------------------------------------------------------
SYMBOLIC_ROLES = [
"<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
"<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>",
"<upper_body_clothing>", "<hair_style>", "<hair_length>", "<headwear>",
"<texture>", "<pattern>", "<grid>", "<zone>", "<offset>",
"<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
"<fabric>", "<jewelry>"
]
# Sanity-check: every role must be known by the tokenizer
missing = [t for t in SYMBOLIC_ROLES
if tokenizer.convert_tokens_to_ids(t) == tokenizer.unk_token_id]
if missing:
raise RuntimeError(f"Tokenizer is missing special tokens: {missing}")
# ------------------------------------------------------------------
# 3. Encoder-only inference util ----------------------------------
# ------------------------------------------------------------------
@spaces.GPU
def encode_and_trace(text: str, selected_roles: list[str]):
with torch.no_grad():
batch = tokenizer(text, return_tensors="pt").to("cuda")
ids, mask = batch.input_ids, batch.attention_mask
x = emb_drop(emb_ln(embeddings(ids)))
ext_mask = full_model.bert.get_extended_attention_mask(mask, x.shape[:-1])
enc = encoder(x, attention_mask=ext_mask) # (1, S, H)
want = {tokenizer.convert_tokens_to_ids(t) for t in selected_roles}
keep = torch.tensor([tid in want for tid in ids[0]], device=enc.device)
found = [tokenizer.convert_ids_to_tokens([tid])[0] for tid in ids[0] if tid in want]
if keep.any():
vec = enc[0][keep].mean(0)
norm = f"{vec.norm().item():.4f}"
else:
norm = "0.0000"
return {
"Symbolic Tokens": ", ".join(found) or "(none)",
"Mean Norm": norm,
"Token Count": int(keep.sum().item()),
}
# ------------------------------------------------------------------
# 4. Gradio UI -----------------------------------------------------
# ------------------------------------------------------------------
def build_interface():
with gr.Blocks(title="🧠 Symbolic Encoder Inspector") as demo:
gr.Markdown("## 🧠 Symbolic Encoder Inspector")
with gr.Row():
with gr.Column():
txt = gr.Textbox(label="Input with Symbolic Tokens", lines=3)
chk = gr.CheckboxGroup(choices=SYMBOLIC_ROLES, label="Trace these roles")
btn = gr.Button("Encode & Trace")
with gr.Column():
out_tok = gr.Textbox(label="Symbolic Tokens Found")
out_norm = gr.Textbox(label="Mean Norm")
out_cnt = gr.Textbox(label="Token Count")
btn.click(encode_and_trace, [txt, chk], [out_tok, out_norm, out_cnt])
return demo
if __name__ == "__main__":
build_interface().launch()
|