AbstractPhil's picture
Update app.py
42a1b9e verified
# app.py – FIXED encoder-only demo for bert-beatrix-2048
# launch: python app.py
# -----------------------------------------------
import json, re, sys, math
from pathlib import Path, PurePosixPath
import torch, torch.nn.functional as F
import gradio as gr
import spaces
from huggingface_hub import snapshot_download
from bert_handler import create_handler_from_checkpoint
# ------------------------------------------------------------------
# 0. Download & patch HF checkpoint --------------------------------
REPO_ID = "AbstractPhil/bert-beatrix-2048"
LOCAL_CKPT = "bert-beatrix-2048"
snapshot_download(
repo_id=REPO_ID,
revision="main",
local_dir=LOCAL_CKPT,
local_dir_use_symlinks=False,
)
# β†’ strip repo prefix in auto_map (one-time)
cfg_path = Path(LOCAL_CKPT) / "config.json"
with cfg_path.open() as f: cfg = json.load(f)
amap = cfg.get("auto_map", {})
for k,v in amap.items():
if "--" in v:
amap[k] = PurePosixPath(v.split("--",1)[1]).as_posix()
cfg["auto_map"] = amap
with cfg_path.open("w") as f: json.dump(cfg,f,indent=2)
# ------------------------------------------------------------------
# 1. Load model & components --------------------------------------
handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_CKPT)
full_model = full_model.eval().cuda()
# ------------------------------------------------------------------
# 2. Symbolic roles -------------------------------------------------
SYMBOLIC_ROLES = [
"<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
"<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>",
"<upper_body_clothing>", "<hair_style>", "<hair_length>", "<headwear>",
"<texture>", "<pattern>", "<grid>", "<zone>", "<offset>",
"<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
"<fabric>", "<jewelry>",
]
# Verify all symbolic tokens exist in tokenizer
missing_tokens = []
symbolic_token_ids = {}
for token in SYMBOLIC_ROLES:
token_id = tokenizer.convert_tokens_to_ids(token)
if token_id == tokenizer.unk_token_id:
missing_tokens.append(token)
else:
symbolic_token_ids[token] = token_id
if missing_tokens:
print(f"⚠️ Missing symbolic tokens: {missing_tokens}")
print("Available tokens will be used for classification")
MASK = tokenizer.mask_token
MASK_ID = tokenizer.mask_token_id
print(f"βœ… Loaded {len(symbolic_token_ids)} symbolic tokens")
# ------------------------------------------------------------------
# 3. FIXED MLM-based symbolic classification ----------------------
def get_symbolic_predictions(input_ids, attention_mask, mask_positions, selected_roles):
"""
Proper MLM-based prediction for symbolic tokens at masked positions
Args:
input_ids: (B, S) token IDs with [MASK] at positions to classify
attention_mask: (B, S) attention mask
mask_positions: list of positions that are masked
selected_roles: list of symbolic role tokens to consider
Returns:
predictions and probabilities for each masked position
"""
# Get MLM logits from the model (this is what it was trained for)
with torch.no_grad():
outputs = full_model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits # (B, S, V)
# Filter to only selected symbolic role token IDs
selected_token_ids = [symbolic_token_ids[role] for role in selected_roles
if role in symbolic_token_ids]
if not selected_token_ids:
return [], []
results = []
for pos in mask_positions:
# Get logits for this masked position
pos_logits = logits[0, pos] # (V,)
# Extract logits for symbolic tokens only
symbolic_logits = pos_logits[selected_token_ids] # (num_symbolic,)
# Apply softmax to get probabilities
symbolic_probs = F.softmax(symbolic_logits, dim=-1)
# Get top predictions
top_indices = torch.argsort(symbolic_probs, descending=True)
pos_results = []
for i in top_indices:
token_idx = selected_token_ids[i]
token = tokenizer.convert_ids_to_tokens([token_idx])[0]
prob = symbolic_probs[i].item()
pos_results.append({
"token": token,
"probability": prob,
"token_id": token_idx
})
results.append({
"position": pos,
"predictions": pos_results
})
return results
def create_strategic_masks(text, tokenizer, strategy="content_words"):
"""
Create strategic mask positions based on different strategies
Args:
text: input text
tokenizer: tokenizer
strategy: masking strategy
Returns:
input_ids with masks, attention_mask, original_tokens, mask_positions
"""
# Tokenize original text
batch = tokenizer(text, return_tensors="pt", add_special_tokens=True)
input_ids = batch.input_ids[0] # (S,)
attention_mask = batch.attention_mask[0] # (S,)
# Get original tokens for reference
original_tokens = tokenizer.convert_ids_to_tokens(input_ids)
# Find positions to mask based on strategy
mask_positions = []
if strategy == "content_words":
# Mask content words (avoid special tokens, punctuation, common words)
skip_tokens = {
tokenizer.cls_token, tokenizer.sep_token, tokenizer.pad_token,
".", ",", "!", "?", ":", ";", "'", '"', "-", "(", ")", "[", "]",
"the", "a", "an", "and", "or", "but", "in", "on", "at", "to",
"for", "of", "with", "by", "is", "are", "was", "were", "be", "been"
}
for i, token in enumerate(original_tokens):
if (token not in skip_tokens and
not token.startswith("##") and # avoid subword tokens
len(token) > 2 and
token.isalpha()):
mask_positions.append(i)
elif strategy == "every_nth":
# Mask every 3rd token (avoiding special tokens)
for i in range(1, len(original_tokens) - 1, 3): # skip CLS and SEP
mask_positions.append(i)
elif strategy == "random":
# Randomly mask 15% of tokens
import random
candidates = list(range(1, len(original_tokens) - 1)) # skip CLS and SEP
num_to_mask = max(1, int(len(candidates) * 0.15))
mask_positions = random.sample(candidates, min(num_to_mask, len(candidates)))
mask_positions.sort()
elif strategy == "manual":
# For manual specification - return original for now
# Users can specify positions in the UI
pass
# Limit to reasonable number of masks
mask_positions = mask_positions[:10] # Max 10 masks for UI clarity
# Create masked input
masked_input_ids = input_ids.clone()
for pos in mask_positions:
masked_input_ids[pos] = MASK_ID
return masked_input_ids.unsqueeze(0), attention_mask.unsqueeze(0), original_tokens, mask_positions
@spaces.GPU
def symbolic_classification_analysis(text, selected_roles, masking_strategy="content_words", num_predictions=5):
"""
Perform symbolic classification analysis using MLM prediction
FIXED: Now tests what the model actually learned
"""
if not selected_roles:
selected_roles = list(symbolic_token_ids.keys())
if not text.strip():
return "Please enter some text to analyze.", "", 0
try:
# DETECT if input follows training pattern vs needs conversion
if any(role in text for role in symbolic_token_ids.keys()):
# Input already has symbolic tokens - test descriptive prediction
return test_descriptive_prediction(text, selected_roles, num_predictions)
else:
# Convert input to training-style format and test
return test_with_context_injection(text, selected_roles, num_predictions)
except Exception as e:
error_msg = f"Error during analysis: {str(e)}"
print(error_msg)
return error_msg, "", 0
def test_descriptive_prediction(text, selected_roles, num_predictions):
"""
Test what descriptive words the model predicts after symbolic tokens
This matches the actual training objective
"""
# Find positions after symbolic tokens
tokens = tokenizer.tokenize(text, add_special_tokens=True)
token_ids = tokenizer.convert_tokens_to_ids(tokens)
# Find symbolic token positions
symbolic_positions = []
for i, token in enumerate(tokens):
if token in symbolic_token_ids:
# Mask the next 1-3 positions after symbolic token
for offset in range(1, min(4, len(tokens) - i)):
if i + offset < len(tokens) and tokens[i + offset] not in ['[SEP]', '[PAD]']:
symbolic_positions.append({
'mask_pos': i + offset,
'symbolic_token': token,
'original_token': tokens[i + offset]
})
if not symbolic_positions:
return "No symbolic tokens found in input. Try format like: '<subject> a young woman'", "", 0
# Create masked versions and get predictions
results = []
for pos_info in symbolic_positions[:5]: # Limit to 5 positions
masked_ids = token_ids.copy()
masked_ids[pos_info['mask_pos']] = MASK_ID
# Get MLM predictions
masked_input = torch.tensor([masked_ids]).to("cuda")
attention_mask = torch.ones_like(masked_input)
with torch.no_grad():
outputs = full_model(input_ids=masked_input, attention_mask=attention_mask)
logits = outputs.logits[0, pos_info['mask_pos']] # Logits for masked position
# Get top 10 predictions from full vocabulary
probs = F.softmax(logits, dim=-1)
top_indices = torch.argsort(probs, descending=True)[:num_predictions]
predictions = []
for idx in top_indices:
token_text = tokenizer.convert_ids_to_tokens([idx.item()])[0]
prob = probs[idx].item()
predictions.append({
"token": token_text,
"probability": prob
})
results.append({
"symbolic_context": pos_info['symbolic_token'],
"position": pos_info['mask_pos'],
"original_token": pos_info['original_token'],
"predictions": predictions
})
# Format results
analysis = {
"input_text": text,
"test_type": "descriptive_prediction",
"explanation": "Testing what descriptive words model predicts after symbolic tokens",
"results": results
}
summary_lines = [f"🎯 Testing Descriptive Prediction (what model actually learned)\n"]
for result in results:
ctx = result["symbolic_context"]
orig = result["original_token"]
top_pred = result["predictions"][0]
summary_lines.append(
f"After {ctx}: '{orig}' β†’ '{top_pred['token']}' ({top_pred['probability']:.4f})"
)
summary = "\n".join(summary_lines)
return json.dumps(analysis, indent=2), summary, len(results)
def test_with_context_injection(text, selected_roles, num_predictions):
"""
Inject symbolic context and test what descriptive words are predicted
"""
results = []
# Test each selected symbolic role as context
for role in selected_roles[:3]: # Limit to 3 roles for speed
# Create training-style context
context_text = f"{role} {text}"
# Tokenize and find good positions to mask
tokens = tokenizer.tokenize(context_text, add_special_tokens=True)
token_ids = tokenizer.convert_tokens_to_ids(tokens)
# Find role position and mask next content word
role_pos = None
for i, token in enumerate(tokens):
if token == role:
role_pos = i
break
if role_pos is None or role_pos + 2 >= len(tokens):
continue
# Mask position after role (skip articles like "a", "the")
mask_pos = role_pos + 1
skip_words = {'a', 'an', 'the', 'some', 'this', 'that'}
while mask_pos < len(tokens) - 1:
current_token = tokens[mask_pos].lower()
if current_token not in skip_words and len(current_token) > 2:
break
mask_pos += 1
if mask_pos >= len(tokens):
continue
# Create masked input
masked_ids = token_ids.copy()
original_token = tokens[mask_pos]
masked_ids[mask_pos] = MASK_ID
# Get predictions
masked_input = torch.tensor([masked_ids]).to("cuda")
attention_mask = torch.ones_like(masked_input)
with torch.no_grad():
outputs = full_model(input_ids=masked_input, attention_mask=attention_mask)
logits = outputs.logits[0, mask_pos]
# Get top predictions
probs = F.softmax(logits, dim=-1)
top_indices = torch.argsort(probs, descending=True)[:num_predictions]
predictions = []
for idx in top_indices:
token_text = tokenizer.convert_ids_to_tokens([idx.item()])[0]
prob = probs[idx].item()
predictions.append({
"token": token_text,
"probability": prob
})
results.append({
"symbolic_context": role,
"position": mask_pos,
"original_token": original_token,
"context_text": context_text,
"predictions": predictions
})
# Format results
analysis = {
"input_text": text,
"test_type": "context_injection",
"explanation": "Injected symbolic tokens and tested descriptive predictions",
"results": results
}
summary_lines = [f"🎯 Testing with Symbolic Context Injection\n"]
for result in results:
role = result["symbolic_context"]
orig = result["original_token"]
top_pred = result["predictions"][0]
summary_lines.append(
f"{role} context: '{orig}' β†’ '{top_pred['token']}' ({top_pred['probability']:.4f})"
)
summary = "\n".join(summary_lines)
return json.dumps(analysis, indent=2), summary, len(results)
def create_manual_mask_analysis(text, mask_positions_str, selected_roles):
"""
Allow manual specification of mask positions
"""
try:
# Parse mask positions
mask_positions = [int(x.strip()) for x in mask_positions_str.split(",") if x.strip().isdigit()]
if not mask_positions:
return "Please specify valid mask positions (comma-separated numbers)", "", 0
# Tokenize text
batch = tokenizer(text, return_tensors="pt", add_special_tokens=True)
input_ids = batch.input_ids[0]
attention_mask = batch.attention_mask[0]
original_tokens = tokenizer.convert_ids_to_tokens(input_ids)
# Validate positions
valid_positions = [pos for pos in mask_positions if 0 <= pos < len(input_ids)]
if not valid_positions:
return f"Invalid positions. Text has {len(input_ids)} tokens (0-{len(input_ids)-1})", "", 0
# Create masked input
masked_input_ids = input_ids.clone()
for pos in valid_positions:
masked_input_ids[pos] = MASK_ID
# Run analysis
masked_input_ids = masked_input_ids.unsqueeze(0).to("cuda")
attention_mask = attention_mask.unsqueeze(0).to("cuda")
predictions = get_symbolic_predictions(
masked_input_ids, attention_mask, valid_positions, selected_roles
)
# Format results
results = []
for pred_data in predictions:
pos = pred_data["position"]
original = original_tokens[pos]
top_pred = pred_data["predictions"][0] if pred_data["predictions"] else None
if top_pred:
results.append(
f"Pos {pos}: '{original}' β†’ {top_pred['token']} ({top_pred['probability']:.4f})"
)
return "\n".join(results), f"Analyzed {len(valid_positions)} positions", len(valid_positions)
except Exception as e:
return f"Error: {str(e)}", "", 0
# ------------------------------------------------------------------
# 4. Gradio UI -----------------------------------------------------
def build_interface():
with gr.Blocks(title="🧠 MLM Symbolic Classifier", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🧠 MLM-Based Symbolic Classification")
gr.Markdown("Analyze text using masked language modeling to predict symbolic roles at specific positions.")
with gr.Tab("Automatic Analysis"):
with gr.Row():
with gr.Column():
txt_input = gr.Textbox(
label="Input Text",
lines=4,
placeholder="Try: '<subject> a young woman wearing elegant dress' or just 'young woman wearing dress'"
)
with gr.Row():
masking_strategy = gr.Dropdown(
choices=["content_words", "every_nth", "random"],
value="content_words",
label="Masking Strategy"
)
num_predictions = gr.Slider(
minimum=1, maximum=10, value=5, step=1,
label="Top Predictions per Position"
)
roles_selection = gr.CheckboxGroup(
choices=list(symbolic_token_ids.keys()),
value=list(symbolic_token_ids.keys()),
label="Symbolic Roles to Consider"
)
analyze_btn = gr.Button("πŸ” Analyze", variant="primary")
with gr.Column():
summary_output = gr.Textbox(
label="Analysis Summary",
lines=10,
max_lines=15
)
with gr.Row():
positions_analyzed = gr.Number(label="Positions Analyzed", precision=0)
max_confidence = gr.Textbox(label="Best Prediction", max_lines=1)
detailed_output = gr.JSON(label="Detailed Results")
with gr.Tab("Manual Masking"):
with gr.Row():
with gr.Column():
manual_text = gr.Textbox(
label="Input Text",
lines=3,
placeholder="Enter text for manual analysis..."
)
mask_positions_input = gr.Textbox(
label="Mask Positions (comma-separated)",
placeholder="e.g., 2,5,8,12",
info="Specify token positions to mask (0-based indexing)"
)
manual_roles = gr.CheckboxGroup(
choices=list(symbolic_token_ids.keys()),
value=list(symbolic_token_ids.keys())[:10], # Default to first 10
label="Symbolic Roles"
)
manual_analyze_btn = gr.Button("🎯 Analyze Specific Positions")
with gr.Column():
manual_results = gr.Textbox(
label="Manual Analysis Results",
lines=8
)
manual_summary = gr.Textbox(label="Summary")
manual_count = gr.Number(label="Positions", precision=0)
with gr.Tab("Token Inspector"):
with gr.Row():
with gr.Column():
inspect_text = gr.Textbox(
label="Text to Inspect",
lines=2,
placeholder="Enter text to see tokenization..."
)
# Add example patterns button
example_patterns = gr.Button("πŸ“ Load Image Caption Examples")
inspect_btn = gr.Button("πŸ” Inspect Tokens")
with gr.Column():
token_breakdown = gr.Textbox(
label="Token Breakdown",
lines=8,
info="Shows how text is tokenized with position indices"
)
with gr.Tab("Caption Examples"):
gr.Markdown("### πŸ–ΌοΈ Test with Training-Style Patterns")
gr.Markdown("""
**The model was trained to predict descriptive words AFTER symbolic tokens.**
Test with patterns like:
- `<subject> a young woman wearing elegant dress`
- `<lighting> soft natural illumination on the scene`
- `<emotion> happy expression while posing confidently`
""")
example_captions = [
"<subject> a young woman wearing a blue dress",
"<lighting> soft natural illumination in the scene",
"<emotion> happy expression while posing confidently",
"<pose> standing gracefully near the window",
"<upper_body_clothing> elegant silk blouse with intricate patterns",
"<material> luxurious velvet fabric with rich texture",
"<accessory> delicate silver jewelry catching the light",
"<surface> polished marble floor reflecting ambient glow"
]
for caption in example_captions:
with gr.Row():
gr.Textbox(value=caption, label="Training-Style Example", interactive=False, scale=3)
copy_btn = gr.Button("πŸ“‹ Test This", scale=1)
# Event handlers
analyze_btn.click(
symbolic_classification_analysis,
inputs=[txt_input, roles_selection, masking_strategy, num_predictions],
outputs=[detailed_output, summary_output, positions_analyzed]
)
manual_analyze_btn.click(
create_manual_mask_analysis,
inputs=[manual_text, mask_positions_input, manual_roles],
outputs=[manual_results, manual_summary, manual_count]
)
def load_examples():
return "a young woman wearing a blue dress"
def inspect_tokens(text):
if not text.strip():
return "Enter text to inspect tokenization"
tokens = tokenizer.tokenize(text, add_special_tokens=True)
result_lines = []
for i, token in enumerate(tokens):
result_lines.append(f"{i:2d}: '{token}'")
return "\n".join(result_lines)
# Event handlers
example_patterns.click(
load_examples,
outputs=[inspect_text]
)
inspect_btn.click(
inspect_tokens,
inputs=[inspect_text],
outputs=[token_breakdown]
)
return demo
if __name__ == "__main__":
print("πŸš€ Starting MLM Symbolic Classifier...")
print(f"βœ… Model loaded with {len(symbolic_token_ids)} symbolic tokens")
print(f"🎯 Available symbolic roles: {list(symbolic_token_ids.keys())[:5]}...")
build_interface().launch(
server_name="0.0.0.0",
server_port=7860,
share=True
)