File size: 10,423 Bytes
ac736ed
5c91758
bc50d7d
3d81019
ebf50a4
 
b7111b8
 
3d8cd48
 
83e90d7
ac736ed
d914cbe
 
60a8335
d914cbe
1aa7dda
3d81019
60a8335
dc9ff0b
3d81019
 
d914cbe
3d81019
 
60a8335
dc9ff0b
fa29176
d914cbe
3be15aa
fbce538
 
 
3d81019
3be15aa
d914cbe
 
80744c0
 
 
3d81019
60a8335
a74fa0d
3d8cd48
5914cea
18634d6
5914cea
3d8cd48
 
 
bf52bfd
 
18634d6
3d8cd48
fa29176
3d81019
4c18d69
 
 
 
 
 
60a8335
a74fa0d
 
 
 
 
 
 
 
 
 
169e7aa
a74fa0d
 
 
 
 
 
 
 
 
 
857dba3
a74fa0d
 
0cea6d5
a74fa0d
 
857dba3
a74fa0d
857dba3
fa29176
0cea6d5
 
a74fa0d
4c18d69
a74fa0d
 
f82dac8
1091141
f82dac8
 
 
 
 
 
 
 
857dba3
 
3d81019
60a8335
cbcad17
4c18d69
 
b8aeb00
d2568a6
60a8335
b8aeb00
d2568a6
b8aeb00
4c18d69
60a8335
cbcad17
4c18d69
 
 
 
 
 
 
 
b8aeb00
cb96047
b9b16e5
b14ffe5
 
bec805a
b6c17e6
60a8335
 
4c18d69
d2568a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c18d69
 
 
 
 
 
 
 
 
 
 
 
18634d6
 
5c91758
 
83e90d7
9822204
3577a57
9822204
 
 
 
3577a57
 
d914cbe
3577a57
9822204
 
3577a57
4c18d69
 
b23060e
 
 
 
 
 
83e90d7
5c91758
 
b23060e
e72b522
bec805a
e72b522
b9b16e5
bec805a
bf52bfd
5f6c4ef
6ed7a92
 
18634d6
5c91758
b23060e
e72b522
b4303dc
b23060e
b4303dc
 
 
 
 
 
 
 
83e90d7
3577a57
 
5c91758
b23060e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
import streamlit as st
import spacy
import torch
from transformers import BertTokenizer, BertModel
from transformers.models.bert.modeling_bert import BertForMaskedLM

from models.spabert.models.spatial_bert_model import SpatialBertConfig, SpatialBertForMaskedLM, SpatialBertModel
from models.spabert.utils.common_utils import load_spatial_bert_pretrained_weights
from models.spabert.datasets.osm_sample_loader import PbfMapDataset
from torch.utils.data import DataLoader
from PIL import Image

device = torch.device('cpu')


#Spacy Initialization Section
nlp = spacy.load("./models/en_core_web_sm")


#BERT Initialization Section
bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
bert_model = BertModel.from_pretrained("bert-base-uncased")
bert_model.to(device)
bert_model.eval()


#SpaBERT Initialization Section
data_file_path = 'models/spabert/datasets/SpaBERTPivots.json'    #Sample file otherwise this model will take too long on CPU.
pretrained_model_path = 'models/spabert/datasets/fine-spabert-base-uncased-finetuned-osm-mn.pth'

config = SpatialBertConfig()
config.output_hidden_states = True
spaBERT_model = SpatialBertForMaskedLM(config)

pre_trained_model = torch.load(pretrained_model_path, map_location=torch.device('cpu'))
spaBERT_model.load_state_dict(bert_model.state_dict(), strict = False)
spaBERT_model.load_state_dict(pre_trained_model, strict=False)

spaBERT_model.to(device)
spaBERT_model.eval()


#Load data using SpatialDataset
spatialDataset = PbfMapDataset(data_file_path = data_file_path,
                                        tokenizer = bert_tokenizer,
                                        max_token_len = 256,                          #Originally 300
                                        #max_token_len = max_seq_length,              #Originally 300
                                        distance_norm_factor = 0.0001,
                                        spatial_dist_fill = 20,
                                        with_type = False,
                                        sep_between_neighbors = True,    
                                        label_encoder = None,             
                                        mode = None)                                  #If set to None it will use the full dataset for mlm

data_loader = DataLoader(spatialDataset, batch_size=1, num_workers=0, shuffle=False, pin_memory=False, drop_last=False) 

# Create a dictionary to map entity names to indices
entity_index_dict = {entity['pivot_name']: i for i, entity in enumerate(spatialDataset)}

# Ensure names are stored in lowercase for case-insensitive matching
entity_index_dict = {name.lower(): index for name, index in entity_index_dict.items()}


#Pre-aquire the SpaBERT embeddings for all geo-entities within our dataset
def process_entity(batch, model, device):
    input_ids = batch['masked_input'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    position_list_x = batch['norm_lng_list'].to(device)
    position_list_y = batch['norm_lat_list'].to(device)
    sent_position_ids = batch['sent_position_ids'].to(device)
    pseudo_sentence = batch['pseudo_sentence'].to(device)

    # Convert tensor to list of token IDs, and decode them into a readable sentence
    pseudo_sentence_decoded = bert_tokenizer.decode(pseudo_sentence[0].tolist(), skip_special_tokens=False)

    with torch.no_grad():
        outputs = spaBERT_model(#input_ids=input_ids,
                        input_ids=pseudo_sentence,
                        attention_mask=attention_mask,
                        sent_position_ids=sent_position_ids,
                        position_list_x=position_list_x,
                        position_list_y=position_list_y)
                        #NOTE: we are ommitting the pseudo_sentence here. Verify that this is correct

    spaBERT_embedding = outputs.hidden_states[-1].to(device)

    # Extract the [CLS] token embedding (first token)
    spaBERT_embedding = spaBERT_embedding[:, 0, :].detach()  # [batch_size, hidden_size]

    #return pivot_embeddings.cpu().numpy(), input_ids.cpu().numpy()
    return spaBERT_embedding, input_ids

spaBERT_embeddings = []
for batch in (data_loader):
    spaBERT_embedding, input_ids = process_entity(batch, spaBERT_model, device)
    spaBERT_embeddings.append(spaBERT_embedding)

embedding_cache = {}


#Get BERT Embedding for review
def get_bert_embedding(review_text):
    #tokenize review
    inputs = bert_tokenizer(review_text, return_tensors='pt', padding=True, truncation=True).to(device)
    
    # Forward pass through the BERT model
    with torch.no_grad():
        outputs = bert_model(**inputs)

    # Extract embeddings from the last hidden state
    bert_embedding = outputs.last_hidden_state[:, 0, :].detach()     #CLS Token
    return bert_embedding


#Get SpaBERT Embedding for geo-entity
def get_spaBert_embedding(entity):
    entity_index = entity_index_dict.get(entity.lower(), None)
    if entity_index is None:
        #st.write("Got Bert embedding for: ", entity)
        return get_bert_embedding(entity)                            #Fallback in-case SpaBERT could not resolve entity to retrieve embedding. Rare-cases only.
    else:
        #st.write("Got SpaBert embedding for: ", entity)
        return spaBERT_embeddings[entity_index]

        
#Go through each review, identify all geo-entities, then extract their SpaBERT embedings
def processSpatialEntities(review, nlp):
    doc = nlp(review)
    entity_spans = [(ent.start, ent.end, ent.text, ent.label_) for ent in doc.ents]
    token_embeddings = []

    # Iterate over each entity span and process only geo entities
    for start, end, text, label in entity_spans:
        if label in ['FAC', 'ORG', 'LOC', 'GPE']:  # Filter to geo-entities
            spaBert_emb = get_spaBert_embedding(text)
            token_embeddings.append(spaBert_emb)
            st.write("Geo-Entity Found in review: ", text)

    token_embeddings = torch.stack(token_embeddings, dim=0)
    processed_embedding = token_embeddings.mean(dim=0)  # Shape: (768)
    #processed_embedding = processed_embedding.unsqueeze(0)  # Shape: (1, 768)
    return processed_embedding


#Discriminator Initialization section
class Discriminator(nn.Module):
    def __init__(self, input_size=512, hidden_sizes=[512], num_labels=2, dropout_rate=0.1):
        super(Discriminator, self).__init__()
        self.input_dropout = nn.Dropout(p=dropout_rate)
        layers = []
        hidden_sizes = [input_size] + hidden_sizes
        for i in range(len(hidden_sizes)-1):
            layers.extend([nn.Linear(hidden_sizes[i], hidden_sizes[i+1]), nn.LeakyReLU(0.2, inplace=True), nn.Dropout(dropout_rate)])

        self.layers = nn.Sequential(*layers) #per il flatten
        self.logit = nn.Linear(hidden_sizes[-1],num_labels+1) # +1 for the probability of this sample being fake/real.
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, input_rep):
        input_rep = self.input_dropout(input_rep)
        last_rep = self.layers(input_rep)
        logits = self.logit(last_rep)
        probs = self.softmax(logits)
        return last_rep, logits, probs

#dConfig = AutoConfig.from_pretrained("bert-base-uncased")
#hidden_size = int(dConfig.hidden_size)
#num_hidden_layers_d = 2; 
#hidden_levels_d = [hidden_size for i in range(0, num_hidden_layers_d)]
#label_list = ["1", "0"]
#label_list.append('UNL')
#discriminator = Discriminator(input_size=hidden_size*2, hidden_sizes=hidden_levels_d, num_labels=len(label_list), dropout_rate=out_dropout_rate).to(device)



# Function to read reviews from a text file
def load_reviews_from_file(file_path):
    reviews = {}
    try:
        with open(file_path, 'r', encoding='utf-8') as file:
            for i, line in enumerate(file):
                line = line.strip()
                if line:  # Ensure the line is not empty
                    reviews[f"Review {i + 1}"] = line
    except FileNotFoundError:
        st.error(f"File not found: {file_path}")
    return reviews


st.title("SpaGAN Demo")
st.write("Enter a text, and the system will highlight the geo-entities within it.")

# Define a color map and descriptions for different entity types
COLOR_MAP = {
    'FAC': ('red', 'Facilities (e.g., buildings, airports)'),
    'ORG': ('blue', 'Organizations (e.g., companies, institutions)'),
    'LOC': ('purple', 'Locations (e.g., mountain ranges, water bodies)'),
    'GPE': ('green', 'Geopolitical Entities (e.g., countries, cities)')
}

# Display the color key
st.write("**Color Key:**")
for label, (color, description) in COLOR_MAP.items():
    st.markdown(f"- **{label}**: <span style='color:{color}'>{color}</span> - {description}", unsafe_allow_html=True)

review_file_path = "models/spabert/datasets/SampleReviews.txt"
example_reviews = load_reviews_from_file(review_file_path)

# Dropdown for selecting an example review
user_input = st.selectbox("Select an example review", options=list(example_reviews.keys()))

# Get the selected review text
selected_review = example_reviews[user_input]

# Process the text when the button is clicked
if st.button("Highlight Geo-Entities"):
    if selected_review.strip():
        bert_embedding = get_bert_embedding(selected_review)
        st.write("Review Embedding Shape:", bert_embedding.shape)
        
        spaBert_embedding = processSpatialEntities(selected_review,nlp)
        st.write("Geo-Entities embedding shape: ", spaBert_embedding.shape)

        combined_embedding = torch.cat((bert_embedding,spaBert_embedding),dim=-1)
        st.write("Concatenated Embedding Shape:", combined_embedding.shape)
        st.write("Concatenated Embedding:", combined_embedding)
        
        # Process the text using spaCy
        doc = nlp(selected_review)
        
        # Highlight geo-entities with different colors
        highlighted_text = selected_review
        for ent in reversed(doc.ents):
            if ent.label_ in COLOR_MAP:
                color = COLOR_MAP[ent.label_][0]
                highlighted_text = (
                    highlighted_text[:ent.start_char] +
                    f"<span style='color:{color}; font-weight:bold'>{ent.text}</span>" + 
                    highlighted_text[ent.end_char:]
                )

        # Display the highlighted text with HTML support
        st.markdown(highlighted_text, unsafe_allow_html=True)
    else:
        st.error("Please select a review.")