File size: 3,077 Bytes
963134f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import pandas as pd
import os

import torch
import torch.nn as nn

from transformers import GPT2TokenizerFast, GPT2LMHeadModel
from transformers import DataCollatorWithPadding, GPT2Config, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments, RobertaTokenizerFast

import datasets
from datasets import disable_caching
disable_caching()
from datasets import IterableDataset

from conditional_gpt2_model import ConditionalGPT2LMHeadModel


ENCODER_MODEL_NAME = "entropy/roberta_zinc_480m"  # encoder model name
TOKENIZER_MAX_LEN = 256                           # max_length param on tokenizer

DATA_SUBSHARDS = 10                               # number of shards to break each data chunk into

DATA_DIR = None                                   # directory with saved data shards
TRAINER_SAVE_DIR = None                           # directory to save model checkpoints

assert DATA_DIR is not None, "data directory must be specified"
assert TRAINER_SAVE_DIR is not None, "trainer save directory must be specified"



def gen_dataset():
    
    data_filenames = sorted([i for i in os.listdir(DATA_DIR) if '.hf' in i])
    
    for filename in data_filenames:
        
        dataset = datasets.Dataset.load_from_disk(f'{DATA_DIR}/{filename}')
        
        keep_cols = ['input_ids', 'encoder_hidden_states']
        
        dataset = dataset.remove_columns([i for i in dataset.column_names 
                                          if not i in keep_cols]).with_format("torch")
        
        # contiguous shards for faster loading
        shards = [dataset.shard(num_shards=DATA_SUBSHARDS, index=index, contiguous=True) 
                  for index in range(DATA_SUBSHARDS)]
        
        for i, shard in enumerate(shards):
            for example in shard:
                # need to add unit axis to hidden states
                example['encoder_hidden_states'] = example['encoder_hidden_states'][None,:]
                yield example

dataset = IterableDataset.from_generator(gen_dataset)
dataset = dataset.with_format("torch")

tokenizer = RobertaTokenizerFast.from_pretrained(ENCODER_MODEL_NAME, max_len=TOKENIZER_MAX_LEN)
collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)


config = GPT2Config(
    vocab_size=len(tokenizer),
    n_positions=TOKENIZER_MAX_LEN,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
    n_layer=6,
    n_head=8,
    add_cross_attention=True,
)

model = ConditionalGPT2LMHeadModel(config)

# change trainer args as needed
args = TrainingArguments(
    output_dir=TRAINER_SAVE_DIR,
    per_device_train_batch_size=192,
    logging_steps=25,
    gradient_accumulation_steps=8,
    num_train_epochs=1,
    weight_decay=0.1,
    warmup_steps=1000,
    lr_scheduler_type="cosine",
    learning_rate=1e-5,
    save_steps=200,
    save_total_limit=30,
    fp16=True,
    push_to_hub=False,
    max_steps=50000,
)


trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    args=args,
    data_collator=collator,
    train_dataset=dataset,
)

trainer.train()