AbstractPhil commited on
Commit
65782fe
·
verified ·
1 Parent(s): 047762f

Create trainer.py

Browse files
Files changed (1) hide show
  1. trainer.py +549 -0
trainer.py ADDED
@@ -0,0 +1,549 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################
2
+ ## penta-classifier-prototype
3
+ #################################################################################
4
+ ## Author: AbstractPhil
5
+ ## Assistant: Claude Opus 4.1
6
+ #################################################################################
7
+ ## License Apache - cite with care and share with passionate individuals.
8
+ ##
9
+ ## This tiny model somehow defeated all my larger variants.
10
+ ## The first model showing direct evidence of potential pentachora scaling.
11
+ ## No pretraining, pure noise. Nothing bulky or extra, just run it.
12
+ ##
13
+ ## Somehow, this model contains 60+ classifiers in 3 pentachora.
14
+ ## I'm still uncertain as to why, as it defeated the projections.
15
+ ## I need additional research, additional time. But here's the model.
16
+ ##
17
+ ## This is based on one of my earlier prototypes and thus is labeled.
18
+ ## Somehow over the development it fell apart, today I put it together again.
19
+ ##
20
+ #################################################################################
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+ import torch.nn.functional as F
25
+ from torchvision import datasets, transforms
26
+ from torch.utils.data import DataLoader
27
+ import numpy as np
28
+ import matplotlib.pyplot as plt
29
+ from tqdm import tqdm
30
+ from torch.utils.tensorboard import SummaryWriter
31
+ from huggingface_hub import HfApi, create_repo, upload_folder
32
+ from safetensors.torch import save_file, load_file
33
+ import os
34
+ import json
35
+ import hashlib
36
+ from datetime import datetime
37
+ from google.colab import userdata
38
+
39
+ # ============== SETUP HF AND PATHS ==============
40
+ HF_TOKEN = userdata.get('HF_TOKEN')
41
+ REPO_ID = "AbstractPhil/penta-classifier-prototype"
42
+
43
+ # Create unique run ID based on timestamp and config
44
+ run_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
45
+ config_str = f"emnist_byclass_b1024_lr1e-3_{run_timestamp}"
46
+ run_hash = hashlib.md5(config_str.encode()).hexdigest()[:8]
47
+
48
+ # Local directories
49
+ os.makedirs("checkpoints", exist_ok=True)
50
+ os.makedirs("tensorboard_logs", exist_ok=True)
51
+
52
+ # TensorBoard setup
53
+ writer = SummaryWriter(f'tensorboard_logs/{run_hash}')
54
+
55
+ # Initialize HF API
56
+ api = HfApi()
57
+ try:
58
+ create_repo(REPO_ID, repo_type="model", token=HF_TOKEN, exist_ok=True)
59
+ print(f"Using HuggingFace repo: {REPO_ID}")
60
+ except Exception as e:
61
+ print(f"Repo setup: {e}")
62
+
63
+ # ============== CONFIGURATION ==============
64
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
65
+ print(f"Using device: {device}")
66
+ if device.type == "cuda":
67
+ print(f"GPU: {torch.cuda.get_device_name(0)}")
68
+ print(f"Memory Allocated: {torch.cuda.memory_allocated(0) / 1e9:.2f} GB")
69
+ torch.backends.cudnn.benchmark = True
70
+ torch.backends.cudnn.enabled = True
71
+
72
+ # Hyperparameters
73
+ config = {
74
+ "input_dim": 28 * 28,
75
+ "base_dim": 64,
76
+ "batch_size": 1024,
77
+ "epochs": 5,
78
+ "initial_lr": 1e-3,
79
+ "temp_contrastive": 0.1,
80
+ "lambda_contrastive": 0.5,
81
+ "lambda_cayley": 0.01,
82
+ "dataset": "EMNIST_byclass",
83
+ "run_hash": run_hash,
84
+ "timestamp": run_timestamp
85
+ }
86
+
87
+ # Save config
88
+ config_path = f"checkpoints/config_{run_hash}.json"
89
+ with open(config_path, 'w') as f:
90
+ json.dump(config, f, indent=2)
91
+
92
+ # Log config to TensorBoard
93
+ writer.add_text('Config', json.dumps(config, indent=2), 0)
94
+
95
+ # ============== DATASET ==============
96
+ transform = transforms.Compose([
97
+ transforms.ToTensor(),
98
+ transforms.Lambda(lambda x: x.view(-1))
99
+ ])
100
+
101
+ train_dataset = datasets.EMNIST(root="./data", split='byclass', train=True, transform=transform, download=True)
102
+ test_dataset = datasets.EMNIST(root="./data", split='byclass', train=False, transform=transform, download=True)
103
+
104
+ num_classes = len(train_dataset.classes)
105
+ config["num_classes"] = num_classes
106
+
107
+ train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], pin_memory=True,
108
+ shuffle=True, num_workers=4, prefetch_factor=8)
109
+ test_loader = DataLoader(test_dataset, batch_size=config["batch_size"], pin_memory=True,
110
+ shuffle=False, num_workers=4, prefetch_factor=8)
111
+
112
+ print(f"Train: {len(train_dataset)} samples, Test: {len(test_dataset)} samples")
113
+ print(f"Classes: {num_classes}")
114
+
115
+ # ============== MODEL DEFINITIONS ==============
116
+ class AdaptiveEncoder(nn.Module):
117
+ """Multi-layer encoder with normalization and multi-scale outputs"""
118
+ def __init__(self, input_dim, base_dim=128):
119
+ super().__init__()
120
+ self.fc1 = nn.Linear(input_dim, 512)
121
+ self.bn1 = nn.BatchNorm1d(512)
122
+ self.dropout1 = nn.Dropout(0.2)
123
+
124
+ self.fc2 = nn.Linear(512, 256)
125
+ self.bn2 = nn.BatchNorm1d(256)
126
+ self.dropout2 = nn.Dropout(0.2)
127
+
128
+ self.fc3 = nn.Linear(256, 128)
129
+ self.bn3 = nn.BatchNorm1d(128)
130
+
131
+ self.fc_coarse = nn.Linear(256, base_dim // 4)
132
+ self.fc_medium = nn.Linear(128, base_dim // 2)
133
+ self.fc_fine = nn.Linear(128, base_dim)
134
+
135
+ self.norm_coarse = nn.LayerNorm(base_dim // 4)
136
+ self.norm_medium = nn.LayerNorm(base_dim // 2)
137
+ self.norm_fine = nn.LayerNorm(base_dim)
138
+
139
+ def forward(self, x):
140
+ h1 = F.relu(self.bn1(self.fc1(x)))
141
+ h1 = self.dropout1(h1)
142
+ h2 = F.relu(self.bn2(self.fc2(h1)))
143
+ h2 = self.dropout2(h2)
144
+ h3 = F.relu(self.bn3(self.fc3(h2)))
145
+
146
+ coarse = self.norm_coarse(self.fc_coarse(h2))
147
+ medium = self.norm_medium(self.fc_medium(h3))
148
+ fine = self.norm_fine(self.fc_fine(h3))
149
+
150
+ return coarse, medium, fine
151
+
152
+ def init_perfect_pentachora(num_classes, latent_dim, device='cuda'):
153
+ """Initialize as regular 4-simplices in orthogonal subspaces"""
154
+ pentachora = torch.zeros(num_classes, 5, latent_dim, device=device)
155
+
156
+ sqrt15 = np.sqrt(15)
157
+ sqrt10 = np.sqrt(10)
158
+ sqrt5 = np.sqrt(5)
159
+
160
+ simplex = torch.tensor([
161
+ [1.0, 0.0, 0.0, 0.0],
162
+ [-0.25, sqrt15/4, 0.0, 0.0],
163
+ [-0.25, -sqrt15/12, sqrt10/3, 0.0],
164
+ [-0.25, -sqrt15/12, -sqrt10/6, sqrt5/2],
165
+ [-0.25, -sqrt15/12, -sqrt10/6, -sqrt5/2]
166
+ ], dtype=torch.float32, device=device)
167
+
168
+ simplex = F.normalize(simplex, dim=1)
169
+
170
+ dims_per_class = latent_dim // num_classes
171
+ for c in range(num_classes):
172
+ if dims_per_class >= 4:
173
+ start = c * dims_per_class
174
+ pentachora[c, :, start:start+4] = simplex
175
+ else:
176
+ rotation = torch.randn(4, latent_dim, device=device)
177
+ rotation = F.normalize(rotation, dim=1)
178
+ pentachora[c] = torch.mm(simplex, rotation[:4])
179
+
180
+ return nn.Parameter(pentachora * 2.0)
181
+
182
+ class PerfectPentachoron(nn.Module):
183
+ """Multi-scale pentachoron with learnable metric and vertex weights"""
184
+ def __init__(self, num_classes, base_dim, device='cuda'):
185
+ super().__init__()
186
+ self.device = device
187
+ self.num_classes = num_classes
188
+ self.base_dim = base_dim
189
+
190
+ self.penta_coarse = init_perfect_pentachora(num_classes, base_dim // 4, device)
191
+ self.penta_medium = init_perfect_pentachora(num_classes, base_dim // 2, device)
192
+ self.penta_fine = init_perfect_pentachora(num_classes, base_dim, device)
193
+
194
+ self.vertex_weights = nn.Parameter(torch.ones(num_classes, 5, device=device) / 5)
195
+
196
+ self.metric_coarse = nn.Parameter(torch.eye(base_dim // 4, device=device))
197
+ self.metric_medium = nn.Parameter(torch.eye(base_dim // 2, device=device))
198
+ self.metric_fine = nn.Parameter(torch.eye(base_dim, device=device))
199
+
200
+ self.scale_weights = nn.Parameter(torch.tensor([0.2, 0.3, 0.5], device=device))
201
+
202
+ def mahalanobis_distance(self, x, pentachora, metric):
203
+ x_trans = torch.matmul(x, metric)
204
+ p_trans = torch.einsum('cpd,de->cpe', pentachora, metric)
205
+ diffs = p_trans.unsqueeze(0) - x_trans.unsqueeze(1).unsqueeze(2)
206
+ dists = torch.norm(diffs, dim=-1)
207
+ return dists
208
+
209
+ def forward(self, x_coarse, x_medium, x_fine):
210
+ dists_c = self.mahalanobis_distance(x_coarse, self.penta_coarse, self.metric_coarse)
211
+ dists_m = self.mahalanobis_distance(x_medium, self.penta_medium, self.metric_medium)
212
+ dists_f = self.mahalanobis_distance(x_fine, self.penta_fine, self.metric_fine)
213
+
214
+ weights = F.softmax(self.vertex_weights, dim=1).unsqueeze(0)
215
+ dists_c = dists_c * weights
216
+ dists_m = dists_m * weights
217
+ dists_f = dists_f * weights
218
+
219
+ scores_c = -dists_c.sum(dim=-1)
220
+ scores_m = -dists_m.sum(dim=-1)
221
+ scores_f = -dists_f.sum(dim=-1)
222
+
223
+ w = F.softmax(self.scale_weights, dim=0)
224
+ scores = w[0] * scores_c + w[1] * scores_m + w[2] * scores_f
225
+
226
+ return scores, (dists_c, dists_m, dists_f)
227
+
228
+ def regularization_loss(self):
229
+ mask = torch.triu(torch.ones(5, 5, device=self.device), diagonal=1).bool()
230
+
231
+ diffs_c = self.penta_coarse.unsqueeze(2) - self.penta_coarse.unsqueeze(1)
232
+ dists_c = torch.norm(diffs_c, dim=-1)
233
+ edges_c = dists_c[:, mask]
234
+
235
+ diffs_m = self.penta_medium.unsqueeze(2) - self.penta_medium.unsqueeze(1)
236
+ dists_m = torch.norm(diffs_m, dim=-1)
237
+ edges_m = dists_m[:, mask]
238
+
239
+ diffs_f = self.penta_fine.unsqueeze(2) - self.penta_fine.unsqueeze(1)
240
+ dists_f = torch.norm(diffs_f, dim=-1)
241
+ edges_f = dists_f[:, mask]
242
+
243
+ all_edges = torch.stack([edges_c, edges_m, edges_f], dim=0)
244
+
245
+ edge_var = torch.var(all_edges, dim=2).mean()
246
+ min_edges = torch.min(all_edges, dim=2)[0]
247
+ collapse_penalty = torch.relu(0.5 - min_edges).mean()
248
+
249
+ return edge_var + collapse_penalty
250
+
251
+ def contrastive_pentachoron_loss_batched(latents, targets, pentachora, temp=0.1):
252
+ batch_size = latents.size(0)
253
+ num_classes = pentachora.size(0)
254
+
255
+ diffs = latents.unsqueeze(1).unsqueeze(2) - pentachora.unsqueeze(0)
256
+ dists = torch.norm(diffs, dim=-1)
257
+ min_dists, _ = torch.min(dists, dim=2)
258
+
259
+ sims = -min_dists / temp
260
+ targets_one_hot = F.one_hot(targets, num_classes).float()
261
+
262
+ max_sims, _ = torch.max(sims, dim=1, keepdim=True)
263
+ exp_sims = torch.exp(sims - max_sims)
264
+
265
+ pos_sims = torch.sum(exp_sims * targets_one_hot, dim=1)
266
+ all_sims = torch.sum(exp_sims, dim=1)
267
+
268
+ loss = -torch.log(pos_sims / all_sims).mean()
269
+ return loss
270
+
271
+ # ============== TRAINING SETUP ==============
272
+ encoder = AdaptiveEncoder(config["input_dim"], config["base_dim"]).to(device)
273
+ classifier = PerfectPentachoron(num_classes, config["base_dim"], device).to(device)
274
+
275
+ # Try to compile if available
276
+ try:
277
+ encoder = torch.compile(encoder)
278
+ classifier = torch.compile(classifier)
279
+ print("Models compiled successfully")
280
+ except:
281
+ print("Torch compile not available, using eager mode")
282
+
283
+ optimizer = torch.optim.AdamW([
284
+ {'params': encoder.parameters(), 'lr': config["initial_lr"]},
285
+ {'params': classifier.parameters(), 'lr': config["initial_lr"] * 0.5}
286
+ ], weight_decay=1e-5)
287
+
288
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config["epochs"])
289
+
290
+ # ============== CHECKPOINT FUNCTIONS ==============
291
+ def save_checkpoint(epoch, encoder, classifier, optimizer, scheduler, metrics, is_best=False):
292
+ """Save checkpoint as safetensors with proper organization"""
293
+ # Prepare state dict for safetensors
294
+ encoder_state = {f"encoder.{k}": v.cpu() for k, v in encoder.state_dict().items()}
295
+ classifier_state = {f"classifier.{k}": v.cpu() for k, v in classifier.state_dict().items()}
296
+
297
+ # Combine all model weights
298
+ model_state = {**encoder_state, **classifier_state}
299
+
300
+ # Save model weights as safetensors
301
+ checkpoint_name = f"checkpoint_{run_hash}_epoch_{epoch:03d}.safetensors"
302
+ if is_best:
303
+ checkpoint_name = f"best_{run_hash}.safetensors"
304
+
305
+ checkpoint_path = os.path.join("checkpoints", checkpoint_name)
306
+ save_file(model_state, checkpoint_path)
307
+
308
+ # Save training state separately (optimizer, scheduler, metrics)
309
+ training_state = {
310
+ 'epoch': epoch,
311
+ 'optimizer': optimizer.state_dict(),
312
+ 'scheduler': scheduler.state_dict(),
313
+ 'metrics': metrics,
314
+ 'config': config
315
+ }
316
+
317
+ state_path = checkpoint_path.replace('.safetensors', '_state.pt')
318
+ torch.save(training_state, state_path)
319
+
320
+ print(f"Saved checkpoint: {checkpoint_name}")
321
+
322
+ # Upload to HuggingFace
323
+ try:
324
+ # Create organized structure
325
+ upload_folder(
326
+ folder_path="checkpoints",
327
+ repo_id=REPO_ID,
328
+ repo_type="model",
329
+ token=HF_TOKEN,
330
+ path_in_repo=f"weights/{run_hash}",
331
+ commit_message=f"Epoch {epoch} - Test Acc: {metrics['test_acc']:.4f}"
332
+ )
333
+
334
+ # Upload tensorboard logs
335
+ upload_folder(
336
+ folder_path=f"tensorboard_logs/{run_hash}",
337
+ repo_id=REPO_ID,
338
+ repo_type="model",
339
+ token=HF_TOKEN,
340
+ path_in_repo=f"runs/{run_hash}",
341
+ commit_message=f"TensorBoard logs - Epoch {epoch}"
342
+ )
343
+ except Exception as e:
344
+ print(f"HF upload error: {e}")
345
+
346
+ # ============== TRAINING FUNCTIONS ==============
347
+ def train_epoch(epoch):
348
+ encoder.train()
349
+ classifier.train()
350
+
351
+ total_loss = 0.0
352
+ total_ce = 0.0
353
+ total_contr = 0.0
354
+ total_reg = 0.0
355
+ correct = 0
356
+ total = 0
357
+
358
+ pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']}")
359
+ for batch_idx, (inputs, targets) in enumerate(pbar):
360
+ inputs, targets = inputs.to(device), targets.to(device)
361
+
362
+ optimizer.zero_grad()
363
+
364
+ x_coarse, x_medium, x_fine = encoder(inputs)
365
+ scores, all_dists = classifier(x_coarse, x_medium, x_fine)
366
+
367
+ ce_loss = F.cross_entropy(scores, targets)
368
+
369
+ contr_c = contrastive_pentachoron_loss_batched(x_coarse, targets, classifier.penta_coarse, config["temp_contrastive"])
370
+ contr_m = contrastive_pentachoron_loss_batched(x_medium, targets, classifier.penta_medium, config["temp_contrastive"])
371
+ contr_f = contrastive_pentachoron_loss_batched(x_fine, targets, classifier.penta_fine, config["temp_contrastive"])
372
+ contr_loss = (contr_c + contr_m + contr_f) / 3
373
+
374
+ reg_loss = classifier.regularization_loss()
375
+
376
+ loss = ce_loss + config["lambda_contrastive"] * contr_loss + config["lambda_cayley"] * reg_loss
377
+
378
+ loss.backward()
379
+ torch.nn.utils.clip_grad_norm_(encoder.parameters(), 1.0)
380
+ torch.nn.utils.clip_grad_norm_(classifier.parameters(), 1.0)
381
+ optimizer.step()
382
+
383
+ total_loss += loss.item() * inputs.size(0)
384
+ total_ce += ce_loss.item() * inputs.size(0)
385
+ total_contr += contr_loss.item() * inputs.size(0)
386
+ total_reg += reg_loss.item() * inputs.size(0)
387
+
388
+ preds = scores.argmax(dim=1)
389
+ correct += (preds == targets).sum().item()
390
+ total += inputs.size(0)
391
+
392
+ # Log batch metrics to TensorBoard
393
+ if batch_idx % 50 == 0:
394
+ global_step = epoch * len(train_loader) + batch_idx
395
+ writer.add_scalar('Train/BatchLoss', loss.item(), global_step)
396
+ writer.add_scalar('Train/BatchAcc', correct/total, global_step)
397
+
398
+ pbar.set_postfix({
399
+ 'loss': f"{loss.item():.4f}",
400
+ 'acc': f"{correct/total:.4f}",
401
+ 'lr': f"{optimizer.param_groups[0]['lr']:.1e}"
402
+ })
403
+
404
+ return (total_loss/total, total_ce/total, total_contr/total,
405
+ total_reg/total, correct/total)
406
+
407
+ @torch.no_grad()
408
+ def evaluate():
409
+ encoder.eval()
410
+ classifier.eval()
411
+
412
+ correct = 0
413
+ total = 0
414
+ class_correct = [0] * num_classes
415
+ class_total = [0] * num_classes
416
+
417
+ pbar = tqdm(test_loader, desc="Evaluating")
418
+ for inputs, targets in pbar:
419
+ inputs, targets = inputs.to(device), targets.to(device)
420
+
421
+ x_coarse, x_medium, x_fine = encoder(inputs)
422
+ scores, _ = classifier(x_coarse, x_medium, x_fine)
423
+
424
+ preds = scores.argmax(dim=1)
425
+ correct += (preds == targets).sum().item()
426
+ total += inputs.size(0)
427
+
428
+ for i in range(targets.size(0)):
429
+ label = targets[i].item()
430
+ class_total[label] += 1
431
+ if preds[i] == targets[i]:
432
+ class_correct[label] += 1
433
+
434
+ pbar.set_postfix({'acc': f"{correct/total:.4f}"})
435
+
436
+ class_accs = [class_correct[i]/max(1, class_total[i]) for i in range(num_classes)]
437
+ return correct/total, class_accs
438
+
439
+ # ============== MAIN TRAINING LOOP ==============
440
+ print("\n" + "="*60)
441
+ print(f"PERFECT PENTACHORON TRAINING - Run {run_hash}")
442
+ print("="*60 + "\n")
443
+
444
+ best_acc = 0.0
445
+ train_history = []
446
+ test_history = []
447
+ patience = 7
448
+ no_improve = 0
449
+
450
+ for epoch in range(config["epochs"]):
451
+ # Train
452
+ train_loss, train_ce, train_contr, train_reg, train_acc = train_epoch(epoch)
453
+ train_history.append(train_acc)
454
+
455
+ # Evaluate
456
+ test_acc, class_accs = evaluate()
457
+ test_history.append(test_acc)
458
+
459
+ # Log to TensorBoard
460
+ writer.add_scalar('Loss/Total', train_loss, epoch)
461
+ writer.add_scalar('Loss/CE', train_ce, epoch)
462
+ writer.add_scalar('Loss/Contrastive', train_contr, epoch)
463
+ writer.add_scalar('Loss/Regularization', train_reg, epoch)
464
+ writer.add_scalar('Accuracy/Train', train_acc, epoch)
465
+ writer.add_scalar('Accuracy/Test', test_acc, epoch)
466
+ writer.add_scalar('Learning/LR', optimizer.param_groups[0]['lr'], epoch)
467
+ writer.add_scalar('Learning/Generalization_Gap', train_acc - test_acc, epoch)
468
+
469
+ # Log per-class accuracies
470
+ for i, acc in enumerate(class_accs[:10]): # Log first 10 classes
471
+ writer.add_scalar(f'ClassAcc/Class_{i}', acc, epoch)
472
+
473
+ # Log scale weights
474
+ scale_weights = F.softmax(classifier.scale_weights, dim=0)
475
+ writer.add_scalar('Scales/Coarse', scale_weights[0], epoch)
476
+ writer.add_scalar('Scales/Medium', scale_weights[1], epoch)
477
+ writer.add_scalar('Scales/Fine', scale_weights[2], epoch)
478
+
479
+ scheduler.step()
480
+
481
+ # Print results
482
+ print(f"\n[Epoch {epoch+1}/{config['epochs']}]")
483
+ print(f"Train | Loss: {train_loss:.4f} | CE: {train_ce:.4f} | "
484
+ f"Contr: {train_contr:.4f} | Reg: {train_reg:.4f} | Acc: {train_acc:.4f}")
485
+ print(f"Test | Acc: {test_acc:.4f} | Best: {best_acc:.4f}")
486
+
487
+ # Save checkpoint
488
+ metrics = {
489
+ 'train_acc': train_acc,
490
+ 'test_acc': test_acc,
491
+ 'train_loss': train_loss,
492
+ 'class_accs': class_accs
493
+ }
494
+
495
+ # Check if best
496
+ if test_acc > best_acc:
497
+ best_acc = test_acc
498
+ no_improve = 0
499
+ print(f"NEW BEST! Saving checkpoint...")
500
+ save_checkpoint(epoch, encoder, classifier, optimizer, scheduler, metrics, is_best=True)
501
+ else:
502
+ no_improve += 1
503
+ if (epoch + 1) % 5 == 0: # Save every 5 epochs
504
+ save_checkpoint(epoch, encoder, classifier, optimizer, scheduler, metrics)
505
+
506
+ # Early stopping
507
+ if no_improve >= patience:
508
+ print(f"Early stopping triggered (no improvement for {patience} epochs)")
509
+ break
510
+
511
+ # ============== FINAL RESULTS ==============
512
+ print("\n" + "="*60)
513
+ print("FINAL RESULTS")
514
+ print("="*60)
515
+ print(f"Best Test Accuracy: {best_acc:.4f}")
516
+ print(f"Final Train Accuracy: {train_history[-1]:.4f}")
517
+ print(f"Generalization Gap: {train_history[-1] - test_history[-1]:.4f}")
518
+
519
+ # Save final model
520
+ save_checkpoint(epoch, encoder, classifier, optimizer, scheduler, metrics, is_best=False)
521
+
522
+ # Log final pentachoron geometry
523
+ with torch.no_grad():
524
+ vertex_importance = F.softmax(classifier.vertex_weights, dim=1)
525
+ scale_weights = F.softmax(classifier.scale_weights, dim=0).cpu().numpy()
526
+
527
+ geometry_info = {
528
+ 'scale_importance': {
529
+ 'coarse': float(scale_weights[0]),
530
+ 'medium': float(scale_weights[1]),
531
+ 'fine': float(scale_weights[2])
532
+ },
533
+ 'dominant_vertices': {}
534
+ }
535
+
536
+ for c in range(min(10, num_classes)):
537
+ weights = vertex_importance[c].cpu().numpy()
538
+ dominant = np.argmax(weights)
539
+ geometry_info['dominant_vertices'][f'class_{c}'] = {
540
+ 'vertex': int(dominant),
541
+ 'weight': float(weights[dominant])
542
+ }
543
+
544
+ writer.add_text('Final_Geometry', json.dumps(geometry_info, indent=2), epoch)
545
+
546
+ writer.close()
547
+ print(f"\n✨ Training Complete! Run hash: {run_hash}")
548
+ print(f"Results uploaded to: https://huggingface.co/{REPO_ID}")
549
+ print(f"TensorBoard: tensorboard --logdir tensorboard_logs/{run_hash}")