File size: 4,781 Bytes
d5062c8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
# -*- coding: utf-8 -*-
"""
Created on Mon Aug 30 19:54:17 2021
@author: luol2
input representation of model
"""
import os, sys
import numpy as np
from tensorflow.keras.preprocessing.sequence import pad_sequences
from transformers import AutoTokenizer
class Hugface_RepresentationLayer(object):
def __init__(self, tokenizer_name_or_path, label_file,lowercase=True):
#load vocab
self.model_type='bert'
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, use_fast=True,do_lower_case=lowercase)
self.tokenizer.add_tokens(["arg1s","arg1e","gene1s","gene1e","species1s","species1e"])
#load label
self.label_2_index={}
self.index_2_label={}
self.label_table_size=0
self.load_label_vocab(label_file,self.label_2_index,self.index_2_label)
self.label_table_size=len(self.label_2_index)
self.vocab_len=len(self.tokenizer)
def load_label_vocab(self,fea_file,fea_index,index_2_label):
fin=open(fea_file,'r',encoding='utf-8')
all_text=fin.read().strip().split('\n')
fin.close()
for i in range(0,len(all_text)):
fea_index[all_text[i]]=i
index_2_label[str(i)]=all_text[i]
def generate_label_list(self,bert_tokens,labels,word_index):
label_list=['O']*len(word_index)
label_i=0
if len(word_index)!=len(bert_tokens):
print('index != tokens',word_index,bert_tokens)
sys.exit()
last_word_index=0
for i in range(0,len(word_index)):
if word_index[i]==None:
pass
else:
label_list[i]=labels[word_index[i]]
label_list_index=[]
bert_text_label=[]
for i in range(0,len(bert_tokens)):
label_list_index.append(self.label_2_index[label_list[i]])
bert_text_label.append([bert_tokens[i],label_list[i]])
return label_list_index,bert_text_label
def load_data_hugface(self,instances, labels, word_max_len=100, label_type='softmax'):
x_index=[]
x_seg=[]
x_mask=[]
y_list=[]
bert_text_labels=[]
max_len=0
over_num=0
maxT=word_max_len
ave_len=0
for sentence in instances:
sentence_text_list=[]
label_list=[]
for j in range(0,len(sentence)):
sentence_text_list.append(sentence[j][0])
label_list.append(sentence[j][-1])
token_result=self.tokenizer(
sentence_text_list,
max_length=word_max_len,
truncation=True,is_split_into_words=True)
bert_tokens=self.tokenizer.convert_ids_to_tokens(token_result['input_ids'])
word_index=token_result.word_ids(batch_index=0)
ave_len+=len(bert_tokens)
if len(sentence_text_list)>max_len:
max_len=len(sentence_text_list)
if len(bert_tokens)>=maxT:
over_num+=1
x_index.append(token_result['input_ids'])
if self.model_type in {"gpt2", "roberta"}:
x_seg.append([0]*len(token_result['input_ids']))
else:
x_seg.append(token_result['token_type_ids'])
x_mask.append(token_result['attention_mask'])
#print('label:',label_list)
label_list,bert_text_label=self.generate_label_list(bert_tokens,label_list,word_index)
#print('\nlabel list:',label_list)
#print('\nbert_text_label:',bert_text_label)
#sys.exit()
y_list.append(label_list)
#print(y_list)
bert_text_labels.append(bert_text_label)
x1_np = pad_sequences(x_index, word_max_len, value=0, padding='post',truncating='post') # right padding
x2_np = pad_sequences(x_seg, word_max_len, value=0, padding='post',truncating='post')
x3_np = pad_sequences(x_mask, word_max_len, value=0, padding='post',truncating='post')
y_np = pad_sequences(y_list, word_max_len, value=0, padding='post',truncating='post')
# print('bert max len:',max_len,',Over',maxT,':',over_num,'ave len:',ave_len/len(instances),'total:',len(instances))
if label_type=='onehot':
y_np = np.eye(len(labels), dtype='float32')[y_np]
elif label_type=='softmax':
y_np = np.expand_dims(y_np, 2)
elif label_type=='crf':
pass
return [x1_np, x2_np,x3_np], y_np,bert_text_labels
if __name__ == '__main__':
pass
|