Spaces:
Paused
Paused
| #!/usr/bin/python3 | |
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # | |
| # LASER Language-Agnostic SEntence Representations | |
| # is a toolkit to calculate multilingual sentence embeddings | |
| # and to use them for document classification, bitext filtering | |
| # and mining | |
| # | |
| # -------------------------------------------------------- | |
| # | |
| # tools for indexing and search with FAISS | |
| import faiss | |
| import os.path | |
| import sys | |
| import numpy as np | |
| #------------------------------------------------------------- | |
| # Get list of fnames: | |
| # - we loop over the list of given languages | |
| # - for each language, we also check if there are splitted files .%03d | |
| def SplitFnames(par_fname, langs): | |
| fnames = [] | |
| for l in langs: | |
| fname = par_fname + '.' + l | |
| if os.path.isfile(fname): | |
| fnames.append(fname) | |
| for i in range(1000): | |
| fname = par_fname + '.' + l + '.{:03d}'.format(i) | |
| if os.path.isfile(fname): | |
| fnames.append(fname) | |
| if len(fnames) == 0: | |
| print("ERROR: no embeddings found in {:s}*".format(par_fname)) | |
| sys.exit(1) | |
| return fnames | |
| def SplitOpen(par_fname, langs, dim, dtype, verbose=False): | |
| M = [] | |
| nf = 0 | |
| nc = 0 | |
| print('Reading sentence embeddings') | |
| print(' - memory mapped files {:s}'.format(par_fname)) | |
| for fname in SplitFnames(par_fname, langs): | |
| n = int(os.path.getsize(fname) / dim / np.dtype(dtype).itemsize) | |
| if verbose: | |
| print(' - {:s}: {:d} x {:d}'.format(fname, n, dim)) | |
| Mi = np.memmap(fname, mode='r', dtype=dtype, shape=(n, dim)) | |
| nc += n | |
| nf += 1 | |
| M.append(Mi) | |
| print(' - total of {:d} files: {:d} x {:d}'.format(nf, nc, dim)) | |
| return M | |
| def SplitAccess(M, idx): | |
| i = idx | |
| for Mi in M: | |
| n = Mi.shape[0] | |
| if i < n: | |
| return Mi[i,:] | |
| i -= n | |
| print('ERROR: index {:d} is too large form memory mapped files'.format(idx)) | |
| sys.exit(1) | |
| ############################################################################### | |
| # create an FAISS index on the given data | |
| def IndexCreate(dname, idx_type, | |
| verbose=False, normalize=True, save_index=False, dim=1024): | |
| assert idx_type == 'FlatL2', 'only FlatL2 index is currently supported' | |
| x = np.fromfile(dname, dtype=np.float32, count=-1) | |
| nbex = x.shape[0] // dim | |
| print(' - embedding: {:s} {:d} examples of dim {:d}' | |
| .format(dname, nbex, dim)) | |
| x.resize(nbex, dim) | |
| print(' - creating FAISS index') | |
| idx = faiss.IndexFlatL2(dim) | |
| if normalize: | |
| faiss.normalize_L2(x) | |
| idx.add(x) | |
| if save_index: | |
| iname = 'TODO' | |
| print(' - saving index into ' + iname) | |
| faiss.write_index(idx, iname) | |
| return x, idx | |
| ############################################################################### | |
| # search closest vector for all languages pairs and calculate error rate | |
| def IndexSearchMultiple(data, idx, langs, verbose=False, texts=None, print_errors=False): | |
| nl = len(data) | |
| nbex = data[0].shape[0] | |
| err = np.zeros((nl, nl)).astype(float) | |
| ref = np.linspace(0, nbex-1, nbex).astype(int) # [0, nbex) | |
| if verbose: | |
| if texts is None: | |
| print('Calculating similarity error (indices):') | |
| else: | |
| print('Calculating similarity error (textual):') | |
| for i1 in range(nl): | |
| for i2 in range(nl): | |
| if i1 != i2: | |
| D, I = idx[i2].search(data[i1], 1) | |
| if texts: # do textual comparison | |
| e1 = 0 | |
| for p in range(I.shape[0]): | |
| if texts[i2][p] != texts[i2][I[p,0]]: | |
| e1 += 1 | |
| if print_errors: | |
| print('Error {:s}\n {:s}' | |
| .format(texts[i2][p].strip(), texts[i2][I[p,0]].strip())) | |
| err[i1, i2] = e1 / nbex | |
| else: # do index based comparision | |
| err[i1, i2] \ | |
| = (nbex - np.equal(I.reshape(nbex), ref) | |
| .astype(int).sum()) / nbex | |
| if verbose: | |
| print(' - similarity error {:s}/{:s}: {:5.2f}%' | |
| .format(langs[i1], langs[i2], | |
| 100.0 * err[i1, i2])) | |
| return err | |
| ############################################################################### | |
| # print confusion matrix | |
| def IndexPrintConfusionMatrix(err, langs): | |
| nl = len(langs) | |
| assert nl == err.shape[0], 'size of errror matrix doesn not match' | |
| print('Confusion matrix:') | |
| print('{:8s}'.format('langs'), end='') | |
| for i2 in range(nl): | |
| print('{:8s} '.format(langs[i2]), end='') | |
| print('{:8s}'.format('avg')) | |
| for i1 in range(nl): | |
| print('{:3s}'.format(langs[i1]), end='') | |
| for i2 in range(nl): | |
| print('{:8.2f}%'.format(100 * err[i1, i2]), end='') | |
| print('{:8.2f}%'.format(100 * err[i1, :].sum() / (nl-1))) | |
| print('avg', end='') | |
| for i2 in range(nl): | |
| print('{:8.2f}%'.format(100 * err[:, i2].sum() / (nl-1)), end='') | |
| # global average | |
| print('{:8.2f}%'.format(100 * err.sum() / (nl-1) / nl)) | |
| ############################################################################### | |
| # Load an FAISS index | |
| def IndexLoad(idx_name, nprobe, gpu=False): | |
| print('Reading FAISS index') | |
| print(' - index: {:s}'.format(idx_name)) | |
| index = faiss.read_index(idx_name) | |
| print(' - found {:d} sentences of dim {:d}'.format(index.ntotal, index.d)) | |
| print(' - setting nbprobe to {:d}'.format(nprobe)) | |
| if gpu: | |
| print(' - transfer index to %d GPUs ' % faiss.get_num_gpus()) | |
| #co = faiss.GpuMultipleClonerOptions() | |
| #co.shard = True | |
| index = faiss.index_cpu_to_all_gpus(index) # co=co | |
| faiss.GpuParameterSpace().set_index_parameter(index, 'nprobe', nprobe) | |
| return index | |
| ############################################################################### | |
| # Opens a text file with the sentences corresponding to the indices used | |
| # by an FAISS index | |
| # We also need the reference files with the byte offsets to the beginning | |
| # of each sentence | |
| # optionnally: array with number of words per sentence | |
| # All arrays are memory mapped | |
| def IndexTextOpen(txt_fname): | |
| print('Reading text corpus') | |
| print(' - texts: {:s}'.format(txt_fname)) | |
| txt_mmap = np.memmap(txt_fname, mode='r', dtype=np.uint8) | |
| fname = txt_fname.replace('.txt', '.ref.bin32') | |
| if os.path.isfile(fname): | |
| print(' - sentence start offsets (32 bit): {}'.format(fname)) | |
| ref_mmap = np.memmap(fname, mode='r', dtype=np.uint32) | |
| else: | |
| fname = txt_fname.replace('.txt', '.ref.bin64') | |
| if os.path.isfile(fname): | |
| print(' - sentence start offsets (64 bit): {}'.format(fname)) | |
| ref_mmap = np.memmap(fname, mode='r', dtype=np.uint64) | |
| else: | |
| print('ERROR: no file with sentence start offsets found') | |
| sys.exit(1) | |
| print(' - found {:d} sentences'.format(ref_mmap.shape[0])) | |
| nbw_mmap = None | |
| fname = txt_fname.replace('.txt', '.nw.bin8') | |
| if os.path.isfile(fname): | |
| print(' - word counts: {:s}'.format(fname)) | |
| nbw_mmap = np.memmap(fname, mode='r', dtype=np.uint8) | |
| M = None | |
| fname = txt_fname.replace('.txt', '.meta') | |
| if os.path.isfile(fname): | |
| M = [] | |
| n = 0 | |
| print(' - metafile: {:s}'.format(fname)) | |
| with open(fname, 'r') as fp: | |
| for line in fp: | |
| fields = line.strip().split() | |
| if len(fields) != 2: | |
| print('ERROR: format error in meta file') | |
| sys.exit(1) | |
| n += int(fields[1]) | |
| M.append({'lang': fields[0], 'n': n}) | |
| print(' - found {:d} languages:'.format(len(M)), end='') | |
| for L in M: | |
| print(' {:s}'.format(L['lang']), end='') | |
| print('') | |
| return txt_mmap, ref_mmap, nbw_mmap, M | |
| ############################################################################### | |
| # Return the text for the given index | |
| def IndexTextQuery(txt_mmap, ref_mmap, idx): | |
| p = int(ref_mmap[idx]) # get starting byte position | |
| i = 0 | |
| dim = 10000 # max sentence length in bytes | |
| b = bytearray(dim) | |
| # find EOL | |
| while txt_mmap[p+i] != 10 and i < dim: | |
| b[i] = txt_mmap[p+i] | |
| i += 1 | |
| return b[0:i].decode('utf-8') | |
| ############################################################################### | |
| # Search the [k] nearest vectors of [x] in the given index | |
| # and return the text lines | |
| def IndexSearchKNN(index, x, T, R, kmax=1, Dmax=1.0, dedup=True): | |
| D, I = index.search(x, kmax) | |
| prev = {} # for depuplication | |
| res = [] | |
| for n in range(x.shape[0]): | |
| for i in range(kmax): | |
| txt = IndexTextQuery(T, R, I[n, i]) | |
| if (dedup and txt not in prev) and D[n, i] <= Dmax: | |
| prev[txt] = 1 | |
| res.append([txt, D[n, i]]) | |
| return res | |