Spaces:
Sleeping
Sleeping
File size: 8,145 Bytes
46e0dd0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import numpy as np
import torch
from torch.utils.data import Dataset
import pdb
class SpatialDataset(Dataset):
def __init__(self, tokenizer , max_token_len , distance_norm_factor, sep_between_neighbors = False ):
self.tokenizer = tokenizer
self.max_token_len = max_token_len
self.distance_norm_factor = distance_norm_factor
self.sep_between_neighbors = sep_between_neighbors
def parse_spatial_context(self, pivot_name, pivot_pos, neighbor_name_list, neighbor_geometry_list, spatial_dist_fill, pivot_dist_fill = 0):
sep_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.sep_token)
cls_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.cls_token)
#mask_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
pad_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.pad_token)
max_token_len = self.max_token_len
#print("Module reloaded and changes are reflected")
# process pivot
pivot_name_tokens = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(pivot_name))
pivot_token_len = len(pivot_name_tokens)
pivot_lng = pivot_pos[0]
pivot_lat = pivot_pos[1]
# prepare entity mask
entity_mask_arr = []
rand_entity = np.random.uniform(size = len(neighbor_name_list) + 1) # random number for masking entities including neighbors and pivot
# True for mask, False for unmask
# check if pivot entity needs to be masked out, 15% prob. to be masked out
#if rand_entity[0] < 0.15:
# entity_mask_arr.extend([True] * pivot_token_len)
#else:
entity_mask_arr.extend([False] * pivot_token_len)
# process neighbors
neighbor_token_list = []
neighbor_lng_list = []
neighbor_lat_list = []
# add separator between pivot and neighbor tokens
# a trick to avoid adding separator token after the class name (for class name encoding of margin-ranking loss)
if self.sep_between_neighbors and pivot_dist_fill==0:
neighbor_lng_list.append(spatial_dist_fill)
neighbor_lat_list.append(spatial_dist_fill)
neighbor_token_list.append(sep_token_id)
for neighbor_name, neighbor_geometry, rnd in zip(neighbor_name_list, neighbor_geometry_list, rand_entity[1:]):
if not neighbor_name[0].isalpha():
# only consider neighbors starting with letters
continue
neighbor_token = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(neighbor_name))
neighbor_token_len = len(neighbor_token)
# compute the relative distance from neighbor to pivot,
# normalize the relative distance by distance_norm_factor
# apply the calculated distance for all the subtokens of the neighbor
# neighbor_lng_list.extend([(neighbor_geometry[0]- pivot_lng)/self.distance_norm_factor] * neighbor_token_len)
# neighbor_lat_list.extend([(neighbor_geometry[1]- pivot_lat)/self.distance_norm_factor] * neighbor_token_len)
if 'coordinates' in neighbor_geometry: # to handle different json dict structures
neighbor_lng_list.extend([(neighbor_geometry['coordinates'][0]- pivot_lng)/self.distance_norm_factor] * neighbor_token_len)
neighbor_lat_list.extend([(neighbor_geometry['coordinates'][1]- pivot_lat)/self.distance_norm_factor] * neighbor_token_len)
neighbor_token_list.extend(neighbor_token)
else:
neighbor_lng_list.extend([(neighbor_geometry[0]- pivot_lng)/self.distance_norm_factor] * neighbor_token_len)
neighbor_lat_list.extend([(neighbor_geometry[1]- pivot_lat)/self.distance_norm_factor] * neighbor_token_len)
neighbor_token_list.extend(neighbor_token)
if self.sep_between_neighbors:
neighbor_lng_list.append(spatial_dist_fill)
neighbor_lat_list.append(spatial_dist_fill)
neighbor_token_list.append(sep_token_id)
entity_mask_arr.extend([False])
#if rnd < 0.15:
# #True: mask out, False: Keey original token
# entity_mask_arr.extend([True] * neighbor_token_len)
#else:
entity_mask_arr.extend([False] * neighbor_token_len)
pseudo_sentence = pivot_name_tokens + neighbor_token_list
dist_lng_list = [pivot_dist_fill] * pivot_token_len + neighbor_lng_list
dist_lat_list = [pivot_dist_fill] * pivot_token_len + neighbor_lat_list
#including cls and sep
sent_len = len(pseudo_sentence)
max_token_len_middle = max_token_len -2 # 2 for CLS and SEP token
# padding and truncation
if sent_len > max_token_len_middle :
pseudo_sentence = [cls_token_id] + pseudo_sentence[:max_token_len_middle] + [sep_token_id]
dist_lat_list = [spatial_dist_fill] + dist_lat_list[:max_token_len_middle]+ [spatial_dist_fill]
dist_lng_list = [spatial_dist_fill] + dist_lng_list[:max_token_len_middle]+ [spatial_dist_fill]
attention_mask = [False] + [1] * max_token_len_middle + [False] # make sure SEP and CLS are not attented to
else:
pad_len = max_token_len_middle - sent_len
assert pad_len >= 0
pseudo_sentence = [cls_token_id] + pseudo_sentence + [sep_token_id] + [pad_token_id] * pad_len
dist_lat_list = [spatial_dist_fill] + dist_lat_list + [spatial_dist_fill] + [spatial_dist_fill] * pad_len
dist_lng_list = [spatial_dist_fill] + dist_lng_list + [spatial_dist_fill] + [spatial_dist_fill] * pad_len
attention_mask = [False] + [1] * sent_len + [0] * pad_len + [False]
norm_lng_list = np.array(dist_lng_list) # / 0.0001
norm_lat_list = np.array(dist_lat_list) # / 0.0001
## mask entity in the pseudo sentence
#entity_mask_indices = np.where(entity_mask_arr)[0]
#masked_entity_input = [mask_token_id if i in entity_mask_indices else pseudo_sentence[i] for i in range(0, max_token_len)]
#
#
## mask token in the pseudo sentence
#rand_token = np.random.uniform(size = len(pseudo_sentence))
## do not mask out cls and sep token. True: masked tokens False: Keey original token
#token_mask_arr = (rand_token <0.15) & (np.array(pseudo_sentence) != cls_token_id) & (np.array(pseudo_sentence) != sep_token_id) & (np.array(pseudo_sentence) != pad_token_id)
#token_mask_indices = np.where(token_mask_arr)[0]
#
#masked_token_input = [mask_token_id if i in token_mask_indices else pseudo_sentence[i] for i in range(0, max_token_len)]
#
#
## yield masked_token with 50% prob, masked_entity with 50% prob
#if np.random.rand() > 0.5:
# masked_input = torch.tensor(masked_entity_input)
#else:
# masked_input = torch.tensor(masked_token_input)
masked_input = torch.tensor(pseudo_sentence)
train_data = {}
train_data['pivot_name'] = pivot_name
train_data['pivot_token_len'] = pivot_token_len
train_data['masked_input'] = masked_input
train_data['sent_position_ids'] = torch.tensor(np.arange(0, len(pseudo_sentence)))
train_data['attention_mask'] = torch.tensor(attention_mask)
train_data['norm_lng_list'] = torch.tensor(norm_lng_list).to(torch.float32)
train_data['norm_lat_list'] = torch.tensor(norm_lat_list).to(torch.float32)
train_data['pseudo_sentence'] = torch.tensor(pseudo_sentence)
return train_data
def __len__(self):
return NotImplementedError
def __getitem__(self, index):
raise NotImplementedError
|