Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2021 The IDEA Authors. All rights reserved. | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from logging import basicConfig, setLogRecordFactory | |
import torch | |
from torch import nn | |
import json | |
from tqdm import tqdm | |
import os | |
import numpy as np | |
from transformers import ( | |
AutoTokenizer, | |
AutoModelForSequenceClassification, | |
BertTokenizer, | |
file_utils | |
) | |
import pytorch_lightning as pl | |
from pytorch_lightning.callbacks import ModelCheckpoint | |
from pytorch_lightning import trainer, loggers | |
from torch.utils.data import Dataset, DataLoader | |
from transformers.optimization import get_linear_schedule_with_warmup | |
from transformers import BertForPreTraining, BertForMaskedLM, BertModel | |
from transformers import BertConfig, BertForTokenClassification, BertPreTrainedModel | |
import transformers | |
import unicodedata | |
import re | |
import argparse | |
transformers.logging.set_verbosity_error() | |
# os.environ["CUDA_VISIBLE_DEVICES"] = '6' | |
def search(pattern, sequence): | |
n = len(pattern) | |
res = [] | |
for i in range(len(sequence)): | |
if sequence[i:i + n] == pattern: | |
res.append([i, i + n-1]) | |
return res | |
class UbertDataset(Dataset): | |
def __init__(self, data, tokenizer, args, used_mask=True): | |
super().__init__() | |
self.tokenizer = tokenizer | |
self.max_length = args.max_length | |
self.num_labels = args.num_labels | |
self.used_mask = used_mask | |
self.data = data | |
self.args = args | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, index): | |
return self.encode(self.data[index], self.used_mask) | |
def encode(self, item, used_mask=False): | |
input_ids1 = [] | |
attention_mask1 = [] | |
token_type_ids1 = [] | |
span_labels1 = [] | |
span_labels_masks1 = [] | |
input_ids0 = [] | |
attention_mask0 = [] | |
token_type_ids0 = [] | |
span_labels0 = [] | |
span_labels_masks0 = [] | |
subtask_type = item['subtask_type'] | |
for choice in item['choices']: | |
try: | |
texta = item['task_type'] + '[SEP]' + \ | |
subtask_type + '[SEP]' + choice['entity_type'] | |
textb = item['text'] | |
encode_dict = self.tokenizer.encode_plus(texta, textb, | |
max_length=self.max_length, | |
padding='max_length', | |
truncation='longest_first') | |
encode_sent = encode_dict['input_ids'] | |
encode_token_type_ids = encode_dict['token_type_ids'] | |
encode_attention_mask = encode_dict['attention_mask'] | |
span_label = np.zeros((self.max_length, self.max_length)) | |
span_label_mask = np.zeros( | |
(self.max_length, self.max_length))-10000 | |
if item['task_type'] == '分类任务': | |
span_label_mask[0, 0] = 0 | |
span_label[0, 0] = choice['label'] | |
else: | |
question_len = len(self.tokenizer.encode(texta)) | |
span_label_mask[question_len:, question_len:] = np.zeros( | |
(self.max_length-question_len, self.max_length-question_len)) | |
for entity in choice['entity_list']: | |
# if 'entity_name' in entity.keys() and entity['entity_name']=='': | |
# continue | |
entity_idx_list = entity['entity_idx'] | |
if entity_idx_list == []: | |
continue | |
for entity_idx in entity_idx_list: | |
if entity_idx == []: | |
continue | |
start_idx_text = item['text'][:entity_idx[0]] | |
start_idx_text_encode = self.tokenizer.encode( | |
start_idx_text, add_special_tokens=False) | |
start_idx = question_len + \ | |
len(start_idx_text_encode) | |
end_idx_text = item['text'][:entity_idx[1]+1] | |
end_idx_text_encode = self.tokenizer.encode( | |
end_idx_text, add_special_tokens=False) | |
end_idx = question_len + \ | |
len(end_idx_text_encode) - 1 | |
if start_idx < self.max_length and end_idx < self.max_length: | |
span_label[start_idx, end_idx] = 1 | |
if np.sum(span_label) < 1: | |
input_ids0.append(encode_sent) | |
attention_mask0.append(encode_attention_mask) | |
token_type_ids0.append(encode_token_type_ids) | |
span_labels0.append(span_label) | |
span_labels_masks0.append(span_label_mask) | |
else: | |
input_ids1.append(encode_sent) | |
attention_mask1.append(encode_attention_mask) | |
token_type_ids1.append(encode_token_type_ids) | |
span_labels1.append(span_label) | |
span_labels_masks1.append(span_label_mask) | |
except: | |
print(item) | |
print(texta) | |
print(textb) | |
randomize = np.arange(len(input_ids0)) | |
np.random.shuffle(randomize) | |
cur = 0 | |
count = len(input_ids1) | |
while count < self.args.num_labels: | |
if cur < len(randomize): | |
input_ids1.append(input_ids0[randomize[cur]]) | |
attention_mask1.append(attention_mask0[randomize[cur]]) | |
token_type_ids1.append(token_type_ids0[randomize[cur]]) | |
span_labels1.append(span_labels0[randomize[cur]]) | |
span_labels_masks1.append(span_labels_masks0[randomize[cur]]) | |
cur += 1 | |
count += 1 | |
while len(input_ids1) < self.args.num_labels: | |
input_ids1.append([0]*self.max_length) | |
attention_mask1.append([0]*self.max_length) | |
token_type_ids1.append([0]*self.max_length) | |
span_labels1.append(np.zeros((self.max_length, self.max_length))) | |
span_labels_masks1.append( | |
np.zeros((self.max_length, self.max_length))-10000) | |
input_ids = input_ids1[:self.args.num_labels] | |
attention_mask = attention_mask1[:self.args.num_labels] | |
token_type_ids = token_type_ids1[:self.args.num_labels] | |
span_labels = span_labels1[:self.args.num_labels] | |
span_labels_masks = span_labels_masks1[:self.args.num_labels] | |
span_labels = np.array(span_labels) | |
span_labels_masks = np.array(span_labels_masks) | |
if np.sum(span_labels) < 1: | |
span_labels[-1, -1, -1] = 1 | |
span_labels_masks[-1, -1, -1] = 10000 | |
sample = { | |
"input_ids": torch.tensor(input_ids).long(), | |
"token_type_ids": torch.tensor(token_type_ids).long(), | |
"attention_mask": torch.tensor(attention_mask).float(), | |
"span_labels": torch.tensor(span_labels).float(), | |
"span_labels_mask": torch.tensor(span_labels_masks).float() | |
} | |
return sample | |
class UbertDataModel(pl.LightningDataModule): | |
def add_data_specific_args(parent_args): | |
parser = parent_args.add_argument_group('TASK NAME DataModel') | |
parser.add_argument('--num_workers', default=8, type=int) | |
parser.add_argument('--batchsize', default=8, type=int) | |
parser.add_argument('--max_length', default=128, type=int) | |
return parent_args | |
def __init__(self, train_data, val_data, tokenizer, args): | |
super().__init__() | |
self.batchsize = args.batchsize | |
self.train_data = UbertDataset(train_data, tokenizer, args, True) | |
self.valid_data = UbertDataset(val_data, tokenizer, args, False) | |
def train_dataloader(self): | |
return DataLoader(self.train_data, shuffle=True, batch_size=self.batchsize, pin_memory=False) | |
def val_dataloader(self): | |
return DataLoader(self.valid_data, shuffle=False, batch_size=self.batchsize, pin_memory=False) | |
class biaffine(nn.Module): | |
def __init__(self, in_size, out_size, bias_x=True, bias_y=True): | |
super().__init__() | |
self.bias_x = bias_x | |
self.bias_y = bias_y | |
self.out_size = out_size | |
self.U = torch.nn.Parameter(torch.zeros( | |
in_size + int(bias_x), out_size, in_size + int(bias_y))) | |
torch.nn.init.normal_(self.U, mean=0, std=0.1) | |
def forward(self, x, y): | |
if self.bias_x: | |
x = torch.cat((x, torch.ones_like(x[..., :1])), dim=-1) | |
if self.bias_y: | |
y = torch.cat((y, torch.ones_like(y[..., :1])), dim=-1) | |
bilinar_mapping = torch.einsum('bxi,ioj,byj->bxyo', x, self.U, y) | |
return bilinar_mapping | |
class MultilabelCrossEntropy(nn.Module): | |
def __init__(self): | |
super().__init__() | |
def forward(self, y_pred, y_true): | |
y_true = y_true.float() | |
y_pred = torch.mul((1.0 - torch.mul(y_true, 2.0)), y_pred) | |
y_pred_neg = y_pred - torch.mul(y_true, 1e12) | |
y_pred_pos = y_pred - torch.mul(1.0 - y_true, 1e12) | |
zeros = torch.zeros_like(y_pred[..., :1]) | |
y_pred_neg = torch.cat([y_pred_neg, zeros], axis=-1) | |
y_pred_pos = torch.cat([y_pred_pos, zeros], axis=-1) | |
neg_loss = torch.logsumexp(y_pred_neg, axis=-1) | |
pos_loss = torch.logsumexp(y_pred_pos, axis=-1) | |
loss = torch.mean(neg_loss + pos_loss) | |
return loss | |
class UbertModel(BertPreTrainedModel): | |
def __init__(self, config): | |
super().__init__(config) | |
self.bert = BertModel(config) | |
self.query_layer = torch.nn.Sequential(torch.nn.Linear(in_features=self.config.hidden_size, | |
out_features=self.config.biaffine_size), | |
torch.nn.GELU()) | |
self.key_layer = torch.nn.Sequential(torch.nn.Linear(in_features=self.config.hidden_size, out_features=self.config.biaffine_size), | |
torch.nn.GELU()) | |
self.biaffine_query_key_cls = biaffine(self.config.biaffine_size, 1) | |
self.loss_softmax = MultilabelCrossEntropy() | |
self.loss_sigmoid = torch.nn.BCEWithLogitsLoss(reduction='mean') | |
def forward(self, | |
input_ids, | |
attention_mask, | |
token_type_ids, | |
span_labels=None, | |
span_labels_mask=None): | |
batch_size, num_label, seq_len = input_ids.shape | |
input_ids = input_ids.view(-1, seq_len) | |
attention_mask = attention_mask.view(-1, seq_len) | |
token_type_ids = token_type_ids.view(-1, seq_len) | |
batch_size, seq_len = input_ids.shape | |
outputs = self.bert(input_ids=input_ids, | |
attention_mask=attention_mask, | |
token_type_ids=token_type_ids, | |
output_hidden_states=True) # (bsz, seq, dim) | |
hidden_states = outputs[0] | |
batch_size, seq_len, hidden_size = hidden_states.shape | |
query = self.query_layer(hidden_states) | |
key = self.key_layer(hidden_states) | |
span_logits = self.biaffine_query_key_cls( | |
query, key).reshape(-1, num_label, seq_len, seq_len) | |
span_logits = span_logits + span_labels_mask | |
if span_labels == None: | |
return 0, span_logits | |
else: | |
soft_loss1 = self.loss_softmax( | |
span_logits.reshape(-1, num_label, seq_len*seq_len), span_labels.reshape(-1, num_label, seq_len*seq_len)) | |
soft_loss2 = self.loss_softmax(span_logits.permute( | |
0, 2, 3, 1), span_labels.permute(0, 2, 3, 1)) | |
sig_loss = self.loss_sigmoid(span_logits, span_labels) | |
all_loss = 10*(100*sig_loss+soft_loss1+soft_loss2) | |
return all_loss, span_logits | |
class UbertLitModel(pl.LightningModule): | |
def add_model_specific_args(parent_args): | |
parser = parent_args.add_argument_group('BaseModel') | |
parser.add_argument('--learning_rate', default=1e-5, type=float) | |
parser.add_argument('--weight_decay', default=0.1, type=float) | |
parser.add_argument('--warmup', default=0.01, type=float) | |
parser.add_argument('--num_labels', default=10, type=int) | |
return parent_args | |
def __init__(self, args, num_data=1): | |
super().__init__() | |
self.args = args | |
self.num_data = num_data | |
self.model = UbertModel.from_pretrained( | |
self.args.pretrained_model_path) | |
self.count = 0 | |
def setup(self, stage) -> None: | |
if stage == 'fit': | |
num_gpus = self.trainer.gpus if self.trainer.gpus is not None else 0 | |
self.total_step = int(self.trainer.max_epochs * self.num_data / | |
(max(1, num_gpus) * self.trainer.accumulate_grad_batches)) | |
print('Total training step:', self.total_step) | |
def training_step(self, batch, batch_idx): | |
loss, span_logits = self.model(**batch) | |
span_acc, recall, precise = self.comput_metrix_span( | |
span_logits, batch['span_labels']) | |
self.log('train_loss', loss) | |
self.log('train_span_acc', span_acc) | |
self.log('train_span_recall', recall) | |
self.log('train_span_precise', precise) | |
return loss | |
def validation_step(self, batch, batch_idx): | |
loss, span_logits = self.model(**batch) | |
span_acc, recall, precise = self.comput_metrix_span( | |
span_logits, batch['span_labels']) | |
self.log('val_loss', loss) | |
self.log('val_span_acc', span_acc) | |
self.log('val_span_recall', recall) | |
self.log('val_span_precise', precise) | |
def predict_step(self, batch, batch_idx): | |
loss, span_logits = self.model(**batch) | |
span_acc = self.comput_metrix_span(span_logits, batch['span_labels']) | |
return span_acc.item() | |
def configure_optimizers(self): | |
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] | |
paras = list( | |
filter(lambda p: p[1].requires_grad, self.named_parameters())) | |
paras = [{ | |
'params': | |
[p for n, p in paras if not any(nd in n for nd in no_decay)], | |
'weight_decay': self.args.weight_decay | |
}, { | |
'params': [p for n, p in paras if any(nd in n for nd in no_decay)], | |
'weight_decay': 0.0 | |
}] | |
optimizer = torch.optim.AdamW(paras, lr=self.args.learning_rate) | |
scheduler = get_linear_schedule_with_warmup( | |
optimizer, int(self.total_step * self.args.warmup), | |
self.total_step) | |
return [{ | |
'optimizer': optimizer, | |
'lr_scheduler': { | |
'scheduler': scheduler, | |
'interval': 'step', | |
'frequency': 1 | |
} | |
}] | |
def comput_metrix_span(self, logits, labels): | |
ones = torch.ones_like(logits) | |
zero = torch.zeros_like(logits) | |
logits = torch.where(logits < 0, zero, ones) | |
y_pred = logits.view(size=(-1,)) | |
y_true = labels.view(size=(-1,)) | |
corr = torch.eq(y_pred, y_true).float() | |
corr = torch.multiply(y_true, corr) | |
recall = torch.sum(corr.float())/(torch.sum(y_true.float())+1e-5) | |
precise = torch.sum(corr.float())/(torch.sum(y_pred.float())+1e-5) | |
f1 = 2*recall*precise/(recall+precise+1e-5) | |
return f1, recall, precise | |
class TaskModelCheckpoint: | |
def add_argparse_args(parent_args): | |
parser = parent_args.add_argument_group('BaseModel') | |
parser.add_argument('--monitor', default='train_loss', type=str) | |
parser.add_argument('--mode', default='min', type=str) | |
parser.add_argument('--checkpoint_path', | |
default='./checkpoint/', type=str) | |
parser.add_argument( | |
'--filename', default='model-{epoch:02d}-{train_loss:.4f}', type=str) | |
parser.add_argument('--save_top_k', default=3, type=float) | |
parser.add_argument('--every_n_epochs', default=1, type=float) | |
parser.add_argument('--every_n_train_steps', default=100, type=float) | |
parser.add_argument('--save_weights_only', default=True, type=bool) | |
return parent_args | |
def __init__(self, args): | |
self.callbacks = ModelCheckpoint(monitor=args.monitor, | |
save_top_k=args.save_top_k, | |
mode=args.mode, | |
save_last=True, | |
every_n_train_steps=args.every_n_train_steps, | |
save_weights_only=args.save_weights_only, | |
dirpath=args.checkpoint_path, | |
filename=args.filename) | |
class OffsetMapping: | |
def __init__(self): | |
self._do_lower_case = True | |
def stem(token): | |
if token[:2] == '##': | |
return token[2:] | |
else: | |
return token | |
def _is_control(ch): | |
return unicodedata.category(ch) in ('Cc', 'Cf') | |
def _is_special(ch): | |
return bool(ch) and (ch[0] == '[') and (ch[-1] == ']') | |
def rematch(self, text, tokens): | |
if self._do_lower_case: | |
text = text.lower() | |
normalized_text, char_mapping = '', [] | |
for i, ch in enumerate(text): | |
if self._do_lower_case: | |
ch = unicodedata.normalize('NFD', ch) | |
ch = ''.join( | |
[c for c in ch if unicodedata.category(c) != 'Mn']) | |
ch = ''.join([ | |
c for c in ch | |
if not (ord(c) == 0 or ord(c) == 0xfffd or self._is_control(c)) | |
]) | |
normalized_text += ch | |
char_mapping.extend([i] * len(ch)) | |
text, token_mapping, offset = normalized_text, [], 0 | |
for token in tokens: | |
if self._is_special(token): | |
token_mapping.append([offset]) | |
offset += 1 | |
else: | |
token = self.stem(token) | |
start = text[offset:].index(token) + offset | |
end = start + len(token) | |
token_mapping.append(char_mapping[start:end]) | |
offset = end | |
return token_mapping | |
class extractModel: | |
def get_actual_id(self, text, query_text, tokenizer, args): | |
text_encode = tokenizer.encode(text) | |
one_input_encode = tokenizer.encode(query_text) | |
text_start_id = search(text_encode[1:-1], one_input_encode)[0][0] | |
text_end_id = text_start_id+len(text_encode)-1 | |
if text_end_id > args.max_length: | |
text_end_id = args.max_length | |
text_token = tokenizer.tokenize(text) | |
text_mapping = OffsetMapping().rematch(text, text_token) | |
return text_start_id, text_end_id, text_mapping, one_input_encode | |
def extract_index(self, span_logits, sample_length, split_value=0.5): | |
result = [] | |
for i in range(sample_length): | |
for j in range(i, sample_length): | |
if span_logits[i, j] > split_value: | |
result.append((i, j, span_logits[i, j])) | |
return result | |
def extract_entity(self, text, entity_idx, text_start_id, text_mapping): | |
start_split = text_mapping[entity_idx[0]-text_start_id] if entity_idx[0] - \ | |
text_start_id < len(text_mapping) and entity_idx[0]-text_start_id >= 0 else [] | |
end_split = text_mapping[entity_idx[1]-text_start_id] if entity_idx[1] - \ | |
text_start_id < len(text_mapping) and entity_idx[1]-text_start_id >= 0 else [] | |
entity = '' | |
if start_split != [] and end_split != []: | |
entity = text[start_split[0]:end_split[-1]+1] | |
return entity | |
def extract(self, batch_data, model, tokenizer, args): | |
input_ids = [] | |
attention_mask = [] | |
token_type_ids = [] | |
span_labels_masks = [] | |
for item in batch_data: | |
input_ids0 = [] | |
attention_mask0 = [] | |
token_type_ids0 = [] | |
span_labels_masks0 = [] | |
for choice in item['choices']: | |
texta = item['task_type'] + '[SEP]' + \ | |
item['subtask_type'] + '[SEP]' + choice['entity_type'] | |
textb = item['text'] | |
encode_dict = tokenizer.encode_plus(texta, textb, | |
max_length=args.max_length, | |
padding='max_length', | |
truncation='longest_first') | |
encode_sent = encode_dict['input_ids'] | |
encode_token_type_ids = encode_dict['token_type_ids'] | |
encode_attention_mask = encode_dict['attention_mask'] | |
span_label_mask = np.zeros( | |
(args.max_length, args.max_length))-10000 | |
if item['task_type'] == '分类任务': | |
span_label_mask[0, 0] = 0 | |
else: | |
question_len = len(tokenizer.encode(texta)) | |
span_label_mask[question_len:, question_len:] = np.zeros( | |
(args.max_length-question_len, args.max_length-question_len)) | |
input_ids0.append(encode_sent) | |
attention_mask0.append(encode_attention_mask) | |
token_type_ids0.append(encode_token_type_ids) | |
span_labels_masks0.append(span_label_mask) | |
input_ids.append(input_ids0) | |
attention_mask.append(attention_mask0) | |
token_type_ids.append(token_type_ids0) | |
span_labels_masks.append(span_labels_masks0) | |
input_ids = torch.tensor(input_ids).to(model.device) | |
attention_mask = torch.tensor(attention_mask).to(model.device) | |
token_type_ids = torch.tensor(token_type_ids).to(model.device) | |
span_labels_mask = torch.tensor(span_labels_masks).to(model.device) | |
_, span_logits = model.model(input_ids=input_ids, | |
attention_mask=attention_mask, | |
token_type_ids=token_type_ids, | |
span_labels=None, | |
span_labels_mask=span_labels_mask) | |
span_logits = torch.nn.functional.sigmoid(span_logits) | |
span_logits = span_logits.cpu().detach().numpy() | |
for i, item in enumerate(batch_data): | |
if item['task_type'] == '分类任务': | |
cls_idx = 0 | |
max_c = np.argmax(span_logits[i, :, cls_idx, cls_idx]) | |
batch_data[i]['choices'][max_c]['label'] = 1 | |
batch_data[i]['choices'][max_c]['score'] = span_logits[i, | |
max_c, cls_idx, cls_idx] | |
else: | |
if item['subtask_type'] == '抽取式阅读理解': | |
for c in range(len(item['choices'])): | |
texta = item['subtask_type'] + \ | |
'[SEP]' + choice['entity_type'] | |
textb = item['text'] | |
text_start_id, text_end_id, offset_mapping, input_ids = self.get_actual_id( | |
item['text'], texta+'[SEP]'+textb, tokenizer, args) | |
logits = span_logits[i, c, :, :] | |
max_index = np.unravel_index( | |
np.argmax(logits, axis=None), logits.shape) | |
entity_list = [] | |
if logits[max_index] > args.threshold: | |
entity = self.extract_entity( | |
item['text'], (max_index[0], max_index[1]), text_start_id, offset_mapping) | |
entity = { | |
'entity_name': entity, | |
'score': logits[max_index] | |
} | |
if entity not in entity_list: | |
entity_list.append(entity) | |
batch_data[i]['choices'][c]['entity_list'] = entity_list | |
else: | |
for c in range(len(item['choices'])): | |
texta = item['task_type'] + '[SEP]' + item['subtask_type'] + \ | |
'[SEP]' + item['choices'][c]['entity_type'] | |
textb = item['text'] | |
text_start_id, text_end_id, offset_mapping, input_ids = self.get_actual_id( | |
item['text'], texta+'[SEP]'+textb, tokenizer, args) | |
logits = span_logits[i, c, :, :] | |
sample_length = len(input_ids) | |
entity_idx_type_list = self.extract_index( | |
logits, sample_length, split_value=args.threshold) | |
entity_list = [] | |
for entity_idx in entity_idx_type_list: | |
entity = self.extract_entity( | |
item['text'], (entity_idx[0], entity_idx[1]), text_start_id, offset_mapping) | |
entity = { | |
'entity_name': entity, | |
'score': entity_idx[2] | |
} | |
if entity not in entity_list: | |
entity_list.append(entity) | |
batch_data[i]['choices'][c]['entity_list'] = entity_list | |
return batch_data | |
class UbertPiplines: | |
def piplines_args(parent_args): | |
total_parser = parent_args.add_argument_group("piplines args") | |
total_parser.add_argument( | |
'--pretrained_model_path', default='IDEA-CCNL/Erlangshen-Ubert-110M-Chinese', type=str) | |
total_parser.add_argument('--output_save_path', | |
default='./predict.json', type=str) | |
total_parser.add_argument('--load_checkpoints_path', | |
default='', type=str) | |
total_parser.add_argument('--max_extract_entity_number', | |
default=1, type=float) | |
total_parser.add_argument('--train', action='store_true') | |
total_parser.add_argument('--threshold', | |
default=0.5, type=float) | |
total_parser = UbertDataModel.add_data_specific_args(total_parser) | |
total_parser = TaskModelCheckpoint.add_argparse_args(total_parser) | |
total_parser = UbertLitModel.add_model_specific_args(total_parser) | |
total_parser = pl.Trainer.add_argparse_args(parent_args) | |
return parent_args | |
def __init__(self, args): | |
if args.load_checkpoints_path != '': | |
self.model = UbertLitModel.load_from_checkpoint( | |
args.load_checkpoints_path, args=args) | |
else: | |
self.model = UbertLitModel(args) | |
self.args = args | |
self.checkpoint_callback = TaskModelCheckpoint(args).callbacks | |
self.logger = loggers.TensorBoardLogger(save_dir=args.default_root_dir) | |
self.trainer = pl.Trainer.from_argparse_args(args, | |
logger=self.logger, | |
callbacks=[self.checkpoint_callback]) | |
self.tokenizer = BertTokenizer.from_pretrained(args.pretrained_model_path, | |
additional_special_tokens=['[unused'+str(i+1)+']' for i in range(99)]) | |
self.em = extractModel() | |
def fit(self, train_data, dev_data): | |
data_model = UbertDataModel( | |
train_data, dev_data, self.tokenizer, self.args) | |
self.model.num_data = len(train_data) | |
self.trainer.fit(self.model, data_model) | |
def predict(self, test_data, cuda=True): | |
result = [] | |
start = 0 | |
if cuda: | |
self.model = self.model.cuda() | |
self.model.eval() | |
while start < len(test_data): | |
batch_data = test_data[start:start+self.args.batchsize] | |
start += self.args.batchsize | |
batch_result = self.em.extract( | |
batch_data, self.model, self.tokenizer, self.args) | |
result.extend(batch_result) | |
return result | |