singletongue commited on
Commit
ddb7b37
·
verified ·
1 Parent(s): 9baca8d

Limit the number of entity spans detected by NER, fix `get_token_spans()`

Browse files
Files changed (1) hide show
  1. app.py +4 -0
app.py CHANGED
@@ -80,6 +80,9 @@ def get_token_spans(tokenizer, text: str) -> list[tuple[int, int]]:
80
  end = 0
81
  for token in tokenizer.tokenize(text):
82
  token = token.removeprefix("##")
 
 
 
83
  start = text.index(token, end)
84
  end = start + len(token)
85
  token_spans.append((start, end))
@@ -164,6 +167,7 @@ def get_topk_entities_from_texts(
164
  model_outputs = model(**tokenized_examples)
165
  token_spans = get_token_spans(tokenizer, text)
166
  entity_spans = get_predicted_entity_spans(model_outputs.ner_logits[0], token_spans, entity_span_sensitivity)
 
167
  batch_entity_spans.append(entity_spans)
168
 
169
  tokenized_examples = tokenizer(text, entity_spans=entity_spans or None, truncation=True, return_tensors="pt")
 
80
  end = 0
81
  for token in tokenizer.tokenize(text):
82
  token = token.removeprefix("##")
83
+ if token not in text: # e.g., token == "[UNK]"
84
+ continue
85
+
86
  start = text.index(token, end)
87
  end = start + len(token)
88
  token_spans.append((start, end))
 
167
  model_outputs = model(**tokenized_examples)
168
  token_spans = get_token_spans(tokenizer, text)
169
  entity_spans = get_predicted_entity_spans(model_outputs.ner_logits[0], token_spans, entity_span_sensitivity)
170
+ entity_spans = entity_spans[:tokenizer.max_entity_length]
171
  batch_entity_spans.append(entity_spans)
172
 
173
  tokenized_examples = tokenizer(text, entity_spans=entity_spans or None, truncation=True, return_tensors="pt")