File size: 3,653 Bytes
2a0bc63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np
import faiss


def get_invlist(invlists, l):
    """ returns the inverted lists content as a pair of (list_ids, list_codes).

    The codes are reshaped to a proper size

    """
    invlists = faiss.downcast_InvertedLists(invlists)
    ls = invlists.list_size(l)
    list_ids = np.zeros(ls, dtype='int64')
    ids = codes = None
    try:
        ids = invlists.get_ids(l)
        if ls > 0:
            faiss.memcpy(faiss.swig_ptr(list_ids), ids, list_ids.nbytes)
        codes = invlists.get_codes(l)
        if invlists.code_size != faiss.InvertedLists.INVALID_CODE_SIZE:
            list_codes = np.zeros((ls, invlists.code_size), dtype='uint8')
        else:
            # it's a BlockInvertedLists
            npb = invlists.n_per_block
            bs = invlists.block_size
            ls_round = (ls + npb - 1) // npb
            list_codes = np.zeros((ls_round, bs // npb, npb), dtype='uint8')
        if ls > 0:
            faiss.memcpy(faiss.swig_ptr(list_codes), codes, list_codes.nbytes)
    finally:
        if ids is not None:
            invlists.release_ids(l, ids)
        if codes is not None:
            invlists.release_codes(l, codes)
    return list_ids, list_codes


def get_invlist_sizes(invlists):
    """ return the array of sizes of the inverted lists """
    return np.array([
        invlists.list_size(i)
        for i in range(invlists.nlist)
    ], dtype='int64')


def print_object_fields(obj):
    """ list values all fields of an object known to SWIG """

    for name in obj.__class__.__swig_getmethods__:
        print(f"{name} = {getattr(obj, name)}")


def get_pq_centroids(pq):
    """ return the PQ centroids as an array """
    cen = faiss.vector_to_array(pq.centroids)
    return cen.reshape(pq.M, pq.ksub, pq.dsub)


def get_LinearTransform_matrix(pca):
    """ extract matrix + bias from the PCA object

    works for any linear transform (OPQ, random rotation, etc.)

    """
    b = faiss.vector_to_array(pca.b)
    A = faiss.vector_to_array(pca.A).reshape(pca.d_out, pca.d_in)
    return A, b


def make_LinearTransform_matrix(A, b=None):
    """ make a linear transform from a matrix and a bias term (optional)"""
    d_out, d_in = A.shape
    if b is not None:
        assert b.shape == (d_out, )
    lt = faiss.LinearTransform(d_in, d_out, b is not None)
    faiss.copy_array_to_vector(A.ravel(), lt.A)
    if b is not None:
        faiss.copy_array_to_vector(b, lt.b)
    lt.is_trained = True
    lt.set_is_orthonormal()
    return lt


def get_additive_quantizer_codebooks(aq):
    """ return to codebooks of an additive quantizer """
    codebooks = faiss.vector_to_array(aq.codebooks).reshape(-1, aq.d)
    co = faiss.vector_to_array(aq.codebook_offsets)
    return [
        codebooks[co[i]:co[i + 1]]
        for i in range(aq.M)
    ]


def get_flat_data(index):
    """ copy and return the data matrix in an IndexFlat """
    xb = faiss.vector_to_array(index.codes).view("float32")
    return xb.reshape(index.ntotal, index.d)


def get_NSG_neighbors(nsg):
    """ get the neighbor list for the vectors stored in the NSG structure, as

    a N-by-K matrix of indices """
    graph = nsg.get_final_graph()
    neighbors = np.zeros((graph.N, graph.K), dtype='int32')
    faiss.memcpy(
        faiss.swig_ptr(neighbors),
        graph.data,
        neighbors.nbytes
    )
    return neighbors