File size: 8,145 Bytes
46e0dd0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import numpy as np
import torch
from torch.utils.data import Dataset
import pdb

class SpatialDataset(Dataset):
    def __init__(self, tokenizer , max_token_len ,  distance_norm_factor, sep_between_neighbors = False ):
        self.tokenizer = tokenizer
        self.max_token_len = max_token_len 
        self.distance_norm_factor = distance_norm_factor
        self.sep_between_neighbors = sep_between_neighbors
        

    def parse_spatial_context(self, pivot_name, pivot_pos, neighbor_name_list, neighbor_geometry_list, spatial_dist_fill,  pivot_dist_fill = 0):

        sep_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.sep_token)
        cls_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.cls_token)
        #mask_token_id  = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
        pad_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.pad_token)
        max_token_len = self.max_token_len

       
        #print("Module reloaded and changes are reflected")
        # process pivot
        pivot_name_tokens = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(pivot_name))
        pivot_token_len = len(pivot_name_tokens)
            
        pivot_lng = pivot_pos[0]
        pivot_lat = pivot_pos[1]

        # prepare entity mask
        entity_mask_arr = []
        rand_entity = np.random.uniform(size = len(neighbor_name_list) + 1) # random number for masking entities including neighbors and pivot
        # True for mask, False for unmask
        
        # check if pivot entity needs to be masked out, 15% prob. to be masked out
        #if rand_entity[0] < 0.15:
        #    entity_mask_arr.extend([True] * pivot_token_len)
        #else:
        entity_mask_arr.extend([False] * pivot_token_len)

        # process neighbors
        neighbor_token_list = []
        neighbor_lng_list = []
        neighbor_lat_list = []

        # add separator between pivot and neighbor tokens
        # a trick to avoid adding separator token after the class name (for class name encoding of margin-ranking loss)
        if self.sep_between_neighbors and pivot_dist_fill==0: 
            neighbor_lng_list.append(spatial_dist_fill)
            neighbor_lat_list.append(spatial_dist_fill)
            neighbor_token_list.append(sep_token_id)

        for neighbor_name, neighbor_geometry, rnd in zip(neighbor_name_list, neighbor_geometry_list, rand_entity[1:]):

            if not neighbor_name[0].isalpha():
                # only consider neighbors starting with letters
                continue 

            neighbor_token = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(neighbor_name))
            neighbor_token_len = len(neighbor_token)

            # compute the relative distance from neighbor to pivot,
            # normalize the relative distance by distance_norm_factor
            # apply the calculated distance for all the subtokens of the neighbor
            # neighbor_lng_list.extend([(neighbor_geometry[0]- pivot_lng)/self.distance_norm_factor] * neighbor_token_len)
            # neighbor_lat_list.extend([(neighbor_geometry[1]- pivot_lat)/self.distance_norm_factor] * neighbor_token_len)

            if 'coordinates' in neighbor_geometry: # to handle different json dict structures
                neighbor_lng_list.extend([(neighbor_geometry['coordinates'][0]- pivot_lng)/self.distance_norm_factor] * neighbor_token_len)
                neighbor_lat_list.extend([(neighbor_geometry['coordinates'][1]- pivot_lat)/self.distance_norm_factor] * neighbor_token_len)
                neighbor_token_list.extend(neighbor_token)
            else:
                neighbor_lng_list.extend([(neighbor_geometry[0]- pivot_lng)/self.distance_norm_factor] * neighbor_token_len)
                neighbor_lat_list.extend([(neighbor_geometry[1]- pivot_lat)/self.distance_norm_factor] * neighbor_token_len)
                neighbor_token_list.extend(neighbor_token)

            if self.sep_between_neighbors:
                neighbor_lng_list.append(spatial_dist_fill)
                neighbor_lat_list.append(spatial_dist_fill)
                neighbor_token_list.append(sep_token_id)
                
                entity_mask_arr.extend([False])

            
            #if rnd < 0.15:
            #    #True: mask out, False: Keey original token
            #    entity_mask_arr.extend([True] * neighbor_token_len)
            #else:
            entity_mask_arr.extend([False] * neighbor_token_len)


        pseudo_sentence = pivot_name_tokens + neighbor_token_list 
        dist_lng_list = [pivot_dist_fill] * pivot_token_len + neighbor_lng_list 
        dist_lat_list = [pivot_dist_fill] * pivot_token_len + neighbor_lat_list 
        

        #including cls and sep
        sent_len = len(pseudo_sentence)

        max_token_len_middle = max_token_len -2 # 2 for CLS and SEP token

        # padding and truncation
        if sent_len > max_token_len_middle : 
            pseudo_sentence = [cls_token_id] + pseudo_sentence[:max_token_len_middle] + [sep_token_id] 
            dist_lat_list = [spatial_dist_fill] + dist_lat_list[:max_token_len_middle]+ [spatial_dist_fill]
            dist_lng_list = [spatial_dist_fill] + dist_lng_list[:max_token_len_middle]+ [spatial_dist_fill]
            attention_mask = [False] + [1] * max_token_len_middle + [False] # make sure SEP and CLS are not attented to
        else:
            pad_len = max_token_len_middle - sent_len
            assert pad_len >= 0 

            pseudo_sentence = [cls_token_id] + pseudo_sentence + [sep_token_id] + [pad_token_id] * pad_len 
            dist_lat_list = [spatial_dist_fill] + dist_lat_list + [spatial_dist_fill] + [spatial_dist_fill] * pad_len
            dist_lng_list = [spatial_dist_fill] + dist_lng_list + [spatial_dist_fill] + [spatial_dist_fill] * pad_len
            attention_mask = [False] + [1] * sent_len + [0] * pad_len + [False]


        

        norm_lng_list = np.array(dist_lng_list) # / 0.0001
        norm_lat_list = np.array(dist_lat_list) # / 0.0001


        ## mask entity in the pseudo sentence 
        #entity_mask_indices = np.where(entity_mask_arr)[0]
        #masked_entity_input = [mask_token_id if i in entity_mask_indices else pseudo_sentence[i] for i in range(0, max_token_len)]
        #
        #
        ## mask token in the pseudo sentence
        #rand_token = np.random.uniform(size = len(pseudo_sentence))
        ## do not mask out cls and sep token. True: masked tokens False: Keey original token
        #token_mask_arr = (rand_token <0.15) & (np.array(pseudo_sentence) != cls_token_id) & (np.array(pseudo_sentence) != sep_token_id) & (np.array(pseudo_sentence) != pad_token_id)
        #token_mask_indices = np.where(token_mask_arr)[0]
        #
        #masked_token_input = [mask_token_id if i in token_mask_indices else pseudo_sentence[i] for i in range(0, max_token_len)]
        #
        #
        ## yield masked_token with 50% prob, masked_entity with 50% prob
        #if np.random.rand() > 0.5:
        #    masked_input = torch.tensor(masked_entity_input)
        #else:
        #    masked_input = torch.tensor(masked_token_input)
        masked_input = torch.tensor(pseudo_sentence)
        
        train_data = {}
        train_data['pivot_name'] = pivot_name
        train_data['pivot_token_len'] = pivot_token_len
        train_data['masked_input'] = masked_input
        train_data['sent_position_ids'] = torch.tensor(np.arange(0, len(pseudo_sentence)))
        train_data['attention_mask'] = torch.tensor(attention_mask)
        train_data['norm_lng_list'] = torch.tensor(norm_lng_list).to(torch.float32)
        train_data['norm_lat_list'] = torch.tensor(norm_lat_list).to(torch.float32)
        train_data['pseudo_sentence'] = torch.tensor(pseudo_sentence)

        return train_data



    def __len__(self):
        return NotImplementedError

    def __getitem__(self, index):
        raise NotImplementedError