import collections

import numpy as np

MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
                                          ["index", "label"])


def is_start_piece(piece):
    """Check if the current word piece is the starting piece (BERT)."""
    # When a word has been split into
    # WordPieces, the first token does not have any marker and any subsequence
    # tokens are prefixed with ##. So whenever we see the ## token, we
    # append it to the previous set of word indexes.
    return not piece.startswith("##")


def create_masked_lm_predictions(tokens,
                                 vocab_id_list, vocab_id_to_token_dict,
                                 masked_lm_prob,
                                 cls_id, sep_id, mask_id,
                                 max_predictions_per_seq,
                                 np_rng,
                                 max_ngrams=3,
                                 do_whole_word_mask=True,
                                 favor_longer_ngram=False,
                                 do_permutation=False,
                                 geometric_dist=False,
                                 masking_style="bert",
                                 zh_tokenizer=None):
    """Creates the predictions for the masked LM objective.
    Note: Tokens here are vocab ids and not text tokens."""
    '''
    modified from Megatron-LM
    Args:
        tokens: 输入
        vocab_id_list: 词表token_id_list
        vocab_id_to_token_dict: token_id到token字典
        masked_lm_prob:mask概率
        cls_id、sep_id、mask_id:特殊token
        max_predictions_per_seq:最大mask个数
        np_rng:mask随机数
        max_ngrams:最大词长度
        do_whole_word_mask:是否做全词掩码
        favor_longer_ngram:优先用长的词
        do_permutation:是否打乱
        geometric_dist:用np_rng.geometric做随机
        masking_style:mask类型
        zh_tokenizer:WWM的分词器,比如用jieba.lcut做分词之类的
    '''
    cand_indexes = []
    # Note(mingdachen): We create a list for recording if the piece is
    # the starting piece of current token, where 1 means true, so that
    # on-the-fly whole word masking is possible.
    token_boundary = [0] * len(tokens)
    # 如果没有指定中文分词器,那就直接按##算
    if zh_tokenizer is None:
        for (i, token) in enumerate(tokens):
            if token == cls_id or token == sep_id:
                token_boundary[i] = 1
                continue
        # Whole Word Masking means that if we mask all of the wordpieces
        # corresponding to an original word.
        #
        # Note that Whole Word Masking does *not* change the training code
        # at all -- we still predict each WordPiece independently, softmaxed
        # over the entire vocabulary.
            if (do_whole_word_mask and len(cand_indexes) >= 1 and
                    not is_start_piece(vocab_id_to_token_dict[token])):
                cand_indexes[-1].append(i)
            else:
                cand_indexes.append([i])
                if is_start_piece(vocab_id_to_token_dict[token]):
                    token_boundary[i] = 1
    else:
        # 如果指定了中文分词器,那就先用分词器分词,然后再进行判断
        # 获取去掉CLS SEP的原始文本
        raw_tokens = []
        for t in tokens:
            if t != cls_id and t != sep_id:
                raw_tokens.append(t)
        raw_tokens = [vocab_id_to_token_dict[i] for i in raw_tokens]
        # 分词然后获取每次字开头的最长词的长度
        word_list = set(zh_tokenizer(''.join(raw_tokens), HMM=True))
        word_length_dict = {}
        for w in word_list:
            if len(w) < 1:
                continue
            if w[0] not in word_length_dict:
                word_length_dict[w[0]] = len(w)
            elif word_length_dict[w[0]] < len(w):
                word_length_dict[w[0]] = len(w)
        i = 0
        # 从词表里面检索
        while i < len(tokens):
            token_id = tokens[i]
            token = vocab_id_to_token_dict[token_id]
            if len(token) == 0 or token_id == cls_id or token_id == sep_id:
                token_boundary[i] = 1
                i += 1
                continue
            word_max_length = 1
            if token[0] in word_length_dict:
                word_max_length = word_length_dict[token[0]]
            j = 0
            word = ''
            word_end = i+1
            # 兼容以前##的形式,如果后面的词是##开头的,那么直接把后面的拼到前面当作一个词
            old_style = False
            while word_end < len(tokens) and vocab_id_to_token_dict[tokens[word_end]].startswith('##'):
                old_style = True
                word_end += 1
            if not old_style:
                while j < word_max_length and i+j < len(tokens):
                    cur_token = tokens[i+j]
                    word += vocab_id_to_token_dict[cur_token]
                    j += 1
                    if word in word_list:
                        word_end = i+j
            cand_indexes.append([p for p in range(i, word_end)])
            token_boundary[i] = 1
            i = word_end

    output_tokens = list(tokens)

    masked_lm_positions = []
    masked_lm_labels = []

    if masked_lm_prob == 0:
        return (output_tokens, masked_lm_positions,
                masked_lm_labels, token_boundary)

    num_to_predict = min(max_predictions_per_seq,
                         max(1, int(round(len(tokens) * masked_lm_prob))))

    ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64)
    if not geometric_dist:
        # Note(mingdachen):
        # By default, we set the probilities to favor shorter ngram sequences.
        pvals = 1. / np.arange(1, max_ngrams + 1)
        pvals /= pvals.sum(keepdims=True)
        if favor_longer_ngram:
            pvals = pvals[::-1]
    # 获取一个ngram的idx,对于每个word,记录他的ngram的word
    ngram_indexes = []
    for idx in range(len(cand_indexes)):
        ngram_index = []
        for n in ngrams:
            ngram_index.append(cand_indexes[idx:idx + n])
        ngram_indexes.append(ngram_index)

    np_rng.shuffle(ngram_indexes)

    (masked_lms, masked_spans) = ([], [])
    covered_indexes = set()
    for cand_index_set in ngram_indexes:
        if len(masked_lms) >= num_to_predict:
            break
        if not cand_index_set:
            continue
        # Note(mingdachen):
        # Skip current piece if they are covered in lm masking or previous ngrams.
        for index_set in cand_index_set[0]:
            for index in index_set:
                if index in covered_indexes:
                    continue

        if not geometric_dist:
            n = np_rng.choice(ngrams[:len(cand_index_set)],
                              p=pvals[:len(cand_index_set)] /
                              pvals[:len(cand_index_set)].sum(keepdims=True))
        else:
            # Sampling "n" from the geometric distribution and clipping it to
            # the max_ngrams. Using p=0.2 default from the SpanBERT paper
            # https://arxiv.org/pdf/1907.10529.pdf (Sec 3.1)
            n = min(np_rng.geometric(0.2), max_ngrams)

        index_set = sum(cand_index_set[n - 1], [])
        n -= 1
        # Note(mingdachen):
        # Repeatedly looking for a candidate that does not exceed the
        # maximum number of predictions by trying shorter ngrams.
        while len(masked_lms) + len(index_set) > num_to_predict:
            if n == 0:
                break
            index_set = sum(cand_index_set[n - 1], [])
            n -= 1
        # If adding a whole-word mask would exceed the maximum number of
        # predictions, then just skip this candidate.
        if len(masked_lms) + len(index_set) > num_to_predict:
            continue
        is_any_index_covered = False
        for index in index_set:
            if index in covered_indexes:
                is_any_index_covered = True
                break
        if is_any_index_covered:
            continue
        for index in index_set:
            covered_indexes.add(index)
            masked_token = None
            token_id = tokens[index]
            if masking_style == "bert":
                # 80% of the time, replace with [MASK]
                if np_rng.random() < 0.8:
                    masked_token = mask_id
                else:
                    # 10% of the time, keep original
                    if np_rng.random() < 0.5:
                        masked_token = tokens[index]
                    # 10% of the time, replace with random word
                    else:
                        masked_token = vocab_id_list[np_rng.randint(0, len(vocab_id_list))]
            elif masking_style == "t5":
                masked_token = mask_id
            else:
                raise ValueError("invalid value of masking style")

            output_tokens[index] = masked_token
            masked_lms.append(MaskedLmInstance(index=index, label=token_id))

        masked_spans.append(MaskedLmInstance(
            index=index_set,
            label=[tokens[index] for index in index_set]))

    assert len(masked_lms) <= num_to_predict
    np_rng.shuffle(ngram_indexes)

    select_indexes = set()
    if do_permutation:
        for cand_index_set in ngram_indexes:
            if len(select_indexes) >= num_to_predict:
                break
            if not cand_index_set:
                continue
            # Note(mingdachen):
            # Skip current piece if they are covered in lm masking or previous ngrams.
            for index_set in cand_index_set[0]:
                for index in index_set:
                    if index in covered_indexes or index in select_indexes:
                        continue

            n = np.random.choice(ngrams[:len(cand_index_set)],
                                 p=pvals[:len(cand_index_set)] /
                                 pvals[:len(cand_index_set)].sum(keepdims=True))
            index_set = sum(cand_index_set[n - 1], [])
            n -= 1

            while len(select_indexes) + len(index_set) > num_to_predict:
                if n == 0:
                    break
                index_set = sum(cand_index_set[n - 1], [])
                n -= 1
            # If adding a whole-word mask would exceed the maximum number of
            # predictions, then just skip this candidate.
            if len(select_indexes) + len(index_set) > num_to_predict:
                continue
            is_any_index_covered = False
            for index in index_set:
                if index in covered_indexes or index in select_indexes:
                    is_any_index_covered = True
                    break
            if is_any_index_covered:
                continue
            for index in index_set:
                select_indexes.add(index)
        assert len(select_indexes) <= num_to_predict

        select_indexes = sorted(select_indexes)
        permute_indexes = list(select_indexes)
        np_rng.shuffle(permute_indexes)
        orig_token = list(output_tokens)

        for src_i, tgt_i in zip(select_indexes, permute_indexes):
            output_tokens[src_i] = orig_token[tgt_i]
            masked_lms.append(MaskedLmInstance(index=src_i, label=orig_token[src_i]))

    masked_lms = sorted(masked_lms, key=lambda x: x.index)
    # Sort the spans by the index of the first span
    masked_spans = sorted(masked_spans, key=lambda x: x.index[0])

    for p in masked_lms:
        masked_lm_positions.append(p.index)
        masked_lm_labels.append(p.label)
    return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary, masked_spans)