Spaces:
Build error
Build error
update trainer
Browse files
weakly_supervised_parser/model/trainer.py
CHANGED
@@ -10,7 +10,7 @@ from pytorch_lightning import Trainer, seed_everything
|
|
10 |
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
11 |
from transformers import AutoTokenizer, logging
|
12 |
|
13 |
-
from onnxruntime import InferenceSession
|
14 |
from scipy.special import softmax
|
15 |
|
16 |
from weakly_supervised_parser.model.data_module_loader import DataModule
|
@@ -98,7 +98,10 @@ class InsideOutsideStringClassifier:
|
|
98 |
)
|
99 |
|
100 |
def load_model(self, pre_trained_model_path):
|
101 |
-
|
|
|
|
|
|
|
102 |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)
|
103 |
|
104 |
def preprocess_function(self, data):
|
|
|
10 |
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
11 |
from transformers import AutoTokenizer, logging
|
12 |
|
13 |
+
from onnxruntime import InferenceSession, SessionOptions
|
14 |
from scipy.special import softmax
|
15 |
|
16 |
from weakly_supervised_parser.model.data_module_loader import DataModule
|
|
|
98 |
)
|
99 |
|
100 |
def load_model(self, pre_trained_model_path):
|
101 |
+
options = SessionOptions()
|
102 |
+
options.intra_op_num_threads = 1
|
103 |
+
options.inter_op_num_threads = 1
|
104 |
+
self.model = InferenceSession(pre_trained_model_path, options, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
|
105 |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)
|
106 |
|
107 |
def preprocess_function(self, data):
|