Spaces:
Runtime error
Runtime error
File size: 3,992 Bytes
9fdc3cc d0e50f4 9fdc3cc e473617 9fdc3cc 53d9863 e473617 9a610df e473617 53d9863 9fdc3cc 53d9863 e473617 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import torch
from args import args, config
from tqdm import tqdm
from items_dataset import items_dataset
def test_predict(test_loader, device, model, min_label=1, max_label=3):
model.eval()
result = []
for i, test_batch in enumerate(tqdm(test_loader)):
batch_text = test_batch['batch_text']
input_ids = test_batch['input_ids'].to(device)
token_type_ids = test_batch['token_type_ids'].to(device)
attention_mask = test_batch['attention_mask'].to(device)
#labels = test_batch['labels'].to(device)
crf_mask = test_batch["crf_mask"].to(device)
sample_mapping = test_batch["overflow_to_sample_mapping"]
output = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=None, crf_mask=crf_mask)
if args.use_crf:
prediction = model.crf.decode(output[0], crf_mask)
else:
prediction = torch.max(output[0], -1).indices
#make result of every sample
sample_id = -1
sample_result= {"text_a" : test_batch['batch_text'][0]}
for batch_id in range(len(sample_mapping)):
change_sample = False
if sample_id != sample_mapping[batch_id]: change_sample = True
#print(i, id)
if change_sample:
sample_id = sample_mapping[batch_id]
sample_result= {"text_a" : test_batch['batch_text'][sample_id]}
decode_span_table = torch.zeros(len(test_batch['batch_text'][sample_id]))
spans = items_dataset.cal_agreement_span(None, agreement_table=prediction[batch_id], min_agree=min_label, max_agree=max_label)
#decode spans
for span in spans:
#print(span)
if span[0]==0: span[0]+=1
if span[1]==1: span[1]+=1
while(True):
start = test_batch[batch_id].token_to_chars(span[0])
if start != None or span[0]>=span[1]:
break
span[0]+=1
while(True):
end = test_batch[batch_id].token_to_chars(span[1])
if end != None or span[0]>=span[1]:
break
span[1]-=1
if span[0]<span[1]:
de_start = test_batch[batch_id].token_to_chars(span[0])[0]
de_end = test_batch[batch_id].token_to_chars(span[1]-1)[0]
#print(de_start, de_end)
#if(de_start>512): print(de_start, de_end)
decode_span_table[de_start:de_end]=2 #insite
decode_span_table[de_start]=1 #begin
if change_sample:
sample_result["predict_span_table"] = decode_span_table
#sample_result["boundary"] = test_batch["boundary"][id]
result.append(sample_result)
model.train()
return result
def add_sentence_table(result, pattern =":;。,,?!~!: ", threshold_num=5, threshold_rate=0.5):
for sample in result:
boundary_list = []
for i, char in enumerate(sample['text_a']):
if char in pattern:
boundary_list.append(i)
boundary_list.append(len(sample['text_a'])+1)
start=0
end =0
fist_sentence = True
sample["predict_sentence_table"] = torch.zeros(len(sample["predict_span_table"]))
for boundary in boundary_list:
end = boundary
predict_num = sum(sample["predict_span_table"][start:end]>0)
sentence_num = len(sample["predict_span_table"][start:end])
if(predict_num > threshold_num) or (predict_num > sentence_num*threshold_rate):
if fist_sentence:
sample["predict_sentence_table"][start:end] = 2
sample["predict_sentence_table"][start] = 1
fist_sentence=False
else:
sample["predict_sentence_table"][start-1:end] = 2
else: fist_sentence =True
start = end+1
def add_doc_id(result, test_data):
#make dict {'text_a':"docid"}
text_to_id = dict()
for sample in test_data:
text_to_id[sample["text_a"]] = sample["docid"]
#add doc_id
for sample in result:
sample["docid"] = text_to_id[sample["text_a"]] |