File size: 18,602 Bytes
0cba5a9
5e20a2a
fd4a12a
 
aa28bbb
872b08b
fd4a12a
0a14990
 
81a0ae4
5e20a2a
096fe3a
872b08b
5e20a2a
096fe3a
fd4a12a
5e20a2a
fd4a12a
5e20a2a
fd4a12a
 
 
 
 
 
5e20a2a
fd4a12a
 
 
5e20a2a
fd4a12a
 
0a14990
fd4a12a
 
 
9a08859
 
fd4a12a
 
096fe3a
535151e
5e20a2a
fd4a12a
096fe3a
785df91
 
 
 
 
872b08b
785df91
5e20a2a
0cba5a9
 
 
 
 
 
 
 
 
 
 
 
 
 
fd4a12a
0cba5a9
 
 
5e20a2a
096fe3a
 
0cba5a9
5e20a2a
0cba5a9
5e20a2a
0cba5a9
 
 
 
 
 
 
 
 
 
5e20a2a
0cba5a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd4a12a
 
 
0cba5a9
 
 
 
fd4a12a
0cba5a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20331bc
45ddfa3
096fe3a
0cba5a9
5e20a2a
0cba5a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44c0214
0cba5a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
535151e
ea2994b
5e20a2a
0cba5a9
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
# app.py – FIXED encoder-only demo for bert-beatrix-2048
# launch:  python app.py
# -----------------------------------------------
import json, re, sys, math
from pathlib import Path, PurePosixPath

import torch, torch.nn.functional as F
import gradio as gr
import spaces
from huggingface_hub import snapshot_download

from bert_handler import create_handler_from_checkpoint


# ------------------------------------------------------------------
# 0. Download & patch HF checkpoint --------------------------------
REPO_ID   = "AbstractPhil/bert-beatrix-2048"
LOCAL_CKPT = "bert-beatrix-2048"

snapshot_download(
    repo_id=REPO_ID,
    revision="main",
    local_dir=LOCAL_CKPT,
    local_dir_use_symlinks=False,
)

# β†’ strip repo prefix in auto_map (one-time)
cfg_path = Path(LOCAL_CKPT) / "config.json"
with cfg_path.open() as f: cfg = json.load(f)

amap = cfg.get("auto_map", {})
for k,v in amap.items():
    if "--" in v:
        amap[k] = PurePosixPath(v.split("--",1)[1]).as_posix()
cfg["auto_map"] = amap
with cfg_path.open("w") as f: json.dump(cfg,f,indent=2)

# ------------------------------------------------------------------
# 1.  Load model & components --------------------------------------
handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_CKPT)
full_model = full_model.eval().cuda()

# ------------------------------------------------------------------
# 2. Symbolic roles -------------------------------------------------
SYMBOLIC_ROLES = [
    "<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
    "<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>",
    "<upper_body_clothing>", "<hair_style>", "<hair_length>", "<headwear>",
    "<texture>", "<pattern>", "<grid>", "<zone>", "<offset>",
    "<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
    "<fabric>", "<jewelry>",
]

# Verify all symbolic tokens exist in tokenizer
missing_tokens = []
symbolic_token_ids = {}
for token in SYMBOLIC_ROLES:
    token_id = tokenizer.convert_tokens_to_ids(token)
    if token_id == tokenizer.unk_token_id:
        missing_tokens.append(token)
    else:
        symbolic_token_ids[token] = token_id

if missing_tokens:
    print(f"⚠️ Missing symbolic tokens: {missing_tokens}")
    print("Available tokens will be used for classification")

MASK = tokenizer.mask_token
MASK_ID = tokenizer.mask_token_id

print(f"βœ… Loaded {len(symbolic_token_ids)} symbolic tokens")


# ------------------------------------------------------------------
# 3. FIXED MLM-based symbolic classification ----------------------

def get_symbolic_predictions(input_ids, attention_mask, mask_positions, selected_roles):
    """
    Proper MLM-based prediction for symbolic tokens at masked positions
    
    Args:
        input_ids: (B, S) token IDs with [MASK] at positions to classify
        attention_mask: (B, S) attention mask
        mask_positions: list of positions that are masked
        selected_roles: list of symbolic role tokens to consider
        
    Returns:
        predictions and probabilities for each masked position
    """
    # Get MLM logits from the model (this is what it was trained for)
    with torch.no_grad():
        outputs = full_model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits  # (B, S, V)
    
    # Filter to only selected symbolic role token IDs
    selected_token_ids = [symbolic_token_ids[role] for role in selected_roles 
                         if role in symbolic_token_ids]
    
    if not selected_token_ids:
        return [], []
    
    results = []
    
    for pos in mask_positions:
        # Get logits for this masked position
        pos_logits = logits[0, pos]  # (V,)
        
        # Extract logits for symbolic tokens only
        symbolic_logits = pos_logits[selected_token_ids]  # (num_symbolic,)
        
        # Apply softmax to get probabilities
        symbolic_probs = F.softmax(symbolic_logits, dim=-1)
        
        # Get top predictions
        top_indices = torch.argsort(symbolic_probs, descending=True)
        
        pos_results = []
        for i in top_indices:
            token_idx = selected_token_ids[i]
            token = tokenizer.convert_ids_to_tokens([token_idx])[0]
            prob = symbolic_probs[i].item()
            pos_results.append({
                "token": token,
                "probability": prob,
                "token_id": token_idx
            })
        
        results.append({
            "position": pos,
            "predictions": pos_results
        })
    
    return results


def create_strategic_masks(text, tokenizer, strategy="content_words"):
    """
    Create strategic mask positions based on different strategies
    
    Args:
        text: input text
        tokenizer: tokenizer
        strategy: masking strategy
        
    Returns:
        input_ids with masks, attention_mask, original_tokens, mask_positions
    """
    # Tokenize original text
    batch = tokenizer(text, return_tensors="pt", add_special_tokens=True)
    input_ids = batch.input_ids[0]  # (S,)
    attention_mask = batch.attention_mask[0]  # (S,)
    
    # Get original tokens for reference
    original_tokens = tokenizer.convert_ids_to_tokens(input_ids)
    
    # Find positions to mask based on strategy
    mask_positions = []
    
    if strategy == "content_words":
        # Mask content words (avoid special tokens, punctuation, common words)
        skip_tokens = {
            tokenizer.cls_token, tokenizer.sep_token, tokenizer.pad_token,
            ".", ",", "!", "?", ":", ";", "'", '"', "-", "(", ")", "[", "]",
            "the", "a", "an", "and", "or", "but", "in", "on", "at", "to", 
            "for", "of", "with", "by", "is", "are", "was", "were", "be", "been"
        }
        
        for i, token in enumerate(original_tokens):
            if (token not in skip_tokens and 
                not token.startswith("##") and  # avoid subword tokens
                len(token) > 2 and
                token.isalpha()):
                mask_positions.append(i)
    
    elif strategy == "every_nth":
        # Mask every 3rd token (avoiding special tokens)
        for i in range(1, len(original_tokens) - 1, 3):  # skip CLS and SEP
            mask_positions.append(i)
    
    elif strategy == "random":
        # Randomly mask 15% of tokens
        import random
        candidates = list(range(1, len(original_tokens) - 1))  # skip CLS and SEP
        num_to_mask = max(1, int(len(candidates) * 0.15))
        mask_positions = random.sample(candidates, min(num_to_mask, len(candidates)))
        mask_positions.sort()
    
    elif strategy == "manual":
        # For manual specification - return original for now
        # Users can specify positions in the UI
        pass
    
    # Limit to reasonable number of masks
    mask_positions = mask_positions[:10]  # Max 10 masks for UI clarity
    
    # Create masked input
    masked_input_ids = input_ids.clone()
    for pos in mask_positions:
        masked_input_ids[pos] = MASK_ID
    
    return masked_input_ids.unsqueeze(0), attention_mask.unsqueeze(0), original_tokens, mask_positions


@spaces.GPU
def symbolic_classification_analysis(text, selected_roles, masking_strategy="content_words", num_predictions=5):
    """
    Perform symbolic classification analysis using MLM prediction
    """
    if not selected_roles:
        selected_roles = list(symbolic_token_ids.keys())
    
    if not text.strip():
        return "Please enter some text to analyze.", "", 0
    
    try:
        # Create strategically masked input
        masked_input_ids, attention_mask, original_tokens, mask_positions = create_strategic_masks(
            text, tokenizer, masking_strategy
        )
        
        if not mask_positions:
            return "No suitable positions found for masking. Try different text or strategy.", "", 0
        
        # Move to device
        masked_input_ids = masked_input_ids.to("cuda")
        attention_mask = attention_mask.to("cuda")
        
        # Get symbolic predictions
        predictions = get_symbolic_predictions(
            masked_input_ids, attention_mask, mask_positions, selected_roles
        )
        
        # Build detailed analysis
        analysis = {
            "input_text": text,
            "masking_strategy": masking_strategy,
            "total_tokens": len(original_tokens),
            "masked_positions": len(mask_positions),
            "available_symbolic_roles": len(selected_roles),
            "analysis_results": []
        }
        
        for pred_data in predictions:
            pos = pred_data["position"]
            original_token = original_tokens[pos]
            
            # Show top N predictions
            top_preds = pred_data["predictions"][:num_predictions]
            
            position_analysis = {
                "position": pos,
                "original_token": original_token,
                "top_predictions": []
            }
            
            for pred in top_preds:
                position_analysis["top_predictions"].append({
                    "symbolic_role": pred["token"],
                    "probability": f"{pred['probability']:.4f}",
                    "confidence": "High" if pred["probability"] > 0.3 else "Medium" if pred["probability"] > 0.1 else "Low"
                })
            
            analysis["analysis_results"].append(position_analysis)
        
        # Create readable summary
        summary_lines = []
        max_prob = 0
        best_prediction = None
        
        for result in analysis["analysis_results"]:
            pos = result["position"]
            orig = result["original_token"]
            top_pred = result["top_predictions"][0] if result["top_predictions"] else None
            
            if top_pred:
                prob = float(top_pred["probability"])
                role = top_pred["symbolic_role"]
                summary_lines.append(
                    f"Position {pos:2d}: '{orig}' β†’ {role} ({top_pred['probability']}, {top_pred['confidence']})"
                )
                
                if prob > max_prob:
                    max_prob = prob
                    best_prediction = f"{role} (confidence: {top_pred['confidence']})"
        
        summary = "\n".join(summary_lines)
        if best_prediction:
            summary = f"🎯 Best Match: {best_prediction}\n\n" + summary
        
        return json.dumps(analysis, indent=2), summary, len(mask_positions)
        
    except Exception as e:
        error_msg = f"Error during analysis: {str(e)}"
        print(error_msg)
        return error_msg, "", 0


def create_manual_mask_analysis(text, mask_positions_str, selected_roles):
    """
    Allow manual specification of mask positions
    """
    try:
        # Parse mask positions
        mask_positions = [int(x.strip()) for x in mask_positions_str.split(",") if x.strip().isdigit()]
        
        if not mask_positions:
            return "Please specify valid mask positions (comma-separated numbers)", "", 0
        
        # Tokenize text
        batch = tokenizer(text, return_tensors="pt", add_special_tokens=True)
        input_ids = batch.input_ids[0]
        attention_mask = batch.attention_mask[0]
        original_tokens = tokenizer.convert_ids_to_tokens(input_ids)
        
        # Validate positions
        valid_positions = [pos for pos in mask_positions if 0 <= pos < len(input_ids)]
        if not valid_positions:
            return f"Invalid positions. Text has {len(input_ids)} tokens (0-{len(input_ids)-1})", "", 0
        
        # Create masked input
        masked_input_ids = input_ids.clone()
        for pos in valid_positions:
            masked_input_ids[pos] = MASK_ID
        
        # Run analysis
        masked_input_ids = masked_input_ids.unsqueeze(0).to("cuda")
        attention_mask = attention_mask.unsqueeze(0).to("cuda")
        
        predictions = get_symbolic_predictions(
            masked_input_ids, attention_mask, valid_positions, selected_roles
        )
        
        # Format results
        results = []
        for pred_data in predictions:
            pos = pred_data["position"]
            original = original_tokens[pos]
            top_pred = pred_data["predictions"][0] if pred_data["predictions"] else None
            
            if top_pred:
                results.append(
                    f"Pos {pos}: '{original}' β†’ {top_pred['token']} ({top_pred['probability']:.4f})"
                )
        
        return "\n".join(results), f"Analyzed {len(valid_positions)} positions", len(valid_positions)
        
    except Exception as e:
        return f"Error: {str(e)}", "", 0


# ------------------------------------------------------------------
# 4. Gradio UI -----------------------------------------------------
def build_interface():
    with gr.Blocks(title="🧠 MLM Symbolic Classifier", theme=gr.themes.Soft()) as demo:
        gr.Markdown("# 🧠 MLM-Based Symbolic Classification")
        gr.Markdown("Analyze text using masked language modeling to predict symbolic roles at specific positions.")
        
        with gr.Tab("Automatic Analysis"):
            with gr.Row():
                with gr.Column():
                    txt_input = gr.Textbox(
                        label="Input Text", 
                        lines=4,
                        placeholder="Enter text to analyze for symbolic role classification..."
                    )
                    
                    with gr.Row():
                        masking_strategy = gr.Dropdown(
                            choices=["content_words", "every_nth", "random"],
                            value="content_words",
                            label="Masking Strategy"
                        )
                        num_predictions = gr.Slider(
                            minimum=1, maximum=10, value=5, step=1,
                            label="Top Predictions per Position"
                        )
                    
                    roles_selection = gr.CheckboxGroup(
                        choices=list(symbolic_token_ids.keys()),
                        value=list(symbolic_token_ids.keys()),
                        label="Symbolic Roles to Consider"
                    )
                    
                    analyze_btn = gr.Button("πŸ” Analyze", variant="primary")
                
                with gr.Column():
                    summary_output = gr.Textbox(
                        label="Analysis Summary", 
                        lines=10,
                        max_lines=15
                    )
                    
                    with gr.Row():
                        positions_analyzed = gr.Number(label="Positions Analyzed", precision=0)
                        max_confidence = gr.Textbox(label="Best Prediction", max_lines=1)
            
            detailed_output = gr.JSON(label="Detailed Results")
        
        with gr.Tab("Manual Masking"):
            with gr.Row():
                with gr.Column():
                    manual_text = gr.Textbox(
                        label="Input Text",
                        lines=3,
                        placeholder="Enter text for manual analysis..."
                    )
                    
                    mask_positions_input = gr.Textbox(
                        label="Mask Positions (comma-separated)",
                        placeholder="e.g., 2,5,8,12",
                        info="Specify token positions to mask (0-based indexing)"
                    )
                    
                    manual_roles = gr.CheckboxGroup(
                        choices=list(symbolic_token_ids.keys()),
                        value=list(symbolic_token_ids.keys())[:10],  # Default to first 10
                        label="Symbolic Roles"
                    )
                    
                    manual_analyze_btn = gr.Button("🎯 Analyze Specific Positions")
                
                with gr.Column():
                    manual_results = gr.Textbox(
                        label="Manual Analysis Results",
                        lines=8
                    )
                    
                    manual_summary = gr.Textbox(label="Summary")
                    manual_count = gr.Number(label="Positions", precision=0)
        
        with gr.Tab("Token Inspector"):
            with gr.Row():
                with gr.Column():
                    inspect_text = gr.Textbox(
                        label="Text to Inspect",
                        lines=2,
                        placeholder="Enter text to see tokenization..."
                    )
                    inspect_btn = gr.Button("πŸ” Inspect Tokens")
                
                with gr.Column():
                    token_breakdown = gr.Textbox(
                        label="Token Breakdown",
                        lines=8,
                        info="Shows how text is tokenized with position indices"
                    )
        
        # Event handlers
        analyze_btn.click(
            symbolic_classification_analysis,
            inputs=[txt_input, roles_selection, masking_strategy, num_predictions],
            outputs=[detailed_output, summary_output, positions_analyzed]
        )
        
        manual_analyze_btn.click(
            create_manual_mask_analysis,
            inputs=[manual_text, mask_positions_input, manual_roles],
            outputs=[manual_results, manual_summary, manual_count]
        )
        
        def inspect_tokens(text):
            if not text.strip():
                return "Enter text to inspect tokenization"
            
            tokens = tokenizer.tokenize(text, add_special_tokens=True)
            result_lines = []
            
            for i, token in enumerate(tokens):
                result_lines.append(f"{i:2d}: '{token}'")
            
            return "\n".join(result_lines)
        
        inspect_btn.click(
            inspect_tokens,
            inputs=[inspect_text],
            outputs=[token_breakdown]
        )
    
    return demo


if __name__ == "__main__":
    print("πŸš€ Starting MLM Symbolic Classifier...")
    print(f"βœ… Model loaded with {len(symbolic_token_ids)} symbolic tokens")
    print(f"🎯 Available symbolic roles: {list(symbolic_token_ids.keys())[:5]}...")
    
    build_interface().launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=True
    )