JasonLiao commited on
Commit
53d9863
·
1 Parent(s): 9a610df

fix add_sentence_table third times

Browse files
Files changed (1) hide show
  1. code/prediction.py +5 -13
code/prediction.py CHANGED
@@ -79,23 +79,15 @@ def add_sentence_table(result, pattern =":;。,?!~!: ", threshold_nu
79
  sample["predict_sentence_table"] = torch.zeros(len(sample["predict_span_table"]))
80
  for boundary in boundary_list:
81
  end = boundary
82
- new_start = start
83
- new_end = end
84
- predict_num = 0
85
- for i, predict_char in enumerate(sample["predict_span_table"][start:end]):
86
- if predict_char!=0:
87
- predict_num+=1
88
  sentence_num = len(sample["predict_span_table"][start:end])
89
  if(predict_num > threshold_num) or (predict_num > sentence_num*threshold_rate):
90
-
91
- while(sample["predict_span_table"][new_start]==0): new_start+=1
92
- while(sample["predict_span_table"][new_end-1]==0): new_end-=1
93
  if fist_sentence:
94
- sample["predict_sentence_table"][new_start:new_end] = 2
95
- sample["predict_sentence_table"][new_start] = 1
 
96
  else:
97
- sample["predict_sentence_table"][new_start-1:new_end] = 2
98
- fist_sentence=False
99
  else: fist_sentence =True
100
  start = end+1
101
 
 
79
  sample["predict_sentence_table"] = torch.zeros(len(sample["predict_span_table"]))
80
  for boundary in boundary_list:
81
  end = boundary
82
+ predict_num = sum(sample["predict_span_table"][start:end]>0)
 
 
 
 
 
83
  sentence_num = len(sample["predict_span_table"][start:end])
84
  if(predict_num > threshold_num) or (predict_num > sentence_num*threshold_rate):
 
 
 
85
  if fist_sentence:
86
+ sample["predict_sentence_table"][start:end] = 2
87
+ sample["predict_sentence_table"][start] = 1
88
+ fist_sentence=False
89
  else:
90
+ sample["predict_sentence_table"][start-1:end] = 2
 
91
  else: fist_sentence =True
92
  start = end+1
93