Spaces:
Running
on
Zero
Running
on
Zero
# app.py – encoder-only demo for bert-beatrix-2048 | |
# ----------------------------------------------- | |
# launch: python app.py | |
# (gradio UI appears at http://localhost:7860) | |
import torch | |
import gradio as gr | |
import spaces | |
from bert_handler import create_handler_from_checkpoint | |
# ------------------------------------------------------------------ | |
# 1. Model / tokenizer ------------------------------------------------- | |
# ------------------------------------------------------------------ | |
# | |
# • We load one repo *once*, via its canonical name. | |
# • BERTHandler handles the VRAM-safe cleanup & guarantees that the | |
# tokenizer already contains all special tokens saved in the checkpoint. | |
REPO_ID = "AbstractPhil/bert-beatrix-2048" | |
handler, full_model, tokenizer = create_handler_from_checkpoint(REPO_ID) | |
full_model = full_model.eval().cuda() | |
# Grab the encoder + embedding stack only | |
encoder = full_model.bert.encoder | |
embeddings = full_model.bert.embeddings | |
emb_ln = full_model.bert.emb_ln | |
emb_drop = full_model.bert.emb_drop | |
# ------------------------------------------------------------------ | |
# 2. Symbolic token set ------------------------------------------- | |
# ------------------------------------------------------------------ | |
SYMBOLIC_ROLES = [ | |
"<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>", | |
"<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>", | |
"<upper_body_clothing>", "<hair_style>", "<hair_length>", "<headwear>", | |
"<texture>", "<pattern>", "<grid>", "<zone>", "<offset>", | |
"<object_left>", "<object_right>", "<relation>", "<intent>", "<style>", | |
"<fabric>", "<jewelry>" | |
] | |
# Quick sanity check – should *never* be unk | |
missing = [tok for tok in SYMBOLIC_ROLES | |
if tokenizer.convert_tokens_to_ids(tok) == tokenizer.unk_token_id] | |
if missing: | |
raise RuntimeError(f"Tokenizer is missing {len(missing)} special tokens: {missing}") | |
# ------------------------------------------------------------------ | |
# 3. Encoder-only inference util ---------------------------------- | |
# ------------------------------------------------------------------ | |
def encode_and_trace(text: str, selected_roles: list[str]): | |
""" | |
• encodes `text` | |
• pulls out the hidden states for any of the `selected_roles` | |
• returns some quick stats so we can verify everything’s wired up | |
""" | |
with torch.no_grad(): | |
batch = tokenizer(text, return_tensors="pt").to("cuda") | |
inp_ids, attn_mask = batch.input_ids, batch.attention_mask | |
# --- embedding + LayerNorm/dropout --- | |
x = embeddings(inp_ids) | |
x = emb_drop(emb_ln(x)) | |
# --- proper *additive* attention mask --- | |
ext_mask = full_model.bert.get_extended_attention_mask( | |
attn_mask, x.shape[:-1] | |
) | |
encoded = encoder(x, attention_mask=ext_mask) # (B, S, H) | |
# --- pick out the positions that match selected_roles --- | |
sel_ids = {tokenizer.convert_tokens_to_ids(t) for t in selected_roles} | |
ids_list = inp_ids.squeeze(0).tolist() # python ints | |
keep_mask = torch.tensor([tid in sel_ids for tid in ids_list], | |
device=encoded.device) | |
tokens_found = [tokenizer.convert_ids_to_tokens([tid])[0] | |
for tid in ids_list if tid in sel_ids] | |
if keep_mask.any(): | |
repr_vec = encoded.squeeze(0)[keep_mask].mean(0) | |
norm_val = f"{repr_vec.norm().item():.4f}" | |
else: | |
norm_val = "0.0000" | |
return { | |
"Symbolic Tokens": ", ".join(tokens_found) or "(none)", | |
"Embedding Norm": norm_val, | |
"Symbolic Token Count": int(keep_mask.sum().item()), | |
} | |
# ------------------------------------------------------------------ | |
# 4. Gradio UI ----------------------------------------------------- | |
# ------------------------------------------------------------------ | |
def build_interface(): | |
with gr.Blocks(title="🧠 Symbolic Encoder Inspector") as demo: | |
gr.Markdown("## 🧠 Symbolic Encoder Inspector\n" | |
"Paste some text containing the special `<role>` tokens and " | |
"inspect their encoder representations.") | |
with gr.Row(): | |
with gr.Column(): | |
input_text = gr.Textbox( | |
label="Input with Symbolic Tokens", | |
placeholder="Example: A <subject> wearing <upper_body_clothing> …", | |
lines=3, | |
) | |
role_selector = gr.CheckboxGroup( | |
choices=SYMBOLIC_ROLES, | |
label="Trace these symbolic roles" | |
) | |
run_btn = gr.Button("Encode & Trace") | |
with gr.Column(): | |
out_tokens = gr.Textbox(label="Symbolic Tokens Found") | |
out_norm = gr.Textbox(label="Mean Norm") | |
out_count = gr.Textbox(label="Token Count") | |
run_btn.click( | |
fn=encode_and_trace, | |
inputs=[input_text, role_selector], | |
outputs=[out_tokens, out_norm, out_count], | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = build_interface() | |
demo.launch() | |