"""
Part of Speech Constraint
--------------------------
"""
import flair
from flair.data import Sentence
from flair.models import SequenceTagger
import lru
import nltk

import textattack
from textattack.constraints import Constraint
from textattack.shared.utils import LazyLoader, device
from textattack.shared.validators import transformation_consists_of_word_swaps

# Set global flair device to be TextAttack's current device
flair.device = device

stanza = LazyLoader("stanza", globals(), "stanza")


class PartOfSpeech(Constraint):
    """Constraints word swaps to only swap words with the same part of speech.
    Uses the NLTK universal part-of-speech tagger by default. An implementation
    of `<https://arxiv.org/abs/1907.11932>`_ adapted from
    `<https://github.com/jind11/TextFooler>`_.

    POS taggers from Flair `<https://github.com/flairNLP/flair>`_ and
    Stanza `<https://github.com/stanfordnlp/stanza>`_ are also available

    Args:
        tagger_type (str): Name of the tagger to use (available choices: "nltk", "flair", "stanza").
        tagset (str): tagset to use for POS tagging (e.g. "universal")
        allow_verb_noun_swap (bool): If `True`, allow verbs to be swapped with nouns and vice versa.
        compare_against_original (bool): If `True`, compare against the original text.
            Otherwise, compare against the most recent text.
        language_nltk: Language to be used for nltk POS-Tagger
            (available choices: "eng", "rus")
        language_stanza: Language to be used for stanza POS-Tagger
            (available choices: https://stanfordnlp.github.io/stanza/available_models.html)
    """

    def __init__(
        self,
        tagger_type="nltk",
        tagset="universal",
        allow_verb_noun_swap=True,
        compare_against_original=True,
        language_nltk="eng",
        language_stanza="en",
    ):
        super().__init__(compare_against_original)
        self.tagger_type = tagger_type
        self.tagset = tagset
        self.allow_verb_noun_swap = allow_verb_noun_swap
        self.language_nltk = language_nltk
        self.language_stanza = language_stanza

        self._pos_tag_cache = lru.LRU(2**14)
        if tagger_type == "flair":
            if tagset == "universal":
                self._flair_pos_tagger = SequenceTagger.load("upos-fast")
            else:
                self._flair_pos_tagger = SequenceTagger.load("pos-fast")

        if tagger_type == "stanza":
            self._stanza_pos_tagger = stanza.Pipeline(
                lang=self.language_stanza,
                processors="tokenize, pos",
                tokenize_pretokenized=True,
            )

    def clear_cache(self):
        self._pos_tag_cache.clear()

    def _can_replace_pos(self, pos_a, pos_b):
        return (pos_a == pos_b) or (
            self.allow_verb_noun_swap and set([pos_a, pos_b]) <= set(["NOUN", "VERB"])
        )

    def _get_pos(self, before_ctx, word, after_ctx):
        context_words = before_ctx + [word] + after_ctx
        context_key = " ".join(context_words)
        if context_key in self._pos_tag_cache:
            word_list, pos_list = self._pos_tag_cache[context_key]
        else:
            if self.tagger_type == "nltk":
                word_list, pos_list = zip(
                    *nltk.pos_tag(
                        context_words, tagset=self.tagset, lang=self.language_nltk
                    )
                )

            if self.tagger_type == "flair":
                context_key_sentence = Sentence(
                    context_key,
                    use_tokenizer=textattack.shared.utils.TextAttackFlairTokenizer(),
                )
                self._flair_pos_tagger.predict(context_key_sentence)
                word_list, pos_list = textattack.shared.utils.zip_flair_result(
                    context_key_sentence
                )

            if self.tagger_type == "stanza":
                word_list, pos_list = textattack.shared.utils.zip_stanza_result(
                    self._stanza_pos_tagger(context_key), tagset=self.tagset
                )

            self._pos_tag_cache[context_key] = (word_list, pos_list)

        # idx of `word` in `context_words`
        assert word in word_list, "POS list not matched with original word list."
        word_idx = word_list.index(word)
        return pos_list[word_idx]

    def _check_constraint(self, transformed_text, reference_text):
        try:
            indices = transformed_text.attack_attrs["newly_modified_indices"]
        except KeyError:
            raise KeyError(
                "Cannot apply part-of-speech constraint without `newly_modified_indices`"
            )

        for i in indices:
            reference_word = reference_text.words[i]
            transformed_word = transformed_text.words[i]
            before_ctx = reference_text.words[max(i - 4, 0) : i]
            after_ctx = reference_text.words[
                i + 1 : min(i + 4, len(reference_text.words))
            ]
            ref_pos = self._get_pos(before_ctx, reference_word, after_ctx)
            replace_pos = self._get_pos(before_ctx, transformed_word, after_ctx)
            if not self._can_replace_pos(ref_pos, replace_pos):
                return False

        return True

    def check_compatibility(self, transformation):
        return transformation_consists_of_word_swaps(transformation)

    def extra_repr_keys(self):
        return [
            "tagger_type",
            "tagset",
            "allow_verb_noun_swap",
        ] + super().extra_repr_keys()

    def __getstate__(self):
        state = self.__dict__.copy()
        state["_pos_tag_cache"] = self._pos_tag_cache.get_size()
        return state

    def __setstate__(self, state):
        self.__dict__ = state
        self._pos_tag_cache = lru.LRU(state["_pos_tag_cache"])