Spaces:
Running
on
Zero
Running
on
Zero
# app.py β FIXED encoder-only demo for bert-beatrix-2048 | |
# launch: python app.py | |
# ----------------------------------------------- | |
import json, re, sys, math | |
from pathlib import Path, PurePosixPath | |
import torch, torch.nn.functional as F | |
import gradio as gr | |
import spaces | |
from huggingface_hub import snapshot_download | |
from bert_handler import create_handler_from_checkpoint | |
# ------------------------------------------------------------------ | |
# 0. Download & patch HF checkpoint -------------------------------- | |
REPO_ID = "AbstractPhil/bert-beatrix-2048" | |
LOCAL_CKPT = "bert-beatrix-2048" | |
snapshot_download( | |
repo_id=REPO_ID, | |
revision="main", | |
local_dir=LOCAL_CKPT, | |
local_dir_use_symlinks=False, | |
) | |
# β strip repo prefix in auto_map (one-time) | |
cfg_path = Path(LOCAL_CKPT) / "config.json" | |
with cfg_path.open() as f: cfg = json.load(f) | |
amap = cfg.get("auto_map", {}) | |
for k,v in amap.items(): | |
if "--" in v: | |
amap[k] = PurePosixPath(v.split("--",1)[1]).as_posix() | |
cfg["auto_map"] = amap | |
with cfg_path.open("w") as f: json.dump(cfg,f,indent=2) | |
# ------------------------------------------------------------------ | |
# 1. Load model & components -------------------------------------- | |
handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_CKPT) | |
full_model = full_model.eval().cuda() | |
# ------------------------------------------------------------------ | |
# 2. Symbolic roles ------------------------------------------------- | |
SYMBOLIC_ROLES = [ | |
"<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>", | |
"<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>", | |
"<upper_body_clothing>", "<hair_style>", "<hair_length>", "<headwear>", | |
"<texture>", "<pattern>", "<grid>", "<zone>", "<offset>", | |
"<object_left>", "<object_right>", "<relation>", "<intent>", "<style>", | |
"<fabric>", "<jewelry>", | |
] | |
# Verify all symbolic tokens exist in tokenizer | |
missing_tokens = [] | |
symbolic_token_ids = {} | |
for token in SYMBOLIC_ROLES: | |
token_id = tokenizer.convert_tokens_to_ids(token) | |
if token_id == tokenizer.unk_token_id: | |
missing_tokens.append(token) | |
else: | |
symbolic_token_ids[token] = token_id | |
if missing_tokens: | |
print(f"β οΈ Missing symbolic tokens: {missing_tokens}") | |
print("Available tokens will be used for classification") | |
MASK = tokenizer.mask_token | |
MASK_ID = tokenizer.mask_token_id | |
print(f"β Loaded {len(symbolic_token_ids)} symbolic tokens") | |
# ------------------------------------------------------------------ | |
# 3. FIXED MLM-based symbolic classification ---------------------- | |
def get_symbolic_predictions(input_ids, attention_mask, mask_positions, selected_roles): | |
""" | |
Proper MLM-based prediction for symbolic tokens at masked positions | |
Args: | |
input_ids: (B, S) token IDs with [MASK] at positions to classify | |
attention_mask: (B, S) attention mask | |
mask_positions: list of positions that are masked | |
selected_roles: list of symbolic role tokens to consider | |
Returns: | |
predictions and probabilities for each masked position | |
""" | |
# Get MLM logits from the model (this is what it was trained for) | |
with torch.no_grad(): | |
outputs = full_model(input_ids=input_ids, attention_mask=attention_mask) | |
logits = outputs.logits # (B, S, V) | |
# Filter to only selected symbolic role token IDs | |
selected_token_ids = [symbolic_token_ids[role] for role in selected_roles | |
if role in symbolic_token_ids] | |
if not selected_token_ids: | |
return [], [] | |
results = [] | |
for pos in mask_positions: | |
# Get logits for this masked position | |
pos_logits = logits[0, pos] # (V,) | |
# Extract logits for symbolic tokens only | |
symbolic_logits = pos_logits[selected_token_ids] # (num_symbolic,) | |
# Apply softmax to get probabilities | |
symbolic_probs = F.softmax(symbolic_logits, dim=-1) | |
# Get top predictions | |
top_indices = torch.argsort(symbolic_probs, descending=True) | |
pos_results = [] | |
for i in top_indices: | |
token_idx = selected_token_ids[i] | |
token = tokenizer.convert_ids_to_tokens([token_idx])[0] | |
prob = symbolic_probs[i].item() | |
pos_results.append({ | |
"token": token, | |
"probability": prob, | |
"token_id": token_idx | |
}) | |
results.append({ | |
"position": pos, | |
"predictions": pos_results | |
}) | |
return results | |
def create_strategic_masks(text, tokenizer, strategy="content_words"): | |
""" | |
Create strategic mask positions based on different strategies | |
Args: | |
text: input text | |
tokenizer: tokenizer | |
strategy: masking strategy | |
Returns: | |
input_ids with masks, attention_mask, original_tokens, mask_positions | |
""" | |
# Tokenize original text | |
batch = tokenizer(text, return_tensors="pt", add_special_tokens=True) | |
input_ids = batch.input_ids[0] # (S,) | |
attention_mask = batch.attention_mask[0] # (S,) | |
# Get original tokens for reference | |
original_tokens = tokenizer.convert_ids_to_tokens(input_ids) | |
# Find positions to mask based on strategy | |
mask_positions = [] | |
if strategy == "content_words": | |
# Mask content words (avoid special tokens, punctuation, common words) | |
skip_tokens = { | |
tokenizer.cls_token, tokenizer.sep_token, tokenizer.pad_token, | |
".", ",", "!", "?", ":", ";", "'", '"', "-", "(", ")", "[", "]", | |
"the", "a", "an", "and", "or", "but", "in", "on", "at", "to", | |
"for", "of", "with", "by", "is", "are", "was", "were", "be", "been" | |
} | |
for i, token in enumerate(original_tokens): | |
if (token not in skip_tokens and | |
not token.startswith("##") and # avoid subword tokens | |
len(token) > 2 and | |
token.isalpha()): | |
mask_positions.append(i) | |
elif strategy == "every_nth": | |
# Mask every 3rd token (avoiding special tokens) | |
for i in range(1, len(original_tokens) - 1, 3): # skip CLS and SEP | |
mask_positions.append(i) | |
elif strategy == "random": | |
# Randomly mask 15% of tokens | |
import random | |
candidates = list(range(1, len(original_tokens) - 1)) # skip CLS and SEP | |
num_to_mask = max(1, int(len(candidates) * 0.15)) | |
mask_positions = random.sample(candidates, min(num_to_mask, len(candidates))) | |
mask_positions.sort() | |
elif strategy == "manual": | |
# For manual specification - return original for now | |
# Users can specify positions in the UI | |
pass | |
# Limit to reasonable number of masks | |
mask_positions = mask_positions[:10] # Max 10 masks for UI clarity | |
# Create masked input | |
masked_input_ids = input_ids.clone() | |
for pos in mask_positions: | |
masked_input_ids[pos] = MASK_ID | |
return masked_input_ids.unsqueeze(0), attention_mask.unsqueeze(0), original_tokens, mask_positions | |
def symbolic_classification_analysis(text, selected_roles, masking_strategy="content_words", num_predictions=5): | |
""" | |
Perform symbolic classification analysis using MLM prediction | |
FIXED: Now tests what the model actually learned | |
""" | |
if not selected_roles: | |
selected_roles = list(symbolic_token_ids.keys()) | |
if not text.strip(): | |
return "Please enter some text to analyze.", "", 0 | |
try: | |
# DETECT if input follows training pattern vs needs conversion | |
if any(role in text for role in symbolic_token_ids.keys()): | |
# Input already has symbolic tokens - test descriptive prediction | |
return test_descriptive_prediction(text, selected_roles, num_predictions) | |
else: | |
# Convert input to training-style format and test | |
return test_with_context_injection(text, selected_roles, num_predictions) | |
except Exception as e: | |
error_msg = f"Error during analysis: {str(e)}" | |
print(error_msg) | |
return error_msg, "", 0 | |
def test_descriptive_prediction(text, selected_roles, num_predictions): | |
""" | |
Test what descriptive words the model predicts after symbolic tokens | |
This matches the actual training objective | |
""" | |
# Find positions after symbolic tokens | |
tokens = tokenizer.tokenize(text, add_special_tokens=True) | |
token_ids = tokenizer.convert_tokens_to_ids(tokens) | |
# Find symbolic token positions | |
symbolic_positions = [] | |
for i, token in enumerate(tokens): | |
if token in symbolic_token_ids: | |
# Mask the next 1-3 positions after symbolic token | |
for offset in range(1, min(4, len(tokens) - i)): | |
if i + offset < len(tokens) and tokens[i + offset] not in ['[SEP]', '[PAD]']: | |
symbolic_positions.append({ | |
'mask_pos': i + offset, | |
'symbolic_token': token, | |
'original_token': tokens[i + offset] | |
}) | |
if not symbolic_positions: | |
return "No symbolic tokens found in input. Try format like: '<subject> a young woman'", "", 0 | |
# Create masked versions and get predictions | |
results = [] | |
for pos_info in symbolic_positions[:5]: # Limit to 5 positions | |
masked_ids = token_ids.copy() | |
masked_ids[pos_info['mask_pos']] = MASK_ID | |
# Get MLM predictions | |
masked_input = torch.tensor([masked_ids]).to("cuda") | |
attention_mask = torch.ones_like(masked_input) | |
with torch.no_grad(): | |
outputs = full_model(input_ids=masked_input, attention_mask=attention_mask) | |
logits = outputs.logits[0, pos_info['mask_pos']] # Logits for masked position | |
# Get top 10 predictions from full vocabulary | |
probs = F.softmax(logits, dim=-1) | |
top_indices = torch.argsort(probs, descending=True)[:num_predictions] | |
predictions = [] | |
for idx in top_indices: | |
token_text = tokenizer.convert_ids_to_tokens([idx.item()])[0] | |
prob = probs[idx].item() | |
predictions.append({ | |
"token": token_text, | |
"probability": prob | |
}) | |
results.append({ | |
"symbolic_context": pos_info['symbolic_token'], | |
"position": pos_info['mask_pos'], | |
"original_token": pos_info['original_token'], | |
"predictions": predictions | |
}) | |
# Format results | |
analysis = { | |
"input_text": text, | |
"test_type": "descriptive_prediction", | |
"explanation": "Testing what descriptive words model predicts after symbolic tokens", | |
"results": results | |
} | |
summary_lines = [f"π― Testing Descriptive Prediction (what model actually learned)\n"] | |
for result in results: | |
ctx = result["symbolic_context"] | |
orig = result["original_token"] | |
top_pred = result["predictions"][0] | |
summary_lines.append( | |
f"After {ctx}: '{orig}' β '{top_pred['token']}' ({top_pred['probability']:.4f})" | |
) | |
summary = "\n".join(summary_lines) | |
return json.dumps(analysis, indent=2), summary, len(results) | |
def test_with_context_injection(text, selected_roles, num_predictions): | |
""" | |
Inject symbolic context and test what descriptive words are predicted | |
""" | |
results = [] | |
# Test each selected symbolic role as context | |
for role in selected_roles[:3]: # Limit to 3 roles for speed | |
# Create training-style context | |
context_text = f"{role} {text}" | |
# Tokenize and find good positions to mask | |
tokens = tokenizer.tokenize(context_text, add_special_tokens=True) | |
token_ids = tokenizer.convert_tokens_to_ids(tokens) | |
# Find role position and mask next content word | |
role_pos = None | |
for i, token in enumerate(tokens): | |
if token == role: | |
role_pos = i | |
break | |
if role_pos is None or role_pos + 2 >= len(tokens): | |
continue | |
# Mask position after role (skip articles like "a", "the") | |
mask_pos = role_pos + 1 | |
skip_words = {'a', 'an', 'the', 'some', 'this', 'that'} | |
while mask_pos < len(tokens) - 1: | |
current_token = tokens[mask_pos].lower() | |
if current_token not in skip_words and len(current_token) > 2: | |
break | |
mask_pos += 1 | |
if mask_pos >= len(tokens): | |
continue | |
# Create masked input | |
masked_ids = token_ids.copy() | |
original_token = tokens[mask_pos] | |
masked_ids[mask_pos] = MASK_ID | |
# Get predictions | |
masked_input = torch.tensor([masked_ids]).to("cuda") | |
attention_mask = torch.ones_like(masked_input) | |
with torch.no_grad(): | |
outputs = full_model(input_ids=masked_input, attention_mask=attention_mask) | |
logits = outputs.logits[0, mask_pos] | |
# Get top predictions | |
probs = F.softmax(logits, dim=-1) | |
top_indices = torch.argsort(probs, descending=True)[:num_predictions] | |
predictions = [] | |
for idx in top_indices: | |
token_text = tokenizer.convert_ids_to_tokens([idx.item()])[0] | |
prob = probs[idx].item() | |
predictions.append({ | |
"token": token_text, | |
"probability": prob | |
}) | |
results.append({ | |
"symbolic_context": role, | |
"position": mask_pos, | |
"original_token": original_token, | |
"context_text": context_text, | |
"predictions": predictions | |
}) | |
# Format results | |
analysis = { | |
"input_text": text, | |
"test_type": "context_injection", | |
"explanation": "Injected symbolic tokens and tested descriptive predictions", | |
"results": results | |
} | |
summary_lines = [f"π― Testing with Symbolic Context Injection\n"] | |
for result in results: | |
role = result["symbolic_context"] | |
orig = result["original_token"] | |
top_pred = result["predictions"][0] | |
summary_lines.append( | |
f"{role} context: '{orig}' β '{top_pred['token']}' ({top_pred['probability']:.4f})" | |
) | |
summary = "\n".join(summary_lines) | |
return json.dumps(analysis, indent=2), summary, len(results) | |
def create_manual_mask_analysis(text, mask_positions_str, selected_roles): | |
""" | |
Allow manual specification of mask positions | |
""" | |
try: | |
# Parse mask positions | |
mask_positions = [int(x.strip()) for x in mask_positions_str.split(",") if x.strip().isdigit()] | |
if not mask_positions: | |
return "Please specify valid mask positions (comma-separated numbers)", "", 0 | |
# Tokenize text | |
batch = tokenizer(text, return_tensors="pt", add_special_tokens=True) | |
input_ids = batch.input_ids[0] | |
attention_mask = batch.attention_mask[0] | |
original_tokens = tokenizer.convert_ids_to_tokens(input_ids) | |
# Validate positions | |
valid_positions = [pos for pos in mask_positions if 0 <= pos < len(input_ids)] | |
if not valid_positions: | |
return f"Invalid positions. Text has {len(input_ids)} tokens (0-{len(input_ids)-1})", "", 0 | |
# Create masked input | |
masked_input_ids = input_ids.clone() | |
for pos in valid_positions: | |
masked_input_ids[pos] = MASK_ID | |
# Run analysis | |
masked_input_ids = masked_input_ids.unsqueeze(0).to("cuda") | |
attention_mask = attention_mask.unsqueeze(0).to("cuda") | |
predictions = get_symbolic_predictions( | |
masked_input_ids, attention_mask, valid_positions, selected_roles | |
) | |
# Format results | |
results = [] | |
for pred_data in predictions: | |
pos = pred_data["position"] | |
original = original_tokens[pos] | |
top_pred = pred_data["predictions"][0] if pred_data["predictions"] else None | |
if top_pred: | |
results.append( | |
f"Pos {pos}: '{original}' β {top_pred['token']} ({top_pred['probability']:.4f})" | |
) | |
return "\n".join(results), f"Analyzed {len(valid_positions)} positions", len(valid_positions) | |
except Exception as e: | |
return f"Error: {str(e)}", "", 0 | |
# ------------------------------------------------------------------ | |
# 4. Gradio UI ----------------------------------------------------- | |
def build_interface(): | |
with gr.Blocks(title="π§ MLM Symbolic Classifier", theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# π§ MLM-Based Symbolic Classification") | |
gr.Markdown("Analyze text using masked language modeling to predict symbolic roles at specific positions.") | |
with gr.Tab("Automatic Analysis"): | |
with gr.Row(): | |
with gr.Column(): | |
txt_input = gr.Textbox( | |
label="Input Text", | |
lines=4, | |
placeholder="Try: '<subject> a young woman wearing elegant dress' or just 'young woman wearing dress'" | |
) | |
with gr.Row(): | |
masking_strategy = gr.Dropdown( | |
choices=["content_words", "every_nth", "random"], | |
value="content_words", | |
label="Masking Strategy" | |
) | |
num_predictions = gr.Slider( | |
minimum=1, maximum=10, value=5, step=1, | |
label="Top Predictions per Position" | |
) | |
roles_selection = gr.CheckboxGroup( | |
choices=list(symbolic_token_ids.keys()), | |
value=list(symbolic_token_ids.keys()), | |
label="Symbolic Roles to Consider" | |
) | |
analyze_btn = gr.Button("π Analyze", variant="primary") | |
with gr.Column(): | |
summary_output = gr.Textbox( | |
label="Analysis Summary", | |
lines=10, | |
max_lines=15 | |
) | |
with gr.Row(): | |
positions_analyzed = gr.Number(label="Positions Analyzed", precision=0) | |
max_confidence = gr.Textbox(label="Best Prediction", max_lines=1) | |
detailed_output = gr.JSON(label="Detailed Results") | |
with gr.Tab("Manual Masking"): | |
with gr.Row(): | |
with gr.Column(): | |
manual_text = gr.Textbox( | |
label="Input Text", | |
lines=3, | |
placeholder="Enter text for manual analysis..." | |
) | |
mask_positions_input = gr.Textbox( | |
label="Mask Positions (comma-separated)", | |
placeholder="e.g., 2,5,8,12", | |
info="Specify token positions to mask (0-based indexing)" | |
) | |
manual_roles = gr.CheckboxGroup( | |
choices=list(symbolic_token_ids.keys()), | |
value=list(symbolic_token_ids.keys())[:10], # Default to first 10 | |
label="Symbolic Roles" | |
) | |
manual_analyze_btn = gr.Button("π― Analyze Specific Positions") | |
with gr.Column(): | |
manual_results = gr.Textbox( | |
label="Manual Analysis Results", | |
lines=8 | |
) | |
manual_summary = gr.Textbox(label="Summary") | |
manual_count = gr.Number(label="Positions", precision=0) | |
with gr.Tab("Token Inspector"): | |
with gr.Row(): | |
with gr.Column(): | |
inspect_text = gr.Textbox( | |
label="Text to Inspect", | |
lines=2, | |
placeholder="Enter text to see tokenization..." | |
) | |
# Add example patterns button | |
example_patterns = gr.Button("π Load Image Caption Examples") | |
inspect_btn = gr.Button("π Inspect Tokens") | |
with gr.Column(): | |
token_breakdown = gr.Textbox( | |
label="Token Breakdown", | |
lines=8, | |
info="Shows how text is tokenized with position indices" | |
) | |
with gr.Tab("Caption Examples"): | |
gr.Markdown("### πΌοΈ Test with Training-Style Patterns") | |
gr.Markdown(""" | |
**The model was trained to predict descriptive words AFTER symbolic tokens.** | |
Test with patterns like: | |
- `<subject> a young woman wearing elegant dress` | |
- `<lighting> soft natural illumination on the scene` | |
- `<emotion> happy expression while posing confidently` | |
""") | |
example_captions = [ | |
"<subject> a young woman wearing a blue dress", | |
"<lighting> soft natural illumination in the scene", | |
"<emotion> happy expression while posing confidently", | |
"<pose> standing gracefully near the window", | |
"<upper_body_clothing> elegant silk blouse with intricate patterns", | |
"<material> luxurious velvet fabric with rich texture", | |
"<accessory> delicate silver jewelry catching the light", | |
"<surface> polished marble floor reflecting ambient glow" | |
] | |
for caption in example_captions: | |
with gr.Row(): | |
gr.Textbox(value=caption, label="Training-Style Example", interactive=False, scale=3) | |
copy_btn = gr.Button("π Test This", scale=1) | |
# Event handlers | |
analyze_btn.click( | |
symbolic_classification_analysis, | |
inputs=[txt_input, roles_selection, masking_strategy, num_predictions], | |
outputs=[detailed_output, summary_output, positions_analyzed] | |
) | |
manual_analyze_btn.click( | |
create_manual_mask_analysis, | |
inputs=[manual_text, mask_positions_input, manual_roles], | |
outputs=[manual_results, manual_summary, manual_count] | |
) | |
def load_examples(): | |
return "a young woman wearing a blue dress" | |
def inspect_tokens(text): | |
if not text.strip(): | |
return "Enter text to inspect tokenization" | |
tokens = tokenizer.tokenize(text, add_special_tokens=True) | |
result_lines = [] | |
for i, token in enumerate(tokens): | |
result_lines.append(f"{i:2d}: '{token}'") | |
return "\n".join(result_lines) | |
# Event handlers | |
example_patterns.click( | |
load_examples, | |
outputs=[inspect_text] | |
) | |
inspect_btn.click( | |
inspect_tokens, | |
inputs=[inspect_text], | |
outputs=[token_breakdown] | |
) | |
return demo | |
if __name__ == "__main__": | |
print("π Starting MLM Symbolic Classifier...") | |
print(f"β Model loaded with {len(symbolic_token_ids)} symbolic tokens") | |
print(f"π― Available symbolic roles: {list(symbolic_token_ids.keys())[:5]}...") | |
build_interface().launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=True | |
) |