Spaces:
Runtime error
Runtime error
import argparse | |
import logging | |
import sys | |
import time | |
import tensorflow as tf | |
from datasets import load_dataset | |
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
# Hyperparameters sent by the client are passed as command-line arguments to the script. | |
parser.add_argument("--epochs", type=int, default=1) | |
parser.add_argument("--per_device_train_batch_size", type=int, default=16) | |
parser.add_argument("--per_device_eval_batch_size", type=int, default=8) | |
parser.add_argument("--model_name_or_path", type=str) | |
parser.add_argument("--learning_rate", type=str, default=5e-5) | |
parser.add_argument("--do_train", type=bool, default=True) | |
parser.add_argument("--do_eval", type=bool, default=True) | |
parser.add_argument("--output_dir", type=str) | |
args, _ = parser.parse_known_args() | |
# overwrite batch size until we have tf_glue.py | |
args.per_device_train_batch_size = 16 | |
args.per_device_eval_batch_size = 16 | |
# Set up logging | |
logger = logging.getLogger(__name__) | |
logging.basicConfig( | |
level=logging.getLevelName("INFO"), | |
handlers=[logging.StreamHandler(sys.stdout)], | |
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | |
) | |
# Load model and tokenizer | |
model = TFAutoModelForSequenceClassification.from_pretrained(args.model_name_or_path) | |
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) | |
# Load dataset | |
train_dataset, test_dataset = load_dataset("imdb", split=["train", "test"]) | |
train_dataset = train_dataset.shuffle().select(range(5000)) # smaller the size for train dataset to 5k | |
test_dataset = test_dataset.shuffle().select(range(500)) # smaller the size for test dataset to 500 | |
# Preprocess train dataset | |
train_dataset = train_dataset.map( | |
lambda e: tokenizer(e["text"], truncation=True, padding="max_length"), batched=True | |
) | |
train_dataset.set_format(type="tensorflow", columns=["input_ids", "attention_mask", "label"]) | |
train_features = { | |
x: train_dataset[x].to_tensor(default_value=0, shape=[None, tokenizer.model_max_length]) | |
for x in ["input_ids", "attention_mask"] | |
} | |
tf_train_dataset = tf.data.Dataset.from_tensor_slices((train_features, train_dataset["label"])).batch( | |
args.per_device_train_batch_size | |
) | |
# Preprocess test dataset | |
test_dataset = test_dataset.map( | |
lambda e: tokenizer(e["text"], truncation=True, padding="max_length"), batched=True | |
) | |
test_dataset.set_format(type="tensorflow", columns=["input_ids", "attention_mask", "label"]) | |
test_features = { | |
x: test_dataset[x].to_tensor(default_value=0, shape=[None, tokenizer.model_max_length]) | |
for x in ["input_ids", "attention_mask"] | |
} | |
tf_test_dataset = tf.data.Dataset.from_tensor_slices((test_features, test_dataset["label"])).batch( | |
args.per_device_eval_batch_size | |
) | |
# fine optimizer and loss | |
optimizer = tf.keras.optimizers.Adam(learning_rate=args.learning_rate) | |
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) | |
metrics = [tf.keras.metrics.SparseCategoricalAccuracy()] | |
model.compile(optimizer=optimizer, loss=loss, metrics=metrics) | |
start_train_time = time.time() | |
train_results = model.fit(tf_train_dataset, epochs=args.epochs, batch_size=args.per_device_train_batch_size) | |
end_train_time = time.time() - start_train_time | |
logger.info("*** Train ***") | |
logger.info(f"train_runtime = {end_train_time}") | |
for key, value in train_results.history.items(): | |
logger.info(f" {key} = {value}") | |