File size: 5,836 Bytes
47c0211
 
 
 
 
 
 
 
 
 
 
 
245d478
47c0211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245d478
47c0211
 
 
 
 
 
 
 
 
 
 
 
 
 
245d478
47c0211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245d478
47c0211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38c2a22
47c0211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245d478
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import os
import torch
import datasets
import numpy as np
import pandas as pd

os.environ["TOKENIZERS_PARALLELISM"] = "false"

from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from transformers import AutoTokenizer, logging

from onnxruntime import InferenceSession
from scipy.special import softmax

from weakly_supervised_parser.model.data_module_loader import DataModule
from weakly_supervised_parser.model.span_classifier import LightningModel


# Disable model checkpoint warnings
logging.set_verbosity_error()


class InsideOutsideStringClassifier:
    def __init__(self, model_name_or_path: str, num_labels: int = 2, max_seq_length: int = 256):

        self.model_name_or_path = model_name_or_path
        self.num_labels = num_labels
        self.max_seq_length = max_seq_length

    def fit(
        self,
        train_df: pd.DataFrame,
        eval_df: pd.DataFrame,
        outputdir: str,
        filename: str,
        devices: int = 1,
        enable_progress_bar: bool = True,
        enable_model_summary: bool = False,
        enable_checkpointing: bool = True,
        logger: bool = False,
        accelerator: str = "auto",
        train_batch_size: int = 32,
        eval_batch_size: int = 32,
        learning_rate: float = 5e-6,
        max_epochs: int = 10,
        dataloader_num_workers: int = 16,
        seed: int = 42,
    ):
        
        data_module = DataModule(
            model_name_or_path=self.model_name_or_path,
            train_df=train_df,
            eval_df=eval_df,
            test_df=None,
            max_seq_length=self.max_seq_length,
            train_batch_size=train_batch_size,
            eval_batch_size=eval_batch_size,
            num_workers=dataloader_num_workers,
        )

        model = LightningModel(
            model_name_or_path=self.model_name_or_path,
            lr=learning_rate,
            num_labels=self.num_labels,
            train_batch_size=train_batch_size,
            eval_batch_size=eval_batch_size,
        )

        seed_everything(seed, workers=True)

        callbacks = []
        callbacks.append(EarlyStopping(monitor="val_loss", patience=2, mode="min", check_finite=True))
        callbacks.append(ModelCheckpoint(monitor="val_loss", dirpath=outputdir, filename=filename, save_top_k=1, save_weights_only=True, mode="min"))

        trainer = Trainer(
            accelerator=accelerator,
            devices=devices,
            max_epochs=max_epochs,
            callbacks=callbacks,
            enable_progress_bar=enable_progress_bar,
            enable_model_summary=enable_model_summary,
            enable_checkpointing=enable_checkpointing,
            logger=logger,
        )
        trainer.fit(model, data_module)
        trainer.validate(model, data_module.val_dataloader())

        train_batch = next(iter(data_module.train_dataloader()))

        model.to_onnx(
            file_path=f"{outputdir}/{filename}.onnx",
            input_sample=(train_batch["input_ids"].cuda(), train_batch["attention_mask"].cuda()),
            export_params=True,
            opset_version=11,
            input_names=["input", "attention_mask"],
            output_names=["output"],
            dynamic_axes={"input": {0: "batch_size"}, "attention_mask": {0: "batch_size"}, "output": {0: "batch_size"}},
        )

    def load_model(self, pre_trained_model_path):
        self.model = InferenceSession(pre_trained_model_path, providers=["CPUExecutionProvider"]) #providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)

    def preprocess_function(self, data):
        features = self.tokenizer(
            data["sentence"], max_length=self.max_seq_length, padding="max_length", add_special_tokens=True, truncation=True, return_tensors="np"
        )
        return features

    def process_spans(self, spans, scale_axis):
        spans_dataset = datasets.Dataset.from_pandas(spans)
        processed = spans_dataset.map(self.preprocess_function, batched=True, batch_size=None)
        inputs = {"input": processed["input_ids"], "attention_mask": processed["attention_mask"]}
        with torch.no_grad():
            return softmax(self.model.run(None, inputs)[0], axis=scale_axis)

    def predict_proba(self, spans, scale_axis, predict_batch_size):
        if spans.shape[0] > predict_batch_size:
            output = []
            span_batches = np.array_split(spans, spans.shape[0] // predict_batch_size)
            for span_batch in span_batches:
                output.extend(self.process_spans(span_batch, scale_axis))
            return np.vstack(output)
        else:
            return self.process_spans(spans, scale_axis)

    def predict(self, spans):
        return self.predict_proba(spans).argmax(axis=1)

    
class InsideOutsideStringPredictor:

    def __init__(self, model_name_or_path, max_seq_length, pre_trained_model_path, num_workers=32):
        self.model_name_or_path = model_name_or_path
        self.pre_trained_model_path = pre_trained_model_path
        self.max_seq_length = max_seq_length
        self.num_workers = num_workers 

    def predict_proba(self, test_df):
        test_dataloader = data_module = DataModule(
                                        model_name_or_path=self.model_name_or_path,
                                        train_df=None,
                                        eval_df=None,
                                        test_df=test_df,
                                        max_seq_length=self.max_seq_length,
                                        num_workers=self.num_workers,
                                    )

        return trainer.predict(model, dataloaders=test_dataloader)