Sin2pi commited on
Commit
803aa74
·
verified ·
1 Parent(s): 6cae1c4

Create FAMOptimizer.py

Browse files
Files changed (1) hide show
  1. FAMOptimizer.py +555 -0
FAMOptimizer.py ADDED
@@ -0,0 +1,555 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import json
5
+ import os
6
+ from datetime import datetime
7
+
8
+ class FrequencyHandler:
9
+ """Base class for parameter-specific frequency analysis functions"""
10
+
11
+ def analyze(self, grad_sample, n_bands, eps=1e-8):
12
+ """Default frequency analysis implementation"""
13
+ freq_repr = torch.fft.rfft(grad_sample.float())
14
+ freq_power = torch.abs(freq_repr)
15
+
16
+ if freq_power.sum() > 0:
17
+ freq_power = freq_power / (freq_power.sum() + eps)
18
+ band_size = freq_power.shape[0] // n_bands
19
+ if band_size <= 0:
20
+ return [0.0] * n_bands
21
+
22
+ band_powers = []
23
+ for i in range(n_bands):
24
+ start_idx = i * band_size
25
+ end_idx = min((i+1) * band_size, freq_power.shape[0])
26
+ if start_idx < end_idx:
27
+ band_power = freq_power[start_idx:end_idx].sum().item()
28
+ band_powers.append(band_power)
29
+ else:
30
+ band_powers.append(0.0)
31
+
32
+ return band_powers
33
+
34
+ def get_adaptive_momentum(self, band_values, base_alpha):
35
+ """Default adaptive momentum calculation"""
36
+ n_bands = len(band_values)
37
+ high_freq_activity = sum(band_values[n_bands//2:])
38
+
39
+ if high_freq_activity > 0.3:
40
+ return min(0.95, base_alpha + 0.05)
41
+ return base_alpha
42
+
43
+ class ConvFrequencyHandler(FrequencyHandler):
44
+ """Specialized handler for convolutional layers"""
45
+
46
+ def analyze(self, grad_sample, n_bands, eps=1e-8):
47
+ freq_repr = torch.fft.rfft(grad_sample.float())
48
+ freq_power = torch.abs(freq_repr)
49
+
50
+ if freq_power.sum() > 0:
51
+ freq_power = freq_power / (freq_power.sum() + eps)
52
+ band_powers = []
53
+ total_freqs = freq_power.shape[0]
54
+
55
+ for i in range(n_bands):
56
+ start_idx = int((total_freqs ** (i/n_bands)) - 1)
57
+ end_idx = int((total_freqs ** ((i+1)/n_bands)) - 1)
58
+ start_idx = max(0, start_idx)
59
+ end_idx = min(end_idx, total_freqs)
60
+
61
+ if start_idx < end_idx:
62
+ band_power = freq_power[start_idx:end_idx].sum().item()
63
+ band_powers.append(band_power)
64
+ else:
65
+ band_powers.append(0.0)
66
+
67
+ return band_powers
68
+
69
+ def get_adaptive_momentum(self, band_values, base_alpha):
70
+ """Convolutional layers benefit from more smoothing in mid-frequencies"""
71
+ n_bands = len(band_values)
72
+ mid_freq_activity = sum(band_values[n_bands//4:(3*n_bands)//4])
73
+ high_freq_activity = sum(band_values[(3*n_bands)//4:])
74
+ if mid_freq_activity > 0.4:
75
+ return min(0.97, base_alpha + 0.07)
76
+ elif high_freq_activity > 0.3:
77
+ return min(0.95, base_alpha + 0.05)
78
+ return base_alpha
79
+
80
+ class AttentionFrequencyHandler(FrequencyHandler):
81
+ """Specialized handler for attention layers"""
82
+
83
+ def analyze(self, grad_sample, n_bands, eps=1e-8):
84
+ freq_repr = torch.fft.rfft(grad_sample.float())
85
+ freq_power = torch.abs(freq_repr)
86
+
87
+ if freq_power.sum() > 0:
88
+ freq_power = freq_power / (freq_power.sum() + eps)
89
+ band_powers = []
90
+ half_bands = n_bands // 2
91
+ low_band_size = (freq_power.shape[0] // 2) // half_bands
92
+ for i in range(half_bands):
93
+ start_idx = i * low_band_size
94
+ end_idx = min((i+1) * low_band_size, freq_power.shape[0] // 2)
95
+ if start_idx < end_idx:
96
+ band_power = freq_power[start_idx:end_idx].sum().item()
97
+ band_powers.append(band_power)
98
+ else:
99
+ band_powers.append(0.0)
100
+ high_band_size = (freq_power.shape[0] - (freq_power.shape[0] // 2)) // (n_bands - half_bands)
101
+ for i in range(half_bands, n_bands):
102
+ start_idx = (freq_power.shape[0] // 2) + (i - half_bands) * high_band_size
103
+ end_idx = min((freq_power.shape[0] // 2) + (i - half_bands + 1) * high_band_size, freq_power.shape[0])
104
+ if start_idx < end_idx:
105
+ band_power = freq_power[start_idx:end_idx].sum().item()
106
+ band_powers.append(band_power)
107
+ else:
108
+ band_powers.append(0.0)
109
+
110
+ return band_powers
111
+
112
+ def get_adaptive_momentum(self, band_values, base_alpha):
113
+ """Custom adaptive momentum for attention layers"""
114
+ n_bands = len(band_values)
115
+ max_band_idx = np.argmax(band_values)
116
+ if max_band_idx < n_bands // 4:
117
+ return max(0.85, base_alpha - 0.05)
118
+ elif max_band_idx > 3*n_bands // 4:
119
+ return min(0.98, base_alpha + 0.08)
120
+ return base_alpha
121
+
122
+ class EmbeddingFrequencyHandler(FrequencyHandler):
123
+ """Specialized handler for embedding layers"""
124
+
125
+ def get_adaptive_momentum(self, band_values, base_alpha):
126
+ """Embeddings often benefit from very stable updates"""
127
+ n_bands = len(band_values)
128
+ high_freq_activity = sum(band_values[(3*n_bands)//4:])
129
+ if high_freq_activity > 0.2:
130
+ return min(0.98, base_alpha + 0.08)
131
+ return base_alpha
132
+
133
+ class FAMOptimizer(torch.optim.Optimizer):
134
+ """
135
+ Frequency-Adaptive Momentum optimizer with parameter-specific handlers.
136
+
137
+ Args:
138
+ ... (existing parameters)
139
+ debug (bool, optional): Whether to collect debug information (default: False)
140
+ debug_dir (str, optional): Directory to save debug info (default: './fam_debug')
141
+ debug_interval (int, optional): Steps between debug dumps (default: 1000)
142
+ """
143
+ def __init__(self, params, lr=1e-3, alpha=0.9, beta=0.99, eps=1e-8,
144
+ weight_decay=0.0, n_bands=8, fam_start_step=100,
145
+ layer_boost=True, min_size=256, debug=False,
146
+ debug_dir='./fam_debug', debug_interval=1000):
147
+ defaults = dict(lr=lr, alpha=alpha, beta=beta, eps=eps,
148
+ weight_decay=weight_decay, n_bands=n_bands,
149
+ fam_start_step=fam_start_step,
150
+ layer_boost=layer_boost, min_size=min_size)
151
+ self.debug = debug
152
+ self.debug_info = {} if debug else None
153
+ self.debug_dir = debug_dir
154
+ self.debug_interval = debug_interval
155
+ self.last_dump_step = 0
156
+
157
+ if debug and debug_dir:
158
+ os.makedirs(debug_dir, exist_ok=True)
159
+ self.debug_file = os.path.join(
160
+ debug_dir,
161
+ f"fam_debug_{datetime.now().strftime('%m%d_%H%M%S')}.json"
162
+ )
163
+ with open(self.debug_file, 'w') as f:
164
+ json.dump({
165
+ "optimizer": "FAMOptimizer",
166
+ "settings": {
167
+ "lr": lr,
168
+ "alpha": alpha,
169
+ "beta": beta,
170
+ "n_bands": n_bands,
171
+ "fam_start_step": fam_start_step,
172
+ },
173
+ "parameters": {},
174
+ "steps_recorded": []
175
+ }, f, indent=2)
176
+ self.handlers = {
177
+ "default": FrequencyHandler(),
178
+ "conv": ConvFrequencyHandler(),
179
+ "attention": AttentionFrequencyHandler(),
180
+ "embedding": EmbeddingFrequencyHandler()
181
+ }
182
+ param_groups = self._add_handlers_to_groups(params)
183
+ super(FAMOptimizer, self).__init__(params=param_groups, defaults=defaults)
184
+ def _add_handlers_to_groups(self, params):
185
+ """Add appropriate handlers to parameter groups based on type"""
186
+ if isinstance(params, list) and all(isinstance(pg, dict) for pg in params):
187
+ for pg in params:
188
+ if 'handler' not in pg:
189
+ if any('conv' in name.lower() for name in pg.get('names', [])):
190
+ pg['handler'] = 'conv'
191
+ elif any(name in name.lower() for name in pg.get('names', [])
192
+ for name in ['attention', 'mha', 'self_attn']):
193
+ pg['handler'] = 'attention'
194
+ elif any(name in name.lower() for name in pg.get('names', [])
195
+ for name in ['embed', 'token']):
196
+ pg['handler'] = 'embedding'
197
+ else:
198
+ pg['handler'] = 'default'
199
+ return params
200
+ else:
201
+ return [{'params': params, 'handler': 'default'}]
202
+
203
+ def get_handler(self, group):
204
+ """Get the appropriate frequency handler for the parameter group"""
205
+ handler_name = group.get('handler', 'default')
206
+ return self.handlers[handler_name]
207
+
208
+ def dump_debug_info(self, force=False):
209
+ """Save the current debug information to file"""
210
+ if not self.debug or not hasattr(self, 'debug_file'):
211
+ return
212
+ current_step = max([self.state[p]['step'] for p in self.state], default=0)
213
+ if force or (current_step - self.last_dump_step >= self.debug_interval):
214
+ try:
215
+ with open(self.debug_file, 'r') as f:
216
+ debug_data = json.load(f)
217
+ debug_data["steps_recorded"].append(current_step)
218
+
219
+ for param_name, param_info in self.debug_info.items():
220
+ if param_name not in debug_data["parameters"]:
221
+ debug_data["parameters"][param_name] = {
222
+ "handler": param_info.get('handler', 'default'),
223
+ "steps": [],
224
+ "bands": [],
225
+ "alpha": []
226
+ }
227
+ last_recorded = len(debug_data["parameters"][param_name]["steps"])
228
+ if last_recorded < len(param_info['steps']):
229
+ debug_data["parameters"][param_name]["steps"].extend(param_info['steps'][last_recorded:])
230
+ debug_data["parameters"][param_name]["bands"].extend(param_info['bands'][last_recorded:])
231
+ debug_data["parameters"][param_name]["alpha"].extend(param_info['alpha'][last_recorded:])
232
+ with open(self.debug_file, 'w') as f:
233
+ json.dump(debug_data, f)
234
+
235
+ self.last_dump_step = current_step
236
+ for param_info in self.debug_info.values():
237
+ param_info['steps'] = param_info['steps'][-10:]
238
+ param_info['bands'] = param_info['bands'][-10:]
239
+ param_info['alpha'] = param_info['alpha'][-10:]
240
+
241
+ except Exception as e:
242
+ print(f"Error dumping FAM debug info: {e}")
243
+
244
+ @torch.no_grad()
245
+ def step(self, closure=None):
246
+ """Perform a single optimization step."""
247
+ loss = None
248
+ if closure is not None:
249
+ with torch.enable_grad():
250
+ loss = closure()
251
+
252
+ for group in self.param_groups:
253
+ for p_idx, p in enumerate(group['params']):
254
+ if p.grad is None:
255
+ continue
256
+
257
+ grad = p.grad
258
+ if grad.is_sparse:
259
+ raise RuntimeError('FAMOptimizer does not support sparse gradients')
260
+
261
+ state = self.state[p]
262
+
263
+ if len(state) == 0:
264
+ state['step'] = 0
265
+ state['exp_avg'] = torch.zeros_like(p)
266
+ state['freq_history'] = {}
267
+ state['param_name'] = f"param_{p_idx}"
268
+
269
+ state['step'] += 1
270
+
271
+ if group['weight_decay'] != 0:
272
+ grad = grad.add(p, alpha=group['weight_decay'])
273
+
274
+ exp_avg = state['exp_avg']
275
+ alpha = group['alpha']
276
+ beta = group['beta']
277
+ lr = group['lr']
278
+ n_bands = group['n_bands']
279
+ handler = self.get_handler(group)
280
+
281
+ should_apply_fam = (
282
+ state['step'] > group['fam_start_step'] and
283
+ p.numel() > group['min_size']
284
+ )
285
+
286
+ if should_apply_fam:
287
+ try:
288
+ if p.numel() > 10000:
289
+ if p.dim() > 1:
290
+ row_indices = torch.randperm(p.size(0))[:min(p.size(0), 64)]
291
+ col_indices = torch.randperm(p.size(1))[:min(p.size(1), 64)]
292
+ grad_sample = grad[row_indices][:, col_indices].flatten()
293
+ else:
294
+ sample_idx = torch.randperm(p.numel())[:1000]
295
+ grad_sample = grad.flatten()[sample_idx]
296
+ else:
297
+ grad_sample = grad.flatten()
298
+ band_powers = handler.analyze(grad_sample, n_bands, group['eps'])
299
+ if state['step'] <= 10 and p_idx == 0:
300
+ print(f"Step {state['step']}: Found {len(band_powers)} frequency bands")
301
+ print(f"Band powers: {[f'{v:.4f}' for v in band_powers]}")
302
+ for i, power in enumerate(band_powers):
303
+ band_key = f'band_{i}'
304
+ if band_key not in state['freq_history']:
305
+ state['freq_history'][band_key] = power
306
+ else:
307
+ state['freq_history'][band_key] = (
308
+ beta * state['freq_history'][band_key] +
309
+ (1-beta) * power
310
+ )
311
+ band_values = [state['freq_history'].get(f'band_{i}', 0)
312
+ for i in range(n_bands)]
313
+ effective_alpha = handler.get_adaptive_momentum(band_values, alpha)
314
+
315
+ if self.debug:
316
+ param_name = state['param_name']
317
+ if param_name not in self.debug_info:
318
+ self.debug_info[param_name] = {
319
+ 'steps': [],
320
+ 'bands': [],
321
+ 'handler': group.get('handler', 'default'),
322
+ 'alpha': []
323
+ }
324
+
325
+ if state['step'] % 10 == 0:
326
+ self.debug_info[param_name]['steps'].append(state['step'])
327
+ self.debug_info[param_name]['bands'].append(band_values)
328
+ self.debug_info[param_name]['alpha'].append(effective_alpha)
329
+ exp_avg.mul_(effective_alpha).add_(grad, alpha=1-effective_alpha)
330
+ except Exception as e:
331
+ import traceback
332
+ print(f"Error in FAM processing for parameter {p_idx}:")
333
+ print(f"Error type: {type(e).__name__}")
334
+ print(f"Error message: {e}")
335
+ print(f"Parameter shape: {p.shape}, numel: {p.numel()}")
336
+ print(traceback.format_exc())
337
+ exp_avg.mul_(alpha).add_(grad, alpha=1-alpha)
338
+ else:
339
+ exp_avg.mul_(alpha).add_(grad, alpha=1-alpha)
340
+ p.add_(exp_avg, alpha=-lr)
341
+
342
+ if self.debug:
343
+ self.dump_debug_info()
344
+
345
+ return loss
346
+
347
+ def __del__(self):
348
+ """Clean up and final debug dump when optimizer is destroyed"""
349
+ if self.debug:
350
+ self.dump_debug_info(force=True)
351
+
352
+ def get_parameter_groups(model, lr=1e-3, weight_decay=0.0):
353
+ """
354
+ Create parameter groups for FAMOptimizer with appropriate handlers based on layer type
355
+ """
356
+ param_groups = []
357
+ conv_params = []
358
+ conv_names = []
359
+
360
+ attn_params = []
361
+ attn_names = []
362
+
363
+ embed_params = []
364
+ embed_names = []
365
+
366
+ norm_params = []
367
+ norm_names = []
368
+
369
+ other_params = []
370
+ other_names = []
371
+ for name, param in model.named_parameters():
372
+ if not param.requires_grad:
373
+ continue
374
+
375
+ if any(x in name.lower() for x in ['conv', 'cnn']):
376
+ conv_params.append(param)
377
+ conv_names.append(name)
378
+ elif any(x in name.lower() for x in ['attention', 'mha', 'self_attn']):
379
+ attn_params.append(param)
380
+ attn_names.append(name)
381
+ elif any(x in name.lower() for x in ['embed', 'token']):
382
+ embed_params.append(param)
383
+ embed_names.append(name)
384
+ elif any(x in name.lower() for x in ['norm', 'batch', 'layer']):
385
+ norm_params.append(param)
386
+ norm_names.append(name)
387
+ else:
388
+ other_params.append(param)
389
+ other_names.append(name)
390
+ if conv_params:
391
+ param_groups.append({
392
+ 'params': conv_params,
393
+ 'names': conv_names,
394
+ 'lr': lr,
395
+ 'weight_decay': weight_decay,
396
+ 'alpha': 0.9,
397
+ 'handler': 'conv',
398
+ 'n_bands': 10
399
+ })
400
+
401
+ if attn_params:
402
+ param_groups.append({
403
+ 'params': attn_params,
404
+ 'names': attn_names,
405
+ 'lr': lr,
406
+ 'weight_decay': weight_decay,
407
+ 'alpha': 0.92,
408
+ 'handler': 'attention',
409
+ 'n_bands': 12
410
+ })
411
+
412
+ if embed_params:
413
+ param_groups.append({
414
+ 'params': embed_params,
415
+ 'names': embed_names,
416
+ 'lr': lr * 0.8,
417
+ 'weight_decay': weight_decay * 1.5,
418
+ 'alpha': 0.95,
419
+ 'handler': 'embedding',
420
+ 'n_bands': 8
421
+ })
422
+
423
+ if norm_params:
424
+ param_groups.append({
425
+ 'params': norm_params,
426
+ 'names': norm_names,
427
+ 'lr': lr,
428
+ 'weight_decay': 0.0,
429
+ 'alpha': 0.9,
430
+ 'handler': 'default',
431
+ 'n_bands': 4
432
+ })
433
+
434
+ if other_params:
435
+ param_groups.append({
436
+ 'params': other_params,
437
+ 'names': other_names,
438
+ 'lr': lr,
439
+ 'weight_decay': weight_decay,
440
+ 'alpha': 0.9,
441
+ 'handler': 'default',
442
+ 'n_bands': 8
443
+ })
444
+
445
+ return param_groups
446
+
447
+ import torch
448
+ from torch.optim.lr_scheduler import _LRScheduler
449
+ import math
450
+
451
+ class FAMSchedulerb(_LRScheduler):
452
+ """
453
+ Scheduler with linear warmup followed by cosine annealing.
454
+
455
+ Args:
456
+ optimizer: Wrapped optimizer
457
+ warmup_epochs: Number of epochs for the linear warmup
458
+ max_epochs: Total number of epochs
459
+ warmup_start_lr: Initial learning rate for warmup
460
+ eta_min: Minimum learning rate after cosine annealing
461
+ """
462
+ def __init__(self, optimizer, warmup_epochs, max_epochs, warmup_start_lr=1e-8, eta_min=1e-8, last_epoch=-1):
463
+ self.warmup_epochs = warmup_epochs
464
+ self.max_epochs = max_epochs
465
+ self.warmup_start_lr = warmup_start_lr
466
+ self.eta_min = eta_min
467
+ super(FAMScheduler, self).__init__(optimizer, last_epoch)
468
+
469
+ def get_lr(self):
470
+ if self.last_epoch < self.warmup_epochs:
471
+ alpha = self.last_epoch / self.warmup_epochs
472
+ return [self.warmup_start_lr + (base_lr - self.warmup_start_lr) * alpha for base_lr in self.base_lrs]
473
+ else:
474
+ return [self.eta_min + (base_lr - self.eta_min) *
475
+ (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) /
476
+ (self.max_epochs - self.warmup_epochs))) / 2
477
+ for base_lr in self.base_lrs]
478
+ import torch
479
+ import math
480
+
481
+ class SimpleFAM(torch.optim.Optimizer):
482
+ """
483
+ Simplified Frequency-Adaptive Momentum optimizer
484
+
485
+ A lightweight implementation that focuses on the core concepts
486
+ without complex debugging or parameter-specific handlers.
487
+ """
488
+ def __init__(self, params, lr=0.001, alpha=0.9, beta=0.99):
489
+ defaults = dict(lr=lr, alpha=alpha, beta=beta)
490
+ super(SimpleFAM, self).__init__(params, defaults)
491
+ print(f"SimpleFAM initialized with lr={lr}, alpha={alpha}")
492
+
493
+ @torch.no_grad()
494
+ def step(self, closure=None):
495
+ loss = None
496
+ if closure is not None:
497
+ with torch.enable_grad():
498
+ loss = closure()
499
+
500
+ for group in self.param_groups:
501
+ for p in group['params']:
502
+ if p.grad is None:
503
+ continue
504
+
505
+ state = self.state[p]
506
+ if len(state) == 0:
507
+ state['step'] = 0
508
+ state['exp_avg'] = torch.zeros_like(p)
509
+
510
+ state['step'] += 1
511
+ exp_avg = state['exp_avg']
512
+ alpha = group['alpha']
513
+ if p.numel() > 1000 and state['step'] > 100:
514
+ grad_sample = p.grad.flatten()[:min(1000, p.numel())]
515
+ freq = torch.fft.rfft(grad_sample.float())
516
+ power = torch.abs(freq)
517
+ half = power.shape[0] // 2
518
+ high_ratio = power[half:].sum() / (power.sum() + 1e-8)
519
+ effective_alpha = min(0.98, alpha + 0.05 * high_ratio)
520
+ exp_avg.mul_(effective_alpha).add_(p.grad, alpha=1-effective_alpha)
521
+ else:
522
+ exp_avg.mul_(alpha).add_(p.grad, alpha=1-alpha)
523
+ p.add_(exp_avg, alpha=-group['lr'])
524
+
525
+ return loss
526
+
527
+ class FAMScheduler(torch.optim.lr_scheduler._LRScheduler):
528
+ """
529
+ Step-based learning rate scheduler for FAM optimizer
530
+ with warmup and cosine annealing.
531
+ """
532
+ def __init__(self, optimizer, warmup_steps=1000, total_steps=100000,
533
+ decay_start_step=None, warmup_start_lr=1e-6, eta_min=1e-6,
534
+ last_epoch=-1):
535
+ self.warmup_steps = warmup_steps
536
+ self.total_steps = total_steps
537
+ self.decay_start_step = decay_start_step if decay_start_step is not None else warmup_steps
538
+ self.warmup_start_lr = warmup_start_lr
539
+ self.eta_min = eta_min
540
+ super(FAMScheduler, self).__init__(optimizer, last_epoch)
541
+
542
+ def get_lr(self):
543
+ if self.last_epoch < self.warmup_steps:
544
+ alpha = self.last_epoch / self.warmup_steps
545
+ return [self.warmup_start_lr + (base_lr - self.warmup_start_lr) * alpha
546
+ for base_lr in self.base_lrs]
547
+
548
+ elif self.last_epoch < self.decay_start_step:
549
+ return self.base_lrs
550
+
551
+ else:
552
+ return [self.eta_min + (base_lr - self.eta_min) *
553
+ (1 + math.cos(math.pi * (self.last_epoch - self.decay_start_step) /
554
+ (self.total_steps - self.decay_start_step))) / 2 + 1e-8
555
+ for base_lr in self.base_lrs]