Spaces:
Runtime error
Runtime error
# Before running, install required packages: | |
{% if notebook %} | |
! | |
{%- else %} | |
# | |
{%- endif %} | |
pip install datasets transformers[sentencepiece] accelerate sacrebleu==1.4.14 sacremoses | |
import collections | |
import logging | |
import math | |
import random | |
import babel | |
import datasets | |
import numpy as np | |
import torch | |
import transformers | |
from datasets import load_dataset, load_metric | |
from torch.utils.data import DataLoader | |
from torch.utils.data.dataloader import DataLoader | |
from tqdm.auto import tqdm | |
from transformers import (AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer, | |
DataCollatorForLanguageModeling, | |
DataCollatorForSeq2Seq, MBartTokenizer, | |
MBartTokenizerFast, Seq2SeqTrainer, Seq2SeqTrainingArguments, | |
default_data_collator, get_scheduler) | |
from transformers.utils.versions import require_version | |
{{ header("Setup") }} | |
logger = logging.getLogger(__name__) | |
require_version("datasets>=1.8.0") | |
set_seed({{ seed }}) | |
logging.basicConfig( | |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
datefmt="%m/%d/%Y %H:%M:%S", | |
level=logging.ERROR, | |
) | |
datasets.utils.logging.set_verbosity_warning() | |
transformers.utils.logging.set_verbosity_info() | |
{{ header("Load model and dataset") }} | |
{% if subset == 'default' %} | |
datasets = load_dataset('{{dataset}}') | |
{% else %} | |
datasets = load_dataset('{{dataset}}', '{{ subset }}') | |
{% endif %} | |
metric = load_metric("sacrebleu") | |
model_checkpoint = "{{model_checkpoint}}" | |
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True) | |
{% if pretrained %} | |
model = AutoModelFor{{task}}.from_pretrained(model_checkpoint) | |
{% else %} | |
config = AutoConfig.from_pretrained(model_checkpoint) | |
model = AutoModelFor{{task}}.from_config(config) | |
{% endif %} | |
model.resize_token_embeddings(len(tokenizer)) | |
model_name = model_checkpoint.split("/")[-1] | |
{{ header("Preprocessing") }} | |
source_lang = '{{ source_language }}' | |
target_lang = '{{ target_language }}' | |
{% if 'mbart' in model_checkpoint %} | |
# Set decoder_start_token_id | |
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)): | |
assert ( | |
target_lang is not None and source_lang is not None | |
), "mBart requires --target_lang and --source_lang" | |
if isinstance(tokenizer, MBartTokenizer): | |
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[target_lang] | |
else: | |
model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(target_lang) | |
{% endif %} | |
{% if 't5' in model_checkpoint %} | |
if model_checkpoint in ["t5-small", "t5-base", "t5-larg", "t5-3b", "t5-11b"]: | |
for language in (source_lang, target_lang): | |
if language != language[:2]: | |
logging.warning( | |
'Extended language code %s not supported. Falling back on %s.', | |
language, language[:2] | |
) | |
lang_id_to_string = { | |
source_lang: babel.Locale(source_lang[:2]).english_name, | |
target_lang: babel.Locale(target_lang[:2]).english_name, | |
} | |
src_str = 'translate {}'.format(lang_id_to_string[source_lang]) | |
tgt_str = ' to {}: '.format(lang_id_to_string[target_lang]) | |
prefix = src_str + tgt_str | |
else: | |
prefix = "" | |
{% else %} | |
prefix = "" | |
{% endif %} | |
{% if 'mbart' in model_checkpoint %} | |
# For translation we set the codes of our source and target languages (only useful for mBART, the others will | |
# ignore those attributes). | |
if isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)): | |
label = ['ar_AR', 'cs_CZ', 'de_DE', 'en_XX', 'es_XX', 'et_EE', 'fi_FI', 'fr_XX', 'gu_IN', 'hi_IN', 'it_IT', 'ja_XX', 'kk_KZ', 'ko_KR', 'lt_LT', 'lv_LV', 'my_MM', 'ne_NP', 'nl_XX', 'ro_RO', 'ru_RU', 'si_LK', 'tr_TR', 'vi_VN', 'zh_CN'] | |
source_code = [item for item in label if item.startswith(source_lang)][0] | |
target_code = [item for item in label if item.startswith(target_lang)][0] | |
if source_lang is not None: | |
tokenizer.src_lang = source_code | |
if target_lang is not None: | |
tokenizer.tgt_lang = target_code | |
{% endif %} | |
max_input_length = {{ block_size }} | |
max_target_length = {{ block_size }} | |
def preprocess_function(examples): | |
inputs = [prefix + ex[source_lang] for ex in examples["translation"]] | |
targets = [ex[target_lang] for ex in examples["translation"]] | |
model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True) | |
# Setup the tokenizer for targets | |
with tokenizer.as_target_tokenizer(): | |
labels = tokenizer(targets, max_length=max_target_length, truncation=True) | |
model_inputs["labels"] = labels["input_ids"] | |
return model_inputs | |
tokenized_datasets = datasets.map(preprocess_function, batched=True, num_proc=4, remove_columns=list( | |
set(sum(list(datasets.column_names.values()), []))), desc="Running tokenizer on dataset") | |
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model) | |
batch_size = {{ batch_size }} | |
{{ header("Training") }} | |
def compute_metrics(eval_preds): | |
preds, labels = eval_preds | |
# In case the model returns more than the prediction logits | |
if isinstance(preds, tuple): | |
preds = preds[0] | |
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) | |
# Replace -100s in the labels as we can't decode them | |
labels = np.where(labels != -100, labels, tokenizer.pad_token_id) | |
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) | |
# Some simple post-processing | |
decoded_preds = [pred.strip() for pred in decoded_preds] | |
decoded_labels = [[label.strip()] for label in decoded_labels] | |
result = metric.compute(predictions=decoded_preds, | |
references=decoded_labels) | |
return {"bleu": result["score"]} | |
def postprocess(predictions, labels): | |
predictions = predictions.cpu().numpy() | |
labels = labels.cpu().numpy() | |
decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True) | |
# Replace -100 in the labels as we can't decode them. | |
labels = np.where(labels != -100, labels, tokenizer.pad_token_id) | |
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) | |
# Some simple post-processing | |
decoded_preds = [pred.strip() for pred in decoded_preds] | |
decoded_labels = [[label.strip()] for label in decoded_labels] | |
return decoded_preds, decoded_labels | |
training_args = Seq2SeqTrainingArguments( | |
output_dir=f"{model_name}-finetuned", | |
per_device_train_batch_size={{ batch_size }}, | |
per_device_eval_batch_size={{ batch_size }}, | |
evaluation_strategy='epoch', | |
logging_strategy='epoch', | |
save_strategy='epoch', | |
optim='{{ optimizer }}', | |
learning_rate={{ lr }}, | |
num_train_epochs={{ num_epochs }}, | |
gradient_accumulation_steps={{ gradient_accumulation_steps }}, | |
lr_scheduler_type='{{ lr_scheduler_type }}', | |
warmup_steps={{ num_warmup_steps }}, | |
{% if use_weight_decay%} | |
weight_decay={{ weight_decay }}, | |
{% endif %} | |
push_to_hub=False, | |
dataloader_num_workers=0, | |
{% if task=="MaskedLM" %} | |
{% if whole_word_masking %} | |
remove_unused_columns=False, | |
{% endif %} | |
{% endif %} | |
load_best_model_at_end=True, | |
log_level='error' | |
) | |
trainer = Seq2SeqTrainer( | |
model=model, | |
args=training_args, | |
train_dataset=lm_datasets["{{ train }}"], | |
eval_dataset=lm_datasets["{{ validation }}"], | |
data_collator=data_collator, | |
) | |
train_result = trainer.train() | |
trainer.save_model() | |
trainer.log_metrics("train", train_result.metrics) | |
trainer.save_metrics("train", train_result.metrics) | |
trainer.save_state() | |
eval_results = trainer.evaluate() | |
trainer.log_metrics("eval", eval_results) | |
trainer.save_metrics("eval", eval_results) | |