File size: 174,974 Bytes
46e0dd0
1
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"gpuType":"T4","authorship_tag":"ABX9TyNl9y7CuTJ2SdbjaaYlTn5a"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"accelerator":"GPU","widgets":{"application/vnd.jupyter.widget-state+json":{"cc80d2ca9fa7420dab91cf3ff2a51f1e":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_bcf91ae4f4a2410cbadd53a7d6ffc48e","IPY_MODEL_d72605ed5f8949959ddf14e24b085444","IPY_MODEL_0a30f60bb8d94e34b0cb13e7b8824558"],"layout":"IPY_MODEL_399dad9e63714b06bae2627a732312a6"}},"bcf91ae4f4a2410cbadd53a7d6ffc48e":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_089f72c2d1da4834a1c72c2f9f00e155","placeholder":"​","style":"IPY_MODEL_a0fc0aec5c4a46a88c68d302e3dd9b51","value":"config.json: 100%"}},"d72605ed5f8949959ddf14e24b085444":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"success","description":"","description_tooltip":null,"layout":"IPY_MODEL_4180fb0c020c42479407045aa48377a6","max":570,"min":0,"orientation":"horizontal","style":"IPY_MODEL_829713166e88442b96998e9f9c8d33d5","value":570}},"0a30f60bb8d94e34b0cb13e7b8824558":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_c9f61fca47f44527b9fb31f7a0493048","placeholder":"​","style":"IPY_MODEL_3bae135b3d594d72a41a53c05acf3890","value":" 570/570 [00:00&lt;00:00, 29.6kB/s]"}},"399dad9e63714b06bae2627a732312a6":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"089f72c2d1da4834a1c72c2f9f00e155":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"a0fc0aec5c4a46a88c68d302e3dd9b51":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"4180fb0c020c42479407045aa48377a6":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"829713166e88442b96998e9f9c8d33d5":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"c9f61fca47f44527b9fb31f7a0493048":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"3bae135b3d594d72a41a53c05acf3890":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"9baa380ef84441c29cb888cfd7217bdd":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_16b982d3a3334499aa9115b4482d3ea4","IPY_MODEL_d5e7c582da6542879350f0ddc3302f4c","IPY_MODEL_c3e1afb4efc74d9eb55d2fc92cd193fb"],"layout":"IPY_MODEL_009be06baac3434e88e6566cfa3d0a95"}},"16b982d3a3334499aa9115b4482d3ea4":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_ddd0b2a5124f4027b109bc1237ddc1fd","placeholder":"​","style":"IPY_MODEL_71d85e36c226475a8e5ddf35c8a4032f","value":"model.safetensors: 100%"}},"d5e7c582da6542879350f0ddc3302f4c":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"success","description":"","description_tooltip":null,"layout":"IPY_MODEL_993c8521e06f4187ba7d81c4b04fde58","max":440449768,"min":0,"orientation":"horizontal","style":"IPY_MODEL_cace3b560f3a4aad8680750580c13c83","value":440449768}},"c3e1afb4efc74d9eb55d2fc92cd193fb":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_b2ef3ed771bc4029ba8686afeed7b3d1","placeholder":"​","style":"IPY_MODEL_ba2b04e360d7442885191849cc3bfedf","value":" 440M/440M [00:02&lt;00:00, 236MB/s]"}},"009be06baac3434e88e6566cfa3d0a95":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"ddd0b2a5124f4027b109bc1237ddc1fd":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"71d85e36c226475a8e5ddf35c8a4032f":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"993c8521e06f4187ba7d81c4b04fde58":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"cace3b560f3a4aad8680750580c13c83":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"b2ef3ed771bc4029ba8686afeed7b3d1":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"ba2b04e360d7442885191849cc3bfedf":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"f5a77e1a0b3848c7b083af9ed12c60f9":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_8d3b59c822844b32a18f728f1f6ccc1d","IPY_MODEL_479c46a0cea4479db45f1bec4c519549","IPY_MODEL_a1af6531177240a2a4d9b00633e09a00"],"layout":"IPY_MODEL_8bff7f55fcb041d09b7f5b1fda669faa"}},"8d3b59c822844b32a18f728f1f6ccc1d":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_133331fa90c24047a7e37ded3008407d","placeholder":"​","style":"IPY_MODEL_4f8519013f934942a50bb795bbb79f3c","value":"tokenizer_config.json: 100%"}},"479c46a0cea4479db45f1bec4c519549":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"success","description":"","description_tooltip":null,"layout":"IPY_MODEL_f2b387c1d3394532847c82beb5a805fe","max":48,"min":0,"orientation":"horizontal","style":"IPY_MODEL_13af0eb33cb44239aa35d654a7697f9e","value":48}},"a1af6531177240a2a4d9b00633e09a00":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_473d8f4f54084732a88940fbdbcd0263","placeholder":"​","style":"IPY_MODEL_c31bc70cbe5943b585923bed400b732b","value":" 48.0/48.0 [00:00&lt;00:00, 3.05kB/s]"}},"8bff7f55fcb041d09b7f5b1fda669faa":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"133331fa90c24047a7e37ded3008407d":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"4f8519013f934942a50bb795bbb79f3c":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"f2b387c1d3394532847c82beb5a805fe":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"13af0eb33cb44239aa35d654a7697f9e":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"473d8f4f54084732a88940fbdbcd0263":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"c31bc70cbe5943b585923bed400b732b":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"509d34130ad047e485e552bb24aaf6cd":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_fbcb53f2515547beb3963a9982809c4d","IPY_MODEL_55b525378cd6401fba05a64406de1731","IPY_MODEL_6a8a12d558f443bea0b3afa4cad62706"],"layout":"IPY_MODEL_6a77611d249a40ea81069c89f7a51c83"}},"fbcb53f2515547beb3963a9982809c4d":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_0b46da4d7fae43a7a31837d49829e99e","placeholder":"​","style":"IPY_MODEL_f7ec0ead995149848add22ed117f8ec5","value":"vocab.txt: 100%"}},"55b525378cd6401fba05a64406de1731":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"success","description":"","description_tooltip":null,"layout":"IPY_MODEL_b3dddb1d37d642e38185ba480f3797bf","max":231508,"min":0,"orientation":"horizontal","style":"IPY_MODEL_c4600b06964b4963a7e0e5c01fb1d902","value":231508}},"6a8a12d558f443bea0b3afa4cad62706":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_71cae588996a44b995b25dfc9b139610","placeholder":"​","style":"IPY_MODEL_0208388dda1145afb692fb6c4d871c4f","value":" 232k/232k [00:00&lt;00:00, 1.74MB/s]"}},"6a77611d249a40ea81069c89f7a51c83":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"0b46da4d7fae43a7a31837d49829e99e":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"f7ec0ead995149848add22ed117f8ec5":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"b3dddb1d37d642e38185ba480f3797bf":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"c4600b06964b4963a7e0e5c01fb1d902":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"71cae588996a44b995b25dfc9b139610":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"0208388dda1145afb692fb6c4d871c4f":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"77fc854955f64509beafb122a7394df5":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_bf0f18e217f0412f86beb69c2e60002b","IPY_MODEL_7b01d86c894f47209b5741eb8dd53ade","IPY_MODEL_e5c37e67686f4ce9964c307bb7b5203c"],"layout":"IPY_MODEL_8bad0134480b481ea937bb240758f002"}},"bf0f18e217f0412f86beb69c2e60002b":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_47916c8b12cd48348f5307fe5daf3d58","placeholder":"​","style":"IPY_MODEL_636e84e1654c4f4fb1a365484cfa0f35","value":"tokenizer.json: 100%"}},"7b01d86c894f47209b5741eb8dd53ade":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"success","description":"","description_tooltip":null,"layout":"IPY_MODEL_18da796336e54e67b7f8523f9ffe3466","max":466062,"min":0,"orientation":"horizontal","style":"IPY_MODEL_b99a1fe6b98f4368957d2f77efbe77be","value":466062}},"e5c37e67686f4ce9964c307bb7b5203c":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_cb051e2148114a559341ed3ee27363e1","placeholder":"​","style":"IPY_MODEL_fd695ec60ca64355923ee9af1e0bcf87","value":" 466k/466k [00:00&lt;00:00, 6.47MB/s]"}},"8bad0134480b481ea937bb240758f002":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"47916c8b12cd48348f5307fe5daf3d58":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"636e84e1654c4f4fb1a365484cfa0f35":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"18da796336e54e67b7f8523f9ffe3466":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"b99a1fe6b98f4368957d2f77efbe77be":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"cb051e2148114a559341ed3ee27363e1":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"fd695ec60ca64355923ee9af1e0bcf87":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}}}}},"cells":[{"cell_type":"markdown","source":[" This Notebook is to get the Embeddings from our Fine-Tuned SpaBERT model so that we can send them to the GAN-BERT Notebook in place of spatial data."],"metadata":{"id":"BqGM3v_bGGUU"}},{"cell_type":"markdown","source":["# Mount and Import"],"metadata":{"id":"Q_MLWWHJGqkE"}},{"cell_type":"code","execution_count":1,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"Pwe1nsga9EUN","executionInfo":{"status":"ok","timestamp":1722474453697,"user_tz":420,"elapsed":16813,"user":{"displayName":"Jason Phillips","userId":"10136472498761089328"}},"outputId":"102141fa-5a6f-486c-9de4-518599f86c2b"},"outputs":[{"output_type":"stream","name":"stdout","text":["Mounted at /content/drive\n","/content/drive\n"]}],"source":["#Mount Google Drive\n","from google.colab import drive\n","drive.mount('/content/drive')\n","%cd '/content/drive'"]},{"cell_type":"code","source":["import sys\n","models_path = '/content/drive/MyDrive/spaBERT/spabert'\n","sys.path.append(models_path)\n","sys.path.append('/content/drive/MyDrive/spaBERT/spabert/datasets')\n","sys.path.append(\"../\")\n","print(sys.path)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"L1jV-w8GF3cP","executionInfo":{"status":"ok","timestamp":1722474463934,"user_tz":420,"elapsed":217,"user":{"displayName":"Jason Phillips","userId":"10136472498761089328"}},"outputId":"d58a683b-6a31-44db-b061-3369269bd9c0"},"execution_count":2,"outputs":[{"output_type":"stream","name":"stdout","text":["['/content', '/env/python', '/usr/lib/python310.zip', '/usr/lib/python3.10', '/usr/lib/python3.10/lib-dynload', '', '/usr/local/lib/python3.10/dist-packages', '/usr/lib/python3/dist-packages', '/usr/local/lib/python3.10/dist-packages/IPython/extensions', '/usr/local/lib/python3.10/dist-packages/setuptools/_vendor', '/root/.ipython', '/content/drive/MyDrive/spaBERT/spabert', '/content/drive/MyDrive/spaBERT/spabert/datasets', '../']\n"]}]},{"cell_type":"markdown","source":["# Load SpaBERT with our pretrained weights\n"],"metadata":{"id":"7Sg4y6aEGwYm"}},{"cell_type":"code","source":["\n","import sys\n","import torch\n","from transformers.models.bert.modeling_bert import BertForMaskedLM\n","from transformers import BertTokenizer\n","from models.spatial_bert_model import SpatialBertConfig\n","from utils.common_utils import load_spatial_bert_pretrained_weights\n","from models.spatial_bert_model import  SpatialBertForMaskedLM\n","from models.spatial_bert_model import  SpatialBertModel\n","\n","\n","# load dataset we just created\n","data_file_path  = '/content/drive/MyDrive/Master_Project_2024_JP/Spacy Notebook/SPABERT_Coordinate_data_combined.json'\n","pretrained_model = '/content/drive/MyDrive/Master_Project_2024_JP/Spacy Notebook/fine-spabert-base-uncased-finetuned-osm-mn.pth'\n","#pretrained_model = '/content/drive/MyDrive/spaBERT/spabert/notebooks/tutorial_datasets/mlm_mem_keeppos_ep0_iter06000_0.2936.pth'\n","#pretrained_model = '/content/drive/MyDrive/spaBERT/spabert/notebooks/tutorial_datasets/spabert-base-uncased-finetuned-osm-mn.pth'\n","\n","# load bert model and tokenizer\n","bert_model = BertForMaskedLM.from_pretrained('bert-base-uncased')\n","tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n","\n","\n","# load pre-trained spabert model and its config\n","config = SpatialBertConfig()\n","config.output_hidden_states = True\n","\n","model = SpatialBertForMaskedLM(config)            #Should I be using masked or unmasked for the downstream tasks we are trying to perform?\n","#model = SpatialBertModel(config)                 #We fine-tuned the Masked version of the model so the weights won't load correctly\n","\n","model.load_state_dict(bert_model.state_dict() , strict = False)\n","\n","pre_trained_model = torch.load(pretrained_model)\n","\n","# load pretrained weights\n","model_keys = model.state_dict()\n","cnt_layers = 0\n","for key in model_keys:\n","    if key in pre_trained_model:\n","        model_keys[key] = pre_trained_model[key]\n","        cnt_layers += 1\n","    else:\n","        print(\"No weight for\", key)\n","print(cnt_layers, 'layers loaded')\n","\n","model.load_state_dict(model_keys)\n","\n","#Select a CPU or GPU\n","device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n","model.to(device)\n","\n","#Set the model to evaluation mode\n","model.eval()"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":1000,"referenced_widgets":["cc80d2ca9fa7420dab91cf3ff2a51f1e","bcf91ae4f4a2410cbadd53a7d6ffc48e","d72605ed5f8949959ddf14e24b085444","0a30f60bb8d94e34b0cb13e7b8824558","399dad9e63714b06bae2627a732312a6","089f72c2d1da4834a1c72c2f9f00e155","a0fc0aec5c4a46a88c68d302e3dd9b51","4180fb0c020c42479407045aa48377a6","829713166e88442b96998e9f9c8d33d5","c9f61fca47f44527b9fb31f7a0493048","3bae135b3d594d72a41a53c05acf3890","9baa380ef84441c29cb888cfd7217bdd","16b982d3a3334499aa9115b4482d3ea4","d5e7c582da6542879350f0ddc3302f4c","c3e1afb4efc74d9eb55d2fc92cd193fb","009be06baac3434e88e6566cfa3d0a95","ddd0b2a5124f4027b109bc1237ddc1fd","71d85e36c226475a8e5ddf35c8a4032f","993c8521e06f4187ba7d81c4b04fde58","cace3b560f3a4aad8680750580c13c83","b2ef3ed771bc4029ba8686afeed7b3d1","ba2b04e360d7442885191849cc3bfedf","f5a77e1a0b3848c7b083af9ed12c60f9","8d3b59c822844b32a18f728f1f6ccc1d","479c46a0cea4479db45f1bec4c519549","a1af6531177240a2a4d9b00633e09a00","8bff7f55fcb041d09b7f5b1fda669faa","133331fa90c24047a7e37ded3008407d","4f8519013f934942a50bb795bbb79f3c","f2b387c1d3394532847c82beb5a805fe","13af0eb33cb44239aa35d654a7697f9e","473d8f4f54084732a88940fbdbcd0263","c31bc70cbe5943b585923bed400b732b","509d34130ad047e485e552bb24aaf6cd","fbcb53f2515547beb3963a9982809c4d","55b525378cd6401fba05a64406de1731","6a8a12d558f443bea0b3afa4cad62706","6a77611d249a40ea81069c89f7a51c83","0b46da4d7fae43a7a31837d49829e99e","f7ec0ead995149848add22ed117f8ec5","b3dddb1d37d642e38185ba480f3797bf","c4600b06964b4963a7e0e5c01fb1d902","71cae588996a44b995b25dfc9b139610","0208388dda1145afb692fb6c4d871c4f","77fc854955f64509beafb122a7394df5","bf0f18e217f0412f86beb69c2e60002b","7b01d86c894f47209b5741eb8dd53ade","e5c37e67686f4ce9964c307bb7b5203c","8bad0134480b481ea937bb240758f002","47916c8b12cd48348f5307fe5daf3d58","636e84e1654c4f4fb1a365484cfa0f35","18da796336e54e67b7f8523f9ffe3466","b99a1fe6b98f4368957d2f77efbe77be","cb051e2148114a559341ed3ee27363e1","fd695ec60ca64355923ee9af1e0bcf87"]},"id":"loMR8XHdzdM8","executionInfo":{"status":"ok","timestamp":1722474490548,"user_tz":420,"elapsed":24122,"user":{"displayName":"Jason Phillips","userId":"10136472498761089328"}},"outputId":"3e6a2adc-a3b6-45bc-c068-e781119cb98f"},"execution_count":3,"outputs":[{"output_type":"stream","name":"stderr","text":["/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning: \n","The secret `HF_TOKEN` does not exist in your Colab secrets.\n","To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n","You will be able to reuse this secret in all of your notebooks.\n","Please note that authentication is recommended but still optional to access public models or datasets.\n","  warnings.warn(\n"]},{"output_type":"display_data","data":{"text/plain":["config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]"],"application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"cc80d2ca9fa7420dab91cf3ff2a51f1e"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":["model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]"],"application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"9baa380ef84441c29cb888cfd7217bdd"}},"metadata":{}},{"output_type":"stream","name":"stderr","text":["Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']\n","- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n","- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"]},{"output_type":"display_data","data":{"text/plain":["tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]"],"application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"f5a77e1a0b3848c7b083af9ed12c60f9"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":["vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]"],"application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"509d34130ad047e485e552bb24aaf6cd"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":["tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]"],"application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"77fc854955f64509beafb122a7394df5"}},"metadata":{}},{"output_type":"stream","name":"stdout","text":["205 layers loaded\n"]},{"output_type":"execute_result","data":{"text/plain":["SpatialBertForMaskedLM(\n","  (bert): SpatialBertModel(\n","    (embeddings): SpatialEmbedding(\n","      (word_embeddings): Embedding(30522, 768, padding_idx=0)\n","      (position_embeddings): Embedding(512, 768)\n","      (sent_position_embedding): Embedding(512, 768)\n","      (spatial_position_embedding): ContinuousSpatialPositionalEmbedding()\n","      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n","      (dropout): Dropout(p=0.1, inplace=False)\n","    )\n","    (encoder): BertEncoder(\n","      (layer): ModuleList(\n","        (0-11): 12 x BertLayer(\n","          (attention): BertAttention(\n","            (self): BertSelfAttention(\n","              (query): Linear(in_features=768, out_features=768, bias=True)\n","              (key): Linear(in_features=768, out_features=768, bias=True)\n","              (value): Linear(in_features=768, out_features=768, bias=True)\n","              (dropout): Dropout(p=0.1, inplace=False)\n","            )\n","            (output): BertSelfOutput(\n","              (dense): Linear(in_features=768, out_features=768, bias=True)\n","              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n","              (dropout): Dropout(p=0.1, inplace=False)\n","            )\n","          )\n","          (intermediate): BertIntermediate(\n","            (dense): Linear(in_features=768, out_features=3072, bias=True)\n","            (intermediate_act_fn): GELUActivation()\n","          )\n","          (output): BertOutput(\n","            (dense): Linear(in_features=3072, out_features=768, bias=True)\n","            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n","            (dropout): Dropout(p=0.1, inplace=False)\n","          )\n","        )\n","      )\n","    )\n","  )\n","  (cls): SpatialBertOnlyMLMHead(\n","    (predictions): SpatialBertLMPredictionHead(\n","      (transform): SpatialBertPredictionHeadTransform(\n","        (dense): Linear(in_features=768, out_features=768, bias=True)\n","        (transform_act_fn): GELUActivation()\n","        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n","      )\n","      (decoder): Linear(in_features=768, out_features=30522, bias=True)\n","    )\n","  )\n",")"]},"metadata":{},"execution_count":3}]},{"cell_type":"markdown","source":["# Load our dataset.\n","\n","Note: The model requires spatial coordinates for each token (entity)\n","\n","Questions:\n","\n","\n","*   Are we sending in pseudo sentences described in the paper?\n","  *   [CLS] University of Minnesota [SEP]\n","      Minneapolis [SEP] ### ### ### [SEP] Bloom\n","      Island Park [SEP] Bell Museum [SEP]\n","*   After we get the embedding for each entity, do we need to link this back to the review and send that to the GAN-BERT model?\n","  *   There is an option to include labels, do we include the real/fake labels from each review\n","\n","\n","\n"],"metadata":{"id":"BUbb3msLHA7X"}},{"cell_type":"code","source":["from datasets.osm_sample_loader import PbfMapDataset\n","from datasets.dataset_loader import SpatialDataset\n","from torch.utils.data import DataLoader\n","\n","# Load data using SpatialDataset\n","dataset = PbfMapDataset(data_file_path = data_file_path,\n","                                        tokenizer = tokenizer,\n","                                        max_token_len = 300,\n","                                        distance_norm_factor = 0.0001,\n","                                        spatial_dist_fill = 20,\n","                                        with_type = False,\n","                                        sep_between_neighbors = False,    #Initially false, play around with this potentially?\n","                                        label_encoder = None,             #Initially None, potentially change this because we do have real/fake reviews.\n","                                        mode = None)                      #If set to None it will use the full dataset for mlm\n","\n","data_loader = DataLoader(dataset, batch_size=1, num_workers=0, shuffle=False, pin_memory=False, drop_last=True) #issue needs to be fixed with num_workers not stopping after finished"],"metadata":{"id":"4VWgHg39BKWg","executionInfo":{"status":"ok","timestamp":1722474540789,"user_tz":420,"elapsed":1495,"user":{"displayName":"Jason Phillips","userId":"10136472498761089328"}}},"execution_count":4,"outputs":[]},{"cell_type":"code","source":["dataset[0]"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"9YGbhJOqMmcC","executionInfo":{"status":"ok","timestamp":1722475050351,"user_tz":420,"elapsed":191,"user":{"displayName":"Jason Phillips","userId":"10136472498761089328"}},"outputId":"a381bbc5-24c8-4dba-82c7-1e9c25a3c73a"},"execution_count":13,"outputs":[{"output_type":"execute_result","data":{"text/plain":["{'pivot_name': 'kabuki',\n"," 'pivot_token_len': 3,\n"," 'masked_input': tensor([  101,   103,  8569,  3211, 10556,  8569,  3211, 19461, 15460, 19461,\n","         15460, 27166,  9818,   103,  3406,  7905,  7014,  3702, 22078, 13226,\n","           103,   103,  2395,   103, 13642,  2899, 13642,  2899,  3927, 13173,\n","         13173,  8529,  4886,  8529,   103, 19923, 25133, 19213,  6187, 27313,\n","          8953,  1996,  6842,  2314,  1996,  6842,   103,   102,     0,     0,\n","             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n","             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n","             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n","             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n","             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n","             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n","             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n","             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n","             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n","             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n","             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n","             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n","             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n","             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n","             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n","             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n","             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n","             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n","             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n","             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n","             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n","             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n","             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n","             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n","             0,     0,     0,     0,     0,     0,     0,     0,     0,     0]),\n"," 'sent_position_ids': tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,\n","          14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,\n","          28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,\n","          42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,\n","          56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,\n","          70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,\n","          84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,\n","          98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,\n","         112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,\n","         126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139,\n","         140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153,\n","         154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167,\n","         168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181,\n","         182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195,\n","         196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209,\n","         210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223,\n","         224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237,\n","         238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251,\n","         252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265,\n","         266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279,\n","         280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293,\n","         294, 295, 296, 297, 298, 299]),\n"," 'attention_mask': tensor([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n","         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,\n","         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n","         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n","         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n","         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n","         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n","         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n","         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n","         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n","         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n","         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n","         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),\n"," 'norm_lng_list': tensor([  20.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,\n","            0.0000,  -22.7640,  -22.7640,  -22.7640,  -22.7640, -147.5950,\n","         -147.5950,    5.4244,    5.4244, -178.6530,  247.4690,  247.4690,\n","           54.8590, -553.7769, -553.7769, -105.8370, -105.8370, -364.8710,\n","         -364.8710, -364.8710, -364.8710, -364.8710, -364.8710,  227.1890,\n","          227.1890, -290.7000, -290.7000, -290.7000, -290.7000, -439.1880,\n","          910.7584,  910.7584,  975.8950,  975.8950,  975.8950, 1019.9832,\n","         1019.9832, 1019.9832, 1019.9832, 1019.9832, 1019.9832,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000]),\n"," 'norm_lat_list': tensor([  20.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,\n","            0.0000,    5.7590,    5.7590,    5.7590,    5.7590, -153.2390,\n","         -153.2390, -319.6765, -319.6765, -484.0480, -453.8350, -453.8350,\n","          524.3650,  -77.6850,  -77.6850, -576.7880, -576.7880, -682.1400,\n","         -682.1400, -682.1400, -682.1400, -682.1400, -682.1400,  739.7750,\n","          739.7750, -764.6410, -764.6410, -764.6410, -764.6410, -773.0530,\n","         -235.0275, -235.0275, -319.4900, -319.4900, -319.4900, -310.2245,\n","         -310.2245, -310.2245, -310.2245, -310.2245, -310.2245,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000,\n","           20.0000,   20.0000,   20.0000,   20.0000,   20.0000,   20.0000]),\n"," 'pseudo_sentence': tensor([  101, 10556,  8569,  3211, 10556,  8569,  3211, 19461, 15460, 19461,\n","         15460, 27166,  9818, 12849,  3406,  7905,  7014,  3702, 22078, 13226,\n","         13226, 11458,  2395,  2899, 13642,  2899, 13642,  2899,  3927, 13173,\n","         13173,  8529,  4886,  8529,  4886, 19923, 25133, 19213,  6187, 27313,\n","          8953,  1996,  6842,  2314,  1996,  6842,  2314,   102,     0,     0,\n","             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n","             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n","             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n","             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n","             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n","             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n","             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n","             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n","             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n","             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n","             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n","             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n","             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n","             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n","             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n","             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n","             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n","             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n","             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n","             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n","             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n","             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n","             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n","             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n","             0,     0,     0,     0,     0,     0,     0,     0,     0,     0])}"]},"metadata":{},"execution_count":13}]},{"cell_type":"code","source":["def get_entity_index(name):\n","    for i, entity in enumerate(dataset):\n","      if i >= 5:\n","            break\n","      if(entity['pivot_name'] == name):\n","        print(i, entity['pivot_name'])\n","        return i"],"metadata":{"id":"ncV1Xgi2wNfY","executionInfo":{"status":"ok","timestamp":1722474551566,"user_tz":420,"elapsed":186,"user":{"displayName":"Jason Phillips","userId":"10136472498761089328"}}},"execution_count":6,"outputs":[]},{"cell_type":"code","source":["entity_index = get_entity_index(\"kabuki\")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"x9yo56zjwZg2","executionInfo":{"status":"ok","timestamp":1722474552761,"user_tz":420,"elapsed":183,"user":{"displayName":"Jason Phillips","userId":"10136472498761089328"}},"outputId":"a2abeb86-a6d2-4dc3-b7ca-150847080ddc"},"execution_count":7,"outputs":[{"output_type":"stream","name":"stdout","text":["0 kabuki\n"]}]},{"cell_type":"code","source":["print(entity_index)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"_nJnwJpWwnrZ","executionInfo":{"status":"ok","timestamp":1722474553821,"user_tz":420,"elapsed":2,"user":{"displayName":"Jason Phillips","userId":"10136472498761089328"}},"outputId":"b48ff1d7-02b0-4c64-df81-c304bda8f347"},"execution_count":8,"outputs":[{"output_type":"stream","name":"stdout","text":["0\n"]}]},{"cell_type":"code","source":["from tqdm import tqdm\n","\n","# Function to process each entity and get embeddings\n","def process_entity(batch, model, device):\n","    input_ids = batch['masked_input'].to(device)\n","    attention_mask = batch['attention_mask'].to(device)\n","    position_list_x = batch['norm_lng_list'].to(device)\n","    position_list_y = batch['norm_lat_list'].to(device)\n","    sent_position_ids = batch['sent_position_ids'].to(device)\n","\n","    with torch.no_grad():\n","        outputs = model(input_ids=input_ids,\n","                        attention_mask=attention_mask,\n","                        sent_position_ids=sent_position_ids,\n","                        position_list_x=position_list_x,\n","                        position_list_y=position_list_y)\n","                        #NOTE: we are ommitting the pseudo_sentence here. Verify that this is correct\n","\n","    # Extract embeddings\n","    #embeddings = outputs[0]                # Extracting the last hidden state from outputs\n","    embeddings = outputs.hidden_states[-1]\n","\n","    pivot_token_len = batch['pivot_token_len'].item()\n","    pivot_embeddings = embeddings[:, :pivot_token_len, :]\n","\n","    return pivot_embeddings.cpu().numpy(), input_ids.cpu().numpy()\n","\n","all_embeddings = []\n","# Process the first 5 rows and print embeddings\n","# NOTE: fix this to make actual batches instead of just one at a time.\n","for i, batch in enumerate(data_loader):\n","    if i >= 5:\n","        break\n","    embeddings, input_ids = process_entity(batch, model, device)\n","    sequence_length = input_ids.shape[1]\n","\n","    print(f\"Embeddings for entity {i+1}: {embeddings}\")\n","    print(f\"Shape for entity {i+1}: {embeddings.shape}\")\n","    print(f\"Sequence Length for entity {i+1}: {sequence_length}\")\n","    print(f\"Input IDs for entity {i+1}: {input_ids}\")\n","    print(f\"Decoded Tokens for entity {i+1}: {tokenizer.decode(input_ids[0])}\")\n","    all_embeddings.append(embeddings)\n","#process the entire dataset and store the embeddings (uncomment when ready)\n","#all_embeddings = []\n","#for batch in tqdm(data_loader):\n","#  embeddings = process_entity(batch, model, device)\n","#  all_embeddings.append(embeddings)\n"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"KkbeEMEvHbug","executionInfo":{"status":"ok","timestamp":1722474659437,"user_tz":420,"elapsed":452,"user":{"displayName":"Jason Phillips","userId":"10136472498761089328"}},"outputId":"de30d974-e482-4268-ddcf-0ab8bcabd1d8"},"execution_count":11,"outputs":[{"output_type":"stream","name":"stdout","text":["Embeddings for entity 1: [[[-5.81115723e-01 -3.62999469e-01 -3.24680656e-01 ... -2.83270359e-01\n","    3.22587132e-01  2.28714406e-01]\n","  [-2.20670611e-01 -1.10545315e-01 -4.42134071e-04 ...  1.89075321e-01\n","    1.37033060e-01  6.88519329e-02]\n","  [-2.28437528e-01 -1.30716190e-01 -2.46341452e-02 ...  1.89642012e-01\n","    1.10516712e-01  8.37075785e-02]]]\n","Shape for entity 1: (1, 3, 768)\n","Sequence Length for entity 1: 300\n","Input IDs for entity 1: [[  103   103   103   103   103   103  3211 19461 15460 19461 15460 27166\n","   9818 12849  3406  7905  7014   103 22078 13226 13226 11458  2395  2899\n","  13642  2899 13642  2899  3927 13173 13173  8529  4886  8529  4886 19923\n","  25133   103   103   103  8953  1996  6842  2314  1996  6842  2314   102\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0]]\n","Decoded Tokens for entity 1: [MASK] [MASK] [MASK] [MASK] [MASK] [MASK]ki quiznos quiznos cnbc koto arch beth [MASK] moe vanessa vanessa herman street washington ave washington ave washington avenue kara kara umai umai provence clover [MASK] [MASK] [MASK]ugh the hudson river the hudson rivern","Embeddings for entity 2: [[[-0.12602906 -0.32964182  0.0087087  ... -0.6199797   0.38162172\n","    0.06165807]\n","  [ 0.5201469  -0.17047702  0.49288183 ... -0.1474094   0.22941227\n","   -0.02108728]]]\n","Shape for entity 2: (1, 2, 768)\n","Sequence Length for entity 2: 300\n","Input IDs for entity 2: [[  101  9763  2479  9763  2479  9763  2479 20829 18996  2050  6384  3077\n","  18641 18641  6222  6222   103 15477  2395 15544 25970  4580  2675 15544\n","  25970  4580  5318 14132 14425  1037   103   103 12674  3790 15544  5753\n","   7570  5092  7520  7570  5092  7520 10090   103  2103  3006   102     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0]]\n","Decoded Tokens for entity 2: [CLS] rhode island rhode island rhode island stacks napa morrisville penelope penelope baltimore baltimore [MASK] bourbon street rittenhouse square rittenhouse brickyard lafayette a [MASK] [MASK] bergenfield ritz hoboken hoboken ruby [MASK] city marketn","Embeddings for entity 3: [[[-0.30242354 -0.1258468  -0.34332785 ... -0.39644253  0.21635099\n","    0.08158564]\n","  [ 0.03023065  1.0184598  -0.6203913  ...  0.00816763  0.23181963\n","    0.8504028 ]]]\n","Shape for entity 3: (1, 2, 768)\n","Sequence Length for entity 3: 300\n","Input IDs for entity 3: [[  101 15845  2358 15845  2358 15544 21827 11462  2015 15845  7668  2032\n","   7911   103 19668  2618  3870  2395  5292   103  5292 19445  2310 12190\n","    103  2310 12190 18175   103 23528  2063  2395   103 25676  5413  2080\n","   7668 10645  3900  1996  2896  2264  2217  1996  2896  2264  2217  1996\n","   2264  2217   103  4644  1038 10559  9102  2395   102     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0]]\n","Decoded Tokens for entity 3: [CLS] stanton st stanton st rivington freemans stanton cafe himala [MASK] slainte elizabeth street ha [MASK] habana verl [MASK] verlaine [MASK] broome street [MASK] barrio chino cafe katja the lower east side the lower east side the east side [MASK] lane bleecker streetn","Embeddings for entity 4: [[[-4.20583300e-02 -1.42036140e-01  2.97396481e-01 -2.66784206e-02\n","    7.87389040e-01 -1.41787067e-01  4.91385423e-02  6.45073593e-01\n","    1.08488336e-01  7.19574749e-01  7.35178664e-02  1.88348725e-01\n","   -6.21508300e-01  6.39314890e-01 -2.15266809e-01  3.57338116e-02\n","    6.63326919e-01 -5.23383096e-02  2.69820154e-01  5.00185728e-01\n","    4.84644413e-01  2.94570178e-01  3.83226089e-02  5.90295911e-01\n","    2.72509068e-01 -7.37917960e-01 -3.82818729e-01  1.03852473e-01\n","   -3.77366662e-01  5.45314491e-01 -3.04280728e-01  1.94835871e-01\n","   -1.19518153e-01 -1.67141885e-01 -5.94381630e-01  1.48739830e-01\n","    8.69034111e-01  5.50171614e-01 -2.29134813e-01 -3.72935981e-01\n","   -6.85629308e-01  3.89814198e-01 -4.42785978e-01  5.12354016e-01\n","    1.74551770e-01  1.77871585e-02 -1.38427067e+00  6.37979448e-01\n","    3.42206120e-01 -9.29677427e-01  1.76999852e-01 -2.26679265e-01\n","   -3.35535675e-01 -2.98362374e-01  1.58995822e-01  7.35969782e-01\n","   -1.72390535e-01 -3.24331284e-01  1.52257785e-01 -5.13611734e-01\n","    2.43234664e-01 -2.62337297e-01 -5.01992293e-02 -2.19223812e-01\n","   -1.74417615e-01  3.91051590e-01 -3.31758261e-01  4.79678690e-01\n","   -3.50416780e-01 -1.16087861e-01  3.79473180e-01  5.45883365e-02\n","    1.06938958e-01  4.18085791e-03  2.79758632e-01 -3.74172330e-01\n","   -4.49101835e-01  6.63903713e-01  4.66531217e-02  1.01018809e-01\n","   -3.95578980e-01 -3.11598718e-01  1.58134066e-02 -8.07352588e-02\n","   -3.33755948e-02 -6.00086033e-01  6.17646694e-01  5.87905049e-01\n","    3.14644724e-01  5.14653027e-01 -2.95794517e-01 -5.55559635e-01\n","    3.62779856e-01  2.05220893e-01 -7.39268363e-01 -6.30587637e-02\n","    2.01417524e-02  4.68644381e-01 -1.40118361e-01 -1.29960731e-01\n","   -9.45359290e-01 -9.21104372e-01 -1.04621410e-01  2.87229270e-01\n","    6.03421152e-01  9.20968968e-03  5.82647860e-01  5.19698858e-01\n","    4.60989475e-01  8.73905301e-01  5.99700093e-01 -1.64137512e-01\n","    6.52999640e-01 -6.26775026e-01 -1.97182178e-01 -7.93291688e-01\n","    5.57582259e-01 -2.66212523e-01  1.85239583e-01 -7.03943849e-01\n","   -1.32579774e-01 -6.79957807e-01 -3.89716178e-01  5.69840550e-01\n","    2.41307169e-01 -3.36103365e-02 -5.96074425e-02 -1.42137468e-01\n","   -1.18225373e-01  5.31107113e-02  4.67748120e-02  6.44003868e-01\n","    1.48059174e-01 -1.54107526e-01 -6.44745767e-01  5.72740734e-01\n","   -2.67564096e-02 -7.82303140e-02 -1.93561748e-01 -8.09487626e-02\n","   -5.02913594e-02 -2.91087836e-01 -3.19948465e-01 -7.50052854e-02\n","    1.27361283e-01 -3.71302962e-01  1.69466026e-02  2.03009546e-01\n","    3.21476817e-01 -8.63668919e-01 -9.12763104e-02 -4.94651318e-01\n","   -4.03996110e-01 -5.22427857e-01  5.51181100e-02 -2.35734507e-01\n","   -4.82882470e-01 -1.74573883e-01 -9.88620967e-02  4.35078442e-01\n","   -1.23188801e-01  1.92663353e-02  1.84086129e-01 -3.42794448e-01\n","   -1.22237265e-01  4.03788507e-01  2.03192443e-01  1.01162446e+00\n","   -1.03663445e+00 -2.09009856e-01  3.41059029e-01  6.93131387e-01\n","    9.01372015e-01  4.99209881e-01  1.87253579e-01  5.88751078e-01\n","    6.18605912e-01 -2.54198015e-01  2.30199844e-01 -2.57602751e-01\n","   -7.22788930e-01 -1.00487590e-01 -1.71059728e-01  2.15765417e-01\n","    4.25757647e-01 -7.80765474e-01  5.86444557e-01 -4.66398031e-01\n","   -7.23035112e-02 -4.74276990e-01 -3.01024139e-01 -3.25428061e-02\n","   -1.09744772e-01  1.65339619e-01 -1.33192345e-01 -8.04299787e-02\n","   -4.69601393e-01  3.68288428e-01 -5.32886304e-04  1.69890687e-01\n","    4.00348604e-01  5.46520233e-01  5.90638638e-01  1.30353153e-01\n","   -5.35967469e-01  4.10263717e-01  4.51095179e-02 -1.19078286e-01\n","    5.46599746e-01  4.70121801e-01  1.36164710e-01  2.24372745e-02\n","   -2.18612805e-01  9.13304448e-01 -4.32012767e-01 -6.45081773e-02\n","    3.21728349e-01 -7.26598024e-01  2.35831752e-01 -2.82883227e-01\n","   -7.42756069e-01  5.82528152e-02 -1.57105654e-01  7.69988820e-02\n","   -1.20272361e-01  1.14355040e+00  3.59906524e-01 -4.35017586e-01\n","    2.18458205e-01  9.31276083e-01 -4.41959202e-02 -6.19252622e-01\n","    6.43106222e-01  1.78501680e-02 -6.38990164e-01 -2.87731793e-02\n","    8.93166125e-01 -4.49879855e-01  3.44193518e-01 -3.27608705e-01\n","    2.91704059e-01  7.99393430e-02  8.21115613e-01 -6.56152427e-01\n","   -1.95214644e-01 -8.16270709e-01  4.06732075e-02  2.24004969e-01\n","   -3.34382616e-02 -3.86220694e-04  2.43046641e-01 -5.44572651e-01\n","    3.97967279e-01  2.43538264e-02 -6.52384758e-01 -7.50286102e-01\n","   -1.22988865e-01  4.72665161e-01 -6.39030814e-01 -9.14427459e-01\n","    3.49373758e-01 -4.54001635e-01  5.17096996e-01 -1.92048356e-01\n","    8.97237752e-03  1.92620307e-01 -1.11385608e+00  1.36035411e-02\n","    1.22096680e-01 -1.51080698e-01 -1.40820742e-01  2.73605764e-01\n","   -7.88595974e-02  1.10922825e+00  3.04361224e-01 -1.63392186e-01\n","    2.00043246e-01  9.61500928e-02  7.69400835e-01 -7.40555584e-01\n","   -3.31326067e-01  5.52415252e-01 -9.64408144e-02 -5.71601868e-01\n","   -4.02094036e-01 -4.02092725e-01 -3.74877304e-01  4.10249263e-01\n","    1.70451626e-01  3.76181036e-01  1.27399221e-01 -2.94819862e-01\n","    1.55886739e-01 -1.09123319e-01  2.55696803e-01 -9.60553363e-02\n","    3.06817234e-01  3.98070514e-01 -2.74544895e-01 -1.07325697e+00\n","    1.90127477e-01 -5.43784857e-01  2.48867441e-02  1.49757594e-01\n","    1.10319018e-01 -1.85347833e-02  1.67229231e-02 -7.73108065e-01\n","   -5.21152544e+00 -1.23167917e-01 -9.36981365e-02 -1.06195509e+00\n","    2.96960957e-02 -4.36177164e-01  6.05232455e-02 -1.26250684e-01\n","   -2.50521004e-01  1.88304454e-01 -8.31437111e-02 -1.59184054e-01\n","    4.84843493e-01 -1.83688372e-01 -1.18995667e-01 -1.79928824e-01\n","    2.40778357e-01  2.97979731e-02  3.28282416e-02  4.33100075e-01\n","   -6.13244355e-01 -6.68086559e-02  3.61234039e-01 -2.02883184e-01\n","    1.87461913e-01  5.97187638e-01 -4.75949913e-01 -3.54348361e-01\n","    2.38796383e-01 -6.55548930e-01  1.53477639e-01 -7.43759274e-01\n","    1.94866687e-01 -1.00768462e-01  2.72680938e-01 -2.40816195e-02\n","    1.00861453e-01 -1.33742571e+00 -1.01808086e-01 -6.36162981e-02\n","    1.50354028e-01 -1.00004494e-01 -3.45053971e-01 -1.97568417e-01\n","    2.56787896e-01 -8.19468424e-02  3.28859597e-01  3.47068727e-01\n","   -8.65792036e-01 -1.77831814e-01  3.24072897e-01 -1.35814101e-02\n","   -1.25597775e-01  3.28854740e-01  4.10919577e-01  5.57219274e-02\n","    4.38250661e-01  6.61741719e-02  1.02439687e-01 -1.52181283e-01\n","    8.28154206e-01 -1.63404301e-01 -7.52954841e-01  5.85423470e-01\n","    4.27647650e-01 -2.86366165e-01 -3.45892996e-01  5.37865579e-01\n","   -6.90653682e-01  1.78042144e-01 -2.34049246e-01  1.24150239e-01\n","   -1.96845159e-01 -7.80481458e-01 -8.51793122e-03 -4.00982559e-01\n","   -4.22099791e-02 -3.04418892e-01  2.30242372e-01  1.33076772e-01\n","   -4.29064721e-01 -7.99559653e-02  6.48758948e-01 -1.36074066e-01\n","   -3.71639162e-01 -5.11302233e-01 -1.75792471e-01  2.66938448e-01\n","   -6.66414618e-01 -4.31538373e-01  1.14751041e-01  6.42273724e-01\n","    4.85300183e-01 -6.31901026e-02  2.62882680e-01  6.13391817e-01\n","    1.53421238e-01 -1.76589459e-01  1.99390844e-01  6.98335826e-01\n","    3.52326244e-01  1.28748333e-02  7.38424242e-01 -1.01524556e+00\n","    1.45643756e-01  5.22581577e-01 -5.27043700e-01  9.32311416e-01\n","    3.16564858e-01  2.45800480e-01 -8.52773666e-01 -1.12393871e-01\n","    2.61995435e-01 -3.10291201e-01  1.01854414e-01 -7.20303133e-02\n","    5.69662631e-01  2.79316276e-01  1.74242601e-01  6.08565509e-02\n","   -2.90081464e-02  7.88410977e-02  7.00094104e-01  3.33214104e-01\n","   -5.43681800e-01 -9.83734876e-02 -6.88329101e-01 -2.57365346e-01\n","    2.47446582e-01 -2.90095896e-01  2.36264616e-01 -1.94189176e-01\n","    3.57600562e-02  3.47871006e-01  5.96337497e-01 -2.95354035e-02\n","   -7.56269842e-02 -1.58221006e+00  1.30559117e-01  2.56578475e-01\n","   -4.59155649e-01 -1.19588636e-02  3.56034070e-01 -2.27924079e-01\n","   -3.92733783e-01  1.37594506e-01  4.26186413e-01  2.96334028e-01\n","    3.21575284e-01  3.11855674e-01 -1.46409404e+00 -5.16062737e-01\n","   -7.40337819e-02 -3.17381442e-01  4.63878185e-01  5.22419691e-01\n","    1.15563661e-01 -4.31120992e-01 -2.53781497e-01 -4.08371806e-01\n","    6.26711607e-01 -1.21255115e-01 -6.08719707e-01  4.76371676e-01\n","   -1.08336222e+00  1.97172627e-01  6.48893714e-01  3.92040879e-01\n","   -4.26494122e-01  2.69766420e-01 -1.78497881e-01 -2.60389477e-01\n","   -3.12909514e-01 -5.13347685e-01 -9.19822812e-01  3.45755756e-01\n","   -3.24326962e-01 -5.20674706e-01  3.00229371e-01  6.22540638e-02\n","    8.90340135e-02 -2.61194915e-01 -4.81542081e-01 -3.10510676e-03\n","   -7.30873942e-02  4.49387521e-01  7.01649725e-01  1.13326028e-01\n","    2.90782541e-01  2.43255347e-01  4.31055278e-02  5.79595603e-02\n","    2.53597498e-01  2.62962520e-01 -9.00446296e-01 -4.67785001e-01\n","    3.85944128e-01  7.18332291e-01  3.46215218e-01  4.57312435e-01\n","    4.10991199e-02 -2.16120481e-01  3.92579548e-02  4.00523305e-01\n","    2.59365886e-01  1.47923017e+00  7.76369125e-02  3.54021713e-02\n","    1.26978904e-01 -6.96261302e-02  3.08408260e-01 -4.04556990e-01\n","   -1.48088321e-01  1.26081958e-01  2.90125936e-01 -5.12668133e-01\n","   -4.66430187e-02  5.77867329e-02 -6.85970128e-01  3.39248180e-01\n","   -7.65926600e-01 -1.78885192e-01  1.78952202e-01  4.92331415e-01\n","   -5.51603973e-01 -5.35822809e-01  2.31002700e-02  3.38994674e-02\n","    2.74416059e-02 -3.78680736e-01  7.03071011e-03  3.09685916e-01\n","   -3.20844859e-01  2.38409176e-01 -4.87325341e-01  1.29701602e+00\n","    5.08064032e-02 -5.39849579e-01  6.07332438e-02  1.62858218e-01\n","    3.25906515e-01  3.00551832e-01  5.91792017e-02 -8.02731156e-01\n","   -2.62057960e-01  2.21734136e-01 -3.59922677e-01  4.03454781e-01\n","    4.07700807e-01  6.87892660e-02  4.47493047e-01  3.03660452e-01\n","    9.78790075e-02  4.57405746e-02  6.74565315e-01  4.57604183e-03\n","   -6.70423985e-01 -1.34428665e-01 -1.85283065e-01  3.13133180e-01\n","   -5.98921061e-01 -2.60862172e-01 -1.05057731e-02 -6.91389799e-01\n","    1.03930748e+00 -5.32819852e-02  8.30078900e-01  5.38406789e-01\n","   -6.01827428e-02 -7.02574313e-01  2.08321378e-01 -5.81434309e-01\n","    3.61241639e-01 -1.24040820e-01  1.57193050e-01  1.25188440e-01\n","    1.01535246e-01 -3.00486207e-01  2.11111426e-01 -3.72906355e-03\n","   -1.03542995e+00 -2.05558985e-01 -3.48530382e-01 -7.05764517e-02\n","    6.86762333e-01 -4.68826562e-01  2.56001741e-01  4.53331202e-01\n","   -4.12840039e-01  2.04840496e-01  1.46800131e-01  4.22048390e-01\n","   -4.73021805e-01  2.86704034e-01 -2.34030351e-01  4.43962783e-01\n","    8.03722918e-01  5.94853282e-01  3.39484774e-02 -1.43713474e-01\n","   -6.31936789e-02  3.27885240e-01 -1.72328249e-01 -2.11650789e-01\n","   -1.05722807e-01 -3.30590814e-01 -2.68604666e-01  8.00766200e-02\n","   -2.63644643e-02  3.20827097e-01  6.89812779e-01 -7.90950134e-02\n","    3.31547529e-01  1.55034482e-01 -2.12033130e-02 -1.52814403e-01\n","    1.57635231e-02  7.40086846e-03 -1.42272532e-01 -1.67701244e-01\n","   -2.37606183e-01 -9.60177109e-02  1.23363018e+00  4.16106761e-01\n","    8.04355025e-01  1.53998807e-01 -1.75366506e-01 -3.09329659e-01\n","    7.20371723e-01 -4.77514230e-02 -1.87625200e-01 -4.92584221e-02\n","    5.36262989e-02 -1.07319780e-01  1.61969185e-01 -2.98262715e-01\n","   -2.07707714e-02 -2.22531855e-01 -9.41263214e-02  6.14095330e-01\n","    1.83075905e-01  3.58970985e-02 -3.09905976e-01 -1.89835235e-01\n","   -1.38945788e-01  2.11028457e-01  4.42767650e-01 -8.21059644e-02\n","    4.70530987e-02 -1.56839833e-01 -2.55679078e-02  1.11203933e+00\n","   -1.18499383e-01  1.08745009e-01  2.72996753e-01 -6.22277081e-01\n","    6.55273497e-02  7.35994354e-02 -3.80188376e-01  1.07568634e+00\n","    4.40968513e-01  1.06190455e+00  7.34862924e-01 -2.62038946e-01\n","   -4.68788564e-01  6.32562995e-01 -3.22465658e-01 -1.97138771e-01\n","   -2.81998128e-01  1.40144989e-01  3.41533154e-01  3.00975412e-01\n","   -6.97766066e-01  1.08797640e-01 -1.04424350e-01 -6.07960105e-01\n","   -7.52848685e-02  2.58946538e-01  1.66767269e-01  5.78080177e-01\n","   -4.79367435e-01  3.25656533e-01 -1.67153880e-01 -2.02475622e-01\n","   -1.34713441e-01 -2.36295506e-01 -4.88268621e-02 -3.46375629e-02\n","   -9.07131910e-01 -6.63629651e-01 -1.94161162e-01 -1.26584530e-01\n","    3.00141126e-01 -2.70453678e-03  5.83868444e-01 -6.83810234e-01\n","    6.22739159e-02 -2.55683422e-01 -2.98886877e-02 -7.85691217e-02\n","    4.43341911e-01 -4.38892931e-01  4.26329859e-02  7.68738911e-02\n","    3.41413528e-01 -4.89330202e-01 -4.44997162e-01 -4.61175703e-02\n","   -7.08571434e-01  1.09976971e+00 -1.75607607e-01  5.96689939e-01\n","   -5.58088064e-01  3.41733336e-01 -3.73069644e-01 -3.49682719e-01\n","    3.91268343e-01  2.87273586e-01 -3.74164432e-01 -4.94824499e-01\n","   -2.87091643e-01 -7.61020541e-01  3.10847014e-01 -3.27813119e-01\n","   -4.53797579e-01 -3.25889140e-01 -1.13156904e-02  4.22079756e-04\n","   -1.13053811e+00 -2.23255083e-01  1.28881767e-01  3.16387951e-01\n","    3.11034411e-01  4.66525376e-01 -1.59004256e-01 -5.17673790e-01\n","   -2.19223216e-01 -3.11245441e-01 -2.17956886e-01 -6.53032601e-01\n","    7.67981708e-01  4.29491371e-01 -7.00244844e-01  6.39716625e-01\n","    1.54739940e+00  3.31085473e-02 -5.72192948e-03  6.86710253e-02\n","   -7.47919679e-01 -9.01228130e-01 -8.83753061e-01  7.08807856e-02\n","   -1.47135496e-01  1.14804566e-01 -1.19330809e-01 -9.35699224e-01\n","   -1.75705269e-01 -2.37347186e-01  2.71038175e-01  3.73970091e-01]]]\n","Shape for entity 4: (1, 1, 768)\n","Sequence Length for entity 4: 300\n","Input IDs for entity 4: [[  101 24001 24001 24001 16392  2352 24001   103   103 22200  3067  5946\n","  13090  3067  5946 13090 13433 20058  2015 10424  7616 11320  4502  3067\n","   5946  3067  5946  1996 15451  2102  2160  9378  5490  2225 15854  6238\n","   2395 15854  6238   103   103 10557  2630   103   103  2072  8670 10322\n","   2080  8670 10322  2080  2899  2675  2380   102     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0]]\n","Decoded Tokens for entity 4: [CLS] saigon saigon saigon nyc village saigon [MASK] [MASK] shack minetta tavern minetta tavern pommes frites lupa minetta minetta the malt house wash sq west wooster street wooster [MASK] [MASK] ribbon blue [MASK] [MASK]i babbo babbo washington square parkn","Embeddings for entity 5: [[[-0.28091115 -0.28395256  0.17784207 ... -0.5726868   0.5278725\n","    0.23836625]\n","  [-0.70479566 -0.15419573 -0.29911867 ... -1.4865197   0.552794\n","    0.27017713]]]\n","Shape for entity 5: (1, 2, 768)\n","Sequence Length for entity 5: 300\n","Input IDs for entity 5: [[  101  2225  3077  2225  3077 11831  5951  5951 24547   103 13770  2064\n","  13770  7273 12846 17223  5365   103   103   103  6921  7207   103   103\n","    103 16271 16271  5735  7304   102     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0]]\n","Decoded Tokens for entity 5: [CLS] westville westville bombay franklin franklin wal [MASK]tina cantina thai cuisine heavens hollywood [MASK] [MASK] [MASK] madrid clinton [MASK] [MASK] [MASK] clifton clifton russell perun"]}]},{"cell_type":"code","source":["all_embeddings[3].shape"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"DJAymb8IsojE","executionInfo":{"status":"ok","timestamp":1722479415313,"user_tz":420,"elapsed":213,"user":{"displayName":"Jason Phillips","userId":"10136472498761089328"}},"outputId":"1d5465a8-cd25-4dbe-fee3-e6ad30184bd5"},"execution_count":18,"outputs":[{"output_type":"execute_result","data":{"text/plain":["(1, 1, 768)"]},"metadata":{},"execution_count":18}]},{"cell_type":"code","source":["from tqdm import tqdm\n","\n","# Function to process each entity and get embeddings\n","def process_entity(batch, model, device):\n","    input_ids = batch['masked_input'].to(device)\n","    attention_mask = batch['attention_mask'].to(device)\n","    position_list_x = batch['norm_lng_list'].to(device)\n","    position_list_y = batch['norm_lat_list'].to(device)\n","    sent_position_ids = batch['sent_position_ids'].to(device)\n","\n","    print(\"Input IDs before model:\", input_ids.cpu().numpy())\n","    print(\"Tokens before model:\", [tokenizer.convert_ids_to_tokens(ids) for ids in input_ids.cpu().numpy()])\n","\n","    with torch.no_grad():\n","        outputs = model(input_ids=input_ids,\n","                        attention_mask=attention_mask,\n","                        sent_position_ids=sent_position_ids,\n","                        position_list_x=position_list_x,\n","                        position_list_y=position_list_y)\n","\n","    # Extract embeddings\n","    embeddings = outputs.hidden_states[-1]\n","\n","    pivot_token_len = batch['pivot_token_len'].item()\n","    pivot_embeddings = embeddings[:, :pivot_token_len, :]\n","\n","    return pivot_embeddings.cpu().numpy(), input_ids.cpu().numpy()\n","\n","# Process the first 5 rows and print embeddings\n","for i, batch in enumerate(data_loader):\n","    if i >= 5:\n","        break\n","    embeddings, input_ids = process_entity(batch, model, device)\n","    sequence_length = input_ids.shape[1]\n","\n","    print(f\"Embeddings for entity {i+1}: {embeddings}\")\n","    print(f\"Shape for entity {i+1}: {embeddings.shape}\")\n","    print(f\"Sequence Length for entity {i+1}: {sequence_length}\")\n","    print(f\"Input IDs for entity {i+1}: {input_ids}\")\n","    print(f\"Decoded Tokens for entity {i+1}: {tokenizer.decode(input_ids[0], skip_special_tokens=False)}\")\n","\n","# Assuming the tokenizer is available and initialized\n","print(\"Tokenizer vocabulary size:\", tokenizer.vocab_size)\n"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"PohvF_YmySrY","executionInfo":{"status":"ok","timestamp":1720753914013,"user_tz":420,"elapsed":533,"user":{"displayName":"Jason Phillips","userId":"10136472498761089328"}},"outputId":"707e631d-ce03-49a8-ec86-3b6c72d618c4"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Input IDs before model: [[  101   103  8569   103 10556  8569   103 19461 15460 19461 15460 27166\n","   9818 12849  3406  7905  7014  3702   103 13226 13226 11458   103   103\n","    103  2899 13642  2899  3927 13173 13173  8529   103  8529  4886 19923\n","    103 19213  6187 27313  8953  1996  6842  2314  1996  6842   103   102\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0]]\n","Tokens before model: [['[CLS]', '[MASK]', '##bu', '[MASK]', 'ka', '##bu', '[MASK]', 'quiz', '##nos', 'quiz', '##nos', 'cn', '##bc', 'ko', '##to', 'arch', 'beth', '##wood', '[MASK]', 'vanessa', 'vanessa', 'herman', '[MASK]', '[MASK]', '[MASK]', 'washington', 'ave', 'washington', 'avenue', 'kara', 'kara', 'um', '[MASK]', 'um', '##ai', 'provence', '[MASK]', '##leaf', 'ca', '##vana', '##ugh', 'the', 'hudson', 'river', 'the', 'hudsonn","Embeddings for entity 1: [[[-0.38162905  0.0761855  -0.14762239 ... -0.30674034  0.12618616\n","    0.17785785]\n","  [-0.43710983  0.02649811 -0.9639295  ... -0.13571215  0.5441345\n","   -0.11535973]\n","  [-0.5984212  -0.38061848 -0.10242525 ... -0.10456782 -0.27874386\n","    0.5300068 ]]]\n","Shape for entity 1: (1, 3, 768)\n","Sequence Length for entity 1: 300\n","Input IDs for entity 1: [[  101   103  8569   103 10556  8569   103 19461 15460 19461 15460 27166\n","   9818 12849  3406  7905  7014  3702   103 13226 13226 11458   103   103\n","    103  2899 13642  2899  3927 13173 13173  8529   103  8529  4886 19923\n","    103 19213  6187 27313  8953  1996  6842  2314  1996  6842   103   102\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0]]\n","Decoded Tokens for entity 1: [CLS] [MASK]bu [MASK] kabu [MASK] quiznos quiznos cnbc koto arch bethwood [MASK] vanessa vanessa herman [MASK] [MASK] [MASK] washington ave washington avenue kara kara um [MASK] umai provence [MASK]leaf cavanaugh the hudson river the hudsonn","Input IDs before model: [[  103   103  2479  9763  2479  9763  2479 20829 18996  2050  6384  3077\n","  18641 18641  6222  6222 11265 15477  2395 15544 25970  4580  2675 15544\n","  25970  4580  5318 14132   103   103   103 19095 12674  3790 15544  5753\n","   7570  5092  7520  7570  5092  7520 10090  9857  2103  3006   102     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0]]\n","Tokens before model: [['[MASK]', '[MASK]', 'island', 'rhode', 'island', 'rhode', 'island', 'stacks', 'nap', '##a', 'morris', '##ville', 'penelope', 'penelope', 'baltimore', 'baltimore', 'ne', 'bourbon', 'street', 'ri', '##tten', '##house', 'square', 'ri', '##tten', '##house', 'brick', '##yard', '[MASK]', '[MASK]', '[MASK]', 'yorker', 'bergen', '##field', 'ri', '##tz', 'ho', '##bo', '##ken', 'ho', '##bo', '##ken', 'ruby', 'tuesday', 'city', 'marketn","Embeddings for entity 2: [[[-0.59880525 -0.609724    0.05645313 ... -0.6180918   0.60087276\n","    0.11218308]\n","  [-0.03553682 -0.31881958  0.1326758  ...  0.2680149   0.2834256\n","    0.05813964]]]\n","Shape for entity 2: (1, 2, 768)\n","Sequence Length for entity 2: 300\n","Input IDs for entity 2: [[  103   103  2479  9763  2479  9763  2479 20829 18996  2050  6384  3077\n","  18641 18641  6222  6222 11265 15477  2395 15544 25970  4580  2675 15544\n","  25970  4580  5318 14132   103   103   103 19095 12674  3790 15544  5753\n","   7570  5092  7520  7570  5092  7520 10090  9857  2103  3006   102     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0]]\n","Decoded Tokens for entity 2: [MASK] [MASK] island rhode island rhode island stacks napa morrisville penelope penelope baltimore baltimore ne bourbon street rittenhouse square rittenhouse brickyard [MASK] [MASK] [MASK] yorker bergenfield ritz hoboken hoboken ruby tuesday city marketn","Input IDs before model: [[  101 15845  2358 15845   103 15544 21827 11462   103 15845  7668  2032\n","   7911  3148 19668  2618  3870  2395  5292 19445  5292 19445  2310 12190\n","  18175  2310 12190 18175 21250 23528  2063  2395 22814 25676  5413  2080\n","   7668   103  3900  1996  2896  2264  2217  1996  2896  2264  2217  1996\n","   2264  2217  5318   103  1038   103  9102  2395   102     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0]]\n","Tokens before model: [['[CLS]', 'stanton', 'st', 'stanton', '[MASK]', 'ri', '##vington', 'freeman', '[MASK]', 'stanton', 'cafe', 'him', '##ala', '##ya', 'slain', '##te', 'elizabeth', 'street', 'ha', '##bana', 'ha', '##bana', 've', '##rl', '##aine', 've', '##rl', '##aine', 'cheers', 'broom', '##e', 'street', 'gotham', 'barrio', 'chin', '##o', 'cafe', '[MASK]', '##ja', 'the', 'lower', 'east', 'side', 'the', 'lower', 'east', 'side', 'the', 'east', 'side', 'brick', '[MASK]', 'b', '[MASK]', '##cker', 'streetn","Embeddings for entity 3: [[[-0.28443173 -0.21180399 -0.28999195 ... -0.45421076  0.22001196\n","    0.33645397]\n","  [-0.01877349  1.2883748  -0.7619651  ... -0.05091103  0.467824\n","    0.61257905]]]\n","Shape for entity 3: (1, 2, 768)\n","Sequence Length for entity 3: 300\n","Input IDs for entity 3: [[  101 15845  2358 15845   103 15544 21827 11462   103 15845  7668  2032\n","   7911  3148 19668  2618  3870  2395  5292 19445  5292 19445  2310 12190\n","  18175  2310 12190 18175 21250 23528  2063  2395 22814 25676  5413  2080\n","   7668   103  3900  1996  2896  2264  2217  1996  2896  2264  2217  1996\n","   2264  2217  5318   103  1038   103  9102  2395   102     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0]]\n","Decoded Tokens for entity 3: [CLS] stanton st stanton [MASK] rivington freeman [MASK] stanton cafe himalaya slainte elizabeth street habana habana verlaine verlaine cheers broome street gotham barrio chino cafe [MASK]ja the lower east side the lower east side the east side brick [MASK] b [MASK]cker streetn","Input IDs before model: [[  101 24001 24001 24001 16392  2352 24001 22200 24001 22200  3067  5946\n","  13090  3067   103 13090 13433 20058   103 10424  7616 11320  4502  3067\n","    103   103  5946  1996 15451  2102  2160  9378  5490  2225   103  6238\n","    103 15854  6238   103  2630 10557  2630 10557  2019  2072  8670 10322\n","   2080  8670 10322  2080  2899   103  2380   102     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0]]\n","Tokens before model: [['[CLS]', 'saigon', 'saigon', 'saigon', 'nyc', 'village', 'saigon', 'shack', 'saigon', 'shack', 'mine', '##tta', 'tavern', 'mine', '[MASK]', 'tavern', 'po', '##mme', '[MASK]', 'fr', '##ites', 'lu', '##pa', 'mine', '[MASK]', '[MASK]', '##tta', 'the', 'mal', '##t', 'house', 'wash', 'sq', 'west', '[MASK]', '##ster', '[MASK]', 'woo', '##ster', '[MASK]', 'blue', 'ribbon', 'blue', 'ribbon', 'an', '##i', 'ba', '##bb', '##o', 'ba', '##bb', '##o', 'washington', '[MASK]', 'parkn","Embeddings for entity 4: [[[-2.32127458e-01 -9.18930769e-02  2.19136432e-01  2.96555981e-02\n","    7.25247204e-01 -5.32851368e-02  9.94483903e-02  1.86052948e-01\n","   -3.08993846e-01  5.15926659e-01  1.20227590e-01  1.86715692e-01\n","   -4.41497207e-01  2.99828053e-01 -4.47100401e-01 -6.72887936e-02\n","    5.82771480e-01  3.92545797e-02  2.07321435e-01  4.50910300e-01\n","    7.55428791e-01  1.82055593e-01 -4.91308793e-02  7.83410549e-01\n","    2.92378783e-01 -7.31824040e-01 -4.05730903e-01 -3.61942984e-02\n","   -2.11488619e-01  4.78587121e-01 -3.06102931e-01  1.76212519e-01\n","    2.77128518e-01 -1.03135465e-03 -6.32077038e-01 -2.61811912e-01\n","    5.83295703e-01  3.67664188e-01  3.86036187e-02 -6.20798290e-01\n","   -6.43285513e-01  1.37913093e-01 -3.14311922e-01  2.48697117e-01\n","    2.68114090e-01  6.90909922e-02 -1.54573929e+00  4.49200958e-01\n","    4.70977217e-01 -9.16715384e-01  2.18298048e-01 -2.51081228e-01\n","   -4.84505326e-01 -1.13845646e-01  3.58800530e-01  6.11110628e-01\n","    1.73217356e-01 -3.94780159e-01  2.23506913e-01 -5.62340319e-01\n","    4.45486844e-01 -4.27358551e-03 -1.59257993e-01 -3.85173589e-01\n","   -2.34834135e-01  4.60041136e-01 -2.23349303e-01  4.07035410e-01\n","   -4.03714389e-01 -2.92997897e-01  4.04580027e-01  1.24050513e-01\n","    4.21054542e-01  3.23103309e-01  2.99802780e-01 -3.06068093e-01\n","   -4.99940693e-01  5.23851395e-01 -6.14880584e-02 -1.70648918e-01\n","   -3.58882666e-01 -5.97434223e-01  9.55330431e-02 -7.74227306e-02\n","    1.90355644e-01 -7.42989123e-01  7.27307498e-01  3.81276786e-01\n","    2.69985110e-01  4.85077113e-01 -4.89446551e-01 -1.74129173e-01\n","    1.64864421e-01  1.88954532e-01 -8.66878152e-01 -2.62359194e-02\n","    2.11557299e-01  3.36699694e-01  1.20699406e-01 -1.90134600e-01\n","   -7.11980939e-01 -9.89891171e-01 -6.68323115e-02  4.24929857e-01\n","    3.83059144e-01 -2.70120353e-01  7.97938943e-01  3.64006639e-01\n","    5.88814735e-01  8.31399322e-01  1.86285153e-01 -5.58987081e-01\n","    7.59784937e-01 -7.55049646e-01  4.36736383e-02 -4.94223148e-01\n","    5.55506766e-01 -3.09891701e-01 -8.62461254e-02 -9.29422617e-01\n","    3.18029225e-02 -4.23724502e-01 -1.12866074e-01  3.99434298e-01\n","    4.68819588e-01 -8.20364580e-02  7.40628615e-02 -3.52922916e-01\n","   -1.12787008e-01 -5.45717552e-02 -9.34484974e-02  7.09755123e-01\n","    6.92794397e-02  9.64301266e-03 -3.62071544e-01  6.27914727e-01\n","    5.74458614e-02  3.78374904e-01 -4.81822550e-01 -9.19076130e-02\n","    5.21590859e-02 -2.59090122e-02 -1.68424428e-01  1.27533942e-01\n","   -2.05665193e-02 -3.33064228e-01  6.63775504e-02  3.02597225e-01\n","    3.11101973e-01 -6.53546989e-01  1.16736487e-01 -1.07996181e-01\n","   -4.98000443e-01 -5.21951199e-01  3.30040336e-01 -1.23607919e-01\n","   -4.40430611e-01 -4.49122220e-01 -1.26617640e-01  4.19581115e-01\n","   -7.27578774e-02 -1.51742786e-01 -1.88369960e-01 -3.90673995e-01\n","   -5.15018106e-01  6.49584770e-01 -2.82208342e-02  7.18623996e-01\n","   -1.09455752e+00 -5.86448871e-02  3.56408358e-01  5.20350516e-01\n","    7.86418021e-01  2.20755070e-01  8.70168507e-02  5.83842099e-01\n","    5.70335388e-01 -2.07845330e-01 -1.31426767e-01 -1.88749209e-01\n","   -6.71206951e-01 -3.02740186e-01 -3.18212122e-01  2.78418630e-01\n","    5.98317981e-01 -6.82105064e-01  8.30883324e-01 -2.56897628e-01\n","   -5.52377515e-02 -4.53176975e-01 -1.92191601e-01 -2.52950215e-03\n","   -3.30457628e-01  1.77060395e-01  5.28001264e-02 -4.54444997e-02\n","   -8.57218429e-02  2.56544411e-01  1.68761551e-01 -1.15616053e-01\n","    3.50096226e-01  5.14472187e-01  3.86454999e-01  3.88289183e-01\n","   -5.27854741e-01  2.16889977e-01  3.85154448e-02 -1.18602421e-02\n","    3.76549035e-01  5.26992500e-01  8.63661692e-02 -5.73871695e-02\n","   -1.68572336e-01  5.35125911e-01 -2.60125488e-01 -3.40530649e-02\n","    7.53940523e-01 -7.67609596e-01  2.59812802e-01 -2.32273042e-01\n","   -7.08810747e-01  2.43516862e-01 -2.87564486e-01 -1.72863156e-01\n","    5.22883311e-02  1.28039455e+00  3.22267264e-01 -4.01855767e-01\n","    2.96064496e-01  8.73078108e-01  8.17452595e-02 -8.16755176e-01\n","    9.09669578e-01  9.65073612e-03 -8.06252122e-01  5.25929146e-02\n","    4.24942315e-01 -5.44106841e-01  2.14754149e-01 -2.82383353e-01\n","    4.41166848e-01  7.22160712e-02  6.96993291e-01 -5.06527781e-01\n","   -1.05047755e-01 -9.04193938e-01  3.47809084e-02  3.32321137e-01\n","   -8.49012658e-02 -2.75587529e-01  4.31346714e-01 -2.09260434e-01\n","    7.63776720e-01  7.64579698e-02 -4.60940957e-01 -5.22871315e-01\n","   -1.81263238e-01  1.88924998e-01 -4.57495630e-01 -9.82526422e-01\n","    3.51930887e-01 -3.33721638e-01  2.72491693e-01  1.04421666e-02\n","   -2.89329886e-01  3.17154467e-01 -1.15669954e+00 -4.13663983e-02\n","   -1.39939159e-01 -1.27101049e-01 -1.10045694e-01  2.95935601e-01\n","    5.64581566e-02  9.26031291e-01  3.21645319e-01 -2.71831214e-01\n","    3.89375418e-01  2.83110559e-01  8.48079443e-01 -9.14455354e-01\n","   -3.63323480e-01  6.81037128e-01 -6.06349409e-02 -7.10844278e-01\n","   -3.51043314e-01 -1.55432820e-01 -5.91060102e-01  5.02832353e-01\n","    4.64336872e-01  1.67976752e-01  3.44814032e-01 -2.06974909e-01\n","   -1.65562272e-01  1.00891590e-01  1.16840959e-01  1.21708706e-01\n","    5.10155439e-01  5.08186758e-01 -1.57256052e-01 -8.64657044e-01\n","   -1.45590510e-02 -4.41058457e-01  2.06434689e-02 -1.21783063e-01\n","    2.36847579e-01  1.83577538e-01  1.02910586e-01 -7.35641897e-01\n","   -5.24626541e+00 -1.47581592e-01  1.33208379e-01 -1.23400605e+00\n","    7.92669430e-02 -6.18095458e-01  2.70189866e-02 -1.78242072e-01\n","   -2.05801502e-01  3.81381214e-01  3.94054130e-02 -3.71136785e-01\n","    7.45533824e-01 -1.52756542e-01  2.99570169e-02 -2.98850060e-01\n","    2.00396121e-01  1.67178556e-01  9.65733230e-02  4.54347610e-01\n","   -2.31699228e-01 -2.61062942e-02  2.62523830e-01 -3.23136538e-01\n","    2.25508109e-01  4.26325083e-01 -4.59381372e-01 -4.60122973e-01\n","    5.38587291e-03 -8.37337732e-01  4.20886248e-01 -5.26366591e-01\n","    4.13217574e-01  1.67089373e-01  5.49848899e-02 -1.30149126e-01\n","   -2.22641930e-01 -9.42414880e-01 -8.10550973e-02 -1.20146990e-01\n","    2.22168062e-02 -2.90578604e-02 -4.22715157e-01 -5.56364208e-02\n","    2.23644897e-01 -2.15020459e-02  2.28317901e-01  2.65085727e-01\n","   -7.28800774e-01 -1.21461749e-01  1.58559635e-01  3.25366378e-01\n","   -2.05111682e-01  4.83622640e-01  4.82733339e-01 -1.86771348e-01\n","    6.58763587e-01  1.63742796e-01 -1.21733956e-01 -2.87434936e-01\n","    8.75151873e-01 -1.18269145e-01 -5.41791439e-01  6.04858398e-01\n","    3.21258634e-01 -3.88020962e-01 -5.54738522e-01  3.71391892e-01\n","   -4.45316166e-01  6.76476583e-02 -3.50067645e-01  4.24743667e-02\n","    2.62260251e-02 -7.52956212e-01  8.46208632e-02 -6.85882270e-01\n","   -3.05912904e-02 -2.12879717e-01  4.36664402e-01 -1.27411503e-02\n","   -3.39583725e-01 -1.87498271e-01  5.24547100e-01  1.17620982e-01\n","   -4.23407406e-01 -4.42694753e-01 -2.74582684e-01  4.17836040e-01\n","   -3.64087284e-01 -3.49417746e-01 -1.17424928e-01  6.38979673e-01\n","    3.58592451e-01  5.06961681e-02  4.13037360e-01  3.78887624e-01\n","   -5.66198863e-02 -1.15354352e-01  3.89262177e-02  5.96774042e-01\n","    5.01773298e-01  6.14198409e-02  6.02574348e-01 -1.09670341e+00\n","    1.19834162e-01  4.81998891e-01 -4.64814574e-01  7.39254355e-01\n","    4.17518109e-01  1.12306356e-01 -4.77655143e-01 -2.21514344e-01\n","   -1.89910144e-01 -6.52637064e-01  9.20041502e-02  1.97890297e-01\n","    3.78159016e-01  4.19413298e-01  3.98006529e-01  1.19805522e-01\n","    1.82827637e-01  2.08881255e-02  5.97022414e-01  1.67657286e-01\n","   -2.99780875e-01 -2.65357167e-01 -4.31526393e-01 -2.49820575e-01\n","    6.33335114e-02 -3.47334266e-01  2.37650007e-01 -1.20795839e-01\n","    2.14175791e-01  1.73247337e-01  5.16895235e-01 -4.82260324e-02\n","   -2.79049516e-01 -1.27826440e+00  1.93881035e-01  1.11989148e-01\n","   -4.67525035e-01  1.20040372e-01  3.80351424e-01 -5.78683794e-01\n","   -5.13348401e-01 -1.45789474e-01  3.03920537e-01  9.09982771e-02\n","    1.46969870e-01  1.14659496e-01 -1.38986647e+00 -4.45276529e-01\n","   -3.18292856e-01 -3.43734324e-01  1.95033327e-01  5.16402006e-01\n","    3.31339091e-01 -2.53450990e-01 -2.80474722e-01 -5.61761200e-01\n","    5.97164750e-01 -2.39306733e-01 -4.17944103e-01  2.64335930e-01\n","   -8.24504912e-01  3.25888604e-01  2.93572217e-01  3.22955221e-01\n","   -3.03191513e-01  2.45280206e-01 -1.26565292e-01 -2.75402725e-01\n","   -3.87265712e-01 -4.56490844e-01 -1.16372550e+00  6.54272810e-02\n","   -2.41214558e-01 -5.31952202e-01  4.11137015e-01  3.13046128e-01\n","    2.32168809e-01 -1.26493454e-01 -3.97229671e-01  1.23795293e-01\n","   -2.18757018e-01  6.02300465e-01  4.65756148e-01 -4.54503708e-02\n","    4.68754351e-01  2.49049067e-01 -6.19977526e-03 -9.53918975e-03\n","    5.38234949e-01  4.46773529e-01 -8.64204466e-01 -5.95894098e-01\n","    3.01001161e-01  6.81683242e-01  7.83131242e-01  3.93483669e-01\n","   -1.88484564e-01 -6.66944861e-01  6.39800727e-02  4.21733141e-01\n","    3.75889868e-01  1.12136650e+00 -6.24863580e-02  5.32989949e-02\n","    5.77582568e-02  4.49248701e-02  3.45943153e-01 -4.22091454e-01\n","   -2.64078349e-01  3.48746508e-01  2.67953962e-01 -6.59206271e-01\n","   -1.32132098e-01  6.68191293e-04 -7.88972735e-01  4.61224258e-01\n","   -7.54186630e-01 -2.69988179e-01  2.35952027e-02  2.18563303e-01\n","   -5.98371267e-01 -5.82349062e-01 -1.31415695e-01 -3.01632077e-01\n","    7.62087703e-02 -6.15611196e-01  1.38914615e-01  5.00038028e-01\n","   -9.59163457e-02  4.39847231e-01 -3.89594436e-01  1.49660861e+00\n","   -2.77075171e-02 -5.28882623e-01  2.69330263e-01  2.86625177e-01\n","    2.52468556e-01  6.19645119e-01  5.69440387e-02 -1.04844320e+00\n","   -1.93931609e-01 -9.11211688e-03 -2.90297180e-01  1.93473130e-01\n","    7.73154378e-01 -1.38327748e-01  2.92048275e-01  3.49459462e-02\n","   -9.24798325e-02  6.59925938e-02  7.54584372e-01  8.22292641e-02\n","   -8.22204769e-01 -3.25041771e-01 -8.61675739e-02  3.47461626e-02\n","   -6.12557709e-01  1.34840712e-01  2.79068239e-02 -3.59331250e-01\n","    7.77637720e-01 -1.40012190e-01  6.35521293e-01  8.11932266e-01\n","    1.31203914e-02 -8.88220429e-01  1.53721929e-01 -4.90582585e-01\n","    5.28307259e-01 -1.70463726e-01 -1.31836794e-02  2.04324186e-01\n","    2.01147139e-01 -6.57060266e-01  1.16611801e-01 -4.01594400e-01\n","   -9.24448311e-01 -4.08888936e-01 -4.30269688e-01 -1.84269235e-01\n","    1.00098908e+00 -2.12632939e-01  1.13407299e-01  2.37189204e-01\n","   -2.24308521e-01  4.03059542e-01  2.42508426e-02  3.67209017e-01\n","   -5.16948581e-01  3.10835242e-01 -3.00943404e-01  4.39933121e-01\n","    1.01037085e+00  5.10283887e-01  1.59436405e-01 -6.57830834e-02\n","    1.89033464e-01  5.90162218e-01 -2.24603012e-01  4.03840169e-02\n","   -5.36061525e-02 -1.35151312e-01 -2.63676643e-01 -1.14582680e-01\n","   -2.54939473e-03  3.76228899e-01  7.56380320e-01 -4.80822846e-02\n","    4.65031087e-01  3.13537605e-02 -3.45779687e-01  6.92100897e-02\n","   -3.29069793e-02 -7.03116134e-03  1.11795226e-02 -2.13431776e-01\n","    2.46553612e-03 -1.07456803e-01  1.42936516e+00  3.84415299e-01\n","    1.08065140e+00  4.10582781e-01 -3.28331798e-01 -6.76242232e-01\n","    8.84066582e-01  1.31225780e-01 -3.55111003e-01  5.37703894e-02\n","    2.55652785e-01 -1.24962129e-01  1.00433066e-01 -9.07299668e-02\n","    9.93123800e-02 -2.25957721e-01  1.48931697e-01  4.36302334e-01\n","    3.61645743e-02  1.53136238e-01 -4.26383138e-01 -1.04319178e-01\n","   -5.80002442e-02  1.96728289e-01  3.55804950e-01 -4.64399680e-02\n","   -1.13301806e-01 -1.69528246e-01 -1.93541527e-01  9.86101985e-01\n","   -2.26272881e-01 -1.32977039e-01  1.57398209e-01 -7.07822561e-01\n","    4.77300538e-03 -1.60121769e-02 -4.43615735e-01  8.17510724e-01\n","    4.26374465e-01  8.53859484e-01  8.92099798e-01 -2.20892414e-01\n","   -2.55746245e-01  5.04031718e-01 -3.98926407e-01  2.87004024e-01\n","   -3.01377892e-01 -1.33867025e-01  1.74790367e-01  2.01532006e-01\n","   -9.13695276e-01  4.98295575e-01 -5.84768541e-02 -3.61467630e-01\n","   -2.88227826e-01  9.36566517e-02  7.73162171e-02  3.84741366e-01\n","   -4.05206710e-01  1.64408937e-01 -2.27388948e-01 -2.81376272e-01\n","   -9.52850804e-02 -4.09162611e-01  4.07696515e-02  8.54342133e-02\n","   -1.16895831e+00 -8.20526838e-01 -3.28012228e-01 -2.17720807e-01\n","    5.53152502e-01 -1.64649427e-01  3.71408701e-01 -5.93737066e-01\n","   -7.25642964e-02 -2.13964388e-01  1.36529446e-01  7.43945688e-02\n","    4.36803877e-01 -2.30315685e-01  2.99450606e-01  2.09624469e-01\n","    3.97413343e-01 -5.87478042e-01 -2.04276100e-01 -2.77292758e-01\n","   -5.40722728e-01  1.02845681e+00 -1.03554748e-01  6.24136627e-01\n","   -5.85075855e-01  2.21690238e-01 -7.13158309e-01 -1.82012260e-01\n","    3.05482507e-01  2.14402586e-01 -2.53102005e-01 -3.56745750e-01\n","   -2.13511437e-01 -9.00834739e-01  4.02278066e-01 -4.62024748e-01\n","   -2.51226604e-01 -2.31557265e-01  1.21537335e-01  1.36342585e-01\n","   -9.52898085e-01 -1.06946789e-01  4.38074768e-01  2.97690511e-01\n","    9.93487388e-02  5.76441586e-01 -2.17864022e-01 -3.39290261e-01\n","   -1.04374520e-01 -5.08866191e-01 -2.68830091e-01 -5.44010222e-01\n","    8.76054049e-01  5.63209474e-01 -5.08262277e-01  5.43685138e-01\n","    1.42189920e+00  3.79516244e-01  3.66262764e-01  2.73080152e-02\n","   -9.20679927e-01 -6.34345829e-01 -7.93236494e-01 -5.89900613e-02\n","   -2.95962870e-01 -4.21347879e-02  1.59806550e-01 -6.78540111e-01\n","   -9.10421908e-02 -1.48749292e-01  2.27457732e-01  1.55423284e-01]]]\n","Shape for entity 4: (1, 1, 768)\n","Sequence Length for entity 4: 300\n","Input IDs for entity 4: [[  101 24001 24001 24001 16392  2352 24001 22200 24001 22200  3067  5946\n","  13090  3067   103 13090 13433 20058   103 10424  7616 11320  4502  3067\n","    103   103  5946  1996 15451  2102  2160  9378  5490  2225   103  6238\n","    103 15854  6238   103  2630 10557  2630 10557  2019  2072  8670 10322\n","   2080  8670 10322  2080  2899   103  2380   102     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0]]\n","Decoded Tokens for entity 4: [CLS] saigon saigon saigon nyc village saigon shack saigon shack minetta tavern mine [MASK] tavern pomme [MASK] frites lupa mine [MASK] [MASK]tta the malt house wash sq west [MASK]ster [MASK] wooster [MASK] blue ribbon blue ribbon ani babbo babbo washington [MASK] parkn","Input IDs before model: [[  101  2225   103  2225  3077 11831   103  5951 24547  2064 13770  2064\n","  13770  7273 12846   103  5365  5365 11942  6921  6921  7207  6396  2063\n","    103   103 16271  5735  7304   102     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0]]\n","Tokens before model: [['[CLS]', 'west', '[MASK]', 'west', '##ville', 'bombay', '[MASK]', 'franklin', 'wal', 'can', '##tina', 'can', '##tina', 'thai', 'cuisine', '[MASK]', 'hollywood', 'hollywood', 'cane', 'madrid', 'madrid', 'clinton', 'ny', '##e', '[MASK]', '[MASK]', 'clifton', 'russell', 'perun"]},{"output_type":"stream","name":"stderr","text":["/usr/local/lib/python3.10/dist-packages/transformers/modeling_utils.py:1052: FutureWarning: The `device` argument is deprecated and will be removed in v5 of Transformers.\n","  warnings.warn(\n"]},{"output_type":"stream","name":"stdout","text":["Embeddings for entity 5: [[[ 0.04580183 -0.3487388   0.07472646 ... -0.504709    0.51350564\n","    0.5396417 ]\n","  [-0.63502467  0.133689   -0.26431495 ... -1.0011739   0.59217864\n","    0.14183448]]]\n","Shape for entity 5: (1, 2, 768)\n","Sequence Length for entity 5: 300\n","Input IDs for entity 5: [[  101  2225   103  2225  3077 11831   103  5951 24547  2064 13770  2064\n","  13770  7273 12846   103  5365  5365 11942  6921  6921  7207  6396  2063\n","    103   103 16271  5735  7304   102     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0\n","      0     0     0     0     0     0     0     0     0     0     0     0]]\n","Decoded Tokens for entity 5: [CLS] west [MASK] westville bombay [MASK] franklin wal cantina cantina thai cuisine [MASK] hollywood hollywood cane madrid madrid clinton nye [MASK] [MASK] clifton russell perun","Tokenizer vocabulary size: 30522\n"]}]},{"cell_type":"code","source":["embeddings.shape"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"dk5QikqoIz0l","executionInfo":{"status":"ok","timestamp":1718998349345,"user_tz":420,"elapsed":194,"user":{"displayName":"Jason Phillips","userId":"10136472498761089328"}},"outputId":"2a9471c2-7071-4ffe-d2bb-25936b92e7fe"},"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["(1, 2, 768)"]},"metadata":{},"execution_count":28}]},{"cell_type":"code","source":["#all_embeddings[0].shape"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"OqFj7sSXLHaT","executionInfo":{"status":"ok","timestamp":1718762536551,"user_tz":420,"elapsed":364,"user":{"displayName":"Jason Phillips","userId":"10136472498761089328"}},"outputId":"433882f7-7732-453a-d062-ce96dfa68206"},"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["(1, 3, 768)"]},"metadata":{},"execution_count":27}]},{"cell_type":"code","source":[],"metadata":{"id":"Iet7FCF6OxW8"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# Testing the model with various inputs\n"],"metadata":{"id":"0CBELAbxFqpv"}},{"cell_type":"code","source":["# Prepare a dummy input for testing\n","input_text = \"This is a test sentence.\"\n","encoded_input = tokenizer(input_text, return_tensors='pt', padding='max_length', max_length=300, truncation=True)\n","input_ids = encoded_input['input_ids'].to(device)\n","attention_mask = encoded_input['attention_mask'].to(device)\n","\n","# Create dummy position lists\n","batch_size = input_ids.shape[0]\n","max_len = input_ids.shape[1]\n","position_list_x = torch.zeros((batch_size, max_len), dtype=torch.float32).to(device)\n","position_list_y = torch.zeros((batch_size, max_len), dtype=torch.float32).to(device)\n","sent_position_ids = torch.arange(max_len).unsqueeze(0).expand(batch_size, -1).to(device)\n","\n","# Forward pass\n","with torch.no_grad():\n","    outputs = model(input_ids=input_ids,\n","                    attention_mask=attention_mask,\n","                    sent_position_ids=sent_position_ids,\n","                    position_list_x=position_list_x,\n","                    position_list_y=position_list_y)\n","\n","# Check the output\n","if config.output_hidden_states:\n","    hidden_states = outputs.hidden_states\n","    print(f\"Number of layers (including embedding layer): {len(hidden_states)}\")\n","    print(f\"Shape of hidden states for each layer: {[state.shape for state in hidden_states]}\")\n","else:\n","    print(\"Hidden states not outputted\")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"JamZgfEP-6K8","executionInfo":{"status":"ok","timestamp":1718993333983,"user_tz":420,"elapsed":464,"user":{"displayName":"Jason Phillips","userId":"10136472498761089328"}},"outputId":"1a78ac8d-e202-4f6f-bdf0-065ccaeec313"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Number of layers (including embedding layer): 13\n","Shape of hidden states for each layer: [torch.Size([1, 300, 768]), torch.Size([1, 300, 768]), torch.Size([1, 300, 768]), torch.Size([1, 300, 768]), torch.Size([1, 300, 768]), torch.Size([1, 300, 768]), torch.Size([1, 300, 768]), torch.Size([1, 300, 768]), torch.Size([1, 300, 768]), torch.Size([1, 300, 768]), torch.Size([1, 300, 768]), torch.Size([1, 300, 768]), torch.Size([1, 300, 768])]\n"]}]},{"cell_type":"code","source":["# Access the hidden states\n","hidden_states = outputs.hidden_states\n","\n","# Get the hidden states of the last layer:\n","last_hidden_state = hidden_states[-1]  # Shape: [1, 300, 768]\n","\n","# Print the tokens and their corresponding embeddings for the first sequence in the batch\n","input_ids_batch = input_ids[0].cpu().numpy()  # Convert to numpy for easy indexing\n","\n","for token_index in range(len(input_ids_batch)):\n","    token_id = input_ids_batch[token_index]\n","    token_str = tokenizer.decode(token_id)\n","    token_embedding = last_hidden_state[0, token_index, :]  # Shape: [768]\n","\n","    print(f\"Token: {token_str} | Token ID: {token_id} | Embedding shape: {token_embedding.shape}\")"],"metadata":{"id":"BUVgcSCA_msw"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["import torch\n","from transformers import BertTokenizer\n","\n","# Initialize the tokenizer\n","tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n","\n","# Example input text\n","pivot_name = \"Central Park\"\n","pivot_pos = [40.785091, -73.968285]  # Latitude, Longitude\n","neighbor_names = [\"Museum of Natural History\", \"Metropolitan Museum of Art\"]\n","neighbor_positions = [[40.781324, -73.973988], [40.779437, -73.963244]]\n","\n","# Tokenize pivot and neighbors\n","pivot_tokens = tokenizer.tokenize(pivot_name)\n","pivot_token_ids = tokenizer.convert_tokens_to_ids(pivot_tokens)\n","neighbor_token_ids = [tokenizer.convert_tokens_to_ids(tokenizer.tokenize(name)) for name in neighbor_names]\n","\n","# Flatten neighbor token IDs\n","neighbor_token_ids_flat = [token_id for sublist in neighbor_token_ids for token_id in sublist]\n","\n","# Create the full token ID list with special tokens\n","input_tokens = [tokenizer.cls_token_id] + pivot_token_ids + [tokenizer.sep_token_id] + neighbor_token_ids_flat + [tokenizer.sep_token_id]\n","input_ids = tokenizer.convert_tokens_to_ids(input_tokens)\n","\n","# Pad the input IDs to the max length\n","max_token_len = 300\n","padding_length = max_token_len - len(input_ids)\n","input_ids += [tokenizer.pad_token_id] * padding_length\n","\n","# Create attention mask\n","attention_mask = [1] * len(input_tokens) + [0] * padding_length\n","\n","# Create sentence position IDs\n","sent_position_ids = list(range(max_token_len))\n","\n","# Normalize positions\n","distance_norm_factor = 0.0001\n","norm_lng_list = [(pos[1] - pivot_pos[1]) / distance_norm_factor for pos in [pivot_pos] + neighbor_positions]\n","norm_lat_list = [(pos[0] - pivot_pos[0]) / distance_norm_factor for pos in [pivot_pos] + neighbor_positions]\n","\n","# Pad the position lists to the max length\n","norm_lng_list += [0.0] * (max_token_len - len(norm_lng_list))\n","norm_lat_list += [0.0] * (max_token_len - len(norm_lat_list))\n","\n","# Create the batch dictionary\n","batch = {\n","    'masked_input': torch.tensor([input_ids], dtype=torch.long),\n","    'attention_mask': torch.tensor([attention_mask], dtype=torch.long),\n","    'sent_position_ids': torch.tensor([sent_position_ids], dtype=torch.long),\n","    'norm_lng_list': torch.tensor([norm_lng_list], dtype=torch.float),\n","    'norm_lat_list': torch.tensor([norm_lat_list], dtype=torch.float),\n","    'pivot_token_len': torch.tensor([len(pivot_token_ids)], dtype=torch.long)\n","}\n","\n","# Display the example batch\n","print(batch)"],"metadata":{"id":"g6NTtY4RDvfA"},"execution_count":null,"outputs":[]}]}