Spaces:
Running
on
Zero
Running
on
Zero
# Updating the app to use only the encoder from the model, ensuring symbolic support | |
import spaces | |
from bert_handler import create_handler_from_checkpoint | |
import torch | |
import gradio as gr | |
import re | |
from pathlib import Path | |
from huggingface_hub import snapshot_download | |
# Load checkpoint using BERTHandler (loads tokenizer and full model) | |
checkpoint_path = snapshot_download( | |
repo_id="AbstractPhil/bert-beatrix-2048", | |
revision="main", | |
local_dir="bert-beatrix-2048", | |
local_dir_use_symlinks=False | |
) | |
handler, model, tokenizer = create_handler_from_checkpoint(checkpoint_path) | |
model = model.eval().cuda() | |
# Extract encoder only (NomicBertModel -> encoder) | |
encoder = model.bert.encoder | |
embeddings = model.bert.embeddings | |
emb_ln = model.bert.emb_ln | |
emb_drop = model.bert.emb_drop | |
def encode_and_predict(text: str, selected_roles: list[str]): | |
with torch.no_grad(): | |
inputs = tokenizer(text, return_tensors="pt").to("cuda") | |
input_ids = inputs.input_ids | |
attention_mask = inputs.attention_mask | |
# Run embedding + encoder pipeline | |
x = embeddings(input_ids) | |
x = emb_ln(x) | |
x = emb_drop(x) | |
encoded = encoder(x, attention_mask=attention_mask.bool()) | |
symbolic_ids = [tokenizer.convert_tokens_to_ids(tok) for tok in selected_roles] | |
symbolic_mask = torch.isin(input_ids, torch.tensor(symbolic_ids, device=input_ids.device)) | |
masked_tokens = [tokenizer.convert_ids_to_tokens([tid])[0] for tid in input_ids[0] if tid in symbolic_ids] | |
role_reprs = encoded[symbolic_mask].mean(dim=0) if symbolic_mask.any() else torch.zeros_like(encoded[0, 0]) | |
return { | |
"Symbolic Tokens": masked_tokens, | |
"Embedding Norm": f"{role_reprs.norm().item():.4f}", | |
"Symbolic Token Count": symbolic_mask.sum().item(), | |
} | |
symbolic_roles = [ | |
"<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>", | |
"<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>", | |
"<upper_body_clothing>", "<hair_style>", "<hair_length>", "<headwear>", | |
"<texture>", "<pattern>", "<grid>", "<zone>", "<offset>", | |
"<object_left>", "<object_right>", "<relation>", "<intent>", "<style>", | |
"<fabric>", "<jewelry>" | |
] | |
def build_interface(): | |
with gr.Blocks() as demo: | |
gr.Markdown("## 🧠 Symbolic Encoder Inspector") | |
with gr.Row(): | |
with gr.Column(): | |
input_text = gr.Textbox(label="Input with Symbolic Tokens", lines=3) | |
selected_roles = gr.CheckboxGroup( | |
choices=symbolic_roles, | |
label="Which symbolic tokens should be traced?" | |
) | |
run_btn = gr.Button("Encode & Trace") | |
with gr.Column(): | |
symbolic_tokens = gr.Textbox(label="Symbolic Tokens Found") | |
embedding_norm = gr.Textbox(label="Mean Norm of Symbolic Embeddings") | |
token_count = gr.Textbox(label="Count of Symbolic Tokens") | |
run_btn.click(fn=encode_and_predict, inputs=[input_text, selected_roles], outputs=[symbolic_tokens, embedding_norm, token_count]) | |
return demo | |
if __name__ == "__main__": | |
demo = build_interface() | |
demo.launch() | |