File size: 8,857 Bytes
7885a28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
"""
Tests for the birch clustering algorithm.
"""

import numpy as np
import pytest

from sklearn.cluster import AgglomerativeClustering, Birch
from sklearn.cluster.tests.common import generate_clustered_data
from sklearn.datasets import make_blobs
from sklearn.exceptions import ConvergenceWarning
from sklearn.metrics import pairwise_distances_argmin, v_measure_score
from sklearn.utils._testing import assert_allclose, assert_array_equal
from sklearn.utils.fixes import CSR_CONTAINERS


def test_n_samples_leaves_roots(global_random_seed, global_dtype):
    # Sanity check for the number of samples in leaves and roots
    X, y = make_blobs(n_samples=10, random_state=global_random_seed)
    X = X.astype(global_dtype, copy=False)
    brc = Birch()
    brc.fit(X)
    n_samples_root = sum([sc.n_samples_ for sc in brc.root_.subclusters_])
    n_samples_leaves = sum(
        [sc.n_samples_ for leaf in brc._get_leaves() for sc in leaf.subclusters_]
    )
    assert n_samples_leaves == X.shape[0]
    assert n_samples_root == X.shape[0]


def test_partial_fit(global_random_seed, global_dtype):
    # Test that fit is equivalent to calling partial_fit multiple times
    X, y = make_blobs(n_samples=100, random_state=global_random_seed)
    X = X.astype(global_dtype, copy=False)
    brc = Birch(n_clusters=3)
    brc.fit(X)
    brc_partial = Birch(n_clusters=None)
    brc_partial.partial_fit(X[:50])
    brc_partial.partial_fit(X[50:])
    assert_allclose(brc_partial.subcluster_centers_, brc.subcluster_centers_)

    # Test that same global labels are obtained after calling partial_fit
    # with None
    brc_partial.set_params(n_clusters=3)
    brc_partial.partial_fit(None)
    assert_array_equal(brc_partial.subcluster_labels_, brc.subcluster_labels_)


def test_birch_predict(global_random_seed, global_dtype):
    # Test the predict method predicts the nearest centroid.
    rng = np.random.RandomState(global_random_seed)
    X = generate_clustered_data(n_clusters=3, n_features=3, n_samples_per_cluster=10)
    X = X.astype(global_dtype, copy=False)

    # n_samples * n_samples_per_cluster
    shuffle_indices = np.arange(30)
    rng.shuffle(shuffle_indices)
    X_shuffle = X[shuffle_indices, :]
    brc = Birch(n_clusters=4, threshold=1.0)
    brc.fit(X_shuffle)

    # Birch must preserve inputs' dtype
    assert brc.subcluster_centers_.dtype == global_dtype

    assert_array_equal(brc.labels_, brc.predict(X_shuffle))
    centroids = brc.subcluster_centers_
    nearest_centroid = brc.subcluster_labels_[
        pairwise_distances_argmin(X_shuffle, centroids)
    ]
    assert_allclose(v_measure_score(nearest_centroid, brc.labels_), 1.0)


def test_n_clusters(global_random_seed, global_dtype):
    # Test that n_clusters param works properly
    X, y = make_blobs(n_samples=100, centers=10, random_state=global_random_seed)
    X = X.astype(global_dtype, copy=False)
    brc1 = Birch(n_clusters=10)
    brc1.fit(X)
    assert len(brc1.subcluster_centers_) > 10
    assert len(np.unique(brc1.labels_)) == 10

    # Test that n_clusters = Agglomerative Clustering gives
    # the same results.
    gc = AgglomerativeClustering(n_clusters=10)
    brc2 = Birch(n_clusters=gc)
    brc2.fit(X)
    assert_array_equal(brc1.subcluster_labels_, brc2.subcluster_labels_)
    assert_array_equal(brc1.labels_, brc2.labels_)

    # Test that a small number of clusters raises a warning.
    brc4 = Birch(threshold=10000.0)
    with pytest.warns(ConvergenceWarning):
        brc4.fit(X)


@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
def test_sparse_X(global_random_seed, global_dtype, csr_container):
    # Test that sparse and dense data give same results
    X, y = make_blobs(n_samples=100, centers=10, random_state=global_random_seed)
    X = X.astype(global_dtype, copy=False)
    brc = Birch(n_clusters=10)
    brc.fit(X)

    csr = csr_container(X)
    brc_sparse = Birch(n_clusters=10)
    brc_sparse.fit(csr)

    # Birch must preserve inputs' dtype
    assert brc_sparse.subcluster_centers_.dtype == global_dtype

    assert_array_equal(brc.labels_, brc_sparse.labels_)
    assert_allclose(brc.subcluster_centers_, brc_sparse.subcluster_centers_)


def test_partial_fit_second_call_error_checks():
    # second partial fit calls will error when n_features is not consistent
    # with the first call
    X, y = make_blobs(n_samples=100)
    brc = Birch(n_clusters=3)
    brc.partial_fit(X, y)

    msg = "X has 1 features, but Birch is expecting 2 features"
    with pytest.raises(ValueError, match=msg):
        brc.partial_fit(X[:, [0]], y)


def check_branching_factor(node, branching_factor):
    subclusters = node.subclusters_
    assert branching_factor >= len(subclusters)
    for cluster in subclusters:
        if cluster.child_:
            check_branching_factor(cluster.child_, branching_factor)


def test_branching_factor(global_random_seed, global_dtype):
    # Test that nodes have at max branching_factor number of subclusters
    X, y = make_blobs(random_state=global_random_seed)
    X = X.astype(global_dtype, copy=False)
    branching_factor = 9

    # Purposefully set a low threshold to maximize the subclusters.
    brc = Birch(n_clusters=None, branching_factor=branching_factor, threshold=0.01)
    brc.fit(X)
    check_branching_factor(brc.root_, branching_factor)
    brc = Birch(n_clusters=3, branching_factor=branching_factor, threshold=0.01)
    brc.fit(X)
    check_branching_factor(brc.root_, branching_factor)


def check_threshold(birch_instance, threshold):
    """Use the leaf linked list for traversal"""
    current_leaf = birch_instance.dummy_leaf_.next_leaf_
    while current_leaf:
        subclusters = current_leaf.subclusters_
        for sc in subclusters:
            assert threshold >= sc.radius
        current_leaf = current_leaf.next_leaf_


def test_threshold(global_random_seed, global_dtype):
    # Test that the leaf subclusters have a threshold lesser than radius
    X, y = make_blobs(n_samples=80, centers=4, random_state=global_random_seed)
    X = X.astype(global_dtype, copy=False)
    brc = Birch(threshold=0.5, n_clusters=None)
    brc.fit(X)
    check_threshold(brc, 0.5)

    brc = Birch(threshold=5.0, n_clusters=None)
    brc.fit(X)
    check_threshold(brc, 5.0)


def test_birch_n_clusters_long_int():
    # Check that birch supports n_clusters with np.int64 dtype, for instance
    # coming from np.arange. #16484
    X, _ = make_blobs(random_state=0)
    n_clusters = np.int64(5)
    Birch(n_clusters=n_clusters).fit(X)


def test_feature_names_out():
    """Check `get_feature_names_out` for `Birch`."""
    X, _ = make_blobs(n_samples=80, n_features=4, random_state=0)
    brc = Birch(n_clusters=4)
    brc.fit(X)
    n_clusters = brc.subcluster_centers_.shape[0]

    names_out = brc.get_feature_names_out()
    assert_array_equal([f"birch{i}" for i in range(n_clusters)], names_out)


def test_transform_match_across_dtypes(global_random_seed):
    X, _ = make_blobs(n_samples=80, n_features=4, random_state=global_random_seed)
    brc = Birch(n_clusters=4, threshold=1.1)
    Y_64 = brc.fit_transform(X)
    Y_32 = brc.fit_transform(X.astype(np.float32))

    assert_allclose(Y_64, Y_32, atol=1e-6)


def test_subcluster_dtype(global_dtype):
    X = make_blobs(n_samples=80, n_features=4, random_state=0)[0].astype(
        global_dtype, copy=False
    )
    brc = Birch(n_clusters=4)
    assert brc.fit(X).subcluster_centers_.dtype == global_dtype


def test_both_subclusters_updated():
    """Check that both subclusters are updated when a node a split, even when there are
    duplicated data points. Non-regression test for #23269.
    """

    X = np.array(
        [
            [-2.6192791, -1.5053215],
            [-2.9993038, -1.6863596],
            [-2.3724914, -1.3438171],
            [-2.336792, -1.3417323],
            [-2.4089134, -1.3290224],
            [-2.3724914, -1.3438171],
            [-3.364009, -1.8846745],
            [-2.3724914, -1.3438171],
            [-2.617677, -1.5003285],
            [-2.2960556, -1.3260119],
            [-2.3724914, -1.3438171],
            [-2.5459878, -1.4533926],
            [-2.25979, -1.3003055],
            [-2.4089134, -1.3290224],
            [-2.3724914, -1.3438171],
            [-2.4089134, -1.3290224],
            [-2.5459878, -1.4533926],
            [-2.3724914, -1.3438171],
            [-2.9720619, -1.7058647],
            [-2.336792, -1.3417323],
            [-2.3724914, -1.3438171],
        ],
        dtype=np.float32,
    )

    # no error
    Birch(branching_factor=5, threshold=1e-5, n_clusters=None).fit(X)


# TODO(1.8): Remove
def test_birch_copy_deprecated():
    X, _ = make_blobs(n_samples=80, n_features=4, random_state=0)
    brc = Birch(n_clusters=4, copy=True)
    with pytest.warns(FutureWarning, match="`copy` was deprecated"):
        brc.fit(X)