en_tako_query_analyzer / custom_textcat.py
noahjax's picture
Update spaCy pipeline
723fe48 verified
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
from spacy.util import registry
from thinc.types import Floats2d
from spacy.tokens import Doc
from spacy.pipeline import TextCategorizer
from spacy.training import Example, validate_examples
from spacy.pipeline.textcat import textcat_score
from spacy.vocab import Vocab
from spacy.scorer import Scorer
from spacy.language import Language
from thinc.api import Model
import numpy
@Language.factory(
"weighted_textcat",
assigns=["doc.cats"],
default_config={
"threshold": 0.0,
"scorer": {"@scorers": "spacy.textcat_scorer.v2"},
},
default_score_weights={
"cats_score": 1.0,
"cats_score_desc": None,
"cats_micro_p": None,
"cats_micro_r": None,
"cats_micro_f": None,
"cats_macro_p": None,
"cats_macro_r": None,
"cats_macro_f": None,
"cats_macro_auc": None,
"cats_f_per_type": None,
},
)
def make_textcat(
nlp: Language,
name: str,
model: Model[List[Doc], List[Floats2d]],
threshold: float,
scorer: Optional[Callable],
class_weights: Optional[List],
) -> "TextCategorizer":
"""Create a TextCategorizer component. The text categorizer predicts categories
over a whole document. It can learn one or more labels, and the labels are considered
to be mutually exclusive (i.e. one true label per doc).
model (Model[List[Doc], List[Floats2d]]): A model instance that predicts
scores for each category.
threshold (float): Cutoff to consider a prediction "positive".
scorer (Optional[Callable]): The scoring method.
"""
if class_weights == "null":
class_weights = None
return CustomTextcat(
nlp.vocab,
model,
name,
threshold=threshold,
scorer=scorer,
weights=class_weights,
)
def textcat_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
return Scorer.score_cats(
examples,
"cats",
multi_label=False,
**kwargs,
)
@registry.scorers("spacy.textcat_scorer.v2")
def make_textcat_scorer():
return textcat_score
class CustomTextcat(TextCategorizer):
def __init__(
self,
vocab: Vocab,
model: Model,
name: str = "textcat",
*,
threshold: float,
scorer: Optional[Callable] = textcat_score,
weights: Optional[List[float]] = None,
) -> None:
"""Initialize a text categorizer for single-label classification.
vocab (Vocab): The shared vocabulary.
model (thinc.api.Model): The Thinc Model powering the pipeline component.
name (str): The component instance name, used to add entries to the
losses during training.
threshold (float): Unused, not needed for single-label (exclusive
classes) classification.
scorer (Optional[Callable]): The scoring method. Defaults to
Scorer.score_cats for the attribute "cats".
DOCS: https://spacy.io/api/textcategorizer#init
"""
self.vocab = vocab
self.model = model
self.name = name
self._rehearsal_model = None
cfg: Dict[str, Any] = {
"labels": [],
"threshold": threshold,
"positive_label": None,
}
self.cfg = dict(cfg)
self.scorer = scorer
if weights is not None:
self.weights = numpy.array(weights)
def get_loss(self, examples: Iterable[Example], scores) -> Tuple[float, float]:
"""Find the loss and gradient of loss for the batch of documents and
their predicted scores.
examples (Iterable[Examples]): The batch of examples.
scores: Scores representing the model's predictions.
RETURNS (Tuple[float, float]): The loss and the gradient.
DOCS: https://spacy.io/api/textcategorizer#get_loss
"""
validate_examples(examples, "TextCategorizer.get_loss")
self._validate_categories(examples)
truths, not_missing = self._examples_to_truth(examples)
not_missing = self.model.ops.asarray(not_missing) # type: ignore
d_scores = scores - truths
d_scores *= not_missing
weights = self.model.ops.asarray(self.weights) # type: ignore
if weights is not None:
squared = d_scores**2
mean_square_error = numpy.sum(squared * weights) / (
numpy.sum(weights) * len(squared)
)
d_scores *= weights
else:
mean_square_error = (d_scores**2).mean()
return float(mean_square_error), d_scores