File size: 2,629 Bytes
7885a28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import numpy as np
from scipy.sparse import issparse
from scipy.sparse._sputils import convert_pydata_sparse_to_scipy
from scipy.sparse.csgraph._tools import (
    csgraph_to_dense, csgraph_from_dense,
    csgraph_masked_from_dense, csgraph_from_masked
)

DTYPE = np.float64


def validate_graph(csgraph, directed, dtype=DTYPE,
                   csr_output=True, dense_output=True,
                   copy_if_dense=False, copy_if_sparse=False,
                   null_value_in=0, null_value_out=np.inf,
                   infinity_null=True, nan_null=True):
    """Routine for validation and conversion of csgraph inputs"""
    if not (csr_output or dense_output):
        raise ValueError("Internal: dense or csr output must be true")

    accept_fv = [null_value_in]
    if infinity_null:
        accept_fv.append(np.inf)
    if nan_null:
        accept_fv.append(np.nan)
    csgraph = convert_pydata_sparse_to_scipy(csgraph, accept_fv=accept_fv)

    # if undirected and csc storage, then transposing in-place
    # is quicker than later converting to csr.
    if (not directed) and issparse(csgraph) and csgraph.format == "csc":
        csgraph = csgraph.T

    if issparse(csgraph):
        if csr_output:
            csgraph = csgraph.tocsr(copy=copy_if_sparse).astype(DTYPE, copy=False)
        else:
            csgraph = csgraph_to_dense(csgraph, null_value=null_value_out)
    elif np.ma.isMaskedArray(csgraph):
        if dense_output:
            mask = csgraph.mask
            csgraph = np.array(csgraph.data, dtype=DTYPE, copy=copy_if_dense)
            csgraph[mask] = null_value_out
        else:
            csgraph = csgraph_from_masked(csgraph)
    else:
        if dense_output:
            csgraph = csgraph_masked_from_dense(csgraph,
                                                copy=copy_if_dense,
                                                null_value=null_value_in,
                                                nan_null=nan_null,
                                                infinity_null=infinity_null)
            mask = csgraph.mask
            csgraph = np.asarray(csgraph.data, dtype=DTYPE)
            csgraph[mask] = null_value_out
        else:
            csgraph = csgraph_from_dense(csgraph, null_value=null_value_in,
                                         infinity_null=infinity_null,
                                         nan_null=nan_null)

    if csgraph.ndim != 2:
        raise ValueError("compressed-sparse graph must be 2-D")

    if csgraph.shape[0] != csgraph.shape[1]:
        raise ValueError("compressed-sparse graph must be shape (N, N)")

    return csgraph