GitHub Action
refs/heads/ci-cd/hugging-face
8b414b0
raw
history blame
3.87 kB
import os
from pathlib import Path
import pandas as pd
import pytorch_lightning as pl
import torch
import wandb
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger # noqa
from sklearn.model_selection import train_test_split
from torch import Tensor
from src.data_reader import load_train_test_df
from src.model_finetuning.config import CONFIG
from src.model_finetuning.dataloader import ClassificationDataloader
from src.model_finetuning.losses import MCRMSELoss
from src.model_finetuning.model import BertLightningModel
from src.utils import get_target_columns, seed_everything
seed_everything()
def train(
config: dict,
train_df: pd.DataFrame,
val_df: pd.DataFrame,
verbose: bool = False
) -> BertLightningModel:
log_dir = Path("logs/")
log_dir.mkdir(exist_ok=True)
model = BertLightningModel(config)
dataloader = ClassificationDataloader(
tokenizer=model.tokenizer,
train_df=train_df,
val_df=val_df,
config=config
)
logger = WandbLogger(
project="automated_essay_evaluator",
entity="parmezano",
config=config,
log_model='all',
group=os.environ['GROUP_NAME'],
)
if not wandb.run:
raise TypeError
wandb.run.log_code(".")
wandb.watch(model, criterion=MCRMSELoss())
lr_monitor = LearningRateMonitor(logging_interval='step')
model_checkpoint = ModelCheckpoint(
dirpath=str(log_dir.resolve()),
monitor='val/epoch_loss',
verbose=True,
mode='min',
auto_insert_metric_name=True,
save_weights_only=True,
)
trainer = pl.Trainer(
logger=logger,
callbacks=[lr_monitor, model_checkpoint],
accelerator=config['accelerator'],
max_epochs=config['max_epochs'],
accumulate_grad_batches=config['accumulate_grad_batches'],
gradient_clip_val=config['gradient_clip_val'],
precision=config['precision'],
# this is for debug
# max_epochs=1,
# limit_train_batches=1,
# limit_val_batches=1,
# limit_predict_batches=1,
)
train_dataloader = dataloader.train_dataloader()
val_dataloader = dataloader.val_dataloader()
trainer.fit(model, train_dataloader, val_dataloader)
wandb.finish()
if verbose:
print(f"Best class metric: {model.best_metric}, class scores: {model.class_metric}")
model = BertLightningModel.load_from_checkpoint(model_checkpoint.best_model_path, config=config)
return model
def predict(config: dict, model: BertLightningModel, df: pd.DataFrame) -> pd.DataFrame:
trainer = pl.Trainer(
accelerator=config['accelerator'],
gradient_clip_val=config['gradient_clip_val'],
precision=config['precision'],
)
predict_dataloader = ClassificationDataloader(
tokenizer=model.tokenizer,
train_df=df,
val_df=df,
config=config
).val_dataloader()
validation_predictions = trainer.predict(model, predict_dataloader, return_predictions=True)
if not validation_predictions:
raise TypeError
validation_predictions_conv = [Tensor(_) for _ in validation_predictions]
validation_predictions_tensor = torch.vstack(validation_predictions_conv)
validation_predictions_df = pd.DataFrame({
column: validation_predictions_tensor[:, ii] for ii, column in enumerate(get_target_columns())
}, index=df.index)
validation_predictions_df['text_id'] = df['text_id']
return validation_predictions
def main():
train_data, _ = load_train_test_df()
train_df, val_df = train_test_split(train_data, train_size=CONFIG['train_size'])
train(config=CONFIG, train_df=train_df, val_df=val_df, verbose=True)
if __name__ == '__main__':
main()