Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,751 Bytes
ed080e6 aa28bbb ed080e6 aa28bbb ed080e6 872b08b 81a0ae4 535151e aa28bbb 81a0ae4 872b08b 096fe3a 872b08b 096fe3a aa28bbb 096fe3a ed080e6 872b08b 81a0ae4 ed080e6 9a08859 81a0ae4 ed080e6 872b08b 9a08859 ed080e6 9a08859 ed080e6 9a08859 415afa1 9a08859 415afa1 ed080e6 872b08b 9a08859 ed080e6 9a08859 aa28bbb 9a08859 ed080e6 096fe3a 535151e 096fe3a 8a2e372 096fe3a aa28bbb 096fe3a 785df91 872b08b 785df91 ed080e6 096fe3a 872b08b ed080e6 096fe3a ed080e6 096fe3a ed080e6 096fe3a ed080e6 aa28bbb ed080e6 aa28bbb ed080e6 aa28bbb ed080e6 aa28bbb ed080e6 872b08b 096fe3a ed080e6 096fe3a 785df91 096fe3a 872b08b ed080e6 872b08b 535151e aa28bbb ed080e6 aa28bbb ed080e6 aa28bbb ed080e6 535151e ed080e6 872b08b 535151e ea2994b 872b08b ea2994b 81a0ae4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
# app.py β encoder-only demo for bert-beatrix-2048
# launch: python app.py
import json
import sys
from pathlib import Path, PurePosixPath
from itertools import islice
import gradio as gr
import spaces
import torch
import torch.nn.functional as F
from huggingface_hub import snapshot_download
from bert_handler import create_handler_from_checkpoint
# ------------------------------------------------------------------
# 0. Download & patch config.json --------------------------------
# ------------------------------------------------------------------
REPO_ID = "AbstractPhil/bert-beatrix-2048"
LOCAL_CKPT = "bert-beatrix-2048"
snapshot_download(
repo_id=REPO_ID,
revision="main",
local_dir=LOCAL_CKPT,
local_dir_use_symlinks=False,
)
cfg_path = Path(LOCAL_CKPT) / "config.json"
with cfg_path.open() as f:
cfg = json.load(f)
auto_map = cfg.get("auto_map", {})
patched = False
for k, v in auto_map.items():
if "--" in v: # "repo--module.Class"
auto_map[k] = PurePosixPath(v.split("--", 1)[1]).as_posix()
patched = True
if patched:
cfg["auto_map"] = auto_map
with cfg_path.open("w") as f:
json.dump(cfg, f, indent=2)
print("π οΈ Patched config.json β auto_map paths fixed")
# ------------------------------------------------------------------
# 1. Model / tokenizer -------------------------------------------
# ------------------------------------------------------------------
handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_CKPT)
full_model = full_model.eval().cuda()
encoder = full_model.bert.encoder
embeddings = full_model.bert.embeddings
emb_ln = full_model.bert.emb_ln
emb_drop = full_model.bert.emb_drop
# ------------------------------------------------------------------
# 2. Symbolic token set ------------------------------------------
# ------------------------------------------------------------------
SYMBOLIC_ROLES = [
"<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
"<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>",
"<upper_body_clothing>", "<hair_style>", "<hair_length>", "<headwear>",
"<texture>", "<pattern>", "<grid>", "<zone>", "<offset>",
"<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
"<fabric>", "<jewelry>",
]
unk = tokenizer.unk_token_id
missing = [t for t in SYMBOLIC_ROLES if tokenizer.convert_tokens_to_ids(t) == unk]
if missing:
sys.exit(f"β Tokenizer is missing {missing}")
# ------------------------------------------------------------------
# 3. helper: merge lowest + highest until 3 remain ----------------
# ------------------------------------------------------------------
def reduce_to_three(table):
"""
table : list of dicts {role, token, score}
repeatedly remove lowest and highest,
replace with their average,
until len(table)==3.
"""
working = table[:]
working.sort(key=lambda x: x["score"])
while len(working) > 3:
low = working.pop(0)
high = working.pop(-1)
merged = {
"role": f"{high['role']}|{low['role']}",
"token": f"{high['token']}/{low['token']}",
"score": (high["score"] + low["score"]) / 2.0,
}
working.append(merged)
working.sort(key=lambda x: x["score"])
# highest first for display
working.sort(key=lambda x: x["score"], reverse=True)
return working
# ------------------------------------------------------------------
# 4. Encoder-only inference util ---------------------------------
# ------------------------------------------------------------------
@spaces.GPU
def encode_and_trace(text: str, selected_roles: list[str]):
with torch.no_grad():
if not text.strip():
return "(no input)","",""
batch = tokenizer(text, return_tensors="pt").to("cuda")
ids, attn = batch.input_ids, batch.attention_mask
# encoder forward
x = emb_drop(emb_ln(embeddings(ids)))
ext = full_model.bert.get_extended_attention_mask(attn, x.shape[:-1])
hs = encoder(x, attention_mask=ext) # (1,S,H)
# token-level embeddings (before LN) for similarity calc
token_emb = embeddings(ids).squeeze(0) # (S,H)
rows = []
for role in selected_roles:
rid = tokenizer.convert_tokens_to_ids(role)
rvec = embeddings.weight[rid] # (H,)
# cosine similarity to every *input* token embedding
sims = F.cosine_similarity(rvec.unsqueeze(0), token_emb, dim=-1)
best = torch.argmax(sims).item()
rows.append({
"role" : role,
"token": tokenizer.convert_ids_to_tokens([ids[0, best].item()])[0],
"score": sims[best].item()
})
if not rows:
return "(none selected)","",""
final3 = reduce_to_three(rows)
out_strs = [f"{r['role']} β {r['token']} ({r['score']:+.2f})" for r in final3]
# pad so we always return 3 strings
while len(out_strs) < 3:
out_strs.append("")
return out_strs[0], out_strs[1], out_strs[2]
# ------------------------------------------------------------------
# 5. Gradio UI ----------------------------------------------------
# ------------------------------------------------------------------
def build_interface():
with gr.Blocks(title="π§ Symbolic Encoder Inspector") as demo:
gr.Markdown(
"### π§ Symbolic Encoder Inspector\n"
"Paste text with `<role>` tokens, pick roles to track, then we\n"
"β’ compute role β token cosine scores\n"
"β’ iteratively merge low+high pairs until **3 composite buckets** remain."
)
with gr.Row():
with gr.Column():
txt = gr.Textbox(
label="Input with Symbolic Tokens",
placeholder="Example: A <subject> wearing <upper_body_clothing> β¦",
lines=3,
)
roles = gr.CheckboxGroup(
choices=SYMBOLIC_ROLES,
label="Trace these symbolic roles",
)
btn = gr.Button("Encode & Merge")
with gr.Column():
cat1 = gr.Textbox(label="Category 1 (highest)")
cat2 = gr.Textbox(label="Category 2")
cat3 = gr.Textbox(label="Category 3 (lowest)")
btn.click(encode_and_trace, [txt, roles], [cat1, cat2, cat3])
return demo
if __name__ == "__main__":
build_interface().launch()
|