File size: 5,268 Bytes
2a0bc63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# @nolint

# not linting this file because it imports * from swigfaiss, which
# causes a ton of useless warnings.

import numpy as np
import array
import warnings

from faiss.loader import *

###########################################
# Utility to add a deprecation warning to
# classes from the SWIG interface
###########################################

def _make_deprecated_swig_class(deprecated_name, base_name):
    """

    Dynamically construct deprecated classes as wrappers around renamed ones



    The deprecation warning added in their __new__-method will trigger upon

    construction of an instance of the class, but only once per session.



    We do this here (in __init__.py) because the base classes are defined in

    the SWIG interface, making it cumbersome to add the deprecation there.



    Parameters

    ----------

    deprecated_name : string

        Name of the class to be deprecated; _not_ present in SWIG interface.

    base_name : string

        Name of the class that is replacing deprecated_name; must already be

        imported into the current namespace.



    Returns

    -------

    None

        However, the deprecated class gets added to the faiss namespace

    """
    base_class = globals()[base_name]

    def new_meth(cls, *args, **kwargs):
        msg = f"The class faiss.{deprecated_name} is deprecated in favour of faiss.{base_name}!"
        warnings.warn(msg, DeprecationWarning, stacklevel=2)
        instance = super(base_class, cls).__new__(cls, *args, **kwargs)
        return instance

    # three-argument version of "type" uses (name, tuple-of-bases, dict-of-attributes)
    klazz = type(deprecated_name, (base_class,), {"__new__": new_meth})

    # this ends up adding the class to the "faiss" namespace, in a way that it
    # is available both through "import faiss" and "from faiss import *"
    globals()[deprecated_name] = klazz


###########################################
# numpy array / std::vector conversions
###########################################

sizeof_long = array.array('l').itemsize
deprecated_name_map = {
    # deprecated: replacement
    'Float': 'Float32',
    'Double': 'Float64',
    'Char': 'Int8',
    'Int': 'Int32',
    'Long': 'Int32' if sizeof_long == 4 else 'Int64',
    'LongLong': 'Int64',
    'Byte': 'UInt8',
    # previously misspelled variant
    'Uint64': 'UInt64',
}

for depr_prefix, base_prefix in deprecated_name_map.items():
    _make_deprecated_swig_class(depr_prefix + "Vector", base_prefix + "Vector")

    # same for the three legacy *VectorVector classes
    if depr_prefix in ['Float', 'Long', 'Byte']:
        _make_deprecated_swig_class(depr_prefix + "VectorVector",
                                    base_prefix + "VectorVector")

# mapping from vector names in swigfaiss.swig and the numpy dtype names
# TODO: once deprecated classes are removed, remove the dict and just use .lower() below
vector_name_map = {
    'Float32': 'float32',
    'Float64': 'float64',
    'Int8': 'int8',
    'Int16': 'int16',
    'Int32': 'int32',
    'Int64': 'int64',
    'UInt8': 'uint8',
    'UInt16': 'uint16',
    'UInt32': 'uint32',
    'UInt64': 'uint64',
    **{k: v.lower() for k, v in deprecated_name_map.items()}
}


def vector_to_array(v):
    """ convert a C++ vector to a numpy array """
    classname = v.__class__.__name__
    if classname.startswith('AlignedTable'):
        return AlignedTable_to_array(v)
    assert classname.endswith('Vector')
    dtype = np.dtype(vector_name_map[classname[:-6]])
    a = np.empty(v.size(), dtype=dtype)
    if v.size() > 0:
        memcpy(swig_ptr(a), v.data(), a.nbytes)
    return a


def vector_float_to_array(v):
    return vector_to_array(v)


def copy_array_to_vector(a, v):
    """ copy a numpy array to a vector """
    n, = a.shape
    classname = v.__class__.__name__
    assert classname.endswith('Vector')
    dtype = np.dtype(vector_name_map[classname[:-6]])
    assert dtype == a.dtype, (
        'cannot copy a %s array to a %s (should be %s)' % (
            a.dtype, classname, dtype))
    v.resize(n)
    if n > 0:
        memcpy(v.data(), swig_ptr(a), a.nbytes)

# same for AlignedTable


def copy_array_to_AlignedTable(a, v):
    n, = a.shape
    # TODO check class name
    assert v.itemsize() == a.itemsize
    v.resize(n)
    if n > 0:
        memcpy(v.get(), swig_ptr(a), a.nbytes)


def array_to_AlignedTable(a):
    if a.dtype == 'uint16':
        v = AlignedTableUint16(a.size)
    elif a.dtype == 'uint8':
        v = AlignedTableUint8(a.size)
    else:
        assert False
    copy_array_to_AlignedTable(a, v)
    return v


def AlignedTable_to_array(v):
    """ convert an AlignedTable to a numpy array """
    classname = v.__class__.__name__
    assert classname.startswith('AlignedTable')
    dtype = classname[12:].lower()
    a = np.empty(v.size(), dtype=dtype)
    if a.size > 0:
        memcpy(swig_ptr(a), v.data(), a.nbytes)
    return a