|
|
|
|
|
|
|
|
|
|
|
import json |
|
import re |
|
import sys |
|
from pathlib import Path, PurePosixPath |
|
|
|
import gradio as gr |
|
import spaces |
|
import torch |
|
from huggingface_hub import snapshot_download |
|
|
|
from bert_handler import create_handler_from_checkpoint |
|
|
|
|
|
|
|
|
|
|
|
REPO_ID = "AbstractPhil/bert-beatrix-2048" |
|
LOCAL_CKPT = "bert-beatrix-2048" |
|
|
|
snapshot_download( |
|
repo_id=REPO_ID, |
|
revision="main", |
|
local_dir=LOCAL_CKPT, |
|
local_dir_use_symlinks=False, |
|
) |
|
|
|
|
|
cfg_path = Path(LOCAL_CKPT) / "config.json" |
|
with cfg_path.open() as f: |
|
cfg = json.load(f) |
|
|
|
auto_map = cfg.get("auto_map", {}) |
|
changed = False |
|
for k, v in auto_map.items(): |
|
if "--" in v: |
|
auto_map[k] = PurePosixPath(v.split("--", 1)[1]).as_posix() |
|
changed = True |
|
|
|
if changed: |
|
cfg["auto_map"] = auto_map |
|
with cfg_path.open("w") as f: |
|
json.dump(cfg, f, indent=2) |
|
print("π οΈ Patched config.json β auto_map now points at local modules") |
|
|
|
|
|
|
|
|
|
|
|
handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_CKPT) |
|
full_model = full_model.eval().cuda() |
|
|
|
|
|
encoder = full_model.bert.encoder |
|
embeddings = full_model.bert.embeddings |
|
emb_ln = full_model.bert.emb_ln |
|
emb_drop = full_model.bert.emb_drop |
|
|
|
|
|
|
|
|
|
|
|
SYMBOLIC_ROLES = [ |
|
"<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>", |
|
"<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>", |
|
"<upper_body_clothing>", "<hair_style>", "<hair_length>", "<headwear>", |
|
"<texture>", "<pattern>", "<grid>", "<zone>", "<offset>", |
|
"<object_left>", "<object_right>", "<relation>", "<intent>", "<style>", |
|
"<fabric>", "<jewelry>", |
|
] |
|
|
|
|
|
missing = [tok for tok in SYMBOLIC_ROLES |
|
if tokenizer.convert_tokens_to_ids(tok) == tokenizer.unk_token_id] |
|
if missing: |
|
sys.exit(f"β Tokenizer is missing {missing}") |
|
|
|
|
|
|
|
|
|
|
|
@spaces.GPU |
|
def encode_and_trace(text: str, selected_roles: list[str]): |
|
with torch.no_grad(): |
|
batch = tokenizer(text, return_tensors="pt").to("cuda") |
|
ids, mask = batch.input_ids, batch.attention_mask |
|
|
|
x = emb_drop(emb_ln(embeddings(ids))) |
|
|
|
ext_mask = full_model.bert.get_extended_attention_mask(mask, x.shape[:-1]) |
|
enc = encoder(x, attention_mask=ext_mask) |
|
|
|
sel_ids = {tokenizer.convert_tokens_to_ids(t) for t in selected_roles} |
|
flags = torch.tensor([tid in sel_ids for tid in ids[0].tolist()], |
|
device=enc.device) |
|
|
|
found = [tokenizer.convert_ids_to_tokens([tid])[0] |
|
for tid in ids[0].tolist() if tid in sel_ids] |
|
|
|
if flags.any(): |
|
vec = enc[0][flags].mean(0) |
|
norm = f"{vec.norm().item():.4f}" |
|
else: |
|
norm = "0.0000" |
|
|
|
return { |
|
"Symbolic Tokens": ", ".join(found) or "(none)", |
|
"Embedding Norm": norm, |
|
"Symbolic Token Count": int(flags.sum().item()), |
|
} |
|
|
|
|
|
|
|
|
|
|
|
def build_interface(): |
|
with gr.Blocks(title="π§ Symbolic Encoder Inspector") as demo: |
|
gr.Markdown( |
|
"## π§ Symbolic Encoder Inspector\n" |
|
"Paste some text containing the special `<role>` tokens and " |
|
"inspect their encoder representations." |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
txt = gr.Textbox( |
|
label="Input with Symbolic Tokens", |
|
placeholder="Example: A <subject> wearing <upper_body_clothing> β¦", |
|
lines=3, |
|
) |
|
roles = gr.CheckboxGroup( |
|
choices=SYMBOLIC_ROLES, |
|
label="Trace these symbolic roles", |
|
) |
|
btn = gr.Button("Encode & Trace") |
|
with gr.Column(): |
|
out_tok = gr.Textbox(label="Symbolic Tokens Found") |
|
out_norm = gr.Textbox(label="Mean Norm") |
|
out_cnt = gr.Textbox(label="Token Count") |
|
|
|
btn.click(encode_and_trace, [txt, roles], [out_tok, out_norm, out_cnt]) |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
build_interface().launch() |
|
|