TestTaker commited on
Commit
703d114
·
1 Parent(s): c51f116

Fix bert bugs

Browse files
utilities_language_bert/rus_main_workflow_bert.py CHANGED
@@ -106,7 +106,7 @@ def main_workflow(
106
 
107
  # Get summary. May choose between round_summary_length and summary_length
108
  SUMMARY = summarization(current_text, num_sentences=round_summary_length)
109
- logs.update('Нашли интересные предложения. Пригодятся!')
110
  progress.progress(25)
111
 
112
  for sentence in workflow:
@@ -174,7 +174,7 @@ def main_workflow(
174
  logs.update(label='Подобрали неправильные варианты!', state='running')
175
 
176
  for task in RESULT_TASKS:
177
- task.inflect_distractors()
178
  progress.progress(80)
179
  logs.update(label='Просклоняли и проспрягали неправильные варианты!', state='running')
180
 
 
106
 
107
  # Get summary. May choose between round_summary_length and summary_length
108
  SUMMARY = summarization(current_text, num_sentences=round_summary_length)
109
+ logs.success('Нашли интересные предложения. Пригодятся!')
110
  progress.progress(25)
111
 
112
  for sentence in workflow:
 
174
  logs.update(label='Подобрали неправильные варианты!', state='running')
175
 
176
  for task in RESULT_TASKS:
177
+ task.inflect_distractors(level_name=level)
178
  progress.progress(80)
179
  logs.update(label='Просклоняли и проспрягали неправильные варианты!', state='running')
180
 
utilities_language_bert/rus_sentence_bert.py CHANGED
@@ -150,6 +150,7 @@ class TASK:
150
  self.tags = task_data['tags']
151
  self.lemma = task_data['lemma']
152
  self.gender = task_data['gender']
 
153
  self.max_num_distractors = max_num_distractors
154
  self.original_text = task_data['original_text']
155
  self.sentence_text = task_data['sentence_text']
@@ -180,13 +181,13 @@ class TASK:
180
  self.distractors = [d[0] for i, d in enumerate(distractors_sentence) if i < 15]
181
  self.distractors_number = len(distractors_sentence) if distractors_sentence is not None else 0
182
 
183
- def inflect_distractors(self):
184
  inflected_distractors = []
185
  if self.distractors is None:
186
  self.bad_target_word = True
187
  return
188
  for distractor_lemma in self.distractors:
189
- inflected = make_inflection(text=distractor_lemma, pos=self.pos[1], tags=self.tags)
190
  if inflected is not None:
191
  inflected_distractors.append(inflected)
192
  num_distractors = min(4, self.max_num_distractors) if self.max_num_distractors >= 4 \
 
150
  self.tags = task_data['tags']
151
  self.lemma = task_data['lemma']
152
  self.gender = task_data['gender']
153
+ self.in_summary = task_data['in_summary']
154
  self.max_num_distractors = max_num_distractors
155
  self.original_text = task_data['original_text']
156
  self.sentence_text = task_data['sentence_text']
 
181
  self.distractors = [d[0] for i, d in enumerate(distractors_sentence) if i < 15]
182
  self.distractors_number = len(distractors_sentence) if distractors_sentence is not None else 0
183
 
184
+ def inflect_distractors(self, level_name):
185
  inflected_distractors = []
186
  if self.distractors is None:
187
  self.bad_target_word = True
188
  return
189
  for distractor_lemma in self.distractors:
190
+ inflected = make_inflection(text=distractor_lemma, pos=self.pos[1], tags=self.tags, level=level_name)
191
  if inflected is not None:
192
  inflected_distractors.append(inflected)
193
  num_distractors = min(4, self.max_num_distractors) if self.max_num_distractors >= 4 \
utilities_language_general/rus_constants.py CHANGED
@@ -34,7 +34,7 @@ def load_spacy():
34
  @st.cache_resource
35
  def load_bert():
36
  with st.spinner('Загружаю языковую модель'):
37
- _pipeline = pipeline(task="fill-mask", model="a-v-white/bert-base-spanish-wwm-cased-finetuned-literature-pro")
38
  return _pipeline
39
 
40
 
@@ -113,6 +113,7 @@ COMBINE_POS = {
113
  'B2': {'VERB': ['AUX']},
114
  'C1': {'VERB': ['AUX']},
115
  'C2': {'VERB': ['AUX']},
 
116
  },
117
  'phrase':
118
  {
@@ -122,5 +123,6 @@ COMBINE_POS = {
122
  'B2': {'VERB': ['AUX']},
123
  'C1': {'VERB': ['AUX']},
124
  'C2': {'VERB': ['AUX']},
 
125
  },
126
  }
 
34
  @st.cache_resource
35
  def load_bert():
36
  with st.spinner('Загружаю языковую модель'):
37
+ _pipeline = pipeline(task="fill-mask", model="a-v-bely/ruBert-base-finetuned-russian-moshkov-child-corpus-pro")
38
  return _pipeline
39
 
40
 
 
113
  'B2': {'VERB': ['AUX']},
114
  'C1': {'VERB': ['AUX']},
115
  'C2': {'VERB': ['AUX']},
116
+ 'Без уровня': {'VERB': ['AUX']}
117
  },
118
  'phrase':
119
  {
 
123
  'B2': {'VERB': ['AUX']},
124
  'C1': {'VERB': ['AUX']},
125
  'C2': {'VERB': ['AUX']},
126
+ 'Без уровня': {'VERB': ['AUX']}
127
  },
128
  }
utilities_language_general/rus_utils.py CHANGED
@@ -41,7 +41,7 @@ def compute_frequency_dict(text: str) -> dict:
41
 
42
 
43
  def convert_gender(gender_spacy):
44
- genders = {'Masc': 'masc', 'Fem': 'femn', 'Neut': 'neut'}
45
  return genders[gender_spacy]
46
 
47
 
@@ -359,21 +359,23 @@ def get_distractors_from_model_bert(model, scaler, classifier, pos_dict:dict, le
359
  distractor_lemma, distractor_pos = candidate_morph.lemma_, candidate_morph.pos_
360
  distractor_similarity = candidate_distractor[1]
361
  candidate_gender = define_gender(distractor_lemma)
 
362
  length_ratio = abs(len(lemma) - len(distractor_lemma))
363
  decision = make_decision(doc=None, model_type='bert', scaler=scaler, classifier=classifier, pos_dict=pos_dict, level=level_name,
364
  target_lemma=lemma, target_text=None, target_pos=pos, target_position=None,
365
  substitute_lemma=distractor_lemma, substitute_pos=distractor_pos, bert_score=distractor_similarity)
366
- if (((distractor_pos == pos)
367
- or (COMBINE_POS['phrase'][level_name].get(pos) is not None and COMBINE_POS['phrase'][level_name].get(distractor_pos) is not None
368
- and distractor_pos in COMBINE_POS['phrase'][level_name][pos] and pos in COMBINE_POS['phrase'][level_name][distractor_pos]))
369
- and decision
370
- and distractor_lemma != lemma
371
- and (len(_distractors) < max_num_distractors + 10)
372
- and (candidate_gender == gender and level_name in ('B1', 'B2', 'C1', 'C2'))
373
- and (length_ratio <= max_length_ratio) # May be changed if case of phrases
374
- and (distractor_lemma not in global_distractors)
375
- and (edit_distance(lemma, distractor_lemma) # May be changed if case of phrases
376
- / ((len(lemma) + len(distractor_lemma)) / 2) > min_edit_distance_ratio)):
 
377
  if distractor_minimum is not None:
378
  if distractor_lemma in distractor_minimum:
379
  _distractors.append((distractor_lemma, candidate_distractor[1]))
 
41
 
42
 
43
  def convert_gender(gender_spacy):
44
+ genders = {'Masc': 'masc', 'Fem': 'femn', 'Neut': 'neut', None: False}
45
  return genders[gender_spacy]
46
 
47
 
 
359
  distractor_lemma, distractor_pos = candidate_morph.lemma_, candidate_morph.pos_
360
  distractor_similarity = candidate_distractor[1]
361
  candidate_gender = define_gender(distractor_lemma)
362
+ # print(distractor_lemma, candidate_gender, distractor_pos, pos)
363
  length_ratio = abs(len(lemma) - len(distractor_lemma))
364
  decision = make_decision(doc=None, model_type='bert', scaler=scaler, classifier=classifier, pos_dict=pos_dict, level=level_name,
365
  target_lemma=lemma, target_text=None, target_pos=pos, target_position=None,
366
  substitute_lemma=distractor_lemma, substitute_pos=distractor_pos, bert_score=distractor_similarity)
367
+ condition = (((distractor_pos == pos)
368
+ or (COMBINE_POS['phrase'][level_name].get(pos) is not None and COMBINE_POS['phrase'][level_name].get(distractor_pos) is not None
369
+ and distractor_pos in COMBINE_POS['phrase'][level_name][pos] and pos in COMBINE_POS['phrase'][level_name][distractor_pos]))
370
+ and decision
371
+ and distractor_lemma != lemma
372
+ and (len(_distractors) < max_num_distractors + 10)
373
+ and (candidate_gender == gender and level_name in ('B1', 'B2', 'C1', 'C2'))
374
+ and (length_ratio <= max_length_ratio) # May be changed if case of phrases
375
+ and (distractor_lemma not in global_distractors)
376
+ and (edit_distance(lemma, distractor_lemma) # May be changed if case of phrases
377
+ / ((len(lemma) + len(distractor_lemma)) / 2) > min_edit_distance_ratio))
378
+ if condition:
379
  if distractor_minimum is not None:
380
  if distractor_lemma in distractor_minimum:
381
  _distractors.append((distractor_lemma, candidate_distractor[1]))
utilities_language_general/similarity_measures.py CHANGED
@@ -185,6 +185,8 @@ def get_context_linked_words(doc, target_position, target_text):
185
 
186
 
187
  def compute_all_necessary_metrics(target_lemma, target_text, target_position, substitute_lemma, doc, model_type:str, model=None):
 
 
188
 
189
  target_vector = get_vector_for_token(model, target_lemma)
190
  substitute_vector = get_vector_for_token(model, substitute_lemma)
@@ -246,9 +248,11 @@ def make_decision(doc, model_type, scaler, classifier, pos_dict, level, target_l
246
  metrics = compute_all_necessary_metrics(target_lemma=target_lemma, target_text=target_text, target_position=target_position,
247
  substitute_lemma=substitute_lemma, doc=doc, model_type=model_type, model=model)
248
  target_multiword, substitute_multiword = target_lemma.count('_') > 2, substitute_lemma.count('_') > 2
249
- data = [LEVEL_NUMBERS.get(level), pos_dict.get(target_pos), target_multiword, pos_dict.get(substitute_pos), substitute_multiword] + scaler.transform([metrics]).tolist()[0]
250
  if model_type == 'bert':
251
- data = [LEVEL_NUMBERS.get(level), pos_dict.get(target_pos), target_multiword, pos_dict.get(substitute_pos), substitute_multiword, bert_score]
 
 
 
252
  predict = classifier.predict(data)
253
  return bool(predict)
254
 
 
185
 
186
 
187
  def compute_all_necessary_metrics(target_lemma, target_text, target_position, substitute_lemma, doc, model_type:str, model=None):
188
+ if model_type == 'bert':
189
+ return
190
 
191
  target_vector = get_vector_for_token(model, target_lemma)
192
  substitute_vector = get_vector_for_token(model, substitute_lemma)
 
248
  metrics = compute_all_necessary_metrics(target_lemma=target_lemma, target_text=target_text, target_position=target_position,
249
  substitute_lemma=substitute_lemma, doc=doc, model_type=model_type, model=model)
250
  target_multiword, substitute_multiword = target_lemma.count('_') > 2, substitute_lemma.count('_') > 2
 
251
  if model_type == 'bert':
252
+ scaled_data = scaler.transform([[bert_score]]).tolist()[0]
253
+ else:
254
+ scaled_data = scaler.transform([metrics]).tolist()[0]
255
+ data = [LEVEL_NUMBERS.get(level), pos_dict.get(target_pos), target_multiword, pos_dict.get(substitute_pos), substitute_multiword] + scaled_data
256
  predict = classifier.predict(data)
257
  return bool(predict)
258