Spaces:
Runtime error
Runtime error
# Copyright 2022 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import argparse | |
import math | |
import os | |
from abc import ABC | |
from functools import partial | |
import torch | |
import torch.nn.functional as F | |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss | |
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP | |
from ..optimizer import AcceleratedOptimizer | |
from ..scheduler import AcceleratedScheduler | |
from .imports import is_megatron_lm_available, is_transformers_available | |
from .operations import recursively_apply, send_to_device | |
if is_transformers_available(): | |
from transformers.modeling_outputs import ( | |
CausalLMOutputWithCrossAttentions, | |
Seq2SeqLMOutput, | |
SequenceClassifierOutput, | |
) | |
if is_megatron_lm_available(): | |
from megatron import ( | |
get_args, | |
get_num_microbatches, | |
get_tensorboard_writer, | |
get_tokenizer, | |
print_rank_last, | |
) | |
from megatron.arguments import ( | |
_add_data_args, | |
_add_validation_args, | |
core_transformer_config_from_args, | |
parse_args, | |
validate_args, | |
) | |
from megatron.checkpointing import load_args_from_checkpoint, load_checkpoint, save_checkpoint | |
from megatron.core import mpu, tensor_parallel | |
from megatron.core.distributed import DistributedDataParallel as LocalDDP | |
from megatron.core.distributed import finalize_model_grads | |
from megatron.core.enums import ModelType | |
from megatron.core.parallel_state import get_tensor_model_parallel_group, get_tensor_model_parallel_src_rank | |
from megatron.core.pipeline_parallel import get_forward_backward_func | |
from megatron.core.utils import get_model_config | |
from megatron.data.dataset_utils import build_train_valid_test_datasets | |
from megatron.global_vars import set_global_variables | |
from megatron.initialize import ( | |
_compile_dependencies, | |
_init_autoresume, | |
_initialize_distributed, | |
_set_random_seed, | |
set_jit_fusion_options, | |
write_args_to_tensorboard, | |
) | |
from megatron.model import BertModel, Float16Module, GPTModel, T5Model | |
from megatron.model.classification import Classification | |
from megatron.optimizer import get_megatron_optimizer | |
from megatron.text_generation.communication import broadcast_int_list, broadcast_tensor | |
from megatron.text_generation.generation import ( | |
beam_search_and_return_on_first_stage, | |
generate_tokens_probs_and_return_on_first_stage, | |
) | |
from megatron.tokenizer.tokenizer import _vocab_size_with_padding | |
from megatron.training import ( | |
build_train_valid_test_data_iterators, | |
get_optimizer_param_scheduler, | |
num_floating_point_operations, | |
setup_model_and_optimizer, | |
train_step, | |
training_log, | |
) | |
from megatron.utils import ( | |
average_losses_across_data_parallel_group, | |
calc_params_l2_norm, | |
get_ltor_masks_and_position_ids, | |
unwrap_model, | |
) | |
# model utilities | |
def model_provider_func(pre_process=True, post_process=True, add_encoder=True, add_decoder=True): | |
"""Build the model.""" | |
args = get_args() | |
mode = "pre-training" if args.pretraining_flag else "fine-tuning" | |
if args.rank == 0: | |
print(f"Building {args.model_type_name} model in the {mode} mode.") | |
print( | |
"The Megatron LM model weights are initialized at random in `accelerator.prepare`. " | |
"Please use `accelerator.load_checkpoint` to load a pre-trained checkpoint matching the distributed setup." | |
) | |
config = core_transformer_config_from_args(args) | |
if args.model_type_name == "bert": | |
if args.pretraining_flag: | |
num_tokentypes = 2 if args.bert_binary_head else 0 | |
model = BertModel( | |
config=config, | |
num_tokentypes=num_tokentypes, | |
add_binary_head=args.bert_binary_head, | |
parallel_output=True, | |
pre_process=pre_process, | |
post_process=post_process, | |
) | |
else: | |
model = Classification( | |
config=config, | |
num_classes=args.num_labels, | |
num_tokentypes=2, | |
pre_process=pre_process, | |
post_process=post_process, | |
) | |
elif args.model_type_name == "gpt": | |
model = GPTModel( | |
config=config, | |
num_tokentypes=0, | |
parallel_output=True, | |
pre_process=pre_process, | |
post_process=post_process, | |
) | |
elif args.model_type_name == "t5": | |
model = T5Model( | |
config=config, | |
num_tokentypes=0, | |
parallel_output=True, | |
pre_process=pre_process, | |
post_process=post_process, | |
add_encoder=add_encoder, | |
add_decoder=add_decoder, | |
) | |
else: | |
raise ValueError(f"Unsupported model type: {args.model_type_name}") | |
return model | |
def prepare_model_optimizer_scheduler(accelerator): | |
accelerator.print("Preparing model optimizer scheduler") | |
args = get_args() | |
if accelerator.state.megatron_lm_plugin.custom_prepare_model_function is not None: | |
if accelerator.state.megatron_lm_plugin.custom_model_provider_function is None: | |
raise ValueError( | |
"You must provide a `custom_model_provider_function` when using a `custom_prepare_model_function`." | |
) | |
custom_model_provider_func = accelerator.state.megatron_lm_plugin.custom_model_provider_function | |
model = accelerator.state.megatron_lm_plugin.custom_prepare_model_function(custom_model_provider_func) | |
optimizer = prepare_optimizer(accelerator, model) | |
scheduler = prepare_scheduler(accelerator, optimizer, scheduler=None) | |
else: | |
model_type = ModelType.encoder_or_decoder | |
if args.model_type_name == "t5": | |
model_type = ModelType.encoder_and_decoder | |
model_provider_func_ = model_provider_func | |
if accelerator.state.megatron_lm_plugin.custom_model_provider_function is not None: | |
model_provider_func_ = accelerator.state.megatron_lm_plugin.custom_model_provider_function | |
(model, optimizer, scheduler) = setup_model_and_optimizer( | |
model_provider_func_, | |
model_type, | |
no_wd_decay_cond=args.no_wd_decay_cond, | |
scale_lr_cond=args.scale_lr_cond, | |
lr_mult=args.lr_mult, | |
) | |
args.model_len = len(model) | |
return model, optimizer, scheduler | |
# dataloader utilities | |
class MegatronLMDummyDataLoader: | |
""" | |
Dummy dataloader presents model parameters or param groups, this is primarily used to follow conventional training | |
Args: | |
**dataset_kwargs: Megatron data arguments. | |
""" | |
def __init__(self, **dataset_kwargs): | |
parser = argparse.ArgumentParser() | |
parser = _add_data_args(parser) | |
parser = _add_validation_args(parser) | |
data_args = parser.parse_known_args() | |
self.dataset_args = vars(data_args[0]) | |
self.dataset_args.update(dataset_kwargs) | |
self.dataset_args["megatron_dataset_flag"] = True | |
def set_megatron_data_args(self): | |
args = get_args() | |
for key, value in self.dataset_args.items(): | |
old_value = getattr(args, key, "") | |
if old_value != value: | |
print( | |
f"WARNING: MegatronLMDummyDataLoader overriding arguments for " | |
f"{key}:{old_value} with {key}:{value}" | |
) | |
setattr(args, key, value) | |
def get_train_valid_test_datasets_provider(self, accelerator): | |
def train_valid_test_datasets_provider(train_val_test_num_samples): | |
"""Build train, valid, and test datasets.""" | |
args = get_args() | |
dataset_args = { | |
"data_prefix": args.data_path if isinstance(args.data_path, (list, tuple)) else [args.data_path], | |
"splits_string": args.split, | |
"train_valid_test_num_samples": train_val_test_num_samples, | |
"seed": args.seed, | |
} | |
if args.model_type_name == "bert": | |
dataset_args.update( | |
{ | |
"max_seq_length": args.seq_length, | |
"binary_head": args.bert_binary_head, | |
} | |
) | |
elif args.model_type_name == "gpt": | |
dataset_args.update( | |
{ | |
"max_seq_length": args.seq_length, | |
} | |
) | |
elif args.model_type_name == "t5": | |
dataset_args.update( | |
{ | |
"max_seq_length": args.encoder_seq_length, | |
"max_seq_length_dec": args.decoder_seq_length, | |
"dataset_type": "t5", | |
} | |
) | |
else: | |
raise ValueError(f"Unsupported model type: {args.model_type_name}") | |
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(**dataset_args) | |
return train_ds, valid_ds, test_ds | |
if accelerator.state.megatron_lm_plugin.custom_megatron_datasets_provider_function is not None: | |
return accelerator.state.megatron_lm_plugin.custom_megatron_datasets_provider_function | |
try: | |
args = get_args() | |
# Use '--no-use-pep517 -e' to pip install nvidia's megatron from source | |
if args.model_type_name == "bert": | |
from pretrain_bert import train_valid_test_datasets_provider | |
train_valid_test_datasets_provider.is_distributed = True | |
return train_valid_test_datasets_provider | |
elif args.model_type_name == "gpt": | |
from pretrain_gpt import train_valid_test_datasets_provider | |
train_valid_test_datasets_provider.is_distributed = True | |
return train_valid_test_datasets_provider | |
elif args.model_type_name == "t5": | |
from pretrain_t5 import train_valid_test_datasets_provider | |
train_valid_test_datasets_provider.is_distributed = True | |
return train_valid_test_datasets_provider | |
except ImportError: | |
pass | |
return train_valid_test_datasets_provider | |
def build_train_valid_test_data_iterators(self, accelerator): | |
args = get_args() | |
train_valid_test_dataset_provider = self.get_train_valid_test_datasets_provider(accelerator) | |
if args.virtual_pipeline_model_parallel_size is not None: | |
train_data_iterator = [] | |
valid_data_iterator = [] | |
test_data_iterator = [] | |
for i in range(getattr(args, "model_len", 0)): | |
mpu.set_virtual_pipeline_model_parallel_rank(i) | |
iterators = build_train_valid_test_data_iterators(train_valid_test_dataset_provider) | |
train_data_iterator.append(iterators[0]) | |
valid_data_iterator.append(iterators[1]) | |
test_data_iterator.append(iterators[2]) | |
else: | |
train_data_iterator, valid_data_iterator, test_data_iterator = build_train_valid_test_data_iterators( | |
train_valid_test_dataset_provider | |
) | |
return train_data_iterator, valid_data_iterator, test_data_iterator | |
def _handle_megatron_data_iterator(accelerator, data_iterator): | |
class DummyMegatronDataloader: | |
def __iter__(self): | |
return self | |
def __next__(self): | |
return {} | |
is_data_iterator_empty = data_iterator is None | |
is_src_data_iterator_empty = torch.tensor(is_data_iterator_empty, dtype=torch.bool, device=accelerator.device) | |
torch.distributed.broadcast( | |
is_src_data_iterator_empty, get_tensor_model_parallel_src_rank(), group=get_tensor_model_parallel_group() | |
) | |
if not is_src_data_iterator_empty and is_data_iterator_empty: | |
return DummyMegatronDataloader() | |
return data_iterator | |
def prepare_data_loader(accelerator, dataloader): | |
accelerator.print("Preparing dataloader") | |
args = get_args() | |
if not args.megatron_dataset_flag: | |
from ..data_loader import _PYTORCH_DATALOADER_KWARGS, prepare_data_loader | |
micro_batch_size = args.micro_batch_size * args.num_micro_batches | |
kwargs = {k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS[k]) for k in _PYTORCH_DATALOADER_KWARGS} | |
if kwargs["batch_size"] is None: | |
if isinstance(kwargs["sampler"], torch.utils.data.BatchSampler): | |
kwargs["sampler"].batch_size = micro_batch_size | |
else: | |
del kwargs["sampler"] | |
del kwargs["shuffle"] | |
del kwargs["batch_size"] | |
kwargs["batch_sampler"].batch_size = micro_batch_size | |
else: | |
del kwargs["batch_sampler"] | |
kwargs["batch_size"] = micro_batch_size | |
dataloader = torch.utils.data.DataLoader(dataloader.dataset, **kwargs) | |
# split_batches: | |
# Megatron only needs to fetch different data between different dp groups, | |
# and does not need to split the data within the dp group. | |
return prepare_data_loader( | |
dataloader, | |
accelerator.device, | |
num_processes=mpu.get_data_parallel_world_size(), | |
process_index=mpu.get_data_parallel_rank(), | |
split_batches=False, | |
put_on_device=True, | |
rng_types=accelerator.rng_types.copy(), | |
dispatch_batches=accelerator.dispatch_batches, | |
) | |
else: | |
if args.consumed_samples is not None: | |
( | |
args.consumed_train_samples, | |
args.consumed_valid_samples, | |
args.consumed_test_samples, | |
) = args.consumed_samples | |
else: | |
args.consumed_train_samples, args.consumed_valid_samples, args.consumed_test_samples = 0, 0, 0 | |
args.micro_batch_size = args.micro_batch_size * args.num_micro_batches | |
# In order to be compatible with data in transform format, | |
# it needs to increase the size of mbs first, | |
# and then split the large batch data into some mbs. | |
( | |
train_data_iterator, | |
valid_data_iterator, | |
test_data_iterator, | |
) = dataloader.build_train_valid_test_data_iterators(accelerator) | |
args.micro_batch_size = args.micro_batch_size // args.num_micro_batches | |
train_data_iterator = _handle_megatron_data_iterator( | |
accelerator=accelerator, data_iterator=train_data_iterator | |
) | |
valid_data_iterator = _handle_megatron_data_iterator( | |
accelerator=accelerator, data_iterator=valid_data_iterator | |
) | |
test_data_iterator = _handle_megatron_data_iterator(accelerator=accelerator, data_iterator=test_data_iterator) | |
return train_data_iterator, valid_data_iterator, test_data_iterator | |
# optimizer utilities | |
class MegatronLMOptimizerWrapper(AcceleratedOptimizer): | |
def __init__(self, optimizer): | |
super().__init__(optimizer, device_placement=False, scaler=None) | |
def zero_grad(self, set_to_none=None): | |
pass # `model(**batch)` is doing that automatically. Therefore, it's implementation is not needed | |
def step(self): | |
pass # `model(**batch)` is doing that automatically. Therefore, it's implementation is not needed | |
def step_was_skipped(self): | |
"""Whether or not the optimizer step was done, or skipped because of gradient overflow.""" | |
return self.optimizer.skipped_iter | |
def prepare_optimizer(accelerator, model): | |
accelerator.print("Preparing optimizer") | |
args = get_args() | |
return get_megatron_optimizer(model, args.no_wd_decay_cond, args.scale_lr_cond, args.lr_mult) | |
# scheduler utilities | |
class MegatronLMDummyScheduler: | |
""" | |
Dummy scheduler presents model parameters or param groups, this is primarily used to follow conventional training | |
loop when scheduler config is specified in the deepspeed config file. | |
Args: | |
optimizer (`torch.optim.optimizer.Optimizer`): | |
The optimizer to wrap. | |
total_num_steps (int): | |
Total number of steps. | |
warmup_num_steps (int): | |
Number of steps for warmup. | |
**kwargs (additional keyword arguments, *optional*): | |
Other arguments. | |
""" | |
def __init__(self, optimizer, total_num_steps=None, warmup_num_steps=0, **kwargs): | |
self.optimizer = optimizer | |
self.total_num_steps = total_num_steps | |
self.warmup_num_steps = warmup_num_steps | |
self.kwargs = kwargs | |
class MegatronLMSchedulerWrapper(AcceleratedScheduler): | |
def __init__(self, scheduler, optimizers): | |
super().__init__(scheduler, optimizers) | |
def step(self, *args, **kwargs): | |
return # `model(**batch)` is doing that automatically. Therefore, it's implementation is not needed | |
def prepare_scheduler(accelerator, optimizer, scheduler): | |
accelerator.print("Preparing scheduler") | |
scheduler = get_optimizer_param_scheduler(optimizer) | |
return scheduler | |
class AbstractTrainStep(ABC): | |
"""Abstract class for batching, forward pass and loss handler.""" | |
def __init__(self, name): | |
super().__init__() | |
self.name = name | |
def get_batch_func(self, accelerator, megatron_dataset_flag): | |
pass | |
def get_forward_step_func(self): | |
pass | |
def get_loss_func(self, accelerator): | |
pass | |
class BertTrainStep(AbstractTrainStep): | |
""" | |
Bert train step class. | |
Args: | |
args (`argparse.Namespace`): Megatron-LM arguments. | |
""" | |
def __init__(self, accelerator, args): | |
super().__init__("BertTrainStep") | |
self.get_batch = self.get_batch_func(accelerator, args.megatron_dataset_flag) | |
self.loss_func = self.get_loss_func(accelerator, args.pretraining_flag, args.num_labels) | |
self.forward_step = self.get_forward_step_func(args.pretraining_flag, args.bert_binary_head) | |
if not args.model_return_dict: | |
self.model_output_class = None | |
else: | |
self.model_output_class = SequenceClassifierOutput | |
def get_batch_func(self, accelerator, megatron_dataset_flag): | |
def get_batch_megatron(data_iterator): | |
"""Build the batch.""" | |
# Items and their type. | |
keys = ["text", "types", "labels", "is_random", "loss_mask", "padding_mask"] | |
datatype = torch.int64 | |
# Broadcast data. | |
if data_iterator is not None: | |
data = next(data_iterator) | |
else: | |
data = None | |
data_b = tensor_parallel.broadcast_data(keys, data, datatype) | |
# Unpack. | |
tokens = data_b["text"].long() | |
types = data_b["types"].long() | |
sentence_order = data_b["is_random"].long() | |
loss_mask = data_b["loss_mask"].float() | |
lm_labels = data_b["labels"].long() | |
padding_mask = data_b["padding_mask"].long() | |
return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask | |
def get_batch_transformer(data_iterator): | |
"""Build the batch.""" | |
data = next(data_iterator) | |
data = send_to_device(data, torch.cuda.current_device()) | |
# Unpack. | |
tokens = data["input_ids"].long() | |
padding_mask = data["attention_mask"].long() | |
if "token_type_ids" in data: | |
types = data["token_type_ids"].long() | |
else: | |
types = None | |
if "labels" in data: | |
lm_labels = data["labels"].long() | |
loss_mask = (data["labels"] != -100).to(torch.float) | |
else: | |
lm_labels = None | |
loss_mask = None | |
if "next_sentence_label" in data: | |
sentence_order = data["next_sentence_label"].long() | |
else: | |
sentence_order = None | |
return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask | |
if accelerator.state.megatron_lm_plugin.custom_get_batch_function is not None: | |
return accelerator.state.megatron_lm_plugin.custom_get_batch_function | |
if megatron_dataset_flag: | |
try: | |
# Use '--no-use-pep517 -e' to pip install nvidia's megatron from source | |
from pretrain_bert import get_batch | |
return get_batch | |
except ImportError: | |
pass | |
return get_batch_megatron | |
else: | |
return get_batch_transformer | |
def get_loss_func(self, accelerator, pretraining_flag, num_labels): | |
def loss_func_pretrain(loss_mask, sentence_order, output_tensor): | |
lm_loss_, sop_logits = output_tensor | |
lm_loss_ = lm_loss_.float() | |
loss_mask = loss_mask.float() | |
lm_loss = torch.sum(lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() | |
if sop_logits is not None: | |
sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(), sentence_order.view(-1), ignore_index=-1) | |
sop_loss = sop_loss.float() | |
loss = lm_loss + sop_loss | |
averaged_losses = average_losses_across_data_parallel_group([lm_loss, sop_loss]) | |
return loss, {"lm loss": averaged_losses[0], "sop loss": averaged_losses[1]} | |
else: | |
loss = lm_loss | |
averaged_losses = average_losses_across_data_parallel_group([lm_loss]) | |
return loss, {"lm loss": averaged_losses[0]} | |
def loss_func_finetune(labels, logits): | |
if num_labels == 1: | |
# We are doing regression | |
loss_fct = MSELoss() | |
loss = loss_fct(logits.view(-1), labels.view(-1)) | |
elif self.num_labels > 1 and (labels.dtype in (torch.long, torch.int)): | |
loss_fct = CrossEntropyLoss() | |
loss = loss_fct(logits.view(-1, num_labels), labels.view(-1)) | |
else: | |
loss_fct = BCEWithLogitsLoss() | |
loss = loss_fct(logits, labels) | |
averaged_losses = average_losses_across_data_parallel_group([loss]) | |
return loss, {"loss": averaged_losses[0]} | |
if accelerator.state.megatron_lm_plugin.custom_loss_function is not None: | |
return accelerator.state.megatron_lm_plugin.custom_loss_function | |
if pretraining_flag: | |
return loss_func_pretrain | |
else: | |
return loss_func_finetune | |
def get_forward_step_func(self, pretraining_flag, bert_binary_head): | |
def forward_step(data_iterator, model): | |
"""Forward step.""" | |
tokens, types, sentence_order, loss_mask, labels, padding_mask = self.get_batch(data_iterator) | |
if not bert_binary_head: | |
types = None | |
# Forward pass through the model. | |
if pretraining_flag: | |
output_tensor = model(tokens, padding_mask, tokentype_ids=types, lm_labels=labels) | |
return output_tensor, partial(self.loss_func, loss_mask, sentence_order) | |
else: | |
logits = model(tokens, padding_mask, tokentype_ids=types) | |
return logits, partial(self.loss_func, labels) | |
return forward_step | |
class GPTTrainStep(AbstractTrainStep): | |
""" | |
GPT train step class. | |
Args: | |
args (`argparse.Namespace`): Megatron-LM arguments. | |
""" | |
def __init__(self, accelerator, args): | |
super().__init__("GPTTrainStep") | |
self.get_batch = self.get_batch_func(accelerator, args.megatron_dataset_flag) | |
self.loss_func = self.get_loss_func(accelerator) | |
self.forward_step = self.get_forward_step_func() | |
self.eod_token = args.padded_vocab_size - 1 | |
if args.vocab_file is not None: | |
tokenizer = get_tokenizer() | |
self.eod_token = tokenizer.eod | |
self.reset_position_ids = args.reset_position_ids | |
self.reset_attention_mask = args.reset_attention_mask | |
self.eod_mask_loss = args.eod_mask_loss | |
if not args.model_return_dict: | |
self.model_output_class = None | |
else: | |
self.model_output_class = CausalLMOutputWithCrossAttentions | |
def get_batch_func(self, accelerator, megatron_dataset_flag): | |
def get_batch_megatron(data_iterator): | |
"""Generate a batch""" | |
# Items and their type. | |
keys = ["text"] | |
datatype = torch.int64 | |
# Broadcast data. | |
if data_iterator is not None: | |
data = next(data_iterator) | |
else: | |
data = None | |
data_b = tensor_parallel.broadcast_data(keys, data, datatype) | |
# Unpack. | |
tokens_ = data_b["text"].long() | |
labels = tokens_[:, 1:].contiguous() | |
tokens = tokens_[:, :-1].contiguous() | |
# Get the masks and postition ids. | |
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( | |
tokens, self.eod_token, self.reset_position_ids, self.reset_attention_mask, self.eod_mask_loss | |
) | |
return tokens, labels, loss_mask, attention_mask, position_ids | |
def get_batch_transformer(data_iterator): | |
data = next(data_iterator) | |
data = {"input_ids": data["input_ids"]} | |
data = send_to_device(data, torch.cuda.current_device()) | |
tokens_ = data["input_ids"].long() | |
padding = torch.zeros((tokens_.shape[0], 1), dtype=tokens_.dtype, device=tokens_.device) + self.eod_token | |
tokens_ = torch.concat([tokens_, padding], dim=1) | |
labels = tokens_[:, 1:].contiguous() | |
tokens = tokens_[:, :-1].contiguous() | |
# Get the masks and postition ids. | |
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( | |
tokens, self.eod_token, self.reset_position_ids, self.reset_attention_mask, True | |
) | |
return tokens, labels, loss_mask, attention_mask, position_ids | |
if accelerator.state.megatron_lm_plugin.custom_get_batch_function is not None: | |
return accelerator.state.megatron_lm_plugin.custom_get_batch_function | |
if megatron_dataset_flag: | |
try: | |
# Use '--no-use-pep517 -e' to pip install nvidia's megatron from source | |
from pretrain_gpt import get_batch | |
return get_batch | |
except ImportError: | |
pass | |
return get_batch_megatron | |
else: | |
return get_batch_transformer | |
def get_loss_func(self, accelerator): | |
args = get_args() | |
def loss_func(loss_mask, output_tensor): | |
if args.return_logits: | |
losses, logits = output_tensor | |
else: | |
losses = output_tensor | |
losses = losses.float() | |
loss_mask = loss_mask.view(-1).float() | |
if args.context_parallel_size > 1: | |
loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), loss_mask.sum().view(1)]) | |
torch.distributed.all_reduce(loss, group=mpu.get_context_parallel_group()) | |
loss = loss[0] / loss[1] | |
else: | |
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() | |
# Check individual rank losses are not NaN prior to DP all-reduce. | |
if args.check_for_nan_in_loss_and_grad: | |
global_rank = torch.distributed.get_rank() | |
assert not loss.isnan(), ( | |
f"Rank {global_rank}: found NaN in local forward loss calculation. " | |
f"Device: {torch.cuda.current_device()}, node: {os.uname()[1]}" | |
) | |
# Reduce loss for logging. | |
averaged_loss = average_losses_across_data_parallel_group([loss]) | |
output_dict = {"lm loss": averaged_loss[0]} | |
if args.return_logits: | |
output_dict.update({"logits": logits}) | |
return loss, output_dict | |
if accelerator.state.megatron_lm_plugin.custom_loss_function is not None: | |
return accelerator.state.megatron_lm_plugin.custom_loss_function | |
return loss_func | |
def get_forward_step_func(self): | |
def forward_step(data_iterator, model): | |
"""Forward step.""" | |
# Get the batch. | |
tokens, labels, loss_mask, attention_mask, position_ids = self.get_batch(data_iterator) | |
output_tensor = model(tokens, position_ids, attention_mask, labels=labels) | |
return output_tensor, partial(self.loss_func, loss_mask) | |
return forward_step | |
class T5TrainStep(AbstractTrainStep): | |
""" | |
T5 train step class. | |
Args: | |
args (`argparse.Namespace`): Megatron-LM arguments. | |
""" | |
def __init__(self, accelerator, args): | |
super().__init__("T5TrainStep") | |
self.get_batch = self.get_batch_func(accelerator, args.megatron_dataset_flag) | |
self.loss_func = self.get_loss_func(accelerator) | |
self.forward_step = self.get_forward_step_func() | |
if not args.model_return_dict: | |
self.model_output_class = None | |
else: | |
self.model_output_class = Seq2SeqLMOutput | |
def attn_mask_postprocess(attention_mask): | |
# We create a 3D attention mask from a 2D tensor mask. | |
# [b, 1, s] | |
attention_mask_b1s = attention_mask.unsqueeze(1) | |
# [b, s, 1] | |
attention_mask_bs1 = attention_mask.unsqueeze(2) | |
# [b, s, s] | |
attention_mask_bss = attention_mask_b1s * attention_mask_bs1 | |
# Convert attention mask to binary: | |
extended_attention_mask = attention_mask_bss < 0.5 | |
return extended_attention_mask | |
def get_decoder_mask(seq_length, device): | |
attention_mask = torch.tril(torch.ones((1, seq_length, seq_length), device=device)) | |
attention_mask = attention_mask < 0.5 | |
return attention_mask | |
def get_enc_dec_mask(attention_mask, dec_seq_length, device): | |
batch_size, _ = attention_mask.shape | |
# We create a 3D attention mask from a 2D tensor mask. | |
# [b, 1, s] | |
attention_mask_b1s = attention_mask.unsqueeze(1) | |
# [b, s, 1] | |
attention_mask_bs1 = torch.ones((batch_size, dec_seq_length, 1), device=device) | |
attention_mask_bss = attention_mask_bs1 * attention_mask_b1s | |
extended_attention_mask = attention_mask_bss < 0.5 | |
return extended_attention_mask | |
def get_batch_func(self, accelerator, megatron_dataset_flag): | |
def get_batch_megatron(data_iterator): | |
"""Build the batch.""" | |
keys = ["text_enc", "text_dec", "labels", "loss_mask", "enc_mask", "dec_mask", "enc_dec_mask"] | |
datatype = torch.int64 | |
# Broadcast data. | |
if data_iterator is not None: | |
data = next(data_iterator) | |
else: | |
data = None | |
data_b = tensor_parallel.broadcast_data(keys, data, datatype) | |
# Unpack. | |
tokens_enc = data_b["text_enc"].long() | |
tokens_dec = data_b["text_dec"].long() | |
labels = data_b["labels"].long() | |
loss_mask = data_b["loss_mask"].float() | |
enc_mask = data_b["enc_mask"] < 0.5 | |
dec_mask = data_b["dec_mask"] < 0.5 | |
enc_dec_mask = data_b["enc_dec_mask"] < 0.5 | |
return tokens_enc, tokens_dec, loss_mask, labels, enc_mask, dec_mask, enc_dec_mask | |
def get_batch_transformer(data_iterator): | |
"""Build the batch.""" | |
data = next(data_iterator) | |
data = send_to_device(data, torch.cuda.current_device()) | |
tokens_enc = data["input_ids"].long() | |
labels = data["labels"].long() | |
loss_mask = (labels != -100).to(torch.float) | |
if "decoder_input_ids" in data: | |
tokens_dec = data["decoder_input_ids"].long() | |
else: | |
tokens_dec = labels.new_zeros(labels.shape, device=labels.device, dtype=torch.long) | |
tokens_dec[..., 1:] = labels[..., :-1].clone() | |
tokens_dec[..., 0] = 0 | |
tokens_dec.masked_fill_(tokens_dec == -100, 0) | |
enc_mask = T5TrainStep.attn_mask_postprocess(data["attention_mask"].long()) | |
dec_mask = T5TrainStep.get_decoder_mask(tokens_dec.shape[1], tokens_dec.device) | |
enc_dec_mask = T5TrainStep.get_enc_dec_mask( | |
data["attention_mask"].long(), tokens_dec.shape[1], tokens_dec.device | |
) | |
return tokens_enc, tokens_dec, loss_mask, labels, enc_mask, dec_mask, enc_dec_mask | |
if accelerator.state.megatron_lm_plugin.custom_get_batch_function is not None: | |
return accelerator.state.megatron_lm_plugin.custom_get_batch_function | |
if megatron_dataset_flag: | |
try: | |
# Use '--no-use-pep517 -e' to pip install nvidia's megatron from source | |
from pretrain_t5 import get_batch | |
return get_batch | |
except ImportError: | |
pass | |
return get_batch_megatron | |
else: | |
return get_batch_transformer | |
def get_loss_func(self, accelerator): | |
def loss_func(loss_mask, output_tensor): | |
lm_loss_ = output_tensor.float() | |
lm_loss = torch.sum(lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() | |
loss = lm_loss | |
averaged_losses = average_losses_across_data_parallel_group([lm_loss]) | |
return loss, {"lm loss": averaged_losses[0]} | |
if accelerator.state.megatron_lm_plugin.custom_loss_function is not None: | |
return accelerator.state.megatron_lm_plugin.custom_loss_function | |
return loss_func | |
def get_forward_step_func(self): | |
def forward_step(data_iterator, model): | |
"""Forward step.""" | |
# Get the batch. | |
tokens_enc, tokens_dec, loss_mask, lm_labels, enc_mask, dec_mask, enc_dec_mask = self.get_batch( | |
data_iterator | |
) | |
# Forward model lm_labels | |
output_tensor = model( | |
tokens_enc, tokens_dec, enc_mask, dec_mask, enc_dec_mask, tokentype_ids=None, lm_labels=lm_labels | |
) | |
return output_tensor, partial(self.loss_func, loss_mask) | |
return forward_step | |
def finish_mpu_init(): | |
# torch.distributed initialization | |
args = get_args() | |
# Pytorch distributed. | |
_initialize_distributed() | |
# Random seeds for reproducibility. | |
if args.rank == 0: | |
print(f"> setting random seeds to {args.seed} ...") | |
_set_random_seed(args.seed, args.data_parallel_random_init) | |
# intialize megatron setup | |
def initialize(accelerator, extra_args_provider=None, args_defaults={}): | |
accelerator.print("Initializing Megatron-LM") | |
assert torch.cuda.is_available(), "Megatron requires CUDA." | |
# Parse arguments | |
args = parse_args(extra_args_provider, ignore_unknown_args=True) | |
# Set defaults | |
for key, value in args_defaults.items(): | |
if getattr(args, key, None) is not None: | |
if args.rank == 0: | |
print( | |
f"WARNING: overriding default arguments for " f"{key}:{getattr(args, key)} with {key}:{value}", | |
flush=True, | |
) | |
setattr(args, key, value) | |
if args.use_checkpoint_args or args_defaults.get("use_checkpoint_args", False): | |
assert args.load is not None, "--use-checkpoints-args requires --load argument" | |
load_args_from_checkpoint(args) | |
validate_args(args) | |
# set global args, build tokenizer, and set adlr-autoresume, | |
# tensorboard-writer, and timers. | |
set_global_variables(args) | |
# Megatron's MPU is the master. Complete initialization right away. | |
finish_mpu_init() | |
# Autoresume. | |
_init_autoresume() | |
# Compile dependencies. | |
_compile_dependencies() | |
# Set pytorch JIT layer fusion options and warmup JIT functions. | |
set_jit_fusion_options() | |
args = get_args() | |
if getattr(args, "padded_vocab_size", None) is None: | |
args.padded_vocab_size = _vocab_size_with_padding(args.orig_vocab_size, args) | |
if args.model_type_name == "bert" and args.pretraining_flag and args.num_labels == 2: | |
args.bert_binary_head = True | |
else: | |
args.bert_binary_head = False | |
args.iteration = 0 | |
class MegatronEngine(torch.nn.Module): | |
""" | |
Megatron-LM model wrapper | |
Args: | |
accelerator (:class:`~accelerate.Accelerator`): The accelerator object to use. | |
model: Megatron-LM model | |
optimizer: Megatron-LM optimizer | |
lr_scheduler: Megatron-LM lr scheduler | |
""" | |
def __init__(self, accelerator, model, optimizer, scheduler): | |
super().__init__() | |
self.module = model | |
self.base_model = model[0] | |
self.optimizer = optimizer | |
self.scheduler = scheduler | |
args = get_args() | |
if accelerator.state.megatron_lm_plugin.custom_train_step_class is not None: | |
self.train_step_handler = accelerator.state.megatron_lm_plugin.custom_train_step_class( | |
args, **accelerator.state.megatron_lm_plugin.custom_train_step_kwargs | |
) | |
elif args.model_type_name == "bert": | |
self.train_step_handler = BertTrainStep(accelerator, args) | |
elif args.model_type_name == "gpt": | |
self.train_step_handler = GPTTrainStep(accelerator, args) | |
elif args.model_type_name == "t5": | |
self.train_step_handler = T5TrainStep(accelerator, args) | |
else: | |
raise ValueError(f"Unsupported model type: {args.model_type_name}") | |
self.optimizer.skipped_iter = False | |
# Tracking loss. | |
self.total_loss_dict = {} | |
self.eval_total_loss_dict = {} | |
self.iteration = 0 | |
self.report_memory_flag = True | |
self.num_floating_point_operations_so_far = 0 | |
self.module_config = None | |
if args.tensorboard_dir is not None: | |
write_args_to_tensorboard() | |
def get_module_config(self): | |
args = get_args() | |
config = get_model_config(self.module[0]) | |
# Setup some training config params | |
config.grad_scale_func = self.optimizer.scale_loss | |
if isinstance(self.module[0], LocalDDP) and args.overlap_grad_reduce: | |
assert config.no_sync_func is None, ( | |
"When overlap_grad_reduce is True, config.no_sync_func must be None; " | |
"a custom no_sync_func is not supported when overlapping grad-reduce" | |
) | |
config.no_sync_func = [model_chunk.no_sync for model_chunk in self.module] | |
if len(self.module) == 1: | |
config.no_sync_func = config.no_sync_func[0] | |
if args.delay_grad_reduce: | |
config.grad_sync_func = [model_chunk.start_grad_sync for model_chunk in self.module] | |
if len(self.module) == 1: | |
config.grad_sync_func = config.grad_sync_func[0] | |
if args.overlap_param_gather and args.delay_param_gather: | |
config.param_sync_func = [ | |
lambda x: self.optimizer.finish_param_sync(model_index, x) for model_index in range(len(self.module)) | |
] | |
if len(self.module) == 1: | |
config.param_sync_func = config.param_sync_func[0] | |
config.finalize_model_grads_func = finalize_model_grads | |
return config | |
def train(self): | |
for model_module in self.module: | |
model_module.train() | |
if self.module_config is None: | |
self.module_config = self.get_module_config() | |
self.log_eval_results() | |
def eval(self): | |
for model_module in self.module: | |
model_module.eval() | |
if self.module_config is None: | |
self.module_config = self.get_module_config() | |
def get_batch_data_iterator(self, batch_data): | |
args = get_args() | |
data_chunks = [] | |
if len(batch_data) > 0: | |
if args.num_micro_batches > 1: | |
for i in range(0, args.num_micro_batches): | |
data_chunks.append( | |
{ | |
k: v[i * args.micro_batch_size : (i + 1) * args.micro_batch_size] | |
for k, v in batch_data.items() | |
} | |
) | |
else: | |
data_chunks = [batch_data] | |
if len(self.module) > 1: | |
batch_data_iterator = ( | |
[iter(data_chunks) for _ in range(len(self.module))] | |
if len(batch_data) > 0 | |
else [None] * len(self.module) | |
) | |
else: | |
batch_data_iterator = iter(data_chunks) if len(batch_data) > 0 else None | |
return batch_data_iterator | |
def train_step(self, **batch_data): | |
""" | |
Training step for Megatron-LM | |
Args: | |
batch_data (:obj:`dict`): The batch data to train on. | |
""" | |
batch_data_iterator = self.get_batch_data_iterator(batch_data) | |
loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad = train_step( | |
forward_step_func=self.train_step_handler.forward_step, | |
data_iterator=batch_data_iterator, | |
model=self.module, | |
optimizer=self.optimizer, | |
opt_param_scheduler=self.scheduler, | |
config=self.module_config, | |
) | |
self.optimizer.skipped_iter = skipped_iter == 1 | |
return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad | |
def eval_step(self, **batch_data): | |
""" | |
Evaluation step for Megatron-LM | |
Args: | |
batch_data (:obj:`dict`): The batch data to evaluate on. | |
""" | |
args = get_args() | |
batch_data_iterator = self.get_batch_data_iterator(batch_data) | |
forward_backward_func = get_forward_backward_func() | |
loss_dicts = forward_backward_func( | |
forward_step_func=self.train_step_handler.forward_step, | |
data_iterator=batch_data_iterator, | |
model=self.module, | |
num_microbatches=get_num_microbatches(), | |
seq_length=args.seq_length, | |
micro_batch_size=args.micro_batch_size, | |
forward_only=True, | |
) | |
# Empty unused memory | |
if args.empty_unused_memory_level >= 1: | |
torch.cuda.empty_cache() | |
args.consumed_valid_samples += ( | |
mpu.get_data_parallel_world_size() * args.micro_batch_size * get_num_microbatches() | |
) | |
if mpu.is_pipeline_last_stage(ignore_virtual=True): | |
# Average loss across microbatches. | |
loss_reduced = {} | |
for key in loss_dicts[0]: | |
losses_reduced_for_key = [x[key] for x in loss_dicts] | |
if len(losses_reduced_for_key[0].shape) == 0: | |
loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key) | |
else: | |
loss_reduced[key] = torch.concat(losses_reduced_for_key) | |
return loss_reduced | |
return {} | |
def forward(self, **batch_data): | |
# During training, we use train_step() | |
# model(**batch_data) performs following operations by delegating it to `self.train_step`: | |
# 1. Prepare **batch_data for Tendor, Pipeline and Model Parallelism | |
# 2. Set grad to zero. | |
# 3. forward pass and backward pass using Pipeline Parallelism | |
# 4. Empty unused memory. | |
# 5. Reduce gradients. | |
# 6. Update parameters. | |
# 7. Gather params when using Distributed Optimizer (Data Parallelism). | |
# 8. Update learning rate if scheduler is specified. | |
# 9. Empty unused memory. | |
# 10. Average loss across microbatches and across DP ranks. | |
# | |
# During evaluation, we use eval_step() | |
args = get_args() | |
if self.module[0].training: | |
loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = self.train_step(**batch_data) | |
self.iteration += 1 | |
batch_size = mpu.get_data_parallel_world_size() * args.micro_batch_size * get_num_microbatches() | |
args.consumed_train_samples += batch_size | |
self.num_floating_point_operations_so_far += num_floating_point_operations(args, batch_size) | |
if args.tensorboard_dir is not None: | |
# Logging. | |
loss_scale = self.optimizer.get_loss_scale().item() | |
params_norm = None | |
if args.log_params_norm: | |
params_norm = calc_params_l2_norm(self.model) | |
self.report_memory_flag = training_log( | |
loss_dict, | |
self.total_loss_dict, | |
self.optimizer.param_groups[0]["lr"], | |
self.iteration, | |
loss_scale, | |
self.report_memory_flag, | |
skipped_iter, | |
grad_norm, | |
params_norm, | |
num_zeros_in_grad, | |
) | |
else: | |
loss_dict = self.eval_step(**batch_data) | |
if args.tensorboard_dir is not None: | |
for key in loss_dict: | |
self.eval_total_loss_dict[key] = ( | |
self.eval_total_loss_dict.get(key, torch.cuda.FloatTensor([0.0])) + loss_dict[key] | |
) | |
self.eval_total_loss_dict[key + "_num_iters"] = self.eval_total_loss_dict.get( | |
key + "_num_iters", torch.cuda.FloatTensor([0.0]) | |
) + torch.cuda.FloatTensor([1.0]) | |
loss = torch.tensor(0.0, device=torch.cuda.current_device()) | |
for key in loss_dict: | |
if len(loss_dict[key].shape) == 0: | |
loss += loss_dict[key] | |
logits = None | |
if "logits" in loss_dict: | |
logits = loss_dict["logits"] | |
if self.train_step_handler.model_output_class is not None: | |
return self.train_step_handler.model_output_class(loss=loss, logits=logits) | |
return loss | |
def log_eval_results(self): | |
args = get_args() | |
if args.tensorboard_dir is None or self.iteration == 0: | |
return | |
args = get_args() | |
writer = get_tensorboard_writer() | |
string = f"validation loss at iteration {self.iteration} | " | |
for key in self.eval_total_loss_dict: | |
if key.endswith("_num_iters"): | |
continue | |
value = self.eval_total_loss_dict[key] / self.eval_total_loss_dict[key + "_num_iters"] | |
string += f"{key} value: {value} | " | |
ppl = math.exp(min(20, value.item())) | |
if args.pretraining_flag: | |
string += f"{key} PPL: {ppl} | " | |
if writer: | |
writer.add_scalar(f"{key} validation", value.item(), self.iteration) | |
if args.pretraining_flag: | |
writer.add_scalar(f"{key} validation ppl", ppl, self.iteration) | |
length = len(string) + 1 | |
print_rank_last("-" * length) | |
print_rank_last(string) | |
print_rank_last("-" * length) | |
self.eval_total_loss_dict = {} | |
def save_checkpoint(self, output_dir): | |
self.log_eval_results() | |
args = get_args() | |
args.save = output_dir | |
torch.distributed.barrier() | |
save_checkpoint( | |
self.iteration, | |
self.module, | |
self.optimizer, | |
self.scheduler, | |
num_floating_point_operations_so_far=self.num_floating_point_operations_so_far, | |
) | |
torch.distributed.barrier() | |
def load_checkpoint(self, input_dir): | |
args = get_args() | |
args.load = input_dir | |
args.consumed_train_samples = 0 | |
args.consumed_valid_samples = 0 | |
torch.distributed.barrier() | |
iteration, num_floating_point_operations_so_far = load_checkpoint(self.module, self.optimizer, self.scheduler) | |
torch.distributed.barrier() | |
self.iteration = iteration | |
self.num_floating_point_operations_so_far = num_floating_point_operations_so_far | |
if args.fp16 and self.iteration == 0: | |
self.optimizer.reload_model_params() | |
def megatron_generate( | |
self, | |
inputs, | |
attention_mask=None, | |
max_length=None, | |
max_new_tokens=None, | |
num_beams=None, | |
temperature=None, | |
top_k=None, | |
top_p=None, | |
length_penalty=None, | |
**kwargs, | |
): | |
""" | |
Generate method for GPT2 model. This method is used for inference. Supports both greedy and beam search along | |
with sampling. Refer the Megatron-LM repo for more details | |
Args: | |
inputs (torch.Tensor): input ids | |
attention_mask (torch.Tensor, optional): attention mask. Defaults to None. | |
max_length (int, optional): max length of the generated sequence. Defaults to None. | |
Either this or max_new_tokens should be provided. | |
max_new_tokens (int, optional): max number of tokens to be generated. Defaults to None. | |
Either this or max_length should be provided. | |
num_beams (int, optional): number of beams to use for beam search. Defaults to None. | |
temperature (float, optional): temperature for sampling. Defaults to 1.0. | |
top_k (int, optional): top k tokens to consider for sampling. Defaults to 0.0. | |
top_p (float, optional): tokens in top p probability are considered for sampling. Defaults to 0.0. | |
length_penalty (float, optional): length penalty for beam search. Defaults to None. | |
kwargs: additional key-value arguments | |
""" | |
# checking if required arguments are passed | |
args = get_args() | |
if args.model_type_name != "gpt": | |
raise NotImplementedError("Generate method is not implemented for this model") | |
if args.data_parallel_size > 1: | |
raise ValueError("Generate method requires data parallelism to be 1") | |
if args.sequence_parallel: | |
raise ValueError("Generate method requires sequence parallelism to be False") | |
if args.recompute_granularity is not None: | |
raise ValueError("Checkpoint activations cannot be set for inference") | |
if args.vocab_file is None: | |
raise ValueError("Vocab file is required for inference") | |
# Prepare inputs | |
if max_length is None and max_new_tokens is None: | |
raise ValueError("`max_length` or `max_new_tokens` are required for inference") | |
if temperature is None: | |
temperature = 1.0 | |
elif not (0.0 < temperature <= 100.0): | |
raise ValueError("temperature must be a positive number less than or equal to 100.0") | |
if top_k is None: | |
top_k = 0 | |
elif not (0 <= top_k <= 1000): | |
raise ValueError("top_k must be a positive number less than or equal to 1000") | |
if top_p is None: | |
top_p = 0.0 | |
elif top_p > 0.0 and top_k > 0.0: | |
raise ValueError("top_p and top_k sampling cannot be set together") | |
else: | |
if not (0.0 <= top_p <= 1.0): | |
raise ValueError("top_p must be less than or equal to 1.0") | |
top_p_decay = kwargs.get("top_p_decay", 0.0) | |
if not (0.0 <= top_p_decay <= 1.0): | |
raise ValueError("top_p_decay must be less than or equal to 1.0") | |
top_p_bound = kwargs.get("top_p_bound", 0.0) | |
if not (0.0 <= top_p_bound <= 1.0): | |
raise ValueError("top_p_bound must be less than or equal to 1.0") | |
add_BOS = kwargs.get("add_BOS", False) | |
if not (isinstance(add_BOS, bool)): | |
raise ValueError("add_BOS must be a boolean") | |
beam_width = num_beams | |
if beam_width is not None: | |
if not isinstance(beam_width, int): | |
raise ValueError("beam_width must be an integer") | |
if beam_width < 1: | |
raise ValueError("beam_width must be greater than 0") | |
if inputs.shape[0] > 1: | |
return "When doing beam_search, batch size must be 1" | |
tokenizer = get_tokenizer() | |
stop_token = kwargs.get("stop_token", tokenizer.eod) | |
if stop_token is not None: | |
if not isinstance(stop_token, int): | |
raise ValueError("stop_token must be an integer") | |
if length_penalty is None: | |
length_penalty = 1.0 | |
sizes_list = None | |
prompts_tokens_tensor = None | |
prompts_length_tensor = None | |
if torch.distributed.get_rank() == 0: | |
# Get the prompts length. | |
if attention_mask is None: | |
prompts_length_tensor = torch.cuda.LongTensor([inputs.shape[1]] * inputs.shape[0]) | |
else: | |
prompts_length_tensor = attention_mask.sum(axis=-1).cuda() | |
if max_new_tokens is None: | |
max_new_tokens = max_length - inputs.shape[1] | |
if max_new_tokens <= 0: | |
raise ValueError("max_new_tokens must be greater than 0") | |
if add_BOS: | |
max_length = max_new_tokens + inputs.shape[1] + 1 | |
# making sure that `max_length` is a multiple of 4 to leverage fused kernels | |
max_length = 4 * math.ceil(max_length / 4) | |
max_new_tokens = max_length - (inputs.shape[1] + 1) | |
padding = torch.cuda.LongTensor([[tokenizer.eod] * max_new_tokens] * inputs.shape[0]) | |
prompts_tokens_tensor = torch.concat( | |
[torch.unsqueeze(padding[:, 0], axis=-1), inputs.cuda(), padding], axis=-1 | |
) | |
else: | |
# making sure that `max_length` is a multiple of 4 to leverage fused kernels | |
max_length = max_new_tokens + inputs.shape[1] | |
max_length = 4 * math.ceil(max_length / 4) | |
max_new_tokens = max_length - inputs.shape[1] | |
padding = torch.cuda.LongTensor([[tokenizer.eod] * max_new_tokens] * inputs.shape[0]) | |
prompts_tokens_tensor = torch.concat([inputs.cuda(), padding], axis=-1) | |
# We need the sizes of these tensors for the boradcast | |
sizes_list = [ | |
prompts_tokens_tensor.size(0), # Batch size | |
prompts_tokens_tensor.size(1), | |
] # Sequence lenght | |
# First, broadcast the sizes. | |
sizes_tensor = broadcast_int_list(2, int_list=sizes_list, rank=0) | |
# Now that we have the sizes, we can boradcast the tokens | |
# and length tensors. | |
sizes = sizes_tensor.tolist() | |
context_tokens_tensor = broadcast_tensor(sizes, torch.int64, tensor=prompts_tokens_tensor, rank=0) | |
context_length_tensor = broadcast_tensor(sizes[0], torch.int64, tensor=prompts_length_tensor, rank=0) | |
# Run the inference | |
random_seed = kwargs.get("random_seed", 0) | |
torch.random.manual_seed(random_seed) | |
unwrapped_model = unwrap_model(self.base_model, (torchDDP, LocalDDP, Float16Module)) | |
if beam_width is not None: | |
tokens, _ = beam_search_and_return_on_first_stage( | |
unwrapped_model, | |
context_tokens_tensor, | |
context_length_tensor, | |
beam_width, | |
stop_token=stop_token, | |
num_return_gen=1, | |
length_penalty=length_penalty, | |
) | |
else: | |
tokens, _, _ = generate_tokens_probs_and_return_on_first_stage( | |
unwrapped_model, | |
context_tokens_tensor, | |
context_length_tensor, | |
return_output_log_probs=False, | |
top_k=top_k, | |
top_p=top_p, | |
top_p_decay=top_p_decay, | |
top_p_bound=top_p_bound, | |
temperature=temperature, | |
use_eod_token_for_early_termination=True, | |
) | |
return tokens | |
# other utilities | |
def avg_losses_across_data_parallel_group(losses): | |
""" | |
Average losses across data parallel group. | |
Args: | |
losses (List[Tensor]): List of losses to average across data parallel group. | |
""" | |
return average_losses_across_data_parallel_group(losses) | |
def gather_across_data_parallel_groups(tensor): | |
""" | |
Recursively gather tensor in a nested list/tuple/dictionary of tensors from data parallel ranks. | |
Args: | |
tensor (nested list/tuple/dictionary of `torch.Tensor`): | |
The data to gather across data parallel ranks. | |
""" | |
def _gpu_gather_one(tensor): | |
if tensor.ndim == 0: | |
tensor = tensor.clone()[None] | |
output_tensors = [ | |
torch.empty_like(tensor) | |
for _ in range(torch.distributed.get_world_size(group=mpu.get_data_parallel_group())) | |
] | |
torch.distributed.all_gather(output_tensors, tensor, group=mpu.get_data_parallel_group()) | |
return torch.cat(output_tensors, dim=0) | |
return recursively_apply(_gpu_gather_one, tensor, error_on_other_type=True) | |