aayushraina commited on
Commit
4d53bb6
·
verified ·
1 Parent(s): 840b176

Upload train_shakespeare.py

Browse files
Files changed (1) hide show
  1. train_shakespeare.py +0 -9
train_shakespeare.py CHANGED
@@ -6,14 +6,11 @@ from dataclasses import dataclass
6
  import torch
7
  import torch.nn as nn
8
  from torch.nn import functional as F
9
- import wandb
10
 
11
  # Set MPS memory management
12
  os.environ['PYTORCH_MPS_HIGH_WATERMARK_RATIO'] = '0.0'
13
  os.environ['PYTORCH_MPS_LOW_WATERMARK_RATIO'] = '0.5'
14
 
15
- # Initialize wandb
16
- wandb.init(project="shakespeare-gpt", name="gpt2-124M-training")
17
 
18
  class CausalSelfAttention(nn.Module):
19
  def __init__(self, config):
@@ -204,10 +201,6 @@ for iter in range(num_iters):
204
  if iter % eval_interval == 0:
205
  current_loss = loss.item()
206
  print(f'step {iter}, loss: {current_loss:.4f}')
207
- wandb.log({
208
- "iter": iter,
209
- "loss": current_loss
210
- })
211
 
212
  # Save if this is the best model so far
213
  if current_loss < best_loss:
@@ -243,5 +236,3 @@ torch.save({
243
  'loss': loss.item(),
244
  'best_loss': best_loss,
245
  }, final_path)
246
-
247
- wandb.finish()
 
6
  import torch
7
  import torch.nn as nn
8
  from torch.nn import functional as F
 
9
 
10
  # Set MPS memory management
11
  os.environ['PYTORCH_MPS_HIGH_WATERMARK_RATIO'] = '0.0'
12
  os.environ['PYTORCH_MPS_LOW_WATERMARK_RATIO'] = '0.5'
13
 
 
 
14
 
15
  class CausalSelfAttention(nn.Module):
16
  def __init__(self, config):
 
201
  if iter % eval_interval == 0:
202
  current_loss = loss.item()
203
  print(f'step {iter}, loss: {current_loss:.4f}')
 
 
 
 
204
 
205
  # Save if this is the best model so far
206
  if current_loss < best_loss:
 
236
  'loss': loss.item(),
237
  'best_loss': best_loss,
238
  }, final_path)