Spaces:
Running
Running
| import unittest | |
| from unittest.mock import patch | |
| import pandas as pd | |
| import src.backend.evaluate_model as evaluate_model | |
| class TestSummaryGenerator(unittest.TestCase): | |
| def setUp(self): | |
| self.model_id = "test_model" | |
| self.revision = "test_revision" | |
| def test_init(self, mock_model, mock_tokenizer): | |
| evaluate_model.SummaryGenerator(self.model_id, self.revision) | |
| mock_tokenizer.from_pretrained.assert_called_once_with(self.model_id, | |
| self.revision) | |
| mock_model.from_pretrained.assert_called_once_with(self.model_id, | |
| self.revision) | |
| def test_generate_summaries(self, mock_model, mock_tokenizer, mock_nlp): | |
| df = pd.DataFrame({'text': ['text1', 'text2'], | |
| 'dataset': ['dataset1', 'dataset2']}) | |
| generator = evaluate_model.SummaryGenerator(self.model_id, self.revision) | |
| generator.generate_summaries(df) | |
| self.assertEqual(len(generator.summaries_df), len(df)) | |
| def test_compute_avg_length(self, mock_model, mock_tokenizer): | |
| generator = evaluate_model.SummaryGenerator(self.model_id, self.revision) | |
| test_df = pd.DataFrame({'source': ['text'], 'summary': ['This is a test.'], | |
| 'dataset': ['dataset']}) | |
| generator.summaries_df = test_df | |
| generator._compute_avg_length() | |
| self.assertEqual(generator.avg_length, 4) | |
| def test_compute_answer_rate(self, mock_model, mock_tokenizer): | |
| generator = evaluate_model.SummaryGenerator(self.model_id, self.revision) | |
| test_df = pd.DataFrame({'source': ['text'], 'summary': ['This is a test.'], | |
| 'dataset': ['dataset']}) | |
| generator.summaries_df = test_df | |
| generator._compute_answer_rate() | |
| self.assertEqual(generator.answer_rate, 1) | |
| def test_error_rate(self, mock_model, mock_tokenizer): | |
| generator = evaluate_model.SummaryGenerator(self.model_id, self.revision) | |
| test_df = pd.DataFrame({'source': ['text'], 'summary': ['This is a test.'], | |
| 'dataset': ['dataset']}) | |
| generator.summaries_df = test_df | |
| generator._compute_error_rate(0) | |
| self.assertEqual(generator.error_rate, 0) | |
| if __name__ == "__main__": | |
| unittest.main() | |