|
import os |
|
from pathlib import Path |
|
|
|
import pandas as pd |
|
import pytorch_lightning as pl |
|
import torch |
|
import wandb |
|
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint |
|
from pytorch_lightning.loggers import WandbLogger |
|
from sklearn.model_selection import train_test_split |
|
from torch import Tensor |
|
|
|
from src.data_reader import load_train_test_df |
|
from src.model_finetuning.config import CONFIG |
|
from src.model_finetuning.dataloader import ClassificationDataloader |
|
from src.model_finetuning.losses import MCRMSELoss |
|
from src.model_finetuning.model import BertLightningModel |
|
from src.utils import get_target_columns, seed_everything |
|
|
|
seed_everything() |
|
|
|
|
|
def train( |
|
config: dict, |
|
train_df: pd.DataFrame, |
|
val_df: pd.DataFrame, |
|
verbose: bool = False |
|
) -> BertLightningModel: |
|
log_dir = Path("logs/") |
|
log_dir.mkdir(exist_ok=True) |
|
|
|
model = BertLightningModel(config) |
|
|
|
dataloader = ClassificationDataloader( |
|
tokenizer=model.tokenizer, |
|
train_df=train_df, |
|
val_df=val_df, |
|
config=config |
|
) |
|
|
|
logger = WandbLogger( |
|
project="automated_essay_evaluator", |
|
entity="parmezano", |
|
config=config, |
|
log_model='all', |
|
group=os.environ['GROUP_NAME'], |
|
) |
|
|
|
if not wandb.run: |
|
raise TypeError |
|
|
|
wandb.run.log_code(".") |
|
wandb.watch(model, criterion=MCRMSELoss()) |
|
|
|
lr_monitor = LearningRateMonitor(logging_interval='step') |
|
|
|
model_checkpoint = ModelCheckpoint( |
|
dirpath=str(log_dir.resolve()), |
|
monitor='val/epoch_loss', |
|
verbose=True, |
|
mode='min', |
|
auto_insert_metric_name=True, |
|
save_weights_only=True, |
|
) |
|
|
|
trainer = pl.Trainer( |
|
logger=logger, |
|
callbacks=[lr_monitor, model_checkpoint], |
|
accelerator=config['accelerator'], |
|
max_epochs=config['max_epochs'], |
|
accumulate_grad_batches=config['accumulate_grad_batches'], |
|
gradient_clip_val=config['gradient_clip_val'], |
|
precision=config['precision'], |
|
|
|
|
|
|
|
|
|
|
|
|
|
) |
|
train_dataloader = dataloader.train_dataloader() |
|
val_dataloader = dataloader.val_dataloader() |
|
trainer.fit(model, train_dataloader, val_dataloader) |
|
wandb.finish() |
|
|
|
if verbose: |
|
print(f"Best class metric: {model.best_metric}, class scores: {model.class_metric}") |
|
|
|
model = BertLightningModel.load_from_checkpoint(model_checkpoint.best_model_path, config=config) |
|
return model |
|
|
|
|
|
def predict(config: dict, model: BertLightningModel, df: pd.DataFrame) -> pd.DataFrame: |
|
trainer = pl.Trainer( |
|
accelerator=config['accelerator'], |
|
gradient_clip_val=config['gradient_clip_val'], |
|
precision=config['precision'], |
|
) |
|
|
|
predict_dataloader = ClassificationDataloader( |
|
tokenizer=model.tokenizer, |
|
train_df=df, |
|
val_df=df, |
|
config=config |
|
).val_dataloader() |
|
|
|
validation_predictions = trainer.predict(model, predict_dataloader, return_predictions=True) |
|
if not validation_predictions: |
|
raise TypeError |
|
validation_predictions_conv = [Tensor(_) for _ in validation_predictions] |
|
validation_predictions_tensor = torch.vstack(validation_predictions_conv) |
|
validation_predictions_df = pd.DataFrame({ |
|
column: validation_predictions_tensor[:, ii] for ii, column in enumerate(get_target_columns()) |
|
}, index=df.index) |
|
validation_predictions_df['text_id'] = df['text_id'] |
|
|
|
return validation_predictions |
|
|
|
|
|
def main(): |
|
train_data, _ = load_train_test_df() |
|
train_df, val_df = train_test_split(train_data, train_size=CONFIG['train_size']) |
|
train(config=CONFIG, train_df=train_df, val_df=val_df, verbose=True) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|