{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import json\n", "import pandas as pd\n", "\n", "# LOCATION OF THE OSM DATA FOR FINE-TUNING\n", "data = 'tutorial_datasets/osm_mn.csv'\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "## CONSTRUCT DATASET FOR FINE TUNING ##\n", "\n", "# Read data from .csv data file\n", "\n", "state_frame = pd.read_csv(data)\n", "\n", "\n", "# construct list of names and coordinates from data\n", "name_list = []\n", "coordinate_list = []\n", "for i, item in state_frame.iterrows():\n", " name = item[1]\n", " lat = item[2]\n", " lng =item[3]\n", " name_list.append(name)\n", " coordinate_list.append([lng,lat])\n", "\n", "\n", "# construct KDTree out of coordinates list for when we make the neighbor lists\n", "import scipy.spatial as scp\n", "\n", "ordered_neighbor_coordinate_list = scp.KDTree(coordinate_list)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "state_frame" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "# Get top 20 nearest neighbors for each entity in dataset\n", "with open('tutorial_datasets/SPABERT_finetuning_data.json', 'w') as out_f:\n", " for i, item in state_frame.iterrows():\n", " name = item[1]\n", " lat = item[2]\n", " lng = item[3]\n", " coordinates = [lng,lat]\n", "\n", " _, nearest_neighbors_idx = ordered_neighbor_coordinate_list.query([coordinates], k=21)\n", "\n", " # we want to store their names and coordinates\n", "\n", " nearest_neighbors_name = []\n", " nearest_neighbors_coords = []\n", " \n", " # iterate over nearest neighbors list\n", " for idx in nearest_neighbors_idx[0]:\n", " # get name and coordinate of neighbor\n", " neighbor_name = name_list[idx]\n", " neighbor_coords = coordinate_list[idx]\n", " nearest_neighbors_name.append(neighbor_name)\n", " nearest_neighbors_coords.append({\"coordinates\": neighbor_coords})\n", " \n", " # construct neighbor info dictionary object for SpaBERT embedding construction\n", " neighbor_info = {\"name_list\":nearest_neighbors_name, \"geometry_list\":nearest_neighbors_coords}\n", "\n", "\n", " # construct full dictionary object for SpaBERT embedding construction\n", " place = {\"info\":{\"name\":name, \"geometry\":{\"coordinates\": coordinates}}, \"neighbor_info\":neighbor_info}\n", "\n", " out_f.write(json.dumps(place))\n", " out_f.write('\\n')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "### FINE-TUNE SPABERT\n", "import sys\n", "from transformers.models.bert.modeling_bert import BertForMaskedLM\n", "from transformers import BertTokenizer\n", "sys.path.append(\"../\")\n", "from models.spatial_bert_model import SpatialBertConfig\n", "from utils.common_utils import load_spatial_bert_pretrained_weights\n", "from models.spatial_bert_model import SpatialBertForMaskedLM\n", "\n", "# load dataset we just created\n", "\n", "dataset = 'tutorial_datasets/SPABERT_finetuning_data.json'\n", "\n", "# load pre-trained spabert model\n", "\n", "pretrained_model = 'tutorial_datasets/mlm_mem_keeppos_ep0_iter06000_0.2936.pth'\n", "\n", "\n", "# load bert model and tokenizer as well as the SpaBERT config\n", "bert_model = BertForMaskedLM.from_pretrained('bert-base-uncased')\n", "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n", "config = SpatialBertConfig()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# load pre-trained spabert model\n", "import torch\n", "model = SpatialBertForMaskedLM(config)\n", "\n", "model.load_state_dict(bert_model.state_dict() , strict = False) \n", "\n", "pre_trained_model = torch.load(pretrained_model)\n", "\n", "model_keys = model.state_dict()\n", "cnt_layers = 0\n", "for key in model_keys:\n", " if key in pre_trained_model:\n", " model_keys[key] = pre_trained_model[key]\n", " cnt_layers += 1\n", " else:\n", " print(\"No weight for\", key)\n", "print(cnt_layers, 'layers loaded')\n", "\n", "model.load_state_dict(model_keys)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from datasets.osm_sample_loader import PbfMapDataset\n", "from torch.utils.data import DataLoader\n", "# load fine-tning dataset with data loader\n", "\n", "fine_tune_dataset = PbfMapDataset(data_file_path = dataset, \n", " tokenizer = tokenizer, \n", " max_token_len = 300, \n", " distance_norm_factor = 0.0001, \n", " spatial_dist_fill = 20, \n", " with_type = False,\n", " sep_between_neighbors = False, \n", " label_encoder = None,\n", " mode = None)\n", "#initialize data loader\n", "train_loader = DataLoader(fine_tune_dataset, batch_size=12, num_workers=5, shuffle=False, pin_memory=True, drop_last=True)\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "# cast our loaded model to a gpu if one is available, otherwise use the cpu\n", "device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n", "model.to(device)\n", "\n", "# set model to training mode\n", "model.train()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "### FINE TUNING PROCEDURE ###\n", "from tqdm import tqdm \n", "from transformers import AdamW\n", "# initialize optimizer\n", "optim = AdamW(model.parameters(), lr = 5e-5)\n", "\n", "# setup loop with TQDM and dataloader\n", "epoch = tqdm(train_loader, leave=True)\n", "iter = 0\n", "for batch in epoch:\n", " # initialize calculated gradients from previous step\n", " optim.zero_grad()\n", "\n", " # pull all tensor batches required for training\n", " input_ids = batch['masked_input'].to(device)\n", " attention_mask = batch['attention_mask'].to(device)\n", " position_list_x = batch['norm_lng_list'].to(device)\n", " position_list_y = batch['norm_lat_list'].to(device)\n", " sent_position_ids = batch['sent_position_ids'].to(device)\n", "\n", " labels = batch['pseudo_sentence'].to(device)\n", "\n", " # get outputs of model\n", " outputs = model(input_ids, attention_mask = attention_mask, sent_position_ids = sent_position_ids,\n", " position_list_x = position_list_x, position_list_y = position_list_y, labels = labels)\n", " \n", "\n", " # calculate loss\n", " loss = outputs.loss\n", "\n", " # perform backpropigation\n", " loss.backward()\n", "\n", " optim.step()\n", " epoch.set_postfix({'loss':loss.item()})\n", "\n", "\n", " iter += 1\n", "torch.save(model.state_dict(), \"tutorial_datasets/fine-spabert-base-uncased-finetuned-osm-mn.pth\")\n" ] } ], "metadata": { "kernelspec": { "display_name": "base", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.3" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }