Spaces:
Sleeping
Sleeping
Upload train_shakespeare.py
Browse files- train_shakespeare.py +0 -9
train_shakespeare.py
CHANGED
@@ -6,14 +6,11 @@ from dataclasses import dataclass
|
|
6 |
import torch
|
7 |
import torch.nn as nn
|
8 |
from torch.nn import functional as F
|
9 |
-
import wandb
|
10 |
|
11 |
# Set MPS memory management
|
12 |
os.environ['PYTORCH_MPS_HIGH_WATERMARK_RATIO'] = '0.0'
|
13 |
os.environ['PYTORCH_MPS_LOW_WATERMARK_RATIO'] = '0.5'
|
14 |
|
15 |
-
# Initialize wandb
|
16 |
-
wandb.init(project="shakespeare-gpt", name="gpt2-124M-training")
|
17 |
|
18 |
class CausalSelfAttention(nn.Module):
|
19 |
def __init__(self, config):
|
@@ -204,10 +201,6 @@ for iter in range(num_iters):
|
|
204 |
if iter % eval_interval == 0:
|
205 |
current_loss = loss.item()
|
206 |
print(f'step {iter}, loss: {current_loss:.4f}')
|
207 |
-
wandb.log({
|
208 |
-
"iter": iter,
|
209 |
-
"loss": current_loss
|
210 |
-
})
|
211 |
|
212 |
# Save if this is the best model so far
|
213 |
if current_loss < best_loss:
|
@@ -243,5 +236,3 @@ torch.save({
|
|
243 |
'loss': loss.item(),
|
244 |
'best_loss': best_loss,
|
245 |
}, final_path)
|
246 |
-
|
247 |
-
wandb.finish()
|
|
|
6 |
import torch
|
7 |
import torch.nn as nn
|
8 |
from torch.nn import functional as F
|
|
|
9 |
|
10 |
# Set MPS memory management
|
11 |
os.environ['PYTORCH_MPS_HIGH_WATERMARK_RATIO'] = '0.0'
|
12 |
os.environ['PYTORCH_MPS_LOW_WATERMARK_RATIO'] = '0.5'
|
13 |
|
|
|
|
|
14 |
|
15 |
class CausalSelfAttention(nn.Module):
|
16 |
def __init__(self, config):
|
|
|
201 |
if iter % eval_interval == 0:
|
202 |
current_loss = loss.item()
|
203 |
print(f'step {iter}, loss: {current_loss:.4f}')
|
|
|
|
|
|
|
|
|
204 |
|
205 |
# Save if this is the best model so far
|
206 |
if current_loss < best_loss:
|
|
|
236 |
'loss': loss.item(),
|
237 |
'best_loss': best_loss,
|
238 |
}, final_path)
|
|
|
|