File size: 24,408 Bytes
0cba5a9
5e20a2a
fd4a12a
 
aa28bbb
872b08b
fd4a12a
0a14990
 
81a0ae4
5e20a2a
096fe3a
872b08b
5e20a2a
096fe3a
fd4a12a
5e20a2a
fd4a12a
5e20a2a
fd4a12a
 
 
 
 
 
5e20a2a
fd4a12a
 
 
5e20a2a
fd4a12a
 
0a14990
fd4a12a
 
 
9a08859
 
fd4a12a
 
096fe3a
535151e
5e20a2a
fd4a12a
096fe3a
785df91
 
 
 
 
872b08b
785df91
5e20a2a
0cba5a9
 
 
 
 
 
 
 
 
 
 
 
 
 
fd4a12a
0cba5a9
 
 
5e20a2a
096fe3a
 
0cba5a9
5e20a2a
0cba5a9
5e20a2a
0cba5a9
 
 
 
 
 
 
 
 
 
5e20a2a
0cba5a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd4a12a
 
 
0cba5a9
 
 
42a1b9e
0cba5a9
fd4a12a
0cba5a9
 
 
 
 
 
42a1b9e
 
 
 
 
 
 
0cba5a9
42a1b9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0cba5a9
42a1b9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0cba5a9
42a1b9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0cba5a9
42a1b9e
 
 
 
 
 
 
 
 
 
 
0cba5a9
42a1b9e
 
 
 
0cba5a9
42a1b9e
 
 
0cba5a9
42a1b9e
 
 
0cba5a9
42a1b9e
 
 
 
 
 
 
 
 
 
 
 
0cba5a9
42a1b9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0cba5a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20331bc
45ddfa3
096fe3a
0cba5a9
5e20a2a
0cba5a9
 
 
 
 
 
 
 
 
 
42a1b9e
0cba5a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44c0214
0cba5a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ef5341
 
 
 
0cba5a9
 
 
 
 
 
 
 
 
2ef5341
42a1b9e
 
 
 
 
 
 
 
 
2ef5341
 
42a1b9e
 
 
 
 
 
 
 
2ef5341
 
 
 
42a1b9e
 
2ef5341
0cba5a9
 
 
 
 
 
 
 
 
 
 
 
 
2ef5341
 
 
0cba5a9
 
 
 
 
 
 
 
 
 
 
 
2ef5341
 
 
 
 
 
0cba5a9
 
 
 
 
 
535151e
ea2994b
5e20a2a
0cba5a9
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
# app.py – FIXED encoder-only demo for bert-beatrix-2048
# launch:  python app.py
# -----------------------------------------------
import json, re, sys, math
from pathlib import Path, PurePosixPath

import torch, torch.nn.functional as F
import gradio as gr
import spaces
from huggingface_hub import snapshot_download

from bert_handler import create_handler_from_checkpoint


# ------------------------------------------------------------------
# 0. Download & patch HF checkpoint --------------------------------
REPO_ID   = "AbstractPhil/bert-beatrix-2048"
LOCAL_CKPT = "bert-beatrix-2048"

snapshot_download(
    repo_id=REPO_ID,
    revision="main",
    local_dir=LOCAL_CKPT,
    local_dir_use_symlinks=False,
)

# β†’ strip repo prefix in auto_map (one-time)
cfg_path = Path(LOCAL_CKPT) / "config.json"
with cfg_path.open() as f: cfg = json.load(f)

amap = cfg.get("auto_map", {})
for k,v in amap.items():
    if "--" in v:
        amap[k] = PurePosixPath(v.split("--",1)[1]).as_posix()
cfg["auto_map"] = amap
with cfg_path.open("w") as f: json.dump(cfg,f,indent=2)

# ------------------------------------------------------------------
# 1.  Load model & components --------------------------------------
handler, full_model, tokenizer = create_handler_from_checkpoint(LOCAL_CKPT)
full_model = full_model.eval().cuda()

# ------------------------------------------------------------------
# 2. Symbolic roles -------------------------------------------------
SYMBOLIC_ROLES = [
    "<subject>", "<subject1>", "<subject2>", "<pose>", "<emotion>",
    "<surface>", "<lighting>", "<material>", "<accessory>", "<footwear>",
    "<upper_body_clothing>", "<hair_style>", "<hair_length>", "<headwear>",
    "<texture>", "<pattern>", "<grid>", "<zone>", "<offset>",
    "<object_left>", "<object_right>", "<relation>", "<intent>", "<style>",
    "<fabric>", "<jewelry>",
]

# Verify all symbolic tokens exist in tokenizer
missing_tokens = []
symbolic_token_ids = {}
for token in SYMBOLIC_ROLES:
    token_id = tokenizer.convert_tokens_to_ids(token)
    if token_id == tokenizer.unk_token_id:
        missing_tokens.append(token)
    else:
        symbolic_token_ids[token] = token_id

if missing_tokens:
    print(f"⚠️ Missing symbolic tokens: {missing_tokens}")
    print("Available tokens will be used for classification")

MASK = tokenizer.mask_token
MASK_ID = tokenizer.mask_token_id

print(f"βœ… Loaded {len(symbolic_token_ids)} symbolic tokens")


# ------------------------------------------------------------------
# 3. FIXED MLM-based symbolic classification ----------------------

def get_symbolic_predictions(input_ids, attention_mask, mask_positions, selected_roles):
    """
    Proper MLM-based prediction for symbolic tokens at masked positions
    
    Args:
        input_ids: (B, S) token IDs with [MASK] at positions to classify
        attention_mask: (B, S) attention mask
        mask_positions: list of positions that are masked
        selected_roles: list of symbolic role tokens to consider
        
    Returns:
        predictions and probabilities for each masked position
    """
    # Get MLM logits from the model (this is what it was trained for)
    with torch.no_grad():
        outputs = full_model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits  # (B, S, V)
    
    # Filter to only selected symbolic role token IDs
    selected_token_ids = [symbolic_token_ids[role] for role in selected_roles 
                         if role in symbolic_token_ids]
    
    if not selected_token_ids:
        return [], []
    
    results = []
    
    for pos in mask_positions:
        # Get logits for this masked position
        pos_logits = logits[0, pos]  # (V,)
        
        # Extract logits for symbolic tokens only
        symbolic_logits = pos_logits[selected_token_ids]  # (num_symbolic,)
        
        # Apply softmax to get probabilities
        symbolic_probs = F.softmax(symbolic_logits, dim=-1)
        
        # Get top predictions
        top_indices = torch.argsort(symbolic_probs, descending=True)
        
        pos_results = []
        for i in top_indices:
            token_idx = selected_token_ids[i]
            token = tokenizer.convert_ids_to_tokens([token_idx])[0]
            prob = symbolic_probs[i].item()
            pos_results.append({
                "token": token,
                "probability": prob,
                "token_id": token_idx
            })
        
        results.append({
            "position": pos,
            "predictions": pos_results
        })
    
    return results


def create_strategic_masks(text, tokenizer, strategy="content_words"):
    """
    Create strategic mask positions based on different strategies
    
    Args:
        text: input text
        tokenizer: tokenizer
        strategy: masking strategy
        
    Returns:
        input_ids with masks, attention_mask, original_tokens, mask_positions
    """
    # Tokenize original text
    batch = tokenizer(text, return_tensors="pt", add_special_tokens=True)
    input_ids = batch.input_ids[0]  # (S,)
    attention_mask = batch.attention_mask[0]  # (S,)
    
    # Get original tokens for reference
    original_tokens = tokenizer.convert_ids_to_tokens(input_ids)
    
    # Find positions to mask based on strategy
    mask_positions = []
    
    if strategy == "content_words":
        # Mask content words (avoid special tokens, punctuation, common words)
        skip_tokens = {
            tokenizer.cls_token, tokenizer.sep_token, tokenizer.pad_token,
            ".", ",", "!", "?", ":", ";", "'", '"', "-", "(", ")", "[", "]",
            "the", "a", "an", "and", "or", "but", "in", "on", "at", "to", 
            "for", "of", "with", "by", "is", "are", "was", "were", "be", "been"
        }
        
        for i, token in enumerate(original_tokens):
            if (token not in skip_tokens and 
                not token.startswith("##") and  # avoid subword tokens
                len(token) > 2 and
                token.isalpha()):
                mask_positions.append(i)
    
    elif strategy == "every_nth":
        # Mask every 3rd token (avoiding special tokens)
        for i in range(1, len(original_tokens) - 1, 3):  # skip CLS and SEP
            mask_positions.append(i)
    
    elif strategy == "random":
        # Randomly mask 15% of tokens
        import random
        candidates = list(range(1, len(original_tokens) - 1))  # skip CLS and SEP
        num_to_mask = max(1, int(len(candidates) * 0.15))
        mask_positions = random.sample(candidates, min(num_to_mask, len(candidates)))
        mask_positions.sort()
    
    elif strategy == "manual":
        # For manual specification - return original for now
        # Users can specify positions in the UI
        pass
    
    # Limit to reasonable number of masks
    mask_positions = mask_positions[:10]  # Max 10 masks for UI clarity
    
    # Create masked input
    masked_input_ids = input_ids.clone()
    for pos in mask_positions:
        masked_input_ids[pos] = MASK_ID
    
    return masked_input_ids.unsqueeze(0), attention_mask.unsqueeze(0), original_tokens, mask_positions


@spaces.GPU
def symbolic_classification_analysis(text, selected_roles, masking_strategy="content_words", num_predictions=5):
    """
    Perform symbolic classification analysis using MLM prediction
    FIXED: Now tests what the model actually learned
    """
    if not selected_roles:
        selected_roles = list(symbolic_token_ids.keys())
    
    if not text.strip():
        return "Please enter some text to analyze.", "", 0
    
    try:
        # DETECT if input follows training pattern vs needs conversion
        if any(role in text for role in symbolic_token_ids.keys()):
            # Input already has symbolic tokens - test descriptive prediction
            return test_descriptive_prediction(text, selected_roles, num_predictions)
        else:
            # Convert input to training-style format and test
            return test_with_context_injection(text, selected_roles, num_predictions)
        
    except Exception as e:
        error_msg = f"Error during analysis: {str(e)}"
        print(error_msg)
        return error_msg, "", 0


def test_descriptive_prediction(text, selected_roles, num_predictions):
    """
    Test what descriptive words the model predicts after symbolic tokens
    This matches the actual training objective
    """
    # Find positions after symbolic tokens
    tokens = tokenizer.tokenize(text, add_special_tokens=True)
    token_ids = tokenizer.convert_tokens_to_ids(tokens)
    
    # Find symbolic token positions
    symbolic_positions = []
    for i, token in enumerate(tokens):
        if token in symbolic_token_ids:
            # Mask the next 1-3 positions after symbolic token
            for offset in range(1, min(4, len(tokens) - i)):
                if i + offset < len(tokens) and tokens[i + offset] not in ['[SEP]', '[PAD]']:
                    symbolic_positions.append({
                        'mask_pos': i + offset,
                        'symbolic_token': token,
                        'original_token': tokens[i + offset]
                    })
    
    if not symbolic_positions:
        return "No symbolic tokens found in input. Try format like: '<subject> a young woman'", "", 0
    
    # Create masked versions and get predictions
    results = []
    for pos_info in symbolic_positions[:5]:  # Limit to 5 positions
        masked_ids = token_ids.copy()
        masked_ids[pos_info['mask_pos']] = MASK_ID
        
        # Get MLM predictions
        masked_input = torch.tensor([masked_ids]).to("cuda")
        attention_mask = torch.ones_like(masked_input)
        
        with torch.no_grad():
            outputs = full_model(input_ids=masked_input, attention_mask=attention_mask)
            logits = outputs.logits[0, pos_info['mask_pos']]  # Logits for masked position
        
        # Get top 10 predictions from full vocabulary
        probs = F.softmax(logits, dim=-1)
        top_indices = torch.argsort(probs, descending=True)[:num_predictions]
        
        predictions = []
        for idx in top_indices:
            token_text = tokenizer.convert_ids_to_tokens([idx.item()])[0]
            prob = probs[idx].item()
            predictions.append({
                "token": token_text,
                "probability": prob
            })
        
        results.append({
            "symbolic_context": pos_info['symbolic_token'],
            "position": pos_info['mask_pos'],
            "original_token": pos_info['original_token'],
            "predictions": predictions
        })
    
    # Format results
    analysis = {
        "input_text": text,
        "test_type": "descriptive_prediction",
        "explanation": "Testing what descriptive words model predicts after symbolic tokens",
        "results": results
    }
    
    summary_lines = [f"🎯 Testing Descriptive Prediction (what model actually learned)\n"]
    for result in results:
        ctx = result["symbolic_context"]
        orig = result["original_token"] 
        top_pred = result["predictions"][0]
        
        summary_lines.append(
            f"After {ctx}: '{orig}' β†’ '{top_pred['token']}' ({top_pred['probability']:.4f})"
        )
    
    summary = "\n".join(summary_lines)
    return json.dumps(analysis, indent=2), summary, len(results)


def test_with_context_injection(text, selected_roles, num_predictions):
    """
    Inject symbolic context and test what descriptive words are predicted
    """
    results = []
    
    # Test each selected symbolic role as context
    for role in selected_roles[:3]:  # Limit to 3 roles for speed
        # Create training-style context
        context_text = f"{role} {text}"
        
        # Tokenize and find good positions to mask
        tokens = tokenizer.tokenize(context_text, add_special_tokens=True)
        token_ids = tokenizer.convert_tokens_to_ids(tokens)
        
        # Find role position and mask next content word
        role_pos = None
        for i, token in enumerate(tokens):
            if token == role:
                role_pos = i
                break
        
        if role_pos is None or role_pos + 2 >= len(tokens):
            continue
            
        # Mask position after role (skip articles like "a", "the")
        mask_pos = role_pos + 1
        skip_words = {'a', 'an', 'the', 'some', 'this', 'that'}
        while mask_pos < len(tokens) - 1:
            current_token = tokens[mask_pos].lower()
            if current_token not in skip_words and len(current_token) > 2:
                break
            mask_pos += 1
        
        if mask_pos >= len(tokens):
            continue
            
        # Create masked input
        masked_ids = token_ids.copy()
        original_token = tokens[mask_pos]
        masked_ids[mask_pos] = MASK_ID
        
        # Get predictions
        masked_input = torch.tensor([masked_ids]).to("cuda")
        attention_mask = torch.ones_like(masked_input)
        
        with torch.no_grad():
            outputs = full_model(input_ids=masked_input, attention_mask=attention_mask)
            logits = outputs.logits[0, mask_pos]
        
        # Get top predictions
        probs = F.softmax(logits, dim=-1)
        top_indices = torch.argsort(probs, descending=True)[:num_predictions]
        
        predictions = []
        for idx in top_indices:
            token_text = tokenizer.convert_ids_to_tokens([idx.item()])[0]
            prob = probs[idx].item()
            predictions.append({
                "token": token_text,
                "probability": prob
            })
        
        results.append({
            "symbolic_context": role,
            "position": mask_pos,
            "original_token": original_token,
            "context_text": context_text,
            "predictions": predictions
        })
    
    # Format results
    analysis = {
        "input_text": text,
        "test_type": "context_injection",
        "explanation": "Injected symbolic tokens and tested descriptive predictions",
        "results": results
    }
    
    summary_lines = [f"🎯 Testing with Symbolic Context Injection\n"]
    for result in results:
        role = result["symbolic_context"]
        orig = result["original_token"]
        top_pred = result["predictions"][0]
        
        summary_lines.append(
            f"{role} context: '{orig}' β†’ '{top_pred['token']}' ({top_pred['probability']:.4f})"
        )
    
    summary = "\n".join(summary_lines)
    return json.dumps(analysis, indent=2), summary, len(results)


def create_manual_mask_analysis(text, mask_positions_str, selected_roles):
    """
    Allow manual specification of mask positions
    """
    try:
        # Parse mask positions
        mask_positions = [int(x.strip()) for x in mask_positions_str.split(",") if x.strip().isdigit()]
        
        if not mask_positions:
            return "Please specify valid mask positions (comma-separated numbers)", "", 0
        
        # Tokenize text
        batch = tokenizer(text, return_tensors="pt", add_special_tokens=True)
        input_ids = batch.input_ids[0]
        attention_mask = batch.attention_mask[0]
        original_tokens = tokenizer.convert_ids_to_tokens(input_ids)
        
        # Validate positions
        valid_positions = [pos for pos in mask_positions if 0 <= pos < len(input_ids)]
        if not valid_positions:
            return f"Invalid positions. Text has {len(input_ids)} tokens (0-{len(input_ids)-1})", "", 0
        
        # Create masked input
        masked_input_ids = input_ids.clone()
        for pos in valid_positions:
            masked_input_ids[pos] = MASK_ID
        
        # Run analysis
        masked_input_ids = masked_input_ids.unsqueeze(0).to("cuda")
        attention_mask = attention_mask.unsqueeze(0).to("cuda")
        
        predictions = get_symbolic_predictions(
            masked_input_ids, attention_mask, valid_positions, selected_roles
        )
        
        # Format results
        results = []
        for pred_data in predictions:
            pos = pred_data["position"]
            original = original_tokens[pos]
            top_pred = pred_data["predictions"][0] if pred_data["predictions"] else None
            
            if top_pred:
                results.append(
                    f"Pos {pos}: '{original}' β†’ {top_pred['token']} ({top_pred['probability']:.4f})"
                )
        
        return "\n".join(results), f"Analyzed {len(valid_positions)} positions", len(valid_positions)
        
    except Exception as e:
        return f"Error: {str(e)}", "", 0


# ------------------------------------------------------------------
# 4. Gradio UI -----------------------------------------------------
def build_interface():
    with gr.Blocks(title="🧠 MLM Symbolic Classifier", theme=gr.themes.Soft()) as demo:
        gr.Markdown("# 🧠 MLM-Based Symbolic Classification")
        gr.Markdown("Analyze text using masked language modeling to predict symbolic roles at specific positions.")
        
        with gr.Tab("Automatic Analysis"):
            with gr.Row():
                with gr.Column():
                    txt_input = gr.Textbox(
                        label="Input Text", 
                        lines=4,
                        placeholder="Try: '<subject> a young woman wearing elegant dress' or just 'young woman wearing dress'"
                    )
                    
                    with gr.Row():
                        masking_strategy = gr.Dropdown(
                            choices=["content_words", "every_nth", "random"],
                            value="content_words",
                            label="Masking Strategy"
                        )
                        num_predictions = gr.Slider(
                            minimum=1, maximum=10, value=5, step=1,
                            label="Top Predictions per Position"
                        )
                    
                    roles_selection = gr.CheckboxGroup(
                        choices=list(symbolic_token_ids.keys()),
                        value=list(symbolic_token_ids.keys()),
                        label="Symbolic Roles to Consider"
                    )
                    
                    analyze_btn = gr.Button("πŸ” Analyze", variant="primary")
                
                with gr.Column():
                    summary_output = gr.Textbox(
                        label="Analysis Summary", 
                        lines=10,
                        max_lines=15
                    )
                    
                    with gr.Row():
                        positions_analyzed = gr.Number(label="Positions Analyzed", precision=0)
                        max_confidence = gr.Textbox(label="Best Prediction", max_lines=1)
            
            detailed_output = gr.JSON(label="Detailed Results")
        
        with gr.Tab("Manual Masking"):
            with gr.Row():
                with gr.Column():
                    manual_text = gr.Textbox(
                        label="Input Text",
                        lines=3,
                        placeholder="Enter text for manual analysis..."
                    )
                    
                    mask_positions_input = gr.Textbox(
                        label="Mask Positions (comma-separated)",
                        placeholder="e.g., 2,5,8,12",
                        info="Specify token positions to mask (0-based indexing)"
                    )
                    
                    manual_roles = gr.CheckboxGroup(
                        choices=list(symbolic_token_ids.keys()),
                        value=list(symbolic_token_ids.keys())[:10],  # Default to first 10
                        label="Symbolic Roles"
                    )
                    
                    manual_analyze_btn = gr.Button("🎯 Analyze Specific Positions")
                
                with gr.Column():
                    manual_results = gr.Textbox(
                        label="Manual Analysis Results",
                        lines=8
                    )
                    
                    manual_summary = gr.Textbox(label="Summary")
                    manual_count = gr.Number(label="Positions", precision=0)
        
        with gr.Tab("Token Inspector"):
            with gr.Row():
                with gr.Column():
                    inspect_text = gr.Textbox(
                        label="Text to Inspect",
                        lines=2,
                        placeholder="Enter text to see tokenization..."
                    )
                    
                    # Add example patterns button
                    example_patterns = gr.Button("πŸ“ Load Image Caption Examples")
                    
                    inspect_btn = gr.Button("πŸ” Inspect Tokens")
                
                with gr.Column():
                    token_breakdown = gr.Textbox(
                        label="Token Breakdown",
                        lines=8,
                        info="Shows how text is tokenized with position indices"
                    )
        
        with gr.Tab("Caption Examples"):
            gr.Markdown("### πŸ–ΌοΈ Test with Training-Style Patterns")
            gr.Markdown("""
            **The model was trained to predict descriptive words AFTER symbolic tokens.**
            
            Test with patterns like:
            - `<subject> a young woman wearing elegant dress`
            - `<lighting> soft natural illumination on the scene`
            - `<emotion> happy expression while posing confidently`
            """)
            
            example_captions = [
                "<subject> a young woman wearing a blue dress",
                "<lighting> soft natural illumination in the scene",  
                "<emotion> happy expression while posing confidently",
                "<pose> standing gracefully near the window",
                "<upper_body_clothing> elegant silk blouse with intricate patterns",
                "<material> luxurious velvet fabric with rich texture",
                "<accessory> delicate silver jewelry catching the light",
                "<surface> polished marble floor reflecting ambient glow"
            ]
            
            for caption in example_captions:
                with gr.Row():
                    gr.Textbox(value=caption, label="Training-Style Example", interactive=False, scale=3)
                    copy_btn = gr.Button("πŸ“‹ Test This", scale=1)
        
        # Event handlers
        analyze_btn.click(
            symbolic_classification_analysis,
            inputs=[txt_input, roles_selection, masking_strategy, num_predictions],
            outputs=[detailed_output, summary_output, positions_analyzed]
        )
        
        manual_analyze_btn.click(
            create_manual_mask_analysis,
            inputs=[manual_text, mask_positions_input, manual_roles],
            outputs=[manual_results, manual_summary, manual_count]
        )
        
        def load_examples():
            return "a young woman wearing a blue dress"
        
        def inspect_tokens(text):
            if not text.strip():
                return "Enter text to inspect tokenization"
            
            tokens = tokenizer.tokenize(text, add_special_tokens=True)
            result_lines = []
            
            for i, token in enumerate(tokens):
                result_lines.append(f"{i:2d}: '{token}'")
            
            return "\n".join(result_lines)
        
        # Event handlers
        example_patterns.click(
            load_examples,
            outputs=[inspect_text]
        )
        
        inspect_btn.click(
            inspect_tokens,
            inputs=[inspect_text],
            outputs=[token_breakdown]
        )
    
    return demo


if __name__ == "__main__":
    print("πŸš€ Starting MLM Symbolic Classifier...")
    print(f"βœ… Model loaded with {len(symbolic_token_ids)} symbolic tokens")
    print(f"🎯 Available symbolic roles: {list(symbolic_token_ids.keys())[:5]}...")
    
    build_interface().launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=True
    )