singletongue commited on
Commit
0ae2871
·
verified ·
1 Parent(s): ddb7b37

Show scores for top-k entities

Browse files
Files changed (1) hide show
  1. app.py +22 -7
app.py CHANGED
@@ -167,22 +167,34 @@ def get_topk_entities_from_texts(
167
  model_outputs = model(**tokenized_examples)
168
  token_spans = get_token_spans(tokenizer, text)
169
  entity_spans = get_predicted_entity_spans(model_outputs.ner_logits[0], token_spans, entity_span_sensitivity)
170
- entity_spans = entity_spans[:tokenizer.max_entity_length]
171
  batch_entity_spans.append(entity_spans)
172
 
173
  tokenized_examples = tokenizer(text, entity_spans=entity_spans or None, truncation=True, return_tensors="pt")
174
  model_outputs = model(**tokenized_examples)
175
 
176
  if model_outputs.topic_entity_logits is not None:
177
- _, topk_normal_entity_ids = model_outputs.topic_entity_logits[0].topk(entity_k)
178
- topk_normal_entities.append([id2normal_entity[id_] for id_ in topk_normal_entity_ids.tolist()])
 
 
 
 
 
179
  else:
180
  topk_normal_entities.append([])
181
 
182
  if model_outputs.topic_category_logits is not None:
183
  model_outputs.topic_category_logits[:, ignore_category_entity_ids] = float("-inf")
184
- _, topk_category_entity_ids = model_outputs.topic_category_logits[0].topk(category_k)
185
- topk_category_entities.append([id2category_entity[id_] for id_ in topk_category_entity_ids.tolist()])
 
 
 
 
 
 
 
186
  else:
187
  topk_category_entities.append([])
188
 
@@ -197,9 +209,12 @@ def get_topk_entities_from_texts(
197
  )
198
  span_entity_logits += nayose_coef * nayose_scores
199
 
200
- _, topk_span_entity_ids = span_entity_logits.topk(entity_k)
201
  topk_span_entities.append(
202
- [[id2normal_entity[id_] for id_ in ids] for ids in topk_span_entity_ids.tolist()]
 
 
 
203
  )
204
  else:
205
  topk_span_entities.append([])
 
167
  model_outputs = model(**tokenized_examples)
168
  token_spans = get_token_spans(tokenizer, text)
169
  entity_spans = get_predicted_entity_spans(model_outputs.ner_logits[0], token_spans, entity_span_sensitivity)
170
+ entity_spans = entity_spans[: tokenizer.max_entity_length]
171
  batch_entity_spans.append(entity_spans)
172
 
173
  tokenized_examples = tokenizer(text, entity_spans=entity_spans or None, truncation=True, return_tensors="pt")
174
  model_outputs = model(**tokenized_examples)
175
 
176
  if model_outputs.topic_entity_logits is not None:
177
+ topk_normal_entity_scores, topk_normal_entity_ids = model_outputs.topic_entity_logits[0].topk(entity_k)
178
+ topk_normal_entities.append(
179
+ [
180
+ f"{id2normal_entity[id_]} ({score:.3f})"
181
+ for score, id_ in zip(topk_normal_entity_scores, topk_normal_entity_ids.tolist())
182
+ ]
183
+ )
184
  else:
185
  topk_normal_entities.append([])
186
 
187
  if model_outputs.topic_category_logits is not None:
188
  model_outputs.topic_category_logits[:, ignore_category_entity_ids] = float("-inf")
189
+ topk_category_entity_scores, topk_category_entity_ids = model_outputs.topic_category_logits[0].topk(
190
+ category_k
191
+ )
192
+ topk_category_entities.append(
193
+ [
194
+ f"{id2category_entity[id_]} ({score:.3f})"
195
+ for score, id_ in zip(topk_category_entity_scores, topk_category_entity_ids.tolist())
196
+ ]
197
+ )
198
  else:
199
  topk_category_entities.append([])
200
 
 
209
  )
210
  span_entity_logits += nayose_coef * nayose_scores
211
 
212
+ topk_span_entity_scores, topk_span_entity_ids = span_entity_logits.topk(entity_k)
213
  topk_span_entities.append(
214
+ [
215
+ [f"{id2normal_entity[id_]} ({score:.3f})" for score, id_ in zip(scores, ids)]
216
+ for scores, ids in zip(topk_span_entity_scores, topk_span_entity_ids.tolist())
217
+ ]
218
  )
219
  else:
220
  topk_span_entities.append([])