Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 3,368 Bytes
63858e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
from functools import partial
import faiss
import numpy as np
from pathlib import Path
from typing import Iterable
from utils.f import memoize
from transformers import AutoConfig
@memoize
def get_config(model_name):
return AutoConfig.from_pretrained(model_name)
FAISS_LAYER_PATTERN = 'layer_*.faiss'
LAYER_TEMPLATE = 'layer_{:02d}.faiss'
def create_mask(head_size:int , n_heads:int, selected_heads:Iterable[int]):
"""Create a masked vector of size (head_size * n_heads), where 0 indicates we don't care about the contribution of that head 1 indicates that we do care
Parameters:
-----------
head_size: Hidden dimension of the heads
n_heads: Number of heads the model has
selected_heads: Which heads we don't want to zero out
"""
mask = np.zeros(n_heads)
for h in selected_heads:
mask[int(h)] = 1
return np.repeat(mask, head_size)
class Indexes:
"""Wrapper around the faiss indices to make searching for a vector simpler and faster.
Assumes there are files in the folder matching the pattern input
"""
def __init__(self, folder, pattern=FAISS_LAYER_PATTERN):
self.base_dir = Path(folder)
self.n_layers = len(list(self.base_dir.glob(pattern))) - 1 # Subtract final output
self.indexes = [None] * (self.n_layers + 1) # Initialize empty list, adding 1 for input
self.pattern = pattern
self.__init_indexes()
# Extract model name from folder hierarchy
self.model_name = self.base_dir.parent.parent.stem
self.config = get_config(self.model_name)
self.nheads = self.config.num_attention_heads
self.hidden_size = self.config.hidden_size
assert (self.hidden_size % self.nheads) == 0, "Number of heads does not divide cleanly into the hidden size. Aborting"
self.head_size = int(self.config.hidden_size / self.nheads)
def __getitem__(self, v):
"""Slices not allowed, but index only"""
return self.indexes[v]
def __init_indexes(self):
for fname in self.base_dir.glob(self.pattern):
print(fname)
idx = fname.stem.split('_')[-1]
self.indexes[int(idx)] = faiss.read_index(str(fname))
def search(self, layer, query, k):
"""Search a given layer for the query vector. Return k results"""
return self[layer].search(query, k)
class ContextIndexes(Indexes):
"""Special index enabling masking of particular heads before searching"""
def __init__(self, folder, pattern=FAISS_LAYER_PATTERN):
super().__init__(folder, pattern)
self.head_mask = partial(create_mask, self.head_size, self.nheads)
# Int -> [Int] -> np.Array -> Int -> (np.Array(), )
def search(self, layer:int, heads:list, query:np.ndarray, k:int):
"""Search the embeddings for the context layer, masking by selected heads"""
assert max(heads) < self.nheads, "max of selected heads must be lest than nheads. Are you indexing by 1 instead of 0?"
assert min(heads) >= 0, "What is a negative head?"
unique_heads = list(set(heads))
mask_vector = self.head_mask(unique_heads)
mask_vector = mask_vector.reshape(query.shape)
new_query = (query * mask_vector).astype(np.float32)
return self[layer].search(new_query, k)
|