# Copyright 2020 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import random import unittest import timeout_decorator from ..file_utils import cached_property, is_torch_available from ..testing_utils import require_torch if is_torch_available(): import torch from ..models.marian import MarianConfig, MarianMTModel @require_torch class GenerationUtilsTest(unittest.TestCase): @cached_property def config(self): config = MarianConfig.from_pretrained("sshleifer/tiny-marian-en-de") return config @cached_property def model(self): return MarianMTModel(self.config) def test_postprocess_next_token_scores(self): config = self.config model = self.model # Initialize an input id tensor with batch size 8 and sequence length 12 input_ids = torch.arange(0, 96, 1).view((8, 12)) eos = config.eos_token_id bad_words_ids_test_cases = [[[299]], [[23, 24], [54]], [[config.eos_token_id]], []] masked_scores = [ [(0, 299), (1, 299), (2, 299), (3, 299), (4, 299), (5, 299), (6, 299), (7, 299)], [(1, 24), (0, 54), (1, 54), (2, 54), (3, 54), (4, 54), (5, 54), (6, 54), (7, 54)], [(0, eos), (1, eos), (2, eos), (3, eos), (4, eos), (5, eos), (6, eos), (7, eos)], [], ] for test_case_index, bad_words_ids in enumerate(bad_words_ids_test_cases): # Initialize a scores tensor with batch size 8 and vocabulary size 300 scores = torch.rand((8, 300)) output = model.postprocess_next_token_scores( scores, input_ids, 0, bad_words_ids, 13, 15, config.max_length, config.eos_token_id, config.repetition_penalty, 32, 5, ) for masked_score in masked_scores[test_case_index]: self.assertTrue(output[masked_score[0], masked_score[1]] == -float("inf")) @timeout_decorator.timeout(10) def test_postprocess_next_token_scores_large_bad_words_list(self): config = self.config model = self.model # Initialize an input id tensor with batch size 8 and sequence length 12 input_ids = torch.arange(0, 96, 1).view((8, 12)) bad_words_ids = [] for _ in range(100): length_bad_word = random.randint(1, 4) bad_words_ids.append(random.sample(range(1, 300), length_bad_word)) scores = torch.rand((8, 300)) _ = model.postprocess_next_token_scores( scores, input_ids, 0, bad_words_ids, 13, 15, config.max_length, config.eos_token_id, config.repetition_penalty, 32, 5, )