File size: 3,494 Bytes
8920c6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# -*- coding: utf-8 -*-
"""

Created on Sat Dec 21 13:24:21 2024



This script pre-trains the LWM model



@author: salikha4

"""
import torch
import torch.nn as nn
from torch.utils.data import random_split
from input_preprocess import tokenizer, scenarios_list
from utils import create_dataloader, count_parameters
import numpy as np
import lwm_model
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim.lr_scheduler import LambdaLR
from torch.optim import AdamW
from train import train_lwm
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
#%% SETTINGS
EPOCHS = 50
BATCH_SIZE = 128 
VAL_BATCH_SIZE = 64 
WARMUP_EPOCHS = 5
BASE_LR = 5e-4
N_ROWS = 4
N_COLUMNS = 4
ELEMENT_LENGTH = N_ROWS*N_COLUMNS*2
D_MODEL = 128 
MAX_LEN = 513
N_LAYERS = 12 
WEIGHT_DECAY = 0.05
BETA1 = 0.9
BETA2 = 0.999
MASK_PERCENT = 0.40
N_HEADS = 8
DROPOUT = 0.1
#%% GENERATE DATASET
bs_idxs = [1, 2, 3] 
selected_scenario_names = scenarios_list()[:80] 
preprocessed_data = tokenizer(
    selected_scenario_names, 
    MAX_LEN, 
    masking_percent=MASK_PERCENT, 
    mask=True, 
    seed=42
) 
#%% SPLIT DATASET
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
train_ratio = 0.8
val_ratio = 0.2
train_data = {}
val_data = {}
test_data = {}
for key, samples in preprocessed_data.items():
    print(f"key: {key}")
    total_samples = len(samples)
    train_size = int(train_ratio * total_samples)
    val_size = int(val_ratio * total_samples)
    test_size = total_samples - val_size - train_size
    
    train_data[key], val_data[key], test_data[key] = random_split(
        samples, [train_size, val_size, test_size]
    )
train_loaders = create_dataloader(train_data, batch_size=BATCH_SIZE, shuffle=True)
val_loaders = create_dataloader(val_data, batch_size=VAL_BATCH_SIZE, shuffle=False)
#%% INITIALIZE MODEL
load_model = True
gpu_ids = [0]
device = torch.device("cuda:0")
model = lwm_model.lwm().to(device)

if load_model:
    model_name = "lwm_epoch50_train0.0077_val0.0060_masking0.40.pth"
    state_dict = torch.load(f"models/{model_name}", map_location=device)
    new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
    model.load_state_dict(new_state_dict)
    
model = nn.DataParallel(model, gpu_ids)
print(f"Model loaded successfully on GPU {device.index}")

n_parameters = count_parameters(model)
print(f"Number of trainable parameters: {n_parameters:,}")
#%% OPTIMIZER AND SCHEDULER
BASE_LR = 5e-5 
MIN_LR = 1e-8  
TOTAL_STEPS = sum(len(loader) for loader in train_loaders.values()) * EPOCHS
WARMUP_STEPS = sum(len(loader) for loader in train_loaders.values()) * WARMUP_EPOCHS

optimizer = AdamW(
    model.parameters(),
    lr=BASE_LR,
    betas=(BETA1, BETA2),
    weight_decay=WEIGHT_DECAY
)
def lr_lambda(current_step):
    if current_step < WARMUP_STEPS:
        # Linear warmup
        return current_step / WARMUP_STEPS
    else:
        # Scaled cosine decay
        scaled_progress = (current_step - WARMUP_STEPS) / (TOTAL_STEPS - WARMUP_STEPS)
        cosine_decay = 0.5 * (1 + np.cos(np.pi * scaled_progress))
        return cosine_decay * (BASE_LR - MIN_LR) / BASE_LR + MIN_LR / BASE_LR
    
scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)
#%% PRE-TRAIN THE MODEL
pretrained_model = train_lwm(
    model,
    train_loaders,
    val_loaders,
    optimizer,
    scheduler,
    EPOCHS,
    device=device
)