Create trainer.py
Browse files- trainer.py +549 -0
trainer.py
ADDED
|
@@ -0,0 +1,549 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################
|
| 2 |
+
## penta-classifier-prototype
|
| 3 |
+
#################################################################################
|
| 4 |
+
## Author: AbstractPhil
|
| 5 |
+
## Assistant: Claude Opus 4.1
|
| 6 |
+
#################################################################################
|
| 7 |
+
## License Apache - cite with care and share with passionate individuals.
|
| 8 |
+
##
|
| 9 |
+
## This tiny model somehow defeated all my larger variants.
|
| 10 |
+
## The first model showing direct evidence of potential pentachora scaling.
|
| 11 |
+
## No pretraining, pure noise. Nothing bulky or extra, just run it.
|
| 12 |
+
##
|
| 13 |
+
## Somehow, this model contains 60+ classifiers in 3 pentachora.
|
| 14 |
+
## I'm still uncertain as to why, as it defeated the projections.
|
| 15 |
+
## I need additional research, additional time. But here's the model.
|
| 16 |
+
##
|
| 17 |
+
## This is based on one of my earlier prototypes and thus is labeled.
|
| 18 |
+
## Somehow over the development it fell apart, today I put it together again.
|
| 19 |
+
##
|
| 20 |
+
#################################################################################
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn as nn
|
| 24 |
+
import torch.nn.functional as F
|
| 25 |
+
from torchvision import datasets, transforms
|
| 26 |
+
from torch.utils.data import DataLoader
|
| 27 |
+
import numpy as np
|
| 28 |
+
import matplotlib.pyplot as plt
|
| 29 |
+
from tqdm import tqdm
|
| 30 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 31 |
+
from huggingface_hub import HfApi, create_repo, upload_folder
|
| 32 |
+
from safetensors.torch import save_file, load_file
|
| 33 |
+
import os
|
| 34 |
+
import json
|
| 35 |
+
import hashlib
|
| 36 |
+
from datetime import datetime
|
| 37 |
+
from google.colab import userdata
|
| 38 |
+
|
| 39 |
+
# ============== SETUP HF AND PATHS ==============
|
| 40 |
+
HF_TOKEN = userdata.get('HF_TOKEN')
|
| 41 |
+
REPO_ID = "AbstractPhil/penta-classifier-prototype"
|
| 42 |
+
|
| 43 |
+
# Create unique run ID based on timestamp and config
|
| 44 |
+
run_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 45 |
+
config_str = f"emnist_byclass_b1024_lr1e-3_{run_timestamp}"
|
| 46 |
+
run_hash = hashlib.md5(config_str.encode()).hexdigest()[:8]
|
| 47 |
+
|
| 48 |
+
# Local directories
|
| 49 |
+
os.makedirs("checkpoints", exist_ok=True)
|
| 50 |
+
os.makedirs("tensorboard_logs", exist_ok=True)
|
| 51 |
+
|
| 52 |
+
# TensorBoard setup
|
| 53 |
+
writer = SummaryWriter(f'tensorboard_logs/{run_hash}')
|
| 54 |
+
|
| 55 |
+
# Initialize HF API
|
| 56 |
+
api = HfApi()
|
| 57 |
+
try:
|
| 58 |
+
create_repo(REPO_ID, repo_type="model", token=HF_TOKEN, exist_ok=True)
|
| 59 |
+
print(f"Using HuggingFace repo: {REPO_ID}")
|
| 60 |
+
except Exception as e:
|
| 61 |
+
print(f"Repo setup: {e}")
|
| 62 |
+
|
| 63 |
+
# ============== CONFIGURATION ==============
|
| 64 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 65 |
+
print(f"Using device: {device}")
|
| 66 |
+
if device.type == "cuda":
|
| 67 |
+
print(f"GPU: {torch.cuda.get_device_name(0)}")
|
| 68 |
+
print(f"Memory Allocated: {torch.cuda.memory_allocated(0) / 1e9:.2f} GB")
|
| 69 |
+
torch.backends.cudnn.benchmark = True
|
| 70 |
+
torch.backends.cudnn.enabled = True
|
| 71 |
+
|
| 72 |
+
# Hyperparameters
|
| 73 |
+
config = {
|
| 74 |
+
"input_dim": 28 * 28,
|
| 75 |
+
"base_dim": 64,
|
| 76 |
+
"batch_size": 1024,
|
| 77 |
+
"epochs": 5,
|
| 78 |
+
"initial_lr": 1e-3,
|
| 79 |
+
"temp_contrastive": 0.1,
|
| 80 |
+
"lambda_contrastive": 0.5,
|
| 81 |
+
"lambda_cayley": 0.01,
|
| 82 |
+
"dataset": "EMNIST_byclass",
|
| 83 |
+
"run_hash": run_hash,
|
| 84 |
+
"timestamp": run_timestamp
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
# Save config
|
| 88 |
+
config_path = f"checkpoints/config_{run_hash}.json"
|
| 89 |
+
with open(config_path, 'w') as f:
|
| 90 |
+
json.dump(config, f, indent=2)
|
| 91 |
+
|
| 92 |
+
# Log config to TensorBoard
|
| 93 |
+
writer.add_text('Config', json.dumps(config, indent=2), 0)
|
| 94 |
+
|
| 95 |
+
# ============== DATASET ==============
|
| 96 |
+
transform = transforms.Compose([
|
| 97 |
+
transforms.ToTensor(),
|
| 98 |
+
transforms.Lambda(lambda x: x.view(-1))
|
| 99 |
+
])
|
| 100 |
+
|
| 101 |
+
train_dataset = datasets.EMNIST(root="./data", split='byclass', train=True, transform=transform, download=True)
|
| 102 |
+
test_dataset = datasets.EMNIST(root="./data", split='byclass', train=False, transform=transform, download=True)
|
| 103 |
+
|
| 104 |
+
num_classes = len(train_dataset.classes)
|
| 105 |
+
config["num_classes"] = num_classes
|
| 106 |
+
|
| 107 |
+
train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], pin_memory=True,
|
| 108 |
+
shuffle=True, num_workers=4, prefetch_factor=8)
|
| 109 |
+
test_loader = DataLoader(test_dataset, batch_size=config["batch_size"], pin_memory=True,
|
| 110 |
+
shuffle=False, num_workers=4, prefetch_factor=8)
|
| 111 |
+
|
| 112 |
+
print(f"Train: {len(train_dataset)} samples, Test: {len(test_dataset)} samples")
|
| 113 |
+
print(f"Classes: {num_classes}")
|
| 114 |
+
|
| 115 |
+
# ============== MODEL DEFINITIONS ==============
|
| 116 |
+
class AdaptiveEncoder(nn.Module):
|
| 117 |
+
"""Multi-layer encoder with normalization and multi-scale outputs"""
|
| 118 |
+
def __init__(self, input_dim, base_dim=128):
|
| 119 |
+
super().__init__()
|
| 120 |
+
self.fc1 = nn.Linear(input_dim, 512)
|
| 121 |
+
self.bn1 = nn.BatchNorm1d(512)
|
| 122 |
+
self.dropout1 = nn.Dropout(0.2)
|
| 123 |
+
|
| 124 |
+
self.fc2 = nn.Linear(512, 256)
|
| 125 |
+
self.bn2 = nn.BatchNorm1d(256)
|
| 126 |
+
self.dropout2 = nn.Dropout(0.2)
|
| 127 |
+
|
| 128 |
+
self.fc3 = nn.Linear(256, 128)
|
| 129 |
+
self.bn3 = nn.BatchNorm1d(128)
|
| 130 |
+
|
| 131 |
+
self.fc_coarse = nn.Linear(256, base_dim // 4)
|
| 132 |
+
self.fc_medium = nn.Linear(128, base_dim // 2)
|
| 133 |
+
self.fc_fine = nn.Linear(128, base_dim)
|
| 134 |
+
|
| 135 |
+
self.norm_coarse = nn.LayerNorm(base_dim // 4)
|
| 136 |
+
self.norm_medium = nn.LayerNorm(base_dim // 2)
|
| 137 |
+
self.norm_fine = nn.LayerNorm(base_dim)
|
| 138 |
+
|
| 139 |
+
def forward(self, x):
|
| 140 |
+
h1 = F.relu(self.bn1(self.fc1(x)))
|
| 141 |
+
h1 = self.dropout1(h1)
|
| 142 |
+
h2 = F.relu(self.bn2(self.fc2(h1)))
|
| 143 |
+
h2 = self.dropout2(h2)
|
| 144 |
+
h3 = F.relu(self.bn3(self.fc3(h2)))
|
| 145 |
+
|
| 146 |
+
coarse = self.norm_coarse(self.fc_coarse(h2))
|
| 147 |
+
medium = self.norm_medium(self.fc_medium(h3))
|
| 148 |
+
fine = self.norm_fine(self.fc_fine(h3))
|
| 149 |
+
|
| 150 |
+
return coarse, medium, fine
|
| 151 |
+
|
| 152 |
+
def init_perfect_pentachora(num_classes, latent_dim, device='cuda'):
|
| 153 |
+
"""Initialize as regular 4-simplices in orthogonal subspaces"""
|
| 154 |
+
pentachora = torch.zeros(num_classes, 5, latent_dim, device=device)
|
| 155 |
+
|
| 156 |
+
sqrt15 = np.sqrt(15)
|
| 157 |
+
sqrt10 = np.sqrt(10)
|
| 158 |
+
sqrt5 = np.sqrt(5)
|
| 159 |
+
|
| 160 |
+
simplex = torch.tensor([
|
| 161 |
+
[1.0, 0.0, 0.0, 0.0],
|
| 162 |
+
[-0.25, sqrt15/4, 0.0, 0.0],
|
| 163 |
+
[-0.25, -sqrt15/12, sqrt10/3, 0.0],
|
| 164 |
+
[-0.25, -sqrt15/12, -sqrt10/6, sqrt5/2],
|
| 165 |
+
[-0.25, -sqrt15/12, -sqrt10/6, -sqrt5/2]
|
| 166 |
+
], dtype=torch.float32, device=device)
|
| 167 |
+
|
| 168 |
+
simplex = F.normalize(simplex, dim=1)
|
| 169 |
+
|
| 170 |
+
dims_per_class = latent_dim // num_classes
|
| 171 |
+
for c in range(num_classes):
|
| 172 |
+
if dims_per_class >= 4:
|
| 173 |
+
start = c * dims_per_class
|
| 174 |
+
pentachora[c, :, start:start+4] = simplex
|
| 175 |
+
else:
|
| 176 |
+
rotation = torch.randn(4, latent_dim, device=device)
|
| 177 |
+
rotation = F.normalize(rotation, dim=1)
|
| 178 |
+
pentachora[c] = torch.mm(simplex, rotation[:4])
|
| 179 |
+
|
| 180 |
+
return nn.Parameter(pentachora * 2.0)
|
| 181 |
+
|
| 182 |
+
class PerfectPentachoron(nn.Module):
|
| 183 |
+
"""Multi-scale pentachoron with learnable metric and vertex weights"""
|
| 184 |
+
def __init__(self, num_classes, base_dim, device='cuda'):
|
| 185 |
+
super().__init__()
|
| 186 |
+
self.device = device
|
| 187 |
+
self.num_classes = num_classes
|
| 188 |
+
self.base_dim = base_dim
|
| 189 |
+
|
| 190 |
+
self.penta_coarse = init_perfect_pentachora(num_classes, base_dim // 4, device)
|
| 191 |
+
self.penta_medium = init_perfect_pentachora(num_classes, base_dim // 2, device)
|
| 192 |
+
self.penta_fine = init_perfect_pentachora(num_classes, base_dim, device)
|
| 193 |
+
|
| 194 |
+
self.vertex_weights = nn.Parameter(torch.ones(num_classes, 5, device=device) / 5)
|
| 195 |
+
|
| 196 |
+
self.metric_coarse = nn.Parameter(torch.eye(base_dim // 4, device=device))
|
| 197 |
+
self.metric_medium = nn.Parameter(torch.eye(base_dim // 2, device=device))
|
| 198 |
+
self.metric_fine = nn.Parameter(torch.eye(base_dim, device=device))
|
| 199 |
+
|
| 200 |
+
self.scale_weights = nn.Parameter(torch.tensor([0.2, 0.3, 0.5], device=device))
|
| 201 |
+
|
| 202 |
+
def mahalanobis_distance(self, x, pentachora, metric):
|
| 203 |
+
x_trans = torch.matmul(x, metric)
|
| 204 |
+
p_trans = torch.einsum('cpd,de->cpe', pentachora, metric)
|
| 205 |
+
diffs = p_trans.unsqueeze(0) - x_trans.unsqueeze(1).unsqueeze(2)
|
| 206 |
+
dists = torch.norm(diffs, dim=-1)
|
| 207 |
+
return dists
|
| 208 |
+
|
| 209 |
+
def forward(self, x_coarse, x_medium, x_fine):
|
| 210 |
+
dists_c = self.mahalanobis_distance(x_coarse, self.penta_coarse, self.metric_coarse)
|
| 211 |
+
dists_m = self.mahalanobis_distance(x_medium, self.penta_medium, self.metric_medium)
|
| 212 |
+
dists_f = self.mahalanobis_distance(x_fine, self.penta_fine, self.metric_fine)
|
| 213 |
+
|
| 214 |
+
weights = F.softmax(self.vertex_weights, dim=1).unsqueeze(0)
|
| 215 |
+
dists_c = dists_c * weights
|
| 216 |
+
dists_m = dists_m * weights
|
| 217 |
+
dists_f = dists_f * weights
|
| 218 |
+
|
| 219 |
+
scores_c = -dists_c.sum(dim=-1)
|
| 220 |
+
scores_m = -dists_m.sum(dim=-1)
|
| 221 |
+
scores_f = -dists_f.sum(dim=-1)
|
| 222 |
+
|
| 223 |
+
w = F.softmax(self.scale_weights, dim=0)
|
| 224 |
+
scores = w[0] * scores_c + w[1] * scores_m + w[2] * scores_f
|
| 225 |
+
|
| 226 |
+
return scores, (dists_c, dists_m, dists_f)
|
| 227 |
+
|
| 228 |
+
def regularization_loss(self):
|
| 229 |
+
mask = torch.triu(torch.ones(5, 5, device=self.device), diagonal=1).bool()
|
| 230 |
+
|
| 231 |
+
diffs_c = self.penta_coarse.unsqueeze(2) - self.penta_coarse.unsqueeze(1)
|
| 232 |
+
dists_c = torch.norm(diffs_c, dim=-1)
|
| 233 |
+
edges_c = dists_c[:, mask]
|
| 234 |
+
|
| 235 |
+
diffs_m = self.penta_medium.unsqueeze(2) - self.penta_medium.unsqueeze(1)
|
| 236 |
+
dists_m = torch.norm(diffs_m, dim=-1)
|
| 237 |
+
edges_m = dists_m[:, mask]
|
| 238 |
+
|
| 239 |
+
diffs_f = self.penta_fine.unsqueeze(2) - self.penta_fine.unsqueeze(1)
|
| 240 |
+
dists_f = torch.norm(diffs_f, dim=-1)
|
| 241 |
+
edges_f = dists_f[:, mask]
|
| 242 |
+
|
| 243 |
+
all_edges = torch.stack([edges_c, edges_m, edges_f], dim=0)
|
| 244 |
+
|
| 245 |
+
edge_var = torch.var(all_edges, dim=2).mean()
|
| 246 |
+
min_edges = torch.min(all_edges, dim=2)[0]
|
| 247 |
+
collapse_penalty = torch.relu(0.5 - min_edges).mean()
|
| 248 |
+
|
| 249 |
+
return edge_var + collapse_penalty
|
| 250 |
+
|
| 251 |
+
def contrastive_pentachoron_loss_batched(latents, targets, pentachora, temp=0.1):
|
| 252 |
+
batch_size = latents.size(0)
|
| 253 |
+
num_classes = pentachora.size(0)
|
| 254 |
+
|
| 255 |
+
diffs = latents.unsqueeze(1).unsqueeze(2) - pentachora.unsqueeze(0)
|
| 256 |
+
dists = torch.norm(diffs, dim=-1)
|
| 257 |
+
min_dists, _ = torch.min(dists, dim=2)
|
| 258 |
+
|
| 259 |
+
sims = -min_dists / temp
|
| 260 |
+
targets_one_hot = F.one_hot(targets, num_classes).float()
|
| 261 |
+
|
| 262 |
+
max_sims, _ = torch.max(sims, dim=1, keepdim=True)
|
| 263 |
+
exp_sims = torch.exp(sims - max_sims)
|
| 264 |
+
|
| 265 |
+
pos_sims = torch.sum(exp_sims * targets_one_hot, dim=1)
|
| 266 |
+
all_sims = torch.sum(exp_sims, dim=1)
|
| 267 |
+
|
| 268 |
+
loss = -torch.log(pos_sims / all_sims).mean()
|
| 269 |
+
return loss
|
| 270 |
+
|
| 271 |
+
# ============== TRAINING SETUP ==============
|
| 272 |
+
encoder = AdaptiveEncoder(config["input_dim"], config["base_dim"]).to(device)
|
| 273 |
+
classifier = PerfectPentachoron(num_classes, config["base_dim"], device).to(device)
|
| 274 |
+
|
| 275 |
+
# Try to compile if available
|
| 276 |
+
try:
|
| 277 |
+
encoder = torch.compile(encoder)
|
| 278 |
+
classifier = torch.compile(classifier)
|
| 279 |
+
print("Models compiled successfully")
|
| 280 |
+
except:
|
| 281 |
+
print("Torch compile not available, using eager mode")
|
| 282 |
+
|
| 283 |
+
optimizer = torch.optim.AdamW([
|
| 284 |
+
{'params': encoder.parameters(), 'lr': config["initial_lr"]},
|
| 285 |
+
{'params': classifier.parameters(), 'lr': config["initial_lr"] * 0.5}
|
| 286 |
+
], weight_decay=1e-5)
|
| 287 |
+
|
| 288 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config["epochs"])
|
| 289 |
+
|
| 290 |
+
# ============== CHECKPOINT FUNCTIONS ==============
|
| 291 |
+
def save_checkpoint(epoch, encoder, classifier, optimizer, scheduler, metrics, is_best=False):
|
| 292 |
+
"""Save checkpoint as safetensors with proper organization"""
|
| 293 |
+
# Prepare state dict for safetensors
|
| 294 |
+
encoder_state = {f"encoder.{k}": v.cpu() for k, v in encoder.state_dict().items()}
|
| 295 |
+
classifier_state = {f"classifier.{k}": v.cpu() for k, v in classifier.state_dict().items()}
|
| 296 |
+
|
| 297 |
+
# Combine all model weights
|
| 298 |
+
model_state = {**encoder_state, **classifier_state}
|
| 299 |
+
|
| 300 |
+
# Save model weights as safetensors
|
| 301 |
+
checkpoint_name = f"checkpoint_{run_hash}_epoch_{epoch:03d}.safetensors"
|
| 302 |
+
if is_best:
|
| 303 |
+
checkpoint_name = f"best_{run_hash}.safetensors"
|
| 304 |
+
|
| 305 |
+
checkpoint_path = os.path.join("checkpoints", checkpoint_name)
|
| 306 |
+
save_file(model_state, checkpoint_path)
|
| 307 |
+
|
| 308 |
+
# Save training state separately (optimizer, scheduler, metrics)
|
| 309 |
+
training_state = {
|
| 310 |
+
'epoch': epoch,
|
| 311 |
+
'optimizer': optimizer.state_dict(),
|
| 312 |
+
'scheduler': scheduler.state_dict(),
|
| 313 |
+
'metrics': metrics,
|
| 314 |
+
'config': config
|
| 315 |
+
}
|
| 316 |
+
|
| 317 |
+
state_path = checkpoint_path.replace('.safetensors', '_state.pt')
|
| 318 |
+
torch.save(training_state, state_path)
|
| 319 |
+
|
| 320 |
+
print(f"Saved checkpoint: {checkpoint_name}")
|
| 321 |
+
|
| 322 |
+
# Upload to HuggingFace
|
| 323 |
+
try:
|
| 324 |
+
# Create organized structure
|
| 325 |
+
upload_folder(
|
| 326 |
+
folder_path="checkpoints",
|
| 327 |
+
repo_id=REPO_ID,
|
| 328 |
+
repo_type="model",
|
| 329 |
+
token=HF_TOKEN,
|
| 330 |
+
path_in_repo=f"weights/{run_hash}",
|
| 331 |
+
commit_message=f"Epoch {epoch} - Test Acc: {metrics['test_acc']:.4f}"
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
# Upload tensorboard logs
|
| 335 |
+
upload_folder(
|
| 336 |
+
folder_path=f"tensorboard_logs/{run_hash}",
|
| 337 |
+
repo_id=REPO_ID,
|
| 338 |
+
repo_type="model",
|
| 339 |
+
token=HF_TOKEN,
|
| 340 |
+
path_in_repo=f"runs/{run_hash}",
|
| 341 |
+
commit_message=f"TensorBoard logs - Epoch {epoch}"
|
| 342 |
+
)
|
| 343 |
+
except Exception as e:
|
| 344 |
+
print(f"HF upload error: {e}")
|
| 345 |
+
|
| 346 |
+
# ============== TRAINING FUNCTIONS ==============
|
| 347 |
+
def train_epoch(epoch):
|
| 348 |
+
encoder.train()
|
| 349 |
+
classifier.train()
|
| 350 |
+
|
| 351 |
+
total_loss = 0.0
|
| 352 |
+
total_ce = 0.0
|
| 353 |
+
total_contr = 0.0
|
| 354 |
+
total_reg = 0.0
|
| 355 |
+
correct = 0
|
| 356 |
+
total = 0
|
| 357 |
+
|
| 358 |
+
pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']}")
|
| 359 |
+
for batch_idx, (inputs, targets) in enumerate(pbar):
|
| 360 |
+
inputs, targets = inputs.to(device), targets.to(device)
|
| 361 |
+
|
| 362 |
+
optimizer.zero_grad()
|
| 363 |
+
|
| 364 |
+
x_coarse, x_medium, x_fine = encoder(inputs)
|
| 365 |
+
scores, all_dists = classifier(x_coarse, x_medium, x_fine)
|
| 366 |
+
|
| 367 |
+
ce_loss = F.cross_entropy(scores, targets)
|
| 368 |
+
|
| 369 |
+
contr_c = contrastive_pentachoron_loss_batched(x_coarse, targets, classifier.penta_coarse, config["temp_contrastive"])
|
| 370 |
+
contr_m = contrastive_pentachoron_loss_batched(x_medium, targets, classifier.penta_medium, config["temp_contrastive"])
|
| 371 |
+
contr_f = contrastive_pentachoron_loss_batched(x_fine, targets, classifier.penta_fine, config["temp_contrastive"])
|
| 372 |
+
contr_loss = (contr_c + contr_m + contr_f) / 3
|
| 373 |
+
|
| 374 |
+
reg_loss = classifier.regularization_loss()
|
| 375 |
+
|
| 376 |
+
loss = ce_loss + config["lambda_contrastive"] * contr_loss + config["lambda_cayley"] * reg_loss
|
| 377 |
+
|
| 378 |
+
loss.backward()
|
| 379 |
+
torch.nn.utils.clip_grad_norm_(encoder.parameters(), 1.0)
|
| 380 |
+
torch.nn.utils.clip_grad_norm_(classifier.parameters(), 1.0)
|
| 381 |
+
optimizer.step()
|
| 382 |
+
|
| 383 |
+
total_loss += loss.item() * inputs.size(0)
|
| 384 |
+
total_ce += ce_loss.item() * inputs.size(0)
|
| 385 |
+
total_contr += contr_loss.item() * inputs.size(0)
|
| 386 |
+
total_reg += reg_loss.item() * inputs.size(0)
|
| 387 |
+
|
| 388 |
+
preds = scores.argmax(dim=1)
|
| 389 |
+
correct += (preds == targets).sum().item()
|
| 390 |
+
total += inputs.size(0)
|
| 391 |
+
|
| 392 |
+
# Log batch metrics to TensorBoard
|
| 393 |
+
if batch_idx % 50 == 0:
|
| 394 |
+
global_step = epoch * len(train_loader) + batch_idx
|
| 395 |
+
writer.add_scalar('Train/BatchLoss', loss.item(), global_step)
|
| 396 |
+
writer.add_scalar('Train/BatchAcc', correct/total, global_step)
|
| 397 |
+
|
| 398 |
+
pbar.set_postfix({
|
| 399 |
+
'loss': f"{loss.item():.4f}",
|
| 400 |
+
'acc': f"{correct/total:.4f}",
|
| 401 |
+
'lr': f"{optimizer.param_groups[0]['lr']:.1e}"
|
| 402 |
+
})
|
| 403 |
+
|
| 404 |
+
return (total_loss/total, total_ce/total, total_contr/total,
|
| 405 |
+
total_reg/total, correct/total)
|
| 406 |
+
|
| 407 |
+
@torch.no_grad()
|
| 408 |
+
def evaluate():
|
| 409 |
+
encoder.eval()
|
| 410 |
+
classifier.eval()
|
| 411 |
+
|
| 412 |
+
correct = 0
|
| 413 |
+
total = 0
|
| 414 |
+
class_correct = [0] * num_classes
|
| 415 |
+
class_total = [0] * num_classes
|
| 416 |
+
|
| 417 |
+
pbar = tqdm(test_loader, desc="Evaluating")
|
| 418 |
+
for inputs, targets in pbar:
|
| 419 |
+
inputs, targets = inputs.to(device), targets.to(device)
|
| 420 |
+
|
| 421 |
+
x_coarse, x_medium, x_fine = encoder(inputs)
|
| 422 |
+
scores, _ = classifier(x_coarse, x_medium, x_fine)
|
| 423 |
+
|
| 424 |
+
preds = scores.argmax(dim=1)
|
| 425 |
+
correct += (preds == targets).sum().item()
|
| 426 |
+
total += inputs.size(0)
|
| 427 |
+
|
| 428 |
+
for i in range(targets.size(0)):
|
| 429 |
+
label = targets[i].item()
|
| 430 |
+
class_total[label] += 1
|
| 431 |
+
if preds[i] == targets[i]:
|
| 432 |
+
class_correct[label] += 1
|
| 433 |
+
|
| 434 |
+
pbar.set_postfix({'acc': f"{correct/total:.4f}"})
|
| 435 |
+
|
| 436 |
+
class_accs = [class_correct[i]/max(1, class_total[i]) for i in range(num_classes)]
|
| 437 |
+
return correct/total, class_accs
|
| 438 |
+
|
| 439 |
+
# ============== MAIN TRAINING LOOP ==============
|
| 440 |
+
print("\n" + "="*60)
|
| 441 |
+
print(f"PERFECT PENTACHORON TRAINING - Run {run_hash}")
|
| 442 |
+
print("="*60 + "\n")
|
| 443 |
+
|
| 444 |
+
best_acc = 0.0
|
| 445 |
+
train_history = []
|
| 446 |
+
test_history = []
|
| 447 |
+
patience = 7
|
| 448 |
+
no_improve = 0
|
| 449 |
+
|
| 450 |
+
for epoch in range(config["epochs"]):
|
| 451 |
+
# Train
|
| 452 |
+
train_loss, train_ce, train_contr, train_reg, train_acc = train_epoch(epoch)
|
| 453 |
+
train_history.append(train_acc)
|
| 454 |
+
|
| 455 |
+
# Evaluate
|
| 456 |
+
test_acc, class_accs = evaluate()
|
| 457 |
+
test_history.append(test_acc)
|
| 458 |
+
|
| 459 |
+
# Log to TensorBoard
|
| 460 |
+
writer.add_scalar('Loss/Total', train_loss, epoch)
|
| 461 |
+
writer.add_scalar('Loss/CE', train_ce, epoch)
|
| 462 |
+
writer.add_scalar('Loss/Contrastive', train_contr, epoch)
|
| 463 |
+
writer.add_scalar('Loss/Regularization', train_reg, epoch)
|
| 464 |
+
writer.add_scalar('Accuracy/Train', train_acc, epoch)
|
| 465 |
+
writer.add_scalar('Accuracy/Test', test_acc, epoch)
|
| 466 |
+
writer.add_scalar('Learning/LR', optimizer.param_groups[0]['lr'], epoch)
|
| 467 |
+
writer.add_scalar('Learning/Generalization_Gap', train_acc - test_acc, epoch)
|
| 468 |
+
|
| 469 |
+
# Log per-class accuracies
|
| 470 |
+
for i, acc in enumerate(class_accs[:10]): # Log first 10 classes
|
| 471 |
+
writer.add_scalar(f'ClassAcc/Class_{i}', acc, epoch)
|
| 472 |
+
|
| 473 |
+
# Log scale weights
|
| 474 |
+
scale_weights = F.softmax(classifier.scale_weights, dim=0)
|
| 475 |
+
writer.add_scalar('Scales/Coarse', scale_weights[0], epoch)
|
| 476 |
+
writer.add_scalar('Scales/Medium', scale_weights[1], epoch)
|
| 477 |
+
writer.add_scalar('Scales/Fine', scale_weights[2], epoch)
|
| 478 |
+
|
| 479 |
+
scheduler.step()
|
| 480 |
+
|
| 481 |
+
# Print results
|
| 482 |
+
print(f"\n[Epoch {epoch+1}/{config['epochs']}]")
|
| 483 |
+
print(f"Train | Loss: {train_loss:.4f} | CE: {train_ce:.4f} | "
|
| 484 |
+
f"Contr: {train_contr:.4f} | Reg: {train_reg:.4f} | Acc: {train_acc:.4f}")
|
| 485 |
+
print(f"Test | Acc: {test_acc:.4f} | Best: {best_acc:.4f}")
|
| 486 |
+
|
| 487 |
+
# Save checkpoint
|
| 488 |
+
metrics = {
|
| 489 |
+
'train_acc': train_acc,
|
| 490 |
+
'test_acc': test_acc,
|
| 491 |
+
'train_loss': train_loss,
|
| 492 |
+
'class_accs': class_accs
|
| 493 |
+
}
|
| 494 |
+
|
| 495 |
+
# Check if best
|
| 496 |
+
if test_acc > best_acc:
|
| 497 |
+
best_acc = test_acc
|
| 498 |
+
no_improve = 0
|
| 499 |
+
print(f"NEW BEST! Saving checkpoint...")
|
| 500 |
+
save_checkpoint(epoch, encoder, classifier, optimizer, scheduler, metrics, is_best=True)
|
| 501 |
+
else:
|
| 502 |
+
no_improve += 1
|
| 503 |
+
if (epoch + 1) % 5 == 0: # Save every 5 epochs
|
| 504 |
+
save_checkpoint(epoch, encoder, classifier, optimizer, scheduler, metrics)
|
| 505 |
+
|
| 506 |
+
# Early stopping
|
| 507 |
+
if no_improve >= patience:
|
| 508 |
+
print(f"Early stopping triggered (no improvement for {patience} epochs)")
|
| 509 |
+
break
|
| 510 |
+
|
| 511 |
+
# ============== FINAL RESULTS ==============
|
| 512 |
+
print("\n" + "="*60)
|
| 513 |
+
print("FINAL RESULTS")
|
| 514 |
+
print("="*60)
|
| 515 |
+
print(f"Best Test Accuracy: {best_acc:.4f}")
|
| 516 |
+
print(f"Final Train Accuracy: {train_history[-1]:.4f}")
|
| 517 |
+
print(f"Generalization Gap: {train_history[-1] - test_history[-1]:.4f}")
|
| 518 |
+
|
| 519 |
+
# Save final model
|
| 520 |
+
save_checkpoint(epoch, encoder, classifier, optimizer, scheduler, metrics, is_best=False)
|
| 521 |
+
|
| 522 |
+
# Log final pentachoron geometry
|
| 523 |
+
with torch.no_grad():
|
| 524 |
+
vertex_importance = F.softmax(classifier.vertex_weights, dim=1)
|
| 525 |
+
scale_weights = F.softmax(classifier.scale_weights, dim=0).cpu().numpy()
|
| 526 |
+
|
| 527 |
+
geometry_info = {
|
| 528 |
+
'scale_importance': {
|
| 529 |
+
'coarse': float(scale_weights[0]),
|
| 530 |
+
'medium': float(scale_weights[1]),
|
| 531 |
+
'fine': float(scale_weights[2])
|
| 532 |
+
},
|
| 533 |
+
'dominant_vertices': {}
|
| 534 |
+
}
|
| 535 |
+
|
| 536 |
+
for c in range(min(10, num_classes)):
|
| 537 |
+
weights = vertex_importance[c].cpu().numpy()
|
| 538 |
+
dominant = np.argmax(weights)
|
| 539 |
+
geometry_info['dominant_vertices'][f'class_{c}'] = {
|
| 540 |
+
'vertex': int(dominant),
|
| 541 |
+
'weight': float(weights[dominant])
|
| 542 |
+
}
|
| 543 |
+
|
| 544 |
+
writer.add_text('Final_Geometry', json.dumps(geometry_info, indent=2), epoch)
|
| 545 |
+
|
| 546 |
+
writer.close()
|
| 547 |
+
print(f"\n✨ Training Complete! Run hash: {run_hash}")
|
| 548 |
+
print(f"Results uploaded to: https://huggingface.co/{REPO_ID}")
|
| 549 |
+
print(f"TensorBoard: tensorboard --logdir tensorboard_logs/{run_hash}")
|