|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" Fine-tuning a 🤗 Flax Transformers model on token classification tasks (NER, POS, CHUNKS)""" |
|
import json |
|
import logging |
|
import math |
|
import os |
|
import random |
|
import sys |
|
import time |
|
from dataclasses import asdict, dataclass, field |
|
from enum import Enum |
|
from itertools import chain |
|
from pathlib import Path |
|
from typing import Any, Callable, Dict, Optional, Tuple |
|
|
|
import datasets |
|
import evaluate |
|
import jax |
|
import jax.numpy as jnp |
|
import numpy as np |
|
import optax |
|
from datasets import ClassLabel, load_dataset |
|
from flax import struct, traverse_util |
|
from flax.jax_utils import pad_shard_unpad, replicate, unreplicate |
|
from flax.training import train_state |
|
from flax.training.common_utils import get_metrics, onehot, shard |
|
from huggingface_hub import Repository, create_repo |
|
from tqdm import tqdm |
|
|
|
import transformers |
|
from transformers import ( |
|
AutoConfig, |
|
AutoTokenizer, |
|
FlaxAutoModelForTokenClassification, |
|
HfArgumentParser, |
|
is_tensorboard_available, |
|
) |
|
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry |
|
from transformers.utils.versions import require_version |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
check_min_version("4.28.0.dev0") |
|
|
|
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt") |
|
|
|
Array = Any |
|
Dataset = datasets.arrow_dataset.Dataset |
|
PRNGKey = Any |
|
|
|
|
|
@dataclass |
|
class TrainingArguments: |
|
output_dir: str = field( |
|
metadata={"help": "The output directory where the model predictions and checkpoints will be written."}, |
|
) |
|
overwrite_output_dir: bool = field( |
|
default=False, |
|
metadata={ |
|
"help": ( |
|
"Overwrite the content of the output directory. " |
|
"Use this to continue training if output_dir points to a checkpoint directory." |
|
) |
|
}, |
|
) |
|
do_train: bool = field(default=False, metadata={"help": "Whether to run training."}) |
|
do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."}) |
|
per_device_train_batch_size: int = field( |
|
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."} |
|
) |
|
per_device_eval_batch_size: int = field( |
|
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."} |
|
) |
|
learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."}) |
|
weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."}) |
|
adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"}) |
|
adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"}) |
|
adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."}) |
|
adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."}) |
|
num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."}) |
|
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."}) |
|
logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."}) |
|
save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."}) |
|
eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."}) |
|
seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."}) |
|
push_to_hub: bool = field( |
|
default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."} |
|
) |
|
hub_model_id: str = field( |
|
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."} |
|
) |
|
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."}) |
|
|
|
def __post_init__(self): |
|
if self.output_dir is not None: |
|
self.output_dir = os.path.expanduser(self.output_dir) |
|
|
|
def to_dict(self): |
|
""" |
|
Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates |
|
the token values by removing their value. |
|
""" |
|
d = asdict(self) |
|
for k, v in d.items(): |
|
if isinstance(v, Enum): |
|
d[k] = v.value |
|
if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum): |
|
d[k] = [x.value for x in v] |
|
if k.endswith("_token"): |
|
d[k] = f"<{k.upper()}>" |
|
return d |
|
|
|
|
|
@dataclass |
|
class ModelArguments: |
|
""" |
|
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. |
|
""" |
|
|
|
model_name_or_path: str = field( |
|
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} |
|
) |
|
config_name: Optional[str] = field( |
|
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} |
|
) |
|
tokenizer_name: Optional[str] = field( |
|
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} |
|
) |
|
cache_dir: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, |
|
) |
|
model_revision: str = field( |
|
default="main", |
|
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, |
|
) |
|
use_auth_token: bool = field( |
|
default=False, |
|
metadata={ |
|
"help": ( |
|
"Will use the token generated when running `huggingface-cli login` (necessary to use this script " |
|
"with private models)." |
|
) |
|
}, |
|
) |
|
|
|
|
|
@dataclass |
|
class DataTrainingArguments: |
|
""" |
|
Arguments pertaining to what data we are going to input our model for training and eval. |
|
""" |
|
|
|
task_name: Optional[str] = field(default="ner", metadata={"help": "The name of the task (ner, pos...)."}) |
|
dataset_name: Optional[str] = field( |
|
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} |
|
) |
|
dataset_config_name: Optional[str] = field( |
|
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} |
|
) |
|
train_file: Optional[str] = field( |
|
default=None, metadata={"help": "The input training data file (a csv or JSON file)."} |
|
) |
|
validation_file: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "An optional input evaluation data file to evaluate on (a csv or JSON file)."}, |
|
) |
|
test_file: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "An optional input test data file to predict on (a csv or JSON file)."}, |
|
) |
|
text_column_name: Optional[str] = field( |
|
default=None, metadata={"help": "The column name of text to input in the file (a csv or JSON file)."} |
|
) |
|
label_column_name: Optional[str] = field( |
|
default=None, metadata={"help": "The column name of label to input in the file (a csv or JSON file)."} |
|
) |
|
overwrite_cache: bool = field( |
|
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} |
|
) |
|
preprocessing_num_workers: Optional[int] = field( |
|
default=None, |
|
metadata={"help": "The number of processes to use for the preprocessing."}, |
|
) |
|
max_seq_length: int = field( |
|
default=None, |
|
metadata={ |
|
"help": ( |
|
"The maximum total input sequence length after tokenization. If set, sequences longer " |
|
"than this will be truncated, sequences shorter will be padded." |
|
) |
|
}, |
|
) |
|
max_train_samples: Optional[int] = field( |
|
default=None, |
|
metadata={ |
|
"help": ( |
|
"For debugging purposes or quicker training, truncate the number of training examples to this " |
|
"value if set." |
|
) |
|
}, |
|
) |
|
max_eval_samples: Optional[int] = field( |
|
default=None, |
|
metadata={ |
|
"help": ( |
|
"For debugging purposes or quicker training, truncate the number of evaluation examples to this " |
|
"value if set." |
|
) |
|
}, |
|
) |
|
max_predict_samples: Optional[int] = field( |
|
default=None, |
|
metadata={ |
|
"help": ( |
|
"For debugging purposes or quicker training, truncate the number of prediction examples to this " |
|
"value if set." |
|
) |
|
}, |
|
) |
|
label_all_tokens: bool = field( |
|
default=False, |
|
metadata={ |
|
"help": ( |
|
"Whether to put the label for one word on all tokens of generated by that word or just on the " |
|
"one (in which case the other tokens will have a padding index)." |
|
) |
|
}, |
|
) |
|
return_entity_level_metrics: bool = field( |
|
default=False, |
|
metadata={"help": "Whether to return all the entity levels during evaluation or just the overall ones."}, |
|
) |
|
|
|
def __post_init__(self): |
|
if self.dataset_name is None and self.train_file is None and self.validation_file is None: |
|
raise ValueError("Need either a dataset name or a training/validation file.") |
|
else: |
|
if self.train_file is not None: |
|
extension = self.train_file.split(".")[-1] |
|
assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." |
|
if self.validation_file is not None: |
|
extension = self.validation_file.split(".")[-1] |
|
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." |
|
self.task_name = self.task_name.lower() |
|
|
|
|
|
def create_train_state( |
|
model: FlaxAutoModelForTokenClassification, |
|
learning_rate_fn: Callable[[int], float], |
|
num_labels: int, |
|
training_args: TrainingArguments, |
|
) -> train_state.TrainState: |
|
"""Create initial training state.""" |
|
|
|
class TrainState(train_state.TrainState): |
|
"""Train state with an Optax optimizer. |
|
|
|
The two functions below differ depending on whether the task is classification |
|
or regression. |
|
|
|
Args: |
|
logits_fn: Applied to last layer to obtain the logits. |
|
loss_fn: Function to compute the loss. |
|
""" |
|
|
|
logits_fn: Callable = struct.field(pytree_node=False) |
|
loss_fn: Callable = struct.field(pytree_node=False) |
|
|
|
|
|
|
|
|
|
|
|
def decay_mask_fn(params): |
|
flat_params = traverse_util.flatten_dict(params) |
|
|
|
layer_norm_candidates = ["layernorm", "layer_norm", "ln"] |
|
layer_norm_named_params = { |
|
layer[-2:] |
|
for layer_norm_name in layer_norm_candidates |
|
for layer in flat_params.keys() |
|
if layer_norm_name in "".join(layer).lower() |
|
} |
|
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params} |
|
return traverse_util.unflatten_dict(flat_mask) |
|
|
|
tx = optax.adamw( |
|
learning_rate=learning_rate_fn, |
|
b1=training_args.adam_beta1, |
|
b2=training_args.adam_beta2, |
|
eps=training_args.adam_epsilon, |
|
weight_decay=training_args.weight_decay, |
|
mask=decay_mask_fn, |
|
) |
|
|
|
def cross_entropy_loss(logits, labels): |
|
xentropy = optax.softmax_cross_entropy(logits, onehot(labels, num_classes=num_labels)) |
|
return jnp.mean(xentropy) |
|
|
|
return TrainState.create( |
|
apply_fn=model.__call__, |
|
params=model.params, |
|
tx=tx, |
|
logits_fn=lambda logits: logits.argmax(-1), |
|
loss_fn=cross_entropy_loss, |
|
) |
|
|
|
|
|
def create_learning_rate_fn( |
|
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float |
|
) -> Callable[[int], jnp.array]: |
|
"""Returns a linear warmup, linear_decay learning rate function.""" |
|
steps_per_epoch = train_ds_size // train_batch_size |
|
num_train_steps = steps_per_epoch * num_train_epochs |
|
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) |
|
decay_fn = optax.linear_schedule( |
|
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps |
|
) |
|
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) |
|
return schedule_fn |
|
|
|
|
|
def train_data_collator(rng: PRNGKey, dataset: Dataset, batch_size: int): |
|
"""Returns shuffled batches of size `batch_size` from truncated `train dataset`, sharded over all local devices.""" |
|
steps_per_epoch = len(dataset) // batch_size |
|
perms = jax.random.permutation(rng, len(dataset)) |
|
perms = perms[: steps_per_epoch * batch_size] |
|
perms = perms.reshape((steps_per_epoch, batch_size)) |
|
|
|
for perm in perms: |
|
batch = dataset[perm] |
|
batch = {k: np.array(v) for k, v in batch.items()} |
|
batch = shard(batch) |
|
|
|
yield batch |
|
|
|
|
|
def eval_data_collator(dataset: Dataset, batch_size: int): |
|
"""Returns batches of size `batch_size` from `eval dataset`. Sharding handled by `pad_shard_unpad` in the eval loop.""" |
|
batch_idx = np.arange(len(dataset)) |
|
|
|
steps_per_epoch = math.ceil(len(dataset) / batch_size) |
|
batch_idx = np.array_split(batch_idx, steps_per_epoch) |
|
|
|
for idx in batch_idx: |
|
batch = dataset[idx] |
|
batch = {k: np.array(v) for k, v in batch.items()} |
|
|
|
yield batch |
|
|
|
|
|
def main(): |
|
|
|
|
|
|
|
|
|
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) |
|
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): |
|
|
|
|
|
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) |
|
else: |
|
model_args, data_args, training_args = parser.parse_args_into_dataclasses() |
|
|
|
|
|
|
|
send_example_telemetry("run_ner", model_args, data_args, framework="flax") |
|
|
|
|
|
logging.basicConfig( |
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
|
datefmt="%m/%d/%Y %H:%M:%S", |
|
level=logging.INFO, |
|
) |
|
|
|
logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) |
|
if jax.process_index() == 0: |
|
datasets.utils.logging.set_verbosity_warning() |
|
transformers.utils.logging.set_verbosity_info() |
|
else: |
|
datasets.utils.logging.set_verbosity_error() |
|
transformers.utils.logging.set_verbosity_error() |
|
|
|
|
|
if training_args.push_to_hub: |
|
if training_args.hub_model_id is None: |
|
repo_name = get_full_repo_name( |
|
Path(training_args.output_dir).absolute().name, token=training_args.hub_token |
|
) |
|
else: |
|
repo_name = training_args.hub_model_id |
|
create_repo(repo_name, exist_ok=True, token=training_args.hub_token) |
|
repo = Repository(training_args.output_dir, clone_from=repo_name, token=training_args.hub_token) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if data_args.dataset_name is not None: |
|
|
|
raw_datasets = load_dataset( |
|
data_args.dataset_name, |
|
data_args.dataset_config_name, |
|
cache_dir=model_args.cache_dir, |
|
use_auth_token=True if model_args.use_auth_token else None, |
|
) |
|
else: |
|
|
|
data_files = {} |
|
if data_args.train_file is not None: |
|
data_files["train"] = data_args.train_file |
|
if data_args.validation_file is not None: |
|
data_files["validation"] = data_args.validation_file |
|
extension = (data_args.train_file if data_args.train_file is not None else data_args.valid_file).split(".")[-1] |
|
raw_datasets = load_dataset( |
|
extension, |
|
data_files=data_files, |
|
cache_dir=model_args.cache_dir, |
|
use_auth_token=True if model_args.use_auth_token else None, |
|
) |
|
|
|
|
|
|
|
if raw_datasets["train"] is not None: |
|
column_names = raw_datasets["train"].column_names |
|
features = raw_datasets["train"].features |
|
else: |
|
column_names = raw_datasets["validation"].column_names |
|
features = raw_datasets["validation"].features |
|
|
|
if data_args.text_column_name is not None: |
|
text_column_name = data_args.text_column_name |
|
elif "tokens" in column_names: |
|
text_column_name = "tokens" |
|
else: |
|
text_column_name = column_names[0] |
|
|
|
if data_args.label_column_name is not None: |
|
label_column_name = data_args.label_column_name |
|
elif f"{data_args.task_name}_tags" in column_names: |
|
label_column_name = f"{data_args.task_name}_tags" |
|
else: |
|
label_column_name = column_names[1] |
|
|
|
|
|
|
|
def get_label_list(labels): |
|
unique_labels = set() |
|
for label in labels: |
|
unique_labels = unique_labels | set(label) |
|
label_list = list(unique_labels) |
|
label_list.sort() |
|
return label_list |
|
|
|
if isinstance(features[label_column_name].feature, ClassLabel): |
|
label_list = features[label_column_name].feature.names |
|
|
|
label_to_id = {i: i for i in range(len(label_list))} |
|
else: |
|
label_list = get_label_list(raw_datasets["train"][label_column_name]) |
|
label_to_id = {l: i for i, l in enumerate(label_list)} |
|
num_labels = len(label_list) |
|
|
|
|
|
config = AutoConfig.from_pretrained( |
|
model_args.config_name if model_args.config_name else model_args.model_name_or_path, |
|
num_labels=num_labels, |
|
label2id=label_to_id, |
|
id2label={i: l for l, i in label_to_id.items()}, |
|
finetuning_task=data_args.task_name, |
|
cache_dir=model_args.cache_dir, |
|
revision=model_args.model_revision, |
|
use_auth_token=True if model_args.use_auth_token else None, |
|
) |
|
tokenizer_name_or_path = model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path |
|
if config.model_type in {"gpt2", "roberta"}: |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
tokenizer_name_or_path, |
|
cache_dir=model_args.cache_dir, |
|
revision=model_args.model_revision, |
|
use_auth_token=True if model_args.use_auth_token else None, |
|
add_prefix_space=True, |
|
) |
|
else: |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
tokenizer_name_or_path, |
|
cache_dir=model_args.cache_dir, |
|
revision=model_args.model_revision, |
|
use_auth_token=True if model_args.use_auth_token else None, |
|
) |
|
model = FlaxAutoModelForTokenClassification.from_pretrained( |
|
model_args.model_name_or_path, |
|
config=config, |
|
cache_dir=model_args.cache_dir, |
|
revision=model_args.model_revision, |
|
use_auth_token=True if model_args.use_auth_token else None, |
|
) |
|
|
|
|
|
|
|
def tokenize_and_align_labels(examples): |
|
tokenized_inputs = tokenizer( |
|
examples[text_column_name], |
|
max_length=data_args.max_seq_length, |
|
padding="max_length", |
|
truncation=True, |
|
|
|
is_split_into_words=True, |
|
) |
|
|
|
labels = [] |
|
|
|
for i, label in enumerate(examples[label_column_name]): |
|
word_ids = tokenized_inputs.word_ids(batch_index=i) |
|
previous_word_idx = None |
|
label_ids = [] |
|
for word_idx in word_ids: |
|
|
|
|
|
if word_idx is None: |
|
label_ids.append(-100) |
|
|
|
elif word_idx != previous_word_idx: |
|
label_ids.append(label_to_id[label[word_idx]]) |
|
|
|
|
|
else: |
|
label_ids.append(label_to_id[label[word_idx]] if data_args.label_all_tokens else -100) |
|
previous_word_idx = word_idx |
|
|
|
labels.append(label_ids) |
|
tokenized_inputs["labels"] = labels |
|
return tokenized_inputs |
|
|
|
processed_raw_datasets = raw_datasets.map( |
|
tokenize_and_align_labels, |
|
batched=True, |
|
num_proc=data_args.preprocessing_num_workers, |
|
load_from_cache_file=not data_args.overwrite_cache, |
|
remove_columns=raw_datasets["train"].column_names, |
|
desc="Running tokenizer on dataset", |
|
) |
|
|
|
train_dataset = processed_raw_datasets["train"] |
|
eval_dataset = processed_raw_datasets["validation"] |
|
|
|
|
|
for index in random.sample(range(len(train_dataset)), 3): |
|
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") |
|
|
|
|
|
has_tensorboard = is_tensorboard_available() |
|
if has_tensorboard and jax.process_index() == 0: |
|
try: |
|
from flax.metrics.tensorboard import SummaryWriter |
|
|
|
summary_writer = SummaryWriter(training_args.output_dir) |
|
summary_writer.hparams({**training_args.to_dict(), **vars(model_args), **vars(data_args)}) |
|
except ImportError as ie: |
|
has_tensorboard = False |
|
logger.warning( |
|
f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" |
|
) |
|
else: |
|
logger.warning( |
|
"Unable to display metrics through TensorBoard because the package is not installed: " |
|
"Please run pip install tensorboard to enable." |
|
) |
|
|
|
def write_train_metric(summary_writer, train_metrics, train_time, step): |
|
summary_writer.scalar("train_time", train_time, step) |
|
|
|
train_metrics = get_metrics(train_metrics) |
|
for key, vals in train_metrics.items(): |
|
tag = f"train_{key}" |
|
for i, val in enumerate(vals): |
|
summary_writer.scalar(tag, val, step - len(vals) + i + 1) |
|
|
|
def write_eval_metric(summary_writer, eval_metrics, step): |
|
for metric_name, value in eval_metrics.items(): |
|
summary_writer.scalar(f"eval_{metric_name}", value, step) |
|
|
|
num_epochs = int(training_args.num_train_epochs) |
|
rng = jax.random.PRNGKey(training_args.seed) |
|
dropout_rngs = jax.random.split(rng, jax.local_device_count()) |
|
|
|
train_batch_size = training_args.per_device_train_batch_size * jax.local_device_count() |
|
per_device_eval_batch_size = int(training_args.per_device_eval_batch_size) |
|
eval_batch_size = training_args.per_device_eval_batch_size * jax.local_device_count() |
|
|
|
learning_rate_fn = create_learning_rate_fn( |
|
len(train_dataset), |
|
train_batch_size, |
|
training_args.num_train_epochs, |
|
training_args.warmup_steps, |
|
training_args.learning_rate, |
|
) |
|
|
|
state = create_train_state(model, learning_rate_fn, num_labels=num_labels, training_args=training_args) |
|
|
|
|
|
def train_step( |
|
state: train_state.TrainState, batch: Dict[str, Array], dropout_rng: PRNGKey |
|
) -> Tuple[train_state.TrainState, float]: |
|
"""Trains model with an optimizer (both in `state`) on `batch`, returning a pair `(new_state, loss)`.""" |
|
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng) |
|
targets = batch.pop("labels") |
|
|
|
def loss_fn(params): |
|
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] |
|
loss = state.loss_fn(logits, targets) |
|
return loss |
|
|
|
grad_fn = jax.value_and_grad(loss_fn) |
|
loss, grad = grad_fn(state.params) |
|
grad = jax.lax.pmean(grad, "batch") |
|
new_state = state.apply_gradients(grads=grad) |
|
metrics = jax.lax.pmean({"loss": loss, "learning_rate": learning_rate_fn(state.step)}, axis_name="batch") |
|
return new_state, metrics, new_dropout_rng |
|
|
|
p_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,)) |
|
|
|
def eval_step(state, batch): |
|
logits = state.apply_fn(**batch, params=state.params, train=False)[0] |
|
return state.logits_fn(logits) |
|
|
|
p_eval_step = jax.pmap(eval_step, axis_name="batch") |
|
|
|
metric = evaluate.load("seqeval") |
|
|
|
def get_labels(y_pred, y_true): |
|
|
|
|
|
|
|
true_predictions = [ |
|
[label_list[p] for (p, l) in zip(pred, gold_label) if l != -100] |
|
for pred, gold_label in zip(y_pred, y_true) |
|
] |
|
true_labels = [ |
|
[label_list[l] for (p, l) in zip(pred, gold_label) if l != -100] |
|
for pred, gold_label in zip(y_pred, y_true) |
|
] |
|
return true_predictions, true_labels |
|
|
|
def compute_metrics(): |
|
results = metric.compute() |
|
if data_args.return_entity_level_metrics: |
|
|
|
final_results = {} |
|
for key, value in results.items(): |
|
if isinstance(value, dict): |
|
for n, v in value.items(): |
|
final_results[f"{key}_{n}"] = v |
|
else: |
|
final_results[key] = value |
|
return final_results |
|
else: |
|
return { |
|
"precision": results["overall_precision"], |
|
"recall": results["overall_recall"], |
|
"f1": results["overall_f1"], |
|
"accuracy": results["overall_accuracy"], |
|
} |
|
|
|
logger.info(f"===== Starting training ({num_epochs} epochs) =====") |
|
train_time = 0 |
|
|
|
|
|
state = replicate(state) |
|
|
|
train_time = 0 |
|
step_per_epoch = len(train_dataset) // train_batch_size |
|
total_steps = step_per_epoch * num_epochs |
|
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) |
|
for epoch in epochs: |
|
train_start = time.time() |
|
train_metrics = [] |
|
|
|
|
|
rng, input_rng = jax.random.split(rng) |
|
|
|
|
|
for step, batch in enumerate( |
|
tqdm( |
|
train_data_collator(input_rng, train_dataset, train_batch_size), |
|
total=step_per_epoch, |
|
desc="Training...", |
|
position=1, |
|
) |
|
): |
|
state, train_metric, dropout_rngs = p_train_step(state, batch, dropout_rngs) |
|
train_metrics.append(train_metric) |
|
|
|
cur_step = (epoch * step_per_epoch) + (step + 1) |
|
|
|
if cur_step % training_args.logging_steps == 0 and cur_step > 0: |
|
|
|
train_metric = unreplicate(train_metric) |
|
train_time += time.time() - train_start |
|
if has_tensorboard and jax.process_index() == 0: |
|
write_train_metric(summary_writer, train_metrics, train_time, cur_step) |
|
|
|
epochs.write( |
|
f"Step... ({cur_step}/{total_steps} | Training Loss: {train_metric['loss']}, Learning Rate:" |
|
f" {train_metric['learning_rate']})" |
|
) |
|
|
|
train_metrics = [] |
|
|
|
if cur_step % training_args.eval_steps == 0 and cur_step > 0: |
|
eval_metrics = {} |
|
|
|
for batch in tqdm( |
|
eval_data_collator(eval_dataset, eval_batch_size), |
|
total=math.ceil(len(eval_dataset) / eval_batch_size), |
|
desc="Evaluating ...", |
|
position=2, |
|
): |
|
labels = batch.pop("labels") |
|
predictions = pad_shard_unpad(p_eval_step)( |
|
state, batch, min_device_batch=per_device_eval_batch_size |
|
) |
|
predictions = np.array(predictions) |
|
labels[np.array(chain(*batch["attention_mask"])) == 0] = -100 |
|
preds, refs = get_labels(predictions, labels) |
|
metric.add_batch( |
|
predictions=preds, |
|
references=refs, |
|
) |
|
|
|
eval_metrics = compute_metrics() |
|
|
|
if data_args.return_entity_level_metrics: |
|
logger.info(f"Step... ({cur_step}/{total_steps} | Validation metrics: {eval_metrics}") |
|
else: |
|
logger.info( |
|
f"Step... ({cur_step}/{total_steps} | Validation f1: {eval_metrics['f1']}, Validation Acc:" |
|
f" {eval_metrics['accuracy']})" |
|
) |
|
|
|
if has_tensorboard and jax.process_index() == 0: |
|
write_eval_metric(summary_writer, eval_metrics, cur_step) |
|
|
|
if (cur_step % training_args.save_steps == 0 and cur_step > 0) or (cur_step == total_steps): |
|
|
|
if jax.process_index() == 0: |
|
params = jax.device_get(unreplicate(state.params)) |
|
model.save_pretrained(training_args.output_dir, params=params) |
|
tokenizer.save_pretrained(training_args.output_dir) |
|
if training_args.push_to_hub: |
|
repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False) |
|
epochs.desc = f"Epoch ... {epoch + 1}/{num_epochs}" |
|
|
|
|
|
if training_args.do_eval: |
|
eval_metrics = {} |
|
eval_loader = eval_data_collator(eval_dataset, eval_batch_size) |
|
for batch in tqdm(eval_loader, total=len(eval_dataset) // eval_batch_size, desc="Evaluating ...", position=2): |
|
labels = batch.pop("labels") |
|
predictions = pad_shard_unpad(p_eval_step)(state, batch, min_device_batch=per_device_eval_batch_size) |
|
predictions = np.array(predictions) |
|
labels[np.array(chain(*batch["attention_mask"])) == 0] = -100 |
|
preds, refs = get_labels(predictions, labels) |
|
metric.add_batch(predictions=preds, references=refs) |
|
|
|
eval_metrics = compute_metrics() |
|
|
|
if jax.process_index() == 0: |
|
eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()} |
|
path = os.path.join(training_args.output_dir, "eval_results.json") |
|
with open(path, "w") as f: |
|
json.dump(eval_metrics, f, indent=4, sort_keys=True) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|