Spaces:
Sleeping
Sleeping
import warnings | |
from typing import TYPE_CHECKING, List, Optional, Tuple, Union | |
import numpy as np | |
from ..file_utils import ExplicitEnum, add_end_docstrings, is_tf_available, is_torch_available | |
from ..modelcard import ModelCard | |
from ..models.bert.tokenization_bert import BasicTokenizer | |
from ..tokenization_utils import PreTrainedTokenizer | |
from .base import PIPELINE_INIT_ARGS, ArgumentHandler, Pipeline | |
if TYPE_CHECKING: | |
from ..modeling_tf_utils import TFPreTrainedModel | |
from ..modeling_utils import PreTrainedModel | |
if is_tf_available(): | |
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING | |
if is_torch_available(): | |
import torch | |
from ..models.auto.modeling_auto import MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING | |
class TokenClassificationArgumentHandler(ArgumentHandler): | |
""" | |
Handles arguments for token classification. | |
""" | |
def __call__(self, inputs: Union[str, List[str]], **kwargs): | |
if inputs is not None and isinstance(inputs, (list, tuple)) and len(inputs) > 0: | |
inputs = list(inputs) | |
batch_size = len(inputs) | |
elif isinstance(inputs, str): | |
inputs = [inputs] | |
batch_size = 1 | |
else: | |
raise ValueError("At least one input is required.") | |
offset_mapping = kwargs.get("offset_mapping") | |
if offset_mapping: | |
if isinstance(offset_mapping, list) and isinstance(offset_mapping[0], tuple): | |
offset_mapping = [offset_mapping] | |
if len(offset_mapping) != batch_size: | |
raise ValueError("offset_mapping should have the same batch size as the input") | |
return inputs, offset_mapping | |
class AggregationStrategy(ExplicitEnum): | |
"""All the valid aggregation strategies for TokenClassificationPipeline""" | |
NONE = "none" | |
SIMPLE = "simple" | |
FIRST = "first" | |
AVERAGE = "average" | |
MAX = "max" | |
class TokenClassificationPipeline(Pipeline): | |
""" | |
Named Entity Recognition pipeline using any :obj:`ModelForTokenClassification`. See the `named entity recognition | |
examples <../task_summary.html#named-entity-recognition>`__ for more information. | |
This token recognition pipeline can currently be loaded from :func:`~transformers.pipeline` using the following | |
task identifier: :obj:`"ner"` (for predicting the classes of tokens in a sequence: person, organisation, location | |
or miscellaneous). | |
The models that this pipeline can use are models that have been fine-tuned on a token classification task. See the | |
up-to-date list of available models on `huggingface.co/models | |
<https://huggingface.co/models?filter=token-classification>`__. | |
""" | |
default_input_names = "sequences" | |
def __init__( | |
self, | |
model: Union["PreTrainedModel", "TFPreTrainedModel"], | |
tokenizer: PreTrainedTokenizer, | |
modelcard: Optional[ModelCard] = None, | |
framework: Optional[str] = None, | |
args_parser: ArgumentHandler = TokenClassificationArgumentHandler(), | |
device: int = -1, | |
binary_output: bool = False, | |
ignore_labels=["O"], | |
task: str = "", | |
grouped_entities: Optional[bool] = None, | |
ignore_subwords: Optional[bool] = None, | |
aggregation_strategy: Optional[AggregationStrategy] = None, | |
): | |
super().__init__( | |
model=model, | |
tokenizer=tokenizer, | |
modelcard=modelcard, | |
framework=framework, | |
device=device, | |
binary_output=binary_output, | |
task=task, | |
) | |
self.check_model_type( | |
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING | |
if self.framework == "tf" | |
else MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING | |
) | |
self._basic_tokenizer = BasicTokenizer(do_lower_case=False) | |
self._args_parser = args_parser | |
self.ignore_labels = ignore_labels | |
if aggregation_strategy is None: | |
aggregation_strategy = AggregationStrategy.NONE | |
if grouped_entities is not None or ignore_subwords is not None: | |
if grouped_entities and ignore_subwords: | |
aggregation_strategy = AggregationStrategy.FIRST | |
elif grouped_entities and not ignore_subwords: | |
aggregation_strategy = AggregationStrategy.SIMPLE | |
else: | |
aggregation_strategy = AggregationStrategy.NONE | |
if grouped_entities is not None: | |
warnings.warn( | |
f'`grouped_entities` is deprecated and will be removed in version v5.0.0, defaulted to `aggregation_strategy="{aggregation_strategy}"` instead.' | |
) | |
if ignore_subwords is not None: | |
warnings.warn( | |
f'`ignore_subwords` is deprecated and will be removed in version v5.0.0, defaulted to `aggregation_strategy="{aggregation_strategy}"` instead.' | |
) | |
if isinstance(aggregation_strategy, str): | |
aggregation_strategy = AggregationStrategy[aggregation_strategy.upper()] | |
if ( | |
aggregation_strategy in {AggregationStrategy.FIRST, AggregationStrategy.MAX, AggregationStrategy.AVERAGE} | |
and not self.tokenizer.is_fast | |
): | |
raise ValueError( | |
"Slow tokenizers cannot handle subwords. Please set the `aggregation_strategy` option" | |
'to `"simple"` or use a fast tokenizer.' | |
) | |
self.aggregation_strategy = aggregation_strategy | |
def __call__(self, inputs: Union[str, List[str]], **kwargs): | |
""" | |
Classify each token of the text(s) given as inputs. | |
Args: | |
inputs (:obj:`str` or :obj:`List[str]`): | |
One or several texts (or one list of texts) for token classification. | |
Return: | |
A list or a list of list of :obj:`dict`: Each result comes as a list of dictionaries (one for each token in | |
the corresponding input, or each entity if this pipeline was instantiated with an aggregation_strategy) | |
with the following keys: | |
- **word** (:obj:`str`) -- The token/word classified. | |
- **score** (:obj:`float`) -- The corresponding probability for :obj:`entity`. | |
- **entity** (:obj:`str`) -- The entity predicted for that token/word (it is named `entity_group` when | |
`aggregation_strategy` is not :obj:`"none"`. | |
- **index** (:obj:`int`, only present when ``aggregation_strategy="none"``) -- The index of the | |
corresponding token in the sentence. | |
- **start** (:obj:`int`, `optional`) -- The index of the start of the corresponding entity in the sentence. | |
Only exists if the offsets are available within the tokenizer | |
- **end** (:obj:`int`, `optional`) -- The index of the end of the corresponding entity in the sentence. | |
Only exists if the offsets are available within the tokenizer | |
""" | |
_inputs, offset_mappings = self._args_parser(inputs, **kwargs) | |
answers = [] | |
for i, sentence in enumerate(_inputs): | |
# Manage correct placement of the tensors | |
with self.device_placement(): | |
tokens = self.tokenizer( | |
sentence, | |
return_attention_mask=False, | |
return_tensors=self.framework, | |
truncation=True, | |
return_special_tokens_mask=True, | |
return_offsets_mapping=self.tokenizer.is_fast, | |
) | |
if self.tokenizer.is_fast: | |
offset_mapping = tokens.pop("offset_mapping").cpu().numpy()[0] | |
elif offset_mappings: | |
offset_mapping = offset_mappings[i] | |
else: | |
offset_mapping = None | |
special_tokens_mask = tokens.pop("special_tokens_mask").cpu().numpy()[0] | |
# Forward | |
if self.framework == "tf": | |
entities = self.model(tokens.data)[0][0].numpy() | |
input_ids = tokens["input_ids"].numpy()[0] | |
else: | |
with torch.no_grad(): | |
tokens = self.ensure_tensor_on_device(**tokens) | |
entities = self.model(**tokens)[0][0].cpu().numpy() | |
input_ids = tokens["input_ids"].cpu().numpy()[0] | |
scores = np.exp(entities) / np.exp(entities).sum(-1, keepdims=True) | |
pre_entities = self.gather_pre_entities(sentence, input_ids, scores, offset_mapping, special_tokens_mask) | |
grouped_entities = self.aggregate(pre_entities, self.aggregation_strategy) | |
# Filter anything that is in self.ignore_labels | |
entities = [ | |
entity | |
for entity in grouped_entities | |
if entity.get("entity", None) not in self.ignore_labels | |
and entity.get("entity_group", None) not in self.ignore_labels | |
] | |
answers.append(entities) | |
if len(answers) == 1: | |
return answers[0] | |
return answers | |
def gather_pre_entities( | |
self, | |
sentence: str, | |
input_ids: np.ndarray, | |
scores: np.ndarray, | |
offset_mapping: Optional[List[Tuple[int, int]]], | |
special_tokens_mask: np.ndarray, | |
) -> List[dict]: | |
"""Fuse various numpy arrays into dicts with all the information needed for aggregation""" | |
pre_entities = [] | |
for idx, token_scores in enumerate(scores): | |
# Filter special_tokens, they should only occur | |
# at the sentence boundaries since we're not encoding pairs of | |
# sentences so we don't have to keep track of those. | |
if special_tokens_mask[idx]: | |
continue | |
word = self.tokenizer.convert_ids_to_tokens(int(input_ids[idx])) | |
if offset_mapping is not None: | |
start_ind, end_ind = offset_mapping[idx] | |
word_ref = sentence[start_ind:end_ind] | |
is_subword = len(word_ref) != len(word) | |
if int(input_ids[idx]) == self.tokenizer.unk_token_id: | |
word = word_ref | |
is_subword = False | |
else: | |
start_ind = None | |
end_ind = None | |
is_subword = False | |
pre_entity = { | |
"word": word, | |
"scores": token_scores, | |
"start": start_ind, | |
"end": end_ind, | |
"index": idx, | |
"is_subword": is_subword, | |
} | |
pre_entities.append(pre_entity) | |
return pre_entities | |
def aggregate(self, pre_entities: List[dict], aggregation_strategy: AggregationStrategy) -> List[dict]: | |
if aggregation_strategy in {AggregationStrategy.NONE, AggregationStrategy.SIMPLE}: | |
entities = [] | |
for pre_entity in pre_entities: | |
entity_idx = pre_entity["scores"].argmax() | |
score = pre_entity["scores"][entity_idx] | |
entity = { | |
"entity": self.model.config.id2label[entity_idx], | |
"score": score, | |
"index": pre_entity["index"], | |
"word": pre_entity["word"], | |
"start": pre_entity["start"], | |
"end": pre_entity["end"], | |
} | |
entities.append(entity) | |
else: | |
entities = self.aggregate_words(pre_entities, aggregation_strategy) | |
if aggregation_strategy == AggregationStrategy.NONE: | |
return entities | |
return self.group_entities(entities) | |
def aggregate_word(self, entities: List[dict], aggregation_strategy: AggregationStrategy) -> dict: | |
word = self.tokenizer.convert_tokens_to_string([entity["word"] for entity in entities]) | |
if aggregation_strategy == AggregationStrategy.FIRST: | |
scores = entities[0]["scores"] | |
idx = scores.argmax() | |
score = scores[idx] | |
entity = self.model.config.id2label[idx] | |
elif aggregation_strategy == AggregationStrategy.MAX: | |
max_entity = max(entities, key=lambda entity: entity["scores"].max()) | |
scores = max_entity["scores"] | |
idx = scores.argmax() | |
score = scores[idx] | |
entity = self.model.config.id2label[idx] | |
elif aggregation_strategy == AggregationStrategy.AVERAGE: | |
scores = np.stack([entity["scores"] for entity in entities]) | |
average_scores = np.nanmean(scores, axis=0) | |
entity_idx = average_scores.argmax() | |
entity = self.model.config.id2label[entity_idx] | |
score = average_scores[entity_idx] | |
else: | |
raise ValueError("Invalid aggregation_strategy") | |
new_entity = { | |
"entity": entity, | |
"score": score, | |
"word": word, | |
"start": entities[0]["start"], | |
"end": entities[-1]["end"], | |
} | |
return new_entity | |
def aggregate_words(self, entities: List[dict], aggregation_strategy: AggregationStrategy) -> List[dict]: | |
""" | |
Override tokens from a given word that disagree to force agreement on word boundaries. | |
Example: micro|soft| com|pany| B-ENT I-NAME I-ENT I-ENT will be rewritten with first strategy as microsoft| | |
company| B-ENT I-ENT | |
""" | |
assert aggregation_strategy not in { | |
AggregationStrategy.NONE, | |
AggregationStrategy.SIMPLE, | |
}, "NONE and SIMPLE strategies are invalid" | |
word_entities = [] | |
word_group = None | |
for entity in entities: | |
if word_group is None: | |
word_group = [entity] | |
elif entity["is_subword"]: | |
word_group.append(entity) | |
else: | |
word_entities.append(self.aggregate_word(word_group, aggregation_strategy)) | |
word_group = [entity] | |
# Last item | |
word_entities.append(self.aggregate_word(word_group, aggregation_strategy)) | |
return word_entities | |
def group_sub_entities(self, entities: List[dict]) -> dict: | |
""" | |
Group together the adjacent tokens with the same entity predicted. | |
Args: | |
entities (:obj:`dict`): The entities predicted by the pipeline. | |
""" | |
# Get the first entity in the entity group | |
entity = entities[0]["entity"].split("-")[-1] | |
scores = np.nanmean([entity["score"] for entity in entities]) | |
tokens = [entity["word"] for entity in entities] | |
entity_group = { | |
"entity_group": entity, | |
"score": np.mean(scores), | |
"word": self.tokenizer.convert_tokens_to_string(tokens), | |
"start": entities[0]["start"], | |
"end": entities[-1]["end"], | |
} | |
return entity_group | |
def get_tag(self, entity_name: str) -> Tuple[str, str]: | |
if entity_name.startswith("B-"): | |
bi = "B" | |
tag = entity_name[2:] | |
elif entity_name.startswith("I-"): | |
bi = "I" | |
tag = entity_name[2:] | |
else: | |
# It's not in B-, I- format | |
bi = "B" | |
tag = entity_name | |
return bi, tag | |
def group_entities(self, entities: List[dict]) -> List[dict]: | |
""" | |
Find and group together the adjacent tokens with the same entity predicted. | |
Args: | |
entities (:obj:`dict`): The entities predicted by the pipeline. | |
""" | |
entity_groups = [] | |
entity_group_disagg = [] | |
for entity in entities: | |
if not entity_group_disagg: | |
entity_group_disagg.append(entity) | |
continue | |
# If the current entity is similar and adjacent to the previous entity, | |
# append it to the disaggregated entity group | |
# The split is meant to account for the "B" and "I" prefixes | |
# Shouldn't merge if both entities are B-type | |
bi, tag = self.get_tag(entity["entity"]) | |
last_bi, last_tag = self.get_tag(entity_group_disagg[-1]["entity"]) | |
if tag == last_tag and bi != "B": | |
# Modify subword type to be previous_type | |
entity_group_disagg.append(entity) | |
else: | |
# If the current entity is different from the previous entity | |
# aggregate the disaggregated entity group | |
entity_groups.append(self.group_sub_entities(entity_group_disagg)) | |
entity_group_disagg = [entity] | |
if entity_group_disagg: | |
# it's the last entity, add it to the entity groups | |
entity_groups.append(self.group_sub_entities(entity_group_disagg)) | |
return entity_groups | |
NerPipeline = TokenClassificationPipeline | |