File size: 4,671 Bytes
63858e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import h5py
import numpy as np
from functools import partial
from utils.gen_utils import map_nlist, vround
import regex as re
from spacyface.simple_spacy_token import SimpleSpacyToken
from data_processing.sentence_data_wrapper import SentenceH5Data, TokenH5Data
from utils.f import ifnone
    
ZERO_BUFFER = 12 # Number of decimal places each index takes
main_key = r"{:0" + str(ZERO_BUFFER) + r"}"

def to_idx(idx:int):
    return main_key.format(idx)

def zip_len_check(*iters):
    """Zip iterables with a check that they are all the same length"""
    if len(iters) < 2:
        raise ValueError(f"Expected at least 2 iterables to combine. Got {len(iters)} iterables")
    n = len(iters[0])
    for i in iters:
        n_ = len(i)
        if n_ != n:
            raise ValueError(f"Expected all iterations to have len {n} but found {n_}")

    return zip(*iters)

class CorpusDataWrapper:
    """A wrapper for both the token embeddings and the head context.
    
    This class allows access into an HDF5 file designed according to the data/processing module's contents as if it were
    and in memory dictionary.
    """

    def __init__(self, fname, name=None):
        """Open an hdf5 file of the format designed and provide easy access to its contents"""
                
        # For iterating through the dataset
        self.__curr = 0
        
        self.__name = ifnone(name, "CorpusData")
        self.fname = fname
        self.data = h5py.File(fname, 'r')

        main_keys = self.data.keys()
        self.__len = len(main_keys)

        assert self.__len > 0, "Cannot process an empty file"

        embeds = self[0].embeddings
        self.embedding_dim = embeds.shape[-1]
        self.n_layers = embeds.shape[0] - 1  # 1 was added for the input layer
        self.refmap, self.total_vectors = self._init_vector_map()
        
    def __del__(self):
        try: self.data.close()

        # If run as a script, won't be able to close because of an import error
        except ImportError: pass

        except AttributeError:
            print(f"Never successfully loaded {self.fname}")
        
    def __iter__(self):
        return self
    
    def __len__(self):
        return self.__len
    
    def __next__(self):
        if self.__curr >= self.__len:
            self.__curr = 0
            raise StopIteration
            
        out = self[self.__curr]
        self.__curr += 1
        return out
    
    def __getitem__(self, idx):
        """Index into the embeddings"""
        if isinstance(idx, slice):
            
            start = idx.start or 0
            step = idx.step or 1
            stop = idx.stop or (self.__len - 1)
            stop = min(stop, self.__len)
            
            i = start
            out = []
            while i < stop:
                out.append(self[i])
                i += step
            
            return out
        
        elif isinstance(idx, int):
            if idx < 0: i = self.__len + idx
            else: i = idx

            key = to_idx(i)
            return SentenceH5Data(self.data[key])
        
        else:
            raise NotImplementedError

    def __repr__(self):
        return f"{self.__name}: containing {self.__len} items"
    
    def _init_vector_map(self):
        """Create main hashmap for all vectors to get their metadata.
        
        TODO Initialization is a little slow... Should this be stored in a separate hdf5 file?
        
        This doesn't change. Check for special hdf5 file and see if it exists already. If it does, open it. 
        If not, create it
        """
        refmap = {}
        print("Initializing reference map for embedding vector...")
        n_vec = 0
        for z, sentence in enumerate(self):
            for i in range(len(sentence)):
                refs = TokenH5Data(sentence, i)
                refmap[n_vec] = refs
                n_vec += 1
        
        return refmap, n_vec
    
    def extract(self, layer):
        """Extract embeddings from a particular layer from the dataset
        
        For all examples
        """
        embeddings = []
        for i, embeds in enumerate(self):
            embeddings.append(embeds[layer])
            
        out = np.vstack(embeddings)
        return out

    def find(self, vec_num):
        """Find a vector's metadata (by id) in the hdf5 file. Needed to find sentence info and other attr"""
        return self.refmap[vec_num]
    
    def find2d(self, idxs):
        """Find a vector's metadata in the hdf5 file. Needed to find sentence info and other attr"""
        out = [[self.refmap[i] for i in idx] for idx in idxs]
        return out