Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from examples.speech_recognition.criterions.cross_entropy_acc import ( | |
| CrossEntropyWithAccCriterion, | |
| ) | |
| from .asr_test_base import CrossEntropyCriterionTestBase | |
| class CrossEntropyWithAccCriterionTest(CrossEntropyCriterionTestBase): | |
| def setUp(self): | |
| self.criterion_cls = CrossEntropyWithAccCriterion | |
| super().setUp() | |
| def test_cross_entropy_all_correct(self): | |
| sample = self.get_test_sample(correct=True, soft_target=False, aggregate=False) | |
| loss, sample_size, logging_output = self.criterion( | |
| self.model, sample, "sum", log_probs=True | |
| ) | |
| assert logging_output["correct"] == 20 | |
| assert logging_output["total"] == 20 | |
| assert logging_output["sample_size"] == 20 | |
| assert logging_output["ntokens"] == 20 | |
| def test_cross_entropy_all_wrong(self): | |
| sample = self.get_test_sample(correct=False, soft_target=False, aggregate=False) | |
| loss, sample_size, logging_output = self.criterion( | |
| self.model, sample, "sum", log_probs=True | |
| ) | |
| assert logging_output["correct"] == 0 | |
| assert logging_output["total"] == 20 | |
| assert logging_output["sample_size"] == 20 | |
| assert logging_output["ntokens"] == 20 | |