Spaces:
Runtime error
Runtime error
# Copyright 2021 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import unittest | |
import numpy as np | |
import timeout_decorator # noqa | |
from transformers import MarianConfig, is_flax_available | |
from transformers.testing_utils import require_flax, require_sentencepiece, require_tokenizers, slow | |
from transformers.utils import cached_property | |
from ...generation.test_flax_utils import FlaxGenerationTesterMixin | |
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor | |
if is_flax_available(): | |
import os | |
# The slow tests are often failing with OOM error on GPU | |
# This makes JAX allocate exactly what is needed on demand, and deallocate memory that is no longer needed | |
# but will be slower as stated here https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html | |
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" | |
import jax | |
import jax.numpy as jnp | |
from transformers import MarianTokenizer | |
from transformers.models.marian.modeling_flax_marian import FlaxMarianModel, FlaxMarianMTModel, shift_tokens_right | |
def prepare_marian_inputs_dict( | |
config, | |
input_ids, | |
decoder_input_ids=None, | |
attention_mask=None, | |
decoder_attention_mask=None, | |
head_mask=None, | |
decoder_head_mask=None, | |
cross_attn_head_mask=None, | |
): | |
if attention_mask is None: | |
attention_mask = np.where(input_ids != config.pad_token_id, 1, 0) | |
if decoder_attention_mask is None: | |
decoder_attention_mask = np.where(decoder_input_ids != config.pad_token_id, 1, 0) | |
if head_mask is None: | |
head_mask = np.ones((config.encoder_layers, config.encoder_attention_heads)) | |
if decoder_head_mask is None: | |
decoder_head_mask = np.ones((config.decoder_layers, config.decoder_attention_heads)) | |
if cross_attn_head_mask is None: | |
cross_attn_head_mask = np.ones((config.decoder_layers, config.decoder_attention_heads)) | |
return { | |
"input_ids": input_ids, | |
"decoder_input_ids": decoder_input_ids, | |
"attention_mask": attention_mask, | |
"decoder_attention_mask": attention_mask, | |
} | |
class FlaxMarianModelTester: | |
def __init__( | |
self, | |
parent, | |
batch_size=13, | |
seq_length=7, | |
is_training=True, | |
use_labels=False, | |
vocab_size=99, | |
hidden_size=16, | |
num_hidden_layers=2, | |
num_attention_heads=4, | |
intermediate_size=4, | |
hidden_act="gelu", | |
hidden_dropout_prob=0.1, | |
attention_probs_dropout_prob=0.1, | |
max_position_embeddings=32, | |
eos_token_id=2, | |
pad_token_id=1, | |
bos_token_id=0, | |
initializer_range=0.02, | |
): | |
self.parent = parent | |
self.batch_size = batch_size | |
self.seq_length = seq_length | |
self.is_training = is_training | |
self.use_labels = use_labels | |
self.vocab_size = vocab_size | |
self.hidden_size = hidden_size | |
self.num_hidden_layers = num_hidden_layers | |
self.num_attention_heads = num_attention_heads | |
self.intermediate_size = intermediate_size | |
self.hidden_act = hidden_act | |
self.hidden_dropout_prob = hidden_dropout_prob | |
self.attention_probs_dropout_prob = attention_probs_dropout_prob | |
self.max_position_embeddings = max_position_embeddings | |
self.eos_token_id = eos_token_id | |
self.pad_token_id = pad_token_id | |
self.bos_token_id = bos_token_id | |
self.initializer_range = initializer_range | |
def prepare_config_and_inputs(self): | |
input_ids = np.clip(ids_tensor([self.batch_size, self.seq_length - 1], self.vocab_size), 3, self.vocab_size) | |
input_ids = np.concatenate((input_ids, 2 * np.ones((self.batch_size, 1), dtype=np.int64)), -1) | |
decoder_input_ids = shift_tokens_right(input_ids, 1, 2) | |
config = MarianConfig( | |
vocab_size=self.vocab_size, | |
d_model=self.hidden_size, | |
encoder_layers=self.num_hidden_layers, | |
decoder_layers=self.num_hidden_layers, | |
encoder_attention_heads=self.num_attention_heads, | |
decoder_attention_heads=self.num_attention_heads, | |
encoder_ffn_dim=self.intermediate_size, | |
decoder_ffn_dim=self.intermediate_size, | |
dropout=self.hidden_dropout_prob, | |
attention_dropout=self.attention_probs_dropout_prob, | |
max_position_embeddings=self.max_position_embeddings, | |
eos_token_id=self.eos_token_id, | |
bos_token_id=self.bos_token_id, | |
pad_token_id=self.pad_token_id, | |
initializer_range=self.initializer_range, | |
use_cache=False, | |
) | |
inputs_dict = prepare_marian_inputs_dict(config, input_ids, decoder_input_ids) | |
return config, inputs_dict | |
def prepare_config_and_inputs_for_common(self): | |
config, inputs_dict = self.prepare_config_and_inputs() | |
return config, inputs_dict | |
def check_use_cache_forward(self, model_class_name, config, inputs_dict): | |
max_decoder_length = 20 | |
model = model_class_name(config) | |
encoder_outputs = model.encode(inputs_dict["input_ids"]) | |
decoder_input_ids, decoder_attention_mask = ( | |
inputs_dict["decoder_input_ids"], | |
inputs_dict["decoder_attention_mask"], | |
) | |
past_key_values = model.init_cache(decoder_input_ids.shape[0], max_decoder_length, encoder_outputs) | |
decoder_attention_mask = jnp.ones((decoder_input_ids.shape[0], max_decoder_length), dtype="i4") | |
decoder_position_ids = jnp.broadcast_to( | |
jnp.arange(decoder_input_ids.shape[-1] - 1)[None, :], | |
(decoder_input_ids.shape[0], decoder_input_ids.shape[-1] - 1), | |
) | |
outputs_cache = model.decode( | |
decoder_input_ids[:, :-1], | |
encoder_outputs, | |
decoder_attention_mask=decoder_attention_mask, | |
past_key_values=past_key_values, | |
decoder_position_ids=decoder_position_ids, | |
) | |
decoder_position_ids = jnp.array(decoder_input_ids.shape[0] * [[decoder_input_ids.shape[-1] - 1]], dtype="i4") | |
outputs_cache_next = model.decode( | |
decoder_input_ids[:, -1:], | |
encoder_outputs, | |
decoder_attention_mask=decoder_attention_mask, | |
past_key_values=outputs_cache.past_key_values, | |
decoder_position_ids=decoder_position_ids, | |
) | |
outputs = model.decode(decoder_input_ids, encoder_outputs) | |
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5]))) | |
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}") | |
def check_use_cache_forward_with_attn_mask(self, model_class_name, config, inputs_dict): | |
max_decoder_length = 20 | |
model = model_class_name(config) | |
encoder_outputs = model.encode(inputs_dict["input_ids"]) | |
decoder_input_ids, decoder_attention_mask = ( | |
inputs_dict["decoder_input_ids"], | |
inputs_dict["decoder_attention_mask"], | |
) | |
decoder_attention_mask_cache = jnp.concatenate( | |
[ | |
decoder_attention_mask, | |
jnp.zeros((decoder_attention_mask.shape[0], max_decoder_length - decoder_attention_mask.shape[1])), | |
], | |
axis=-1, | |
) | |
past_key_values = model.init_cache(decoder_input_ids.shape[0], max_decoder_length, encoder_outputs) | |
decoder_position_ids = jnp.broadcast_to( | |
jnp.arange(decoder_input_ids.shape[-1] - 1)[None, :], | |
(decoder_input_ids.shape[0], decoder_input_ids.shape[-1] - 1), | |
) | |
outputs_cache = model.decode( | |
decoder_input_ids[:, :-1], | |
encoder_outputs, | |
decoder_attention_mask=decoder_attention_mask_cache, | |
past_key_values=past_key_values, | |
decoder_position_ids=decoder_position_ids, | |
) | |
decoder_position_ids = jnp.array(decoder_input_ids.shape[0] * [[decoder_input_ids.shape[-1] - 1]], dtype="i4") | |
outputs_cache_next = model.decode( | |
decoder_input_ids[:, -1:], | |
encoder_outputs, | |
past_key_values=outputs_cache.past_key_values, | |
decoder_attention_mask=decoder_attention_mask_cache, | |
decoder_position_ids=decoder_position_ids, | |
) | |
outputs = model.decode(decoder_input_ids, encoder_outputs, decoder_attention_mask=decoder_attention_mask) | |
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5]))) | |
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}") | |
class FlaxMarianModelTest(FlaxModelTesterMixin, unittest.TestCase, FlaxGenerationTesterMixin): | |
is_encoder_decoder = True | |
all_model_classes = (FlaxMarianModel, FlaxMarianMTModel) if is_flax_available() else () | |
all_generative_model_classes = (FlaxMarianMTModel,) if is_flax_available() else () | |
def setUp(self): | |
self.model_tester = FlaxMarianModelTester(self) | |
def test_use_cache_forward(self): | |
config, inputs_dict = self.model_tester.prepare_config_and_inputs() | |
for model_class in self.all_model_classes: | |
self.model_tester.check_use_cache_forward(model_class, config, inputs_dict) | |
def test_use_cache_forward_with_attn_mask(self): | |
config, inputs_dict = self.model_tester.prepare_config_and_inputs() | |
for model_class in self.all_model_classes: | |
self.model_tester.check_use_cache_forward_with_attn_mask(model_class, config, inputs_dict) | |
def test_encode(self): | |
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() | |
for model_class in self.all_model_classes: | |
with self.subTest(model_class.__name__): | |
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) | |
model = model_class(config) | |
def encode_jitted(input_ids, attention_mask=None, **kwargs): | |
return model.encode(input_ids=input_ids, attention_mask=attention_mask) | |
with self.subTest("JIT Enabled"): | |
jitted_outputs = encode_jitted(**prepared_inputs_dict).to_tuple() | |
with self.subTest("JIT Disabled"): | |
with jax.disable_jit(): | |
outputs = encode_jitted(**prepared_inputs_dict).to_tuple() | |
self.assertEqual(len(outputs), len(jitted_outputs)) | |
for jitted_output, output in zip(jitted_outputs, outputs): | |
self.assertEqual(jitted_output.shape, output.shape) | |
def test_decode(self): | |
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() | |
for model_class in self.all_model_classes: | |
with self.subTest(model_class.__name__): | |
model = model_class(config) | |
encoder_outputs = model.encode(inputs_dict["input_ids"], inputs_dict["attention_mask"]) | |
prepared_inputs_dict = { | |
"decoder_input_ids": inputs_dict["decoder_input_ids"], | |
"decoder_attention_mask": inputs_dict["decoder_attention_mask"], | |
"encoder_outputs": encoder_outputs, | |
} | |
def decode_jitted(decoder_input_ids, decoder_attention_mask, encoder_outputs): | |
return model.decode( | |
decoder_input_ids=decoder_input_ids, | |
decoder_attention_mask=decoder_attention_mask, | |
encoder_outputs=encoder_outputs, | |
) | |
with self.subTest("JIT Enabled"): | |
jitted_outputs = decode_jitted(**prepared_inputs_dict).to_tuple() | |
with self.subTest("JIT Disabled"): | |
with jax.disable_jit(): | |
outputs = decode_jitted(**prepared_inputs_dict).to_tuple() | |
self.assertEqual(len(outputs), len(jitted_outputs)) | |
for jitted_output, output in zip(jitted_outputs, outputs): | |
self.assertEqual(jitted_output.shape, output.shape) | |
def test_model_from_pretrained(self): | |
for model_class_name in self.all_model_classes: | |
model = model_class_name.from_pretrained("Helsinki-NLP/opus-mt-en-de") | |
# FlaxMarianForSequenceClassification expects eos token in input_ids | |
input_ids = np.ones((1, 1)) * model.config.eos_token_id | |
outputs = model(input_ids) | |
self.assertIsNotNone(outputs) | |
class MarianIntegrationTest(unittest.TestCase): | |
src = None | |
tgt = None | |
def setUpClass(cls) -> None: | |
cls.model_name = f"Helsinki-NLP/opus-mt-{cls.src}-{cls.tgt}" | |
return cls | |
def tokenizer(self): | |
return MarianTokenizer.from_pretrained(self.model_name) | |
def eos_token_id(self) -> int: | |
return self.tokenizer.eos_token_id | |
def model(self): | |
model: FlaxMarianMTModel = FlaxMarianMTModel.from_pretrained(self.model_name) | |
self.assertEqual(model.config.decoder_start_token_id, model.config.pad_token_id) | |
return model | |
def _assert_generated_batch_equal_expected(self, **tokenizer_kwargs): | |
generated_words = self.translate_src_text(**tokenizer_kwargs) | |
self.assertListEqual(self.expected_text, generated_words) | |
def translate_src_text(self, **tokenizer_kwargs): | |
model_inputs = self.tokenizer(self.src_text, padding=True, return_tensors="np", **tokenizer_kwargs) | |
generated_ids = self.model.generate( | |
model_inputs.input_ids, | |
attention_mask=model_inputs.attention_mask, | |
num_beams=2, | |
max_length=128, | |
).sequences | |
generated_words = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True) | |
return generated_words | |
class TestMarian_EN_FR(MarianIntegrationTest): | |
src = "en" | |
tgt = "fr" | |
src_text = [ | |
"I am a small frog.", | |
"Now I can forget the 100 words of german that I know.", | |
] | |
expected_text = [ | |
"Je suis une petite grenouille.", | |
"Maintenant, je peux oublier les 100 mots d'allemand que je connais.", | |
] | |
def test_batch_generation_en_fr(self): | |
self._assert_generated_batch_equal_expected() | |
class TestMarian_FR_EN(MarianIntegrationTest): | |
src = "fr" | |
tgt = "en" | |
src_text = [ | |
"Donnez moi le micro.", | |
"Tom et Mary étaient assis à une table.", # Accents | |
] | |
expected_text = [ | |
"Give me the microphone.", | |
"Tom and Mary were sitting at a table.", | |
] | |
def test_batch_generation_fr_en(self): | |
self._assert_generated_batch_equal_expected() | |
class TestMarian_MT_EN(MarianIntegrationTest): | |
"""Cover low resource/high perplexity setting. This breaks without adjust_logits_generation overwritten""" | |
src = "mt" | |
tgt = "en" | |
src_text = ["Billi messu b'mod ġentili, Ġesù fejjaq raġel li kien milqut bil - marda kerha tal - ġdiem."] | |
expected_text = ["Touching gently, Jesus healed a man who was affected by the sad disease of leprosy."] | |
def test_batch_generation_mt_en(self): | |
self._assert_generated_batch_equal_expected() | |
class TestMarian_EN_DE(MarianIntegrationTest): | |
src = "en" | |
tgt = "de" | |
src_text = [ | |
"I am a small frog.", | |
"Now I can forget the 100 words of german that I know.", | |
"Tom asked his teacher for advice.", | |
"That's how I would do it.", | |
"Tom really admired Mary's courage.", | |
"Turn around and close your eyes.", | |
] | |
expected_text = [ | |
"Ich bin ein kleiner Frosch.", | |
"Jetzt kann ich die 100 Wörter des Deutschen vergessen, die ich kenne.", | |
"Tom bat seinen Lehrer um Rat.", | |
"So würde ich das machen.", | |
"Tom bewunderte Marias Mut wirklich.", | |
"Drehen Sie sich um und schließen Sie die Augen.", | |
] | |
def test_batch_generation_en_de(self): | |
self._assert_generated_batch_equal_expected() | |
class TestMarian_en_zh(MarianIntegrationTest): | |
src = "en" | |
tgt = "zh" | |
src_text = ["My name is Wolfgang and I live in Berlin"] | |
expected_text = ["我叫沃尔夫冈 我住在柏林"] | |
def test_batch_generation_eng_zho(self): | |
self._assert_generated_batch_equal_expected() | |
class TestMarian_RU_FR(MarianIntegrationTest): | |
src = "ru" | |
tgt = "fr" | |
src_text = ["Он показал мне рукопись своей новой пьесы."] | |
expected_text = ["Il m'a montré le manuscrit de sa nouvelle pièce."] | |
def test_batch_generation_ru_fr(self): | |
self._assert_generated_batch_equal_expected() | |
class TestMarian_en_ROMANCE(MarianIntegrationTest): | |
"""Multilingual on target side.""" | |
src = "en" | |
tgt = "ROMANCE" | |
src_text = [ | |
">>fr<< Don't spend so much time watching TV.", | |
">>pt<< Your message has been sent.", | |
">>es<< He's two years older than me.", | |
] | |
expected_text = [ | |
"Ne passez pas autant de temps à regarder la télé.", | |
"A sua mensagem foi enviada.", | |
"Es dos años más viejo que yo.", | |
] | |
def test_batch_generation_en_ROMANCE_multi(self): | |
self._assert_generated_batch_equal_expected() | |