Spaces:
Running
Running
#!/usr/bin/env python | |
# coding: utf-8 | |
import sys | |
import os | |
import numpy as np | |
import pdb | |
import json | |
import scipy.spatial as sp | |
import argparse | |
import torch | |
from torch.utils.data import DataLoader | |
from transformers import AdamW | |
from transformers import BertTokenizer | |
from tqdm import tqdm # for our progress bar | |
sys.path.append('../../../') | |
from datasets.usgs_os_sample_loader import USGS_MapDataset | |
from datasets.wikidata_sample_loader import Wikidata_Geocoord_Dataset, Wikidata_Random_Dataset | |
from models.spatial_bert_model import SpatialBertModel | |
from models.spatial_bert_model import SpatialBertConfig | |
from utils.find_closest import find_ref_closest_match, sort_ref_closest_match | |
from utils.common_utils import load_spatial_bert_pretrained_weights, get_spatialbert_embedding, get_bert_embedding, write_to_csv | |
from utils.baseline_utils import get_baseline_model | |
from transformers import BertModel | |
sys.path.append('/home/zekun/spatial_bert/spatial_bert/datasets') | |
from dataset_loader import SpatialDataset | |
from osm_sample_loader import PbfMapDataset | |
MODEL_OPTIONS = ['spatial_bert-base','spatial_bert-large', 'bert-base','bert-large','roberta-base','roberta-large', | |
'spanbert-base','spanbert-large','luke-base','luke-large', | |
'simcse-bert-base','simcse-bert-large','simcse-roberta-base','simcse-roberta-large'] | |
CANDSET_MODES = ['all_map'] # candidate set is constructed based on all maps or one map | |
def recall_at_k(rank_list, k = 1): | |
total_query = len(rank_list) | |
recall = np.sum(np.array(rank_list)<=k) | |
recall = 1.0 * recall / total_query | |
return recall | |
def reciprocal_rank(all_rank_list): | |
recip_list = [1./rank for rank in all_rank_list] | |
mean_recip = np.mean(recip_list) | |
return mean_recip, recip_list | |
def link_to_itself(source_embedding_ogc_list, target_embedding_ogc_list): | |
source_emb_list = [source_dict['emb'] for source_dict in source_embedding_ogc_list] | |
source_ogc_list = [source_dict['ogc_fid'] for source_dict in source_embedding_ogc_list] | |
target_emb_list = [target_dict['emb'] for target_dict in target_embedding_ogc_list] | |
target_ogc_list = [target_dict['ogc_fid'] for target_dict in target_embedding_ogc_list] | |
rank_list = [] | |
for source_emb, source_ogc in zip(source_emb_list, source_ogc_list): | |
sim_matrix = 1 - sp.distance.cdist(np.array(target_emb_list), np.array([source_emb]), 'cosine') | |
closest_match_ogc = sort_ref_closest_match(sim_matrix, target_ogc_list) | |
closest_match_ogc = [a[0] for a in closest_match_ogc] | |
rank = closest_match_ogc.index(source_ogc) +1 | |
rank_list.append(rank) | |
mean_recip, recip_list = reciprocal_rank(rank_list) | |
r1 = recall_at_k(rank_list, k = 1) | |
r5 = recall_at_k(rank_list, k = 5) | |
r10 = recall_at_k(rank_list, k = 10) | |
return mean_recip , r1, r5, r10 | |
def get_embedding_and_ogc(dataset, model_name, model): | |
dict_list = [] | |
for source in dataset: | |
if model_name == 'spatial_bert-base' or model_name == 'spatial_bert-large': | |
source_emb = get_spatialbert_embedding(source, model) | |
else: | |
source_emb = get_bert_embedding(source, model) | |
source_dict = {} | |
source_dict['emb'] = source_emb | |
source_dict['ogc_fid'] = source['ogc_fid'] | |
#wikidata_dict['wikidata_des_list'] = [wikidata_cand['description']] | |
dict_list.append(source_dict) | |
return dict_list | |
def entity_linking_func(args): | |
model_name = args.model_name | |
candset_mode = args.candset_mode | |
distance_norm_factor = args.distance_norm_factor | |
spatial_dist_fill= args.spatial_dist_fill | |
sep_between_neighbors = args.sep_between_neighbors | |
spatial_bert_weight_dir = args.spatial_bert_weight_dir | |
spatial_bert_weight_name = args.spatial_bert_weight_name | |
if_no_spatial_distance = args.no_spatial_distance | |
random_remove_neighbor = args.random_remove_neighbor | |
assert model_name in MODEL_OPTIONS | |
assert candset_mode in CANDSET_MODES | |
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
if model_name == 'spatial_bert-base': | |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
config = SpatialBertConfig() | |
model = SpatialBertModel(config) | |
model.to(device) | |
model.eval() | |
# load pretrained weights | |
weight_path = os.path.join(spatial_bert_weight_dir, spatial_bert_weight_name) | |
model = load_spatial_bert_pretrained_weights(model, weight_path) | |
elif model_name == 'spatial_bert-large': | |
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased') | |
config = SpatialBertConfig(hidden_size = 1024, intermediate_size = 4096, num_attention_heads=16, num_hidden_layers=24) | |
model = SpatialBertModel(config) | |
model.to(device) | |
model.eval() | |
# load pretrained weights | |
weight_path = os.path.join(spatial_bert_weight_dir, spatial_bert_weight_name) | |
model = load_spatial_bert_pretrained_weights(model, weight_path) | |
else: | |
model, tokenizer = get_baseline_model(model_name) | |
model.to(device) | |
model.eval() | |
source_file_path = '../data/osm-point-minnesota-full.json' | |
source_dataset = PbfMapDataset(data_file_path = source_file_path, | |
tokenizer = tokenizer, | |
max_token_len = 512, | |
distance_norm_factor = distance_norm_factor, | |
spatial_dist_fill = spatial_dist_fill, | |
with_type = False, | |
sep_between_neighbors = sep_between_neighbors, | |
mode = None, | |
random_remove_neighbor = random_remove_neighbor, | |
) | |
target_dataset = PbfMapDataset(data_file_path = source_file_path, | |
tokenizer = tokenizer, | |
max_token_len = 512, | |
distance_norm_factor = distance_norm_factor, | |
spatial_dist_fill = spatial_dist_fill, | |
with_type = False, | |
sep_between_neighbors = sep_between_neighbors, | |
mode = None, | |
random_remove_neighbor = 0., # keep all | |
) | |
# process candidates for each phrase | |
source_embedding_ogc_list = get_embedding_and_ogc(source_dataset, model_name, model) | |
target_embedding_ogc_list = get_embedding_and_ogc(target_dataset, model_name, model) | |
mean_recip , r1, r5, r10 = link_to_itself(source_embedding_ogc_list, target_embedding_ogc_list) | |
print('\n') | |
print(random_remove_neighbor, mean_recip , r1, r5, r10) | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--model_name', type=str, default='spatial_bert-base') | |
parser.add_argument('--candset_mode', type=str, default='all_map') | |
parser.add_argument('--distance_norm_factor', type=float, default = 0.0001) | |
parser.add_argument('--spatial_dist_fill', type=float, default = 20) | |
parser.add_argument('--sep_between_neighbors', default=False, action='store_true') | |
parser.add_argument('--no_spatial_distance', default=False, action='store_true') | |
parser.add_argument('--spatial_bert_weight_dir', type = str, default = None) | |
parser.add_argument('--spatial_bert_weight_name', type = str, default = None) | |
parser.add_argument('--random_remove_neighbor', type = float, default = 0.) | |
args = parser.parse_args() | |
# print('\n') | |
# print(args) | |
# print('\n') | |
entity_linking_func(args) | |
# CUDA_VISIBLE_DEVICES='1' python3 linking_ablation.py --sep_between_neighbors --model_name='spatial_bert-base' --spatial_bert_weight_dir='/data/zekun/spatial_bert_weights/typing_lr5e-05_sep_bert-base_nofreeze_london_california_bsize12/ep0_iter06000_0.2936/' --spatial_bert_weight_name='keeppos_ep0_iter02000_0.4879.pth' --random_remove_neighbor=0.1 | |
# CUDA_VISIBLE_DEVICES='1' python3 linking_ablation.py --sep_between_neighbors --model_name='spatial_bert-large' --spatial_bert_weight_dir='/data/zekun/spatial_bert_weights/typing_lr1e-06_sep_bert-large_nofreeze_london_california_bsize12/ep2_iter02000_0.3921/' --spatial_bert_weight_name='keeppos_ep8_iter03568_0.2661_val0.2284.pth' --random_remove_neighbor=0.1 | |
if __name__ == '__main__': | |
main() | |