File size: 3,047 Bytes
7885a28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import numpy as np
import pytest
from scipy.sparse.csgraph import connected_components

from sklearn.metrics.pairwise import pairwise_distances
from sklearn.neighbors import kneighbors_graph
from sklearn.utils.graph import _fix_connected_components


def test_fix_connected_components():
    # Test that _fix_connected_components reduces the number of component to 1.
    X = np.array([0, 1, 2, 5, 6, 7])[:, None]
    graph = kneighbors_graph(X, n_neighbors=2, mode="distance")

    n_connected_components, labels = connected_components(graph)
    assert n_connected_components > 1

    graph = _fix_connected_components(X, graph, n_connected_components, labels)

    n_connected_components, labels = connected_components(graph)
    assert n_connected_components == 1


def test_fix_connected_components_precomputed():
    # Test that _fix_connected_components accepts precomputed distance matrix.
    X = np.array([0, 1, 2, 5, 6, 7])[:, None]
    graph = kneighbors_graph(X, n_neighbors=2, mode="distance")

    n_connected_components, labels = connected_components(graph)
    assert n_connected_components > 1

    distances = pairwise_distances(X)
    graph = _fix_connected_components(
        distances, graph, n_connected_components, labels, metric="precomputed"
    )

    n_connected_components, labels = connected_components(graph)
    assert n_connected_components == 1

    # but it does not work with precomputed neighbors graph
    with pytest.raises(RuntimeError, match="does not work with a sparse"):
        _fix_connected_components(
            graph, graph, n_connected_components, labels, metric="precomputed"
        )


def test_fix_connected_components_wrong_mode():
    # Test that the an error is raised if the mode string is incorrect.
    X = np.array([0, 1, 2, 5, 6, 7])[:, None]
    graph = kneighbors_graph(X, n_neighbors=2, mode="distance")
    n_connected_components, labels = connected_components(graph)

    with pytest.raises(ValueError, match="Unknown mode"):
        graph = _fix_connected_components(
            X, graph, n_connected_components, labels, mode="foo"
        )


def test_fix_connected_components_connectivity_mode():
    # Test that the connectivity mode fill new connections with ones.
    X = np.array([0, 1, 6, 7])[:, None]
    graph = kneighbors_graph(X, n_neighbors=1, mode="connectivity")
    n_connected_components, labels = connected_components(graph)
    graph = _fix_connected_components(
        X, graph, n_connected_components, labels, mode="connectivity"
    )
    assert np.all(graph.data == 1)


def test_fix_connected_components_distance_mode():
    # Test that the distance mode does not fill new connections with ones.
    X = np.array([0, 1, 6, 7])[:, None]
    graph = kneighbors_graph(X, n_neighbors=1, mode="distance")
    assert np.all(graph.data == 1)

    n_connected_components, labels = connected_components(graph)
    graph = _fix_connected_components(
        X, graph, n_connected_components, labels, mode="distance"
    )
    assert not np.all(graph.data == 1)