Spaces:
Runtime error
Runtime error
Update code/prediction.py
Browse filesand threshold_num and threshold_rate in the function "add_sentence_table", and using the (new_start new_end) for postprocess index
- code/prediction.py +30 -12
code/prediction.py
CHANGED
@@ -66,9 +66,7 @@ def test_predict(test_loader, device, model, min_label=1, max_label=3):
|
|
66 |
model.train()
|
67 |
return result
|
68 |
|
69 |
-
def add_sentence_table(result):
|
70 |
-
|
71 |
-
pattern =":;。,?!~!: "
|
72 |
for sample in result:
|
73 |
boundary_list = []
|
74 |
for i, char in enumerate(sample['text_a']):
|
@@ -77,16 +75,36 @@ def add_sentence_table(result):
|
|
77 |
boundary_list.append(len(sample['text_a'])+1)
|
78 |
start=0
|
79 |
end =0
|
80 |
-
|
81 |
sample["predict_sentence_table"] = torch.zeros(len(sample["predict_span_table"]))
|
82 |
for boundary in boundary_list:
|
83 |
end = boundary
|
84 |
-
|
85 |
-
|
86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
else:
|
88 |
-
sample["predict_sentence_table"][
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
model.train()
|
67 |
return result
|
68 |
|
69 |
+
def add_sentence_table(result, pattern =":;。,?!~!: ", threshold_num=5, threshold_rate=0.5):
|
|
|
|
|
70 |
for sample in result:
|
71 |
boundary_list = []
|
72 |
for i, char in enumerate(sample['text_a']):
|
|
|
75 |
boundary_list.append(len(sample['text_a'])+1)
|
76 |
start=0
|
77 |
end =0
|
78 |
+
fist_sentence = True
|
79 |
sample["predict_sentence_table"] = torch.zeros(len(sample["predict_span_table"]))
|
80 |
for boundary in boundary_list:
|
81 |
end = boundary
|
82 |
+
new_start = start
|
83 |
+
new_end = end
|
84 |
+
for i, predict_char in enumerate(sample["predict_span_table"][start:end]):
|
85 |
+
if predict_char!=0:
|
86 |
+
prrdict_num+=1
|
87 |
+
prrdict_num = sum(sample["predict_span_table"][start:end])
|
88 |
+
sentence_num = len(sample["predict_span_table"][start:end])
|
89 |
+
if(prrdict_num > threshold_num) or (prrdict_num > sentence_num*threshold_rate):
|
90 |
+
|
91 |
+
while(sample["predict_span_table"][new_start]==0): new_start+=1
|
92 |
+
while(sample["predict_span_table"][new_end-1]==0): new_end-=1
|
93 |
+
if fist_sentence:
|
94 |
+
sample["predict_sentence_table"][new_start:new_end] = 2
|
95 |
+
sample["predict_sentence_table"][new_start] = 1
|
96 |
else:
|
97 |
+
sample["predict_sentence_table"][new_start-1:new_end] = 2
|
98 |
+
fist_sentence=False
|
99 |
+
else: fist_sentence =True
|
100 |
+
start = end+1
|
101 |
+
|
102 |
+
def add_doc_id(result, test_data):
|
103 |
+
#make dict {'text_a':"docid"}
|
104 |
+
text_to_id = dict()
|
105 |
+
for sample in test_data:
|
106 |
+
text_to_id[sample["text_a"]] = sample["docid"]
|
107 |
+
|
108 |
+
#add doc_id
|
109 |
+
for sample in result:
|
110 |
+
sample["docid"] = text_to_id[sample["text_a"]]
|