AbstractPhil's picture
Update app.py
ec302cb verified
raw
history blame
3.25 kB
# Updating the app to use only the encoder from the model, ensuring symbolic support
import spaces
from bert_handler import create_handler_from_checkpoint
import torch
import gradio as gr
import re
from pathlib import Path
from huggingface_hub import snapshot_download
# Load checkpoint using BERTHandler (loads tokenizer and full model)
checkpoint_path = snapshot_download(
repo_id="AbstractPhil/bert-beatrix-2048",
revision="main",
local_dir="bert-beatrix-2048",
local_dir_use_symlinks=False
)
handler, model, tokenizer = create_handler_from_checkpoint(checkpoint_path)
model = model.eval().cuda()
# Extract encoder only (NomicBertModel -> encoder)
encoder = model.bert.encoder
embeddings = model.bert.embeddings
emb_ln = model.bert.emb_ln
emb_drop = model.bert.emb_drop
@spaces.GPU
def encode_and_predict(text: str, selected_roles: list[str]):
with torch.no_grad():
inputs = tokenizer(text, return_tensors="pt").to("cuda")
input_ids = inputs.input_ids
attention_mask = inputs.attention_mask
# Run embedding + encoder pipeline
x = embeddings(input_ids)
x = emb_ln(x)
x = emb_drop(x)
encoded = encoder(x, attention_mask=attention_mask.bool())
symbolic_ids = [tokenizer.convert_tokens_to_ids(tok) for tok in selected_roles]
symbolic_mask = torch.isin(input_ids, torch.tensor(symbolic_ids, device=input_ids.device))
masked_tokens = [tokenizer.convert_ids_to_tokens([tid])[0] for tid in input_ids[0] if tid in symbolic_ids]
role_reprs = encoded[symbolic_mask].mean(dim=0) if symbolic_mask.any() else torch.zeros_like(encoded[0, 0])
return {
"Symbolic Tokens": masked_tokens,
"Embedding Norm": f"{role_reprs.norm().item():.4f}",
"Symbolic Token Count": symbolic_mask.sum().item(),
}
symbolic_roles = [
"<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
"<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>",
"<upper_body_clothing>", "<hair_style>", "<hair_length>", "<headwear>",
"<texture>", "<pattern>", "<grid>", "<zone>", "<offset>",
"<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
"<fabric>", "<jewelry>"
]
def build_interface():
with gr.Blocks() as demo:
gr.Markdown("## 🧠 Symbolic Encoder Inspector")
with gr.Row():
with gr.Column():
input_text = gr.Textbox(label="Input with Symbolic Tokens", lines=3)
selected_roles = gr.CheckboxGroup(
choices=symbolic_roles,
label="Which symbolic tokens should be traced?"
)
run_btn = gr.Button("Encode & Trace")
with gr.Column():
symbolic_tokens = gr.Textbox(label="Symbolic Tokens Found")
embedding_norm = gr.Textbox(label="Mean Norm of Symbolic Embeddings")
token_count = gr.Textbox(label="Count of Symbolic Tokens")
run_btn.click(fn=encode_and_predict, inputs=[input_text, selected_roles], outputs=[symbolic_tokens, embedding_norm, token_count])
return demo
if __name__ == "__main__":
demo = build_interface()
demo.launch()