nickil commited on
Commit
093eca9
·
1 Parent(s): 7c79391

update trainer

Browse files
weakly_supervised_parser/model/trainer.py CHANGED
@@ -10,7 +10,7 @@ from pytorch_lightning import Trainer, seed_everything
10
  from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
11
  from transformers import AutoTokenizer, logging
12
 
13
- from onnxruntime import InferenceSession
14
  from scipy.special import softmax
15
 
16
  from weakly_supervised_parser.model.data_module_loader import DataModule
@@ -98,7 +98,10 @@ class InsideOutsideStringClassifier:
98
  )
99
 
100
  def load_model(self, pre_trained_model_path):
101
- self.model = InferenceSession(pre_trained_model_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
 
 
 
102
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)
103
 
104
  def preprocess_function(self, data):
 
10
  from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
11
  from transformers import AutoTokenizer, logging
12
 
13
+ from onnxruntime import InferenceSession, SessionOptions
14
  from scipy.special import softmax
15
 
16
  from weakly_supervised_parser.model.data_module_loader import DataModule
 
98
  )
99
 
100
  def load_model(self, pre_trained_model_path):
101
+ options = SessionOptions()
102
+ options.intra_op_num_threads = 1
103
+ options.inter_op_num_threads = 1
104
+ self.model = InferenceSession(pre_trained_model_path, options, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
105
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)
106
 
107
  def preprocess_function(self, data):