Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -6,6 +6,8 @@ from transformers.models.bert.modeling_bert import BertForMaskedLM
|
|
6 |
|
7 |
from models.spabert.models.spatial_bert_model import SpatialBertConfig, SpatialBertForMaskedLM, SpatialBertModel
|
8 |
from models.spabert.utils.common_utils import load_spatial_bert_pretrained_weights
|
|
|
|
|
9 |
|
10 |
from PIL import Image
|
11 |
|
@@ -35,6 +37,19 @@ spaBERT_model.load_state_dict(pre_trained_model, strict=False)
|
|
35 |
spaBERT_model.to(device)
|
36 |
spaBERT_model.eval()
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
#Get BERT Embedding for review
|
40 |
def get_bert_embedding(review_text):
|
|
|
6 |
|
7 |
from models.spabert.models.spatial_bert_model import SpatialBertConfig, SpatialBertForMaskedLM, SpatialBertModel
|
8 |
from models.spabert.utils.common_utils import load_spatial_bert_pretrained_weights
|
9 |
+
from models.spabert.datasets.osm_sample_loader import PbfMapDataset
|
10 |
+
from torch.utils.data import DataLoader
|
11 |
|
12 |
from PIL import Image
|
13 |
|
|
|
37 |
spaBERT_model.to(device)
|
38 |
spaBERT_model.eval()
|
39 |
|
40 |
+
# Load data using SpatialDataset
|
41 |
+
spatialDataset = PbfMapDataset(data_file_path = data_file_path,
|
42 |
+
tokenizer = tokenizer,
|
43 |
+
#max_token_len = 256, #Originally 300
|
44 |
+
max_token_len = max_seq_length, #Originally 300
|
45 |
+
distance_norm_factor = 0.0001,
|
46 |
+
spatial_dist_fill = 20,
|
47 |
+
with_type = False,
|
48 |
+
sep_between_neighbors = True, #Initially false, play around with this potentially?
|
49 |
+
label_encoder = None, #Initially None, potentially change this because we do have real/fake reviews.
|
50 |
+
mode = None) #If set to None it will use the full dataset for mlm
|
51 |
+
|
52 |
+
data_loader = DataLoader(spatialDataset, batch_size=1, num_workers=0, shuffle=False, pin_memory=False, drop_last=False) #issue needs to be fixed with num_workers not stopping after finished
|
53 |
|
54 |
#Get BERT Embedding for review
|
55 |
def get_bert_embedding(review_text):
|