Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,075 Bytes
096fe3a 81a0ae4 535151e ea2994b 81a0ae4 096fe3a 535151e 096fe3a 81a0ae4 096fe3a 81a0ae4 096fe3a 535151e 81a0ae4 096fe3a 8a2e372 096fe3a 81a0ae4 096fe3a 785df91 81a0ae4 096fe3a 81a0ae4 096fe3a 81a0ae4 096fe3a 81a0ae4 096fe3a 81a0ae4 096fe3a 81a0ae4 096fe3a 81a0ae4 096fe3a 81a0ae4 096fe3a 81a0ae4 096fe3a 785df91 096fe3a 81a0ae4 535151e 81a0ae4 535151e 81a0ae4 535151e ea2994b 81a0ae4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
# app.py – encoder-only demo for bert-beatrix-2048
# -----------------------------------------------
# launch: python app.py
import spaces
import torch
import gradio as gr
from huggingface_hub import snapshot_download
from bert_handler import create_handler_from_checkpoint
# ------------------------------------------------------------------
# 1. Download *once* and load locally -----------------------------
# ------------------------------------------------------------------
LOCAL_CKPT = snapshot_download(
repo_id="AbstractPhil/bert-beatrix-2048",
revision="main",
local_dir="bert-beatrix-2048",
local_dir_use_symlinks=False
)
handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_CKPT)
full_model = full_model.eval().cuda()
# --- pull encoder & embeddings only --------------------------------
encoder = full_model.bert.encoder
embeddings = full_model.bert.embeddings
emb_ln = full_model.bert.emb_ln
emb_drop = full_model.bert.emb_drop
# ------------------------------------------------------------------
# 2. Symbolic token list ------------------------------------------
# ------------------------------------------------------------------
SYMBOLIC_ROLES = [
"<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
"<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>",
"<upper_body_clothing>", "<hair_style>", "<hair_length>", "<headwear>",
"<texture>", "<pattern>", "<grid>", "<zone>", "<offset>",
"<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
"<fabric>", "<jewelry>"
]
# Sanity-check: every role must be known by the tokenizer
missing = [t for t in SYMBOLIC_ROLES
if tokenizer.convert_tokens_to_ids(t) == tokenizer.unk_token_id]
if missing:
raise RuntimeError(f"Tokenizer is missing special tokens: {missing}")
# ------------------------------------------------------------------
# 3. Encoder-only inference util ----------------------------------
# ------------------------------------------------------------------
@spaces.GPU
def encode_and_trace(text: str, selected_roles: list[str]):
with torch.no_grad():
batch = tokenizer(text, return_tensors="pt").to("cuda")
ids, mask = batch.input_ids, batch.attention_mask
x = emb_drop(emb_ln(embeddings(ids)))
ext_mask = full_model.bert.get_extended_attention_mask(mask, x.shape[:-1])
enc = encoder(x, attention_mask=ext_mask) # (1, S, H)
want = {tokenizer.convert_tokens_to_ids(t) for t in selected_roles}
keep = torch.tensor([tid in want for tid in ids[0]], device=enc.device)
found = [tokenizer.convert_ids_to_tokens([tid])[0] for tid in ids[0] if tid in want]
if keep.any():
vec = enc[0][keep].mean(0)
norm = f"{vec.norm().item():.4f}"
else:
norm = "0.0000"
return {
"Symbolic Tokens": ", ".join(found) or "(none)",
"Mean Norm": norm,
"Token Count": int(keep.sum().item()),
}
# ------------------------------------------------------------------
# 4. Gradio UI -----------------------------------------------------
# ------------------------------------------------------------------
def build_interface():
with gr.Blocks(title="🧠 Symbolic Encoder Inspector") as demo:
gr.Markdown("## 🧠 Symbolic Encoder Inspector")
with gr.Row():
with gr.Column():
txt = gr.Textbox(label="Input with Symbolic Tokens", lines=3)
chk = gr.CheckboxGroup(choices=SYMBOLIC_ROLES, label="Trace these roles")
btn = gr.Button("Encode & Trace")
with gr.Column():
out_tok = gr.Textbox(label="Symbolic Tokens Found")
out_norm = gr.Textbox(label="Mean Norm")
out_cnt = gr.Textbox(label="Token Count")
btn.click(encode_and_trace, [txt, chk], [out_tok, out_norm, out_cnt])
return demo
if __name__ == "__main__":
build_interface().launch()
|