# Authors: The scikit-learn developers | |
# SPDX-License-Identifier: BSD-3-Clause | |
import numpy as np | |
import pytest | |
from sklearn.mixture import BayesianGaussianMixture, GaussianMixture | |
def test_gaussian_mixture_n_iter(estimator): | |
# check that n_iter is the number of iteration performed. | |
rng = np.random.RandomState(0) | |
X = rng.rand(10, 5) | |
max_iter = 1 | |
estimator.set_params(max_iter=max_iter) | |
estimator.fit(X) | |
assert estimator.n_iter_ == max_iter | |
def test_mixture_n_components_greater_than_n_samples_error(estimator): | |
"""Check error when n_components <= n_samples""" | |
rng = np.random.RandomState(0) | |
X = rng.rand(10, 5) | |
estimator.set_params(n_components=12) | |
msg = "Expected n_samples >= n_components" | |
with pytest.raises(ValueError, match=msg): | |
estimator.fit(X) | |