JasonLiao commited on
Commit
e473617
·
1 Parent(s): 88b532e

Update code/prediction.py

Browse files

and threshold_num and threshold_rate in the function "add_sentence_table", and using the (new_start new_end) for postprocess index

Files changed (1) hide show
  1. code/prediction.py +30 -12
code/prediction.py CHANGED
@@ -66,9 +66,7 @@ def test_predict(test_loader, device, model, min_label=1, max_label=3):
66
  model.train()
67
  return result
68
 
69
- def add_sentence_table(result):
70
-
71
- pattern =":;。,?!~!: "
72
  for sample in result:
73
  boundary_list = []
74
  for i, char in enumerate(sample['text_a']):
@@ -77,16 +75,36 @@ def add_sentence_table(result):
77
  boundary_list.append(len(sample['text_a'])+1)
78
  start=0
79
  end =0
80
- pre_states =False
81
  sample["predict_sentence_table"] = torch.zeros(len(sample["predict_span_table"]))
82
  for boundary in boundary_list:
83
  end = boundary
84
- if(sum(sample["predict_span_table"][start:end])>0):
85
- if pre_states:
86
- sample["predict_sentence_table"][start-1:end] = 2
 
 
 
 
 
 
 
 
 
 
 
87
  else:
88
- sample["predict_sentence_table"][start:end] = 2
89
- sample["predict_sentence_table"][start] = 1
90
- pre_states=True
91
- else: pre_states =False
92
- start = end+1
 
 
 
 
 
 
 
 
 
 
66
  model.train()
67
  return result
68
 
69
+ def add_sentence_table(result, pattern =":;。,?!~!: ", threshold_num=5, threshold_rate=0.5):
 
 
70
  for sample in result:
71
  boundary_list = []
72
  for i, char in enumerate(sample['text_a']):
 
75
  boundary_list.append(len(sample['text_a'])+1)
76
  start=0
77
  end =0
78
+ fist_sentence = True
79
  sample["predict_sentence_table"] = torch.zeros(len(sample["predict_span_table"]))
80
  for boundary in boundary_list:
81
  end = boundary
82
+ new_start = start
83
+ new_end = end
84
+ for i, predict_char in enumerate(sample["predict_span_table"][start:end]):
85
+ if predict_char!=0:
86
+ prrdict_num+=1
87
+ prrdict_num = sum(sample["predict_span_table"][start:end])
88
+ sentence_num = len(sample["predict_span_table"][start:end])
89
+ if(prrdict_num > threshold_num) or (prrdict_num > sentence_num*threshold_rate):
90
+
91
+ while(sample["predict_span_table"][new_start]==0): new_start+=1
92
+ while(sample["predict_span_table"][new_end-1]==0): new_end-=1
93
+ if fist_sentence:
94
+ sample["predict_sentence_table"][new_start:new_end] = 2
95
+ sample["predict_sentence_table"][new_start] = 1
96
  else:
97
+ sample["predict_sentence_table"][new_start-1:new_end] = 2
98
+ fist_sentence=False
99
+ else: fist_sentence =True
100
+ start = end+1
101
+
102
+ def add_doc_id(result, test_data):
103
+ #make dict {'text_a':"docid"}
104
+ text_to_id = dict()
105
+ for sample in test_data:
106
+ text_to_id[sample["text_a"]] = sample["docid"]
107
+
108
+ #add doc_id
109
+ for sample in result:
110
+ sample["docid"] = text_to_id[sample["text_a"]]