File size: 3,368 Bytes
63858e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from functools import partial
import faiss
import numpy as np
from pathlib import Path
from typing import Iterable
from utils.f import memoize
from transformers import AutoConfig

@memoize
def get_config(model_name):
    return AutoConfig.from_pretrained(model_name)

FAISS_LAYER_PATTERN = 'layer_*.faiss'
LAYER_TEMPLATE = 'layer_{:02d}.faiss' 

def create_mask(head_size:int , n_heads:int, selected_heads:Iterable[int]):
    """Create a masked vector of size (head_size * n_heads), where 0 indicates we don't care about the contribution of that head 1 indicates that we do care
    
    Parameters:
    -----------
        head_size: Hidden dimension of the heads
        n_heads: Number of heads the model has
        selected_heads: Which heads we don't want to zero out
    """

    mask = np.zeros(n_heads)
    for h in selected_heads:
        mask[int(h)] = 1
        
    return np.repeat(mask, head_size)

class Indexes:
    """Wrapper around the faiss indices to make searching for a vector simpler and faster.
    
    Assumes there are files in the folder matching the pattern input
    """
    def __init__(self, folder, pattern=FAISS_LAYER_PATTERN):
        self.base_dir = Path(folder)
        self.n_layers = len(list(self.base_dir.glob(pattern))) - 1 # Subtract final output
        self.indexes = [None] * (self.n_layers + 1) # Initialize empty list, adding 1 for input
        self.pattern = pattern
        self.__init_indexes()

        # Extract model name from folder hierarchy
        self.model_name = self.base_dir.parent.parent.stem
        self.config = get_config(self.model_name)
        self.nheads = self.config.num_attention_heads
        self.hidden_size = self.config.hidden_size
        assert (self.hidden_size % self.nheads) == 0, "Number of heads does not divide cleanly into the hidden size. Aborting"
        self.head_size = int(self.config.hidden_size / self.nheads)

        
    def __getitem__(self, v):
        """Slices not allowed, but index only"""
        return self.indexes[v]
    
    def __init_indexes(self):
        for fname in self.base_dir.glob(self.pattern):
            print(fname)
            idx = fname.stem.split('_')[-1]
            self.indexes[int(idx)] = faiss.read_index(str(fname))

    def search(self, layer, query, k):
        """Search a given layer for the query vector. Return k results"""
        return self[layer].search(query, k)


class ContextIndexes(Indexes):
    """Special index enabling masking of particular heads before searching"""

    def __init__(self, folder, pattern=FAISS_LAYER_PATTERN):
        super().__init__(folder, pattern)

        self.head_mask = partial(create_mask, self.head_size, self.nheads)

    # Int -> [Int] -> np.Array -> Int -> (np.Array(),  )
    def search(self, layer:int, heads:list, query:np.ndarray, k:int):
        """Search the embeddings for the context layer, masking by selected heads"""
        assert max(heads) < self.nheads, "max of selected heads must be lest than nheads. Are you indexing by 1 instead of 0?"
        assert min(heads) >= 0, "What is a negative head?"
        
        unique_heads = list(set(heads))
        mask_vector = self.head_mask(unique_heads)
        mask_vector = mask_vector.reshape(query.shape)

        new_query = (query * mask_vector).astype(np.float32)

        return self[layer].search(new_query, k)