JasonTPhillipsJr commited on
Commit
3d8cd48
·
verified ·
1 Parent(s): e72b522

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -0
app.py CHANGED
@@ -6,6 +6,8 @@ from transformers.models.bert.modeling_bert import BertForMaskedLM
6
 
7
  from models.spabert.models.spatial_bert_model import SpatialBertConfig, SpatialBertForMaskedLM, SpatialBertModel
8
  from models.spabert.utils.common_utils import load_spatial_bert_pretrained_weights
 
 
9
 
10
  from PIL import Image
11
 
@@ -35,6 +37,19 @@ spaBERT_model.load_state_dict(pre_trained_model, strict=False)
35
  spaBERT_model.to(device)
36
  spaBERT_model.eval()
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  #Get BERT Embedding for review
40
  def get_bert_embedding(review_text):
 
6
 
7
  from models.spabert.models.spatial_bert_model import SpatialBertConfig, SpatialBertForMaskedLM, SpatialBertModel
8
  from models.spabert.utils.common_utils import load_spatial_bert_pretrained_weights
9
+ from models.spabert.datasets.osm_sample_loader import PbfMapDataset
10
+ from torch.utils.data import DataLoader
11
 
12
  from PIL import Image
13
 
 
37
  spaBERT_model.to(device)
38
  spaBERT_model.eval()
39
 
40
+ # Load data using SpatialDataset
41
+ spatialDataset = PbfMapDataset(data_file_path = data_file_path,
42
+ tokenizer = tokenizer,
43
+ #max_token_len = 256, #Originally 300
44
+ max_token_len = max_seq_length, #Originally 300
45
+ distance_norm_factor = 0.0001,
46
+ spatial_dist_fill = 20,
47
+ with_type = False,
48
+ sep_between_neighbors = True, #Initially false, play around with this potentially?
49
+ label_encoder = None, #Initially None, potentially change this because we do have real/fake reviews.
50
+ mode = None) #If set to None it will use the full dataset for mlm
51
+
52
+ data_loader = DataLoader(spatialDataset, batch_size=1, num_workers=0, shuffle=False, pin_memory=False, drop_last=False) #issue needs to be fixed with num_workers not stopping after finished
53
 
54
  #Get BERT Embedding for review
55
  def get_bert_embedding(review_text):