Spaces:
Sleeping
Sleeping
""" | |
This mod fine-tunes a BERT model on the ACARIS dataset for comparison with ACARISMdl. | |
""" | |
import torch | |
from torch import nn | |
from torch.utils.data import Dataset, DataLoader | |
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification, TrainingArguments, Trainer, AdamW, EarlyStoppingCallback, PreTrainedModel, DistilBertModel | |
from transformers.modeling_outputs import SequenceClassifierOutput | |
from datasets import load_dataset, Dataset | |
import pandas as pd | |
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix, roc_auc_score | |
import huggingface_hub | |
import os | |
import random | |
import numpy as np | |
config = { | |
"mdl": "distilbert-base-uncased", | |
"epochs": 5, | |
"batchSize": 14, | |
"maxLen": 512, | |
"warmupSteps": 0.1, # proportion of total steps, NOT absolute | |
"weightDecay": 0.02, | |
"outputDir": "./output", | |
"earlyStopping": True, | |
"earlyStoppingPatience": 2, | |
"dropout": 0.1, | |
"initlr": 5e-5, | |
"epsilon": 1e-8 | |
} | |
#wandb.init(project="MarkIII_ACARIS", entity="simtoonia", config=config) | |
def lockSeed(seed): | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
if torch.cuda.is_available(): | |
torch.cuda.manual_seed_all(seed) | |
torch.backends.cudnn.deterministic = True | |
#0 disabled, as determinism is not guaranteed and lowers performance | |
#lockSeed(69) # setting a fixed seed for *some* reproducibility | |
class DistilBertForMulticlassSequenceClassification(DistilBertForSequenceClassification): | |
def __init__(self, config): | |
super().__init__(config) | |
def forward(self, input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None, labels=None, output_attentions=None, output_hidden_states=None, return_dict=None): | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
outputs = self.distilbert(input_ids, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict) | |
hidden_state = outputs[0] | |
pooled_output = hidden_state[:, 0] | |
pooled_output = self.pre_classifier(pooled_output) | |
pooled_output = nn.ReLU()(pooled_output) | |
pooled_output = self.dropout(pooled_output) | |
logits = self.classifier(pooled_output) | |
loss = None | |
if labels is not None: | |
lossFct = nn.CrossEntropyLoss() | |
loss = lossFct(logits.view(-1, self.num_labels), labels.view(-1)) | |
if not return_dict: | |
output = (logits,) + outputs[2:] | |
return ((loss,) + output) if loss is not None else output | |
return SequenceClassifierOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions) | |
class ACARISBERT: | |
def __init__(self, trainPath, valPath): | |
self.trainPath = trainPath | |
self.valPath = valPath | |
self.tokenizer = DistilBertTokenizerFast.from_pretrained(config["mdl"]) | |
self.model = DistilBertForMulticlassSequenceClassification.from_pretrained(config["mdl"], num_labels=3, id2label={0: "neg", 1: "neu", 2: "pos"}, label2id={"neg": 0, "neu": 1, "pos": 2}, dropout=config["dropout"], attention_dropout=config["dropout"]) | |
def read_data(self, path): | |
df = pd.read_csv(path, sep="|", usecols=["content", "sentiment"]) | |
return Dataset.from_pandas(df) | |
def tokenize_data(self, dataset): | |
sentMapping = {"pos": 2, "neg": 0, "neu": 1} | |
tokenized = dataset.map( | |
lambda x: { | |
**self.tokenizer(x["content"], truncation=True, padding="max_length", max_length=config["maxLen"]), | |
"labels": torch.tensor([sentMapping[sent] for sent in x["sentiment"]]) | |
}, | |
batched=True, | |
remove_columns=["content", "sentiment"] | |
) | |
tokenized.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"]) | |
return tokenized | |
def get_data_loaders(self, trainDS, valDS): | |
trainLoader = DataLoader(trainDS, batch_size=config["batchSize"], shuffle=False) | |
valLoader = DataLoader(valDS, batch_size=config["batchSize"], shuffle=False) | |
return trainLoader, valLoader | |
def compute_metrics(self, evalPred): | |
logits, labels = evalPred | |
preds = torch.argmax(torch.Tensor(logits), dim=1) | |
probs = torch.nn.functional.softmax(torch.Tensor(logits), dim=1) | |
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average=None) | |
accuracy = accuracy_score(labels, preds) | |
rocAUC = roc_auc_score(labels, probs, multi_class="ovr") | |
metrics = { | |
"accuracy": accuracy, | |
"roc_auc": rocAUC | |
} | |
metricNames = ["precision", "recall", "f1"] | |
labelNames = ["neg", "neu", "pos"] | |
for metricName, metricValue in zip(metricNames, [precision, recall, f1]): | |
for labelName, value in zip(labelNames, metricValue): | |
metrics[f"{metricName}_{labelName}"] = float(value) | |
return metrics | |
def train(self): | |
trainDS = self.tokenize_data(self.read_data(self.trainPath)) | |
valDS = self.tokenize_data(self.read_data(self.valPath)) | |
totalSteps = len(trainDS) // config["batchSize"] * config["epochs"] | |
warmupSteps = int(totalSteps * config["warmupSteps"]) | |
trainingArgs = TrainingArguments( | |
output_dir=config["outputDir"], | |
num_train_epochs=config["epochs"], | |
per_device_train_batch_size=config["batchSize"], | |
per_device_eval_batch_size=config["batchSize"], | |
warmup_steps=warmupSteps, | |
weight_decay=config["weightDecay"], | |
logging_dir="./logs", | |
logging_steps=100, | |
learning_rate=config["initlr"], | |
evaluation_strategy="epoch", | |
save_strategy="epoch", | |
load_best_model_at_end=True, | |
metric_for_best_model="accuracy", | |
save_total_limit=5, | |
adam_epsilon=config["epsilon"], | |
report_to="wandb", | |
fp16=True | |
) | |
trainer = Trainer( | |
model=self.model, | |
args=trainingArgs, | |
train_dataset=trainDS, | |
eval_dataset=valDS, | |
compute_metrics=self.compute_metrics, | |
callbacks=[EarlyStoppingCallback(early_stopping_patience=config["earlyStoppingPatience"])] | |
) | |
print(f"Number of parameters: {trainer.model.num_parameters()}") | |
print("Running eval ...") | |
trainer.evaluate() | |
print("Running training ...") | |
trainer.train() | |
print("Saving model ...") | |
trainer.save_model(config["outputDir"]) | |
if __name__ == "__main__": | |
acaris_bert = ACARISBERT("./datasets/train.csv", "./datasets/val.csv") | |
acaris_bert.train() | |
wandb.finish() | |